diff --git a/cmd/w3k/go.mod b/cmd/w3k/go.mod deleted file mode 100644 index 4317db0e..00000000 --- a/cmd/w3k/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module github.com/cisco-open/wasm-kernel-module/cmd/w3k - -go 1.18 - -require github.com/cristalhq/acmd v0.11.1 diff --git a/cmd/w3k/go.sum b/cmd/w3k/go.sum deleted file mode 100644 index b4f1b710..00000000 --- a/cmd/w3k/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/cristalhq/acmd v0.11.1 h1:DJ4fh2Pv0nPKmqT646IU/0Vh5FNdGblxvF+3/W3NAUI= -github.com/cristalhq/acmd v0.11.1/go.mod h1:LG5oa43pE/BbxtfMoImHCQN++0Su7dzipdgBjMCBVDQ= diff --git a/cmd/w3k/main.go b/cmd/w3k/main.go index d6bab178..f4f348ca 100644 --- a/cmd/w3k/main.go +++ b/cmd/w3k/main.go @@ -34,15 +34,26 @@ import ( "syscall" cli "github.com/cristalhq/acmd" + + "github.com/cisco-open/wasm-kernel-module/pkg/tls" ) +type CommandContext struct { + UID int `json:"uid,omitempty"` + GID int `json:"gid,omitempty"` + PID int `json:"pid,omitempty"` + CommandName string `json:"command_name,omitempty"` + CommandPath string `json:"command_path,omitempty"` +} + type Command struct { - ID string `json:"id,omitempty"` - Command string `json:"command"` - Name string `json:"name,omitempty"` - Code []byte `json:"code,omitempty"` - Entrypoint string `json:"entrypoint,omitempty"` - Data string `json:"data,omitempty"` + Context CommandContext `json:"context,omitempty"` + ID string `json:"id,omitempty"` + Command string `json:"command"` + Name string `json:"name,omitempty"` + Code []byte `json:"code,omitempty"` + Entrypoint string `json:"entrypoint,omitempty"` + Data string `json:"data,omitempty"` } type Answer struct { @@ -66,6 +77,10 @@ func AcceptOk(Command) (string, error) { return "ok", nil } +func ConnectOk(Command) (string, error) { + return "ok", nil +} + type loadFlags struct { File string Name string @@ -77,6 +92,18 @@ func (c *loadFlags) Flags() *flag.FlagSet { fs.StringVar(&c.File, "file", "my-module.wasm", "the file path of the loaded Wasm module") fs.StringVar(&c.Name, "name", "", "how to name the loaded Wasm module") fs.StringVar(&c.Entrypoint, "entrypoint", "", "initial function to invoke after loading the Wasm module") + + return fs +} + +type serverFlags struct { + CAPemFileName string +} + +func (c *serverFlags) Flags() *flag.FlagSet { + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.StringVar(&c.CAPemFileName, "ca-pem-filename", "ca.pem", "root CA pem location for CA signer") + return fs } @@ -86,7 +113,8 @@ var commandHandlers map[string]CommandHandler func init() { commandHandlers = map[string]CommandHandler{ - "accept": CommandHandlerFunc(AcceptOk), + "accept": CommandHandlerFunc(AcceptOk), + "connect": CommandHandlerFunc(ConnectOk), } } @@ -144,7 +172,59 @@ var cmds = []cli.Command{ Name: "server", Description: "run the support server for the kernel module", Alias: "s", + FlagSet: &serverFlags{}, ExecFunc: func(ctx context.Context, args []string) error { + var cfg serverFlags + if err := cfg.Flags().Parse(args); err != nil { + return err + } + + signerCA, err := tls.NewSignerCA(cfg.CAPemFileName) + if err != nil { + return err + } + _ = signerCA.Certificate + + commandHandlers["csr_sign"] = CommandHandlerFunc(func(c Command) (string, error) { + var data struct { + CSR string `json:"csr"` + } + + if err := json.Unmarshal([]byte(c.Data), &data); err != nil { + return "jsonerror", err + } + + containers, err := tls.ParsePEMs([]byte(data.CSR)) + if err != nil { + return "error", err + } + + if len(containers) != 1 { + return "error", errors.New("invalid csr") + } + + certificate, err := signerCA.SignCertificateRequest(containers[0].GetX509CertificateRequest().CertificateRequest) + if err != nil { + return "error", err + } + + caCertificate := signerCA.GetCaCertificate() + + var response struct { + Certificate *tls.X509Certificate `json:"certificate"` + TrustAnchors []*tls.X509Certificate `json:"trust_anchors"` + } + + response.Certificate = certificate + response.TrustAnchors = append(response.TrustAnchors, caCertificate) + + j, err := json.Marshal(response) + if err != nil { + return "error", err + } + + return string(j), nil + }) dev, err := os.OpenFile("/dev/wasm", os.O_RDWR, 0666) if err != nil { @@ -165,7 +245,7 @@ var cmds = []cli.Command{ return err } - log.Printf("received command: %+v", command) + log.Printf("received command: (%s) %+v", scanner.Bytes(), command) if handler, ok := commandHandlers[command.Command]; ok { answer, err = handler.HandleCommand(command) diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..8ad445a0 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/cisco-open/wasm-kernel-module + +go 1.18 + +require ( + emperror.dev/errors v0.8.1 + github.com/cristalhq/acmd v0.11.1 +) + +require ( + github.com/pkg/errors v0.9.1 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.6.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..606e1f62 --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +emperror.dev/errors v0.8.1 h1:UavXZ5cSX/4u9iyvH6aDcuGkVjeexUGJ7Ij7G4VfQT0= +emperror.dev/errors v0.8.1/go.mod h1:YcRvLPh626Ubn2xqtoprejnA5nFha+TJ+2vew48kWuE= +github.com/cristalhq/acmd v0.11.1 h1:DJ4fh2Pv0nPKmqT646IU/0Vh5FNdGblxvF+3/W3NAUI= +github.com/cristalhq/acmd v0.11.1/go.mod h1:LG5oa43pE/BbxtfMoImHCQN++0Su7dzipdgBjMCBVDQ= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= diff --git a/go.work b/go.work index f877c820..34ff9565 100644 --- a/go.work +++ b/go.work @@ -1,5 +1,5 @@ go 1.20 -use ./cmd/w3k +use . use ./samples/dns-go diff --git a/pkg/tls/parser.go b/pkg/tls/parser.go new file mode 100644 index 00000000..0a589c33 --- /dev/null +++ b/pkg/tls/parser.go @@ -0,0 +1,481 @@ +/* + * The MIT License (MIT) + * Copyright (c) 2023 Cisco and/or its affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software + * and associated documentation files (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial + * portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED + * TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package tls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "math/big" + "net" + "net/url" + "os" + "time" + + "emperror.dev/errors" +) + +type SupportedPEMType string + +const ( + CertificateRequestSupportedPEMType SupportedPEMType = "CERTIFICATE REQUEST" + CertificateSupportedPEMType SupportedPEMType = "CERTIFICATE" + PublicKeySupportedPEMType SupportedPEMType = "PUBLIC KEY" + PrivateKeySupportedPEMType SupportedPEMType = "PRIVATE KEY" + ECPrivateKeySupportedPEMType SupportedPEMType = "EC PRIVATE KEY" +) + +type X509Certificate struct { + *CertificateCommon `json:",inline"` + IsCA bool `json:"isCA,omitempty"` + + Certificate *x509.Certificate `json:"-"` +} + +type X509CertificateRequest struct { + *CertificateCommon `json:",inline"` + + CertificateRequest *x509.CertificateRequest `json:"-"` +} + +type CertificateCommon struct { + PublicKey *PublicKey `json:"publicKey,omitempty"` + + SerialNumber string `json:"serialNumber,omitempty"` + NotBefore *time.Time `json:"notBefore,omitempty"` + NotAfter *time.Time `json:"notAfter,omitempty"` + NotBeforeUnix uint64 `json:"notBeforeUnix,omitempty"` + NotAfterUnix uint64 `json:"notAfterUnix,omitempty"` + + Subject string `json:"subject,omitempty"` + Issuer string `json:"issuer,omitempty"` + + DNSNames []string `json:"dnsNames,omitempty"` + EmailAddresses []string `json:"emailAddresses,omitempty"` + IPAddresses []string `json:"ipAddresses,omitempty"` + URIs []string `json:"urIs,omitempty"` + + Signature []byte `json:"signature,omitempty"` + SignatureAlgorithm string `json:"signatureAlgorithm,omitempty"` + + Raw []byte `json:"raw,omitempty"` + RawSubject []byte `json:"rawSubject,omitempty"` + RawIssuer []byte `json:"rawIssuer,omitempty"` +} + +func (c *CertificateCommon) copyIPsAndURIs(ipAddresses []net.IP, uris []*url.URL) { + if l := len(ipAddresses); l > 0 { + c.IPAddresses = make([]string, l) + for k, v := range ipAddresses { + c.IPAddresses[k] = v.String() + } + } + if l := len(uris); l > 0 { + c.URIs = make([]string, l) + for k, v := range uris { + c.URIs[k] = v.String() + } + } +} + +type PublicKey struct { + Type string `json:"type,omitempty"` + BitSize int32 `json:"bitSize,omitempty"` + + RSA_N []byte `json:"RSA_N,omitempty"` + RSA_E []byte `json:"RSA_E,omitempty"` + + Curve string `json:"curve,omitempty"` + EC_Q []byte `json:"EC_Q,omitempty"` + + Raw []byte `json:"raw,omitempty"` + Key any `json:"-"` +} + +type PrivateKey struct { + Type string `json:"type,omitempty"` + Size int `json:"size,omitempty"` + + RSA_P []byte `json:"RSA_P,omitempty"` + RSA_Q []byte `json:"RSA_Q,omitempty"` + RSA_DP []byte `json:"RSA_DP,omitempty"` + RSA_DQ []byte `json:"RSA_DQ,omitempty"` + RSA_IQ []byte `json:"RSA_IQ,omitempty"` + + Curve string `json:"curve,omitempty"` + EC_D []byte `json:"EC_D,omitempty"` + + PublicKey *PublicKey `json:"publicKey,omitempty"` + + Raw []byte `json:"raw,omitempty"` + Key any `json:"-"` +} + +type ContainerType string + +const ( + X509CertificateContainerType ContainerType = "X509Certificate" + X509CertificateRequestContainerType ContainerType = "X509CertificateRequest" + PublicKeyContainerType ContainerType = "PublicKey" + PrivateKeyContainerType ContainerType = "PrivateKey" +) + +type Container struct { + Type ContainerType `json:"type,omitempty"` + Object any `json:"object,omitempty"` +} + +func (c *Container) GetX509Certificate() *X509Certificate { + if o, ok := c.Object.(*X509Certificate); ok { + return o + } + + return &X509Certificate{} +} + +func (c *Container) GetX509CertificateRequest() *X509CertificateRequest { + if o, ok := c.Object.(*X509CertificateRequest); ok { + return o + } + + return &X509CertificateRequest{} +} + +func (c *Container) GetPublicKey() *PublicKey { + if o, ok := c.Object.(*PublicKey); ok { + return o + } + + return &PublicKey{} +} + +func (c *Container) GetPrivateKey() *PrivateKey { + if o, ok := c.Object.(*PrivateKey); ok { + return o + } + + return &PrivateKey{} +} + +func ParsePEMFromFile(pemFileName ...string) ([]*Container, error) { + var content []byte + + for _, file := range pemFileName { + fileContent, err := os.ReadFile(file) + if err != nil { + return nil, err + } + + content = append(content, fileContent...) + } + + return ParsePEMs(content) +} + +func ParsePEMs(content []byte) ([]*Container, error) { + containers := []*Container{} + + var block *pem.Block + var rest []byte + var multierr error + for { + block, rest = pem.Decode(content) + if block == nil { + break + } + + content = rest + + container, err := ParsePEM(block) + if err != nil { + multierr = errors.Append(multierr, err) + continue + } + + containers = append(containers, container) + + if len(rest) == 0 { + break + } + } + + return containers, multierr +} + +func ParsePEM(block *pem.Block) (*Container, error) { + switch SupportedPEMType(block.Type) { + case CertificateRequestSupportedPEMType: + obj, err := ParseX509CertificateRequestFromDER(block.Bytes) + + return &Container{ + Type: X509CertificateRequestContainerType, + Object: obj, + }, err + case CertificateSupportedPEMType: + obj, err := ParseX509CertificateFromDER(block.Bytes) + + return &Container{ + Type: X509CertificateContainerType, + Object: obj, + }, err + case PublicKeySupportedPEMType: + obj, err := ParseX509PublicKey(block.Bytes) + + return &Container{ + Type: PublicKeyContainerType, + Object: obj, + }, err + case PrivateKeySupportedPEMType, ECPrivateKeySupportedPEMType: + obj, err := ParseX509PrivateKey(block.Bytes) + + return &Container{ + Type: PrivateKeyContainerType, + Object: obj, + }, err + default: + return nil, errors.New("unsupported PEM type") + } +} + +func ParseX509CertificateFromDER(der []byte) (*X509Certificate, error) { + x509cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + + return ConvertX509Certificate(x509cert) +} + +func ConvertX509Certificate(x509cert *x509.Certificate) (*X509Certificate, error) { + cert := &X509Certificate{ + Certificate: x509cert, + IsCA: x509cert.IsCA, + CertificateCommon: &CertificateCommon{ + SerialNumber: hex.EncodeToString(x509cert.SerialNumber.Bytes()), + NotBefore: &x509cert.NotBefore, + NotAfter: &x509cert.NotAfter, + NotBeforeUnix: uint64(x509cert.NotBefore.Unix()), + NotAfterUnix: uint64(x509cert.NotAfter.Unix()), + + Subject: x509cert.Subject.String(), + Issuer: x509cert.Issuer.String(), + + DNSNames: x509cert.DNSNames, + EmailAddresses: x509cert.EmailAddresses, + + Signature: x509cert.Signature, + SignatureAlgorithm: x509cert.SignatureAlgorithm.String(), + + Raw: x509cert.Raw, + RawSubject: x509cert.RawSubject, + RawIssuer: x509cert.RawIssuer, + }, + } + + cert.CertificateCommon.copyIPsAndURIs(x509cert.IPAddresses, x509cert.URIs) + + var err error + cert.PublicKey, err = parseX509PublicKey(x509cert.PublicKey) + if err != nil { + return nil, err + } + cert.PublicKey.Raw = x509cert.RawSubjectPublicKeyInfo + + return cert, nil +} + +func ParseX509CertificateRequestFromDER(der []byte) (*X509CertificateRequest, error) { + x509req, err := x509.ParseCertificateRequest(der) + if err != nil { + return nil, err + } + + return ConvertX509CertificateRequest(x509req) +} + +func ConvertX509CertificateRequest(x509req *x509.CertificateRequest) (*X509CertificateRequest, error) { + req := &X509CertificateRequest{ + CertificateRequest: x509req, + CertificateCommon: &CertificateCommon{ + Subject: x509req.Subject.String(), + DNSNames: x509req.DNSNames, + EmailAddresses: x509req.EmailAddresses, + + Signature: x509req.Signature, + SignatureAlgorithm: x509req.SignatureAlgorithm.String(), + + RawSubject: x509req.RawSubject, + Raw: x509req.Raw, + }, + } + + req.CertificateCommon.copyIPsAndURIs(x509req.IPAddresses, x509req.URIs) + + var err error + req.PublicKey, err = parseX509PublicKey(x509req.PublicKey) + if err != nil { + return nil, err + } + req.PublicKey.Raw = x509req.RawSubjectPublicKeyInfo + + return req, nil +} + +func ParseX509PublicKey(der []byte) (*PublicKey, error) { + x509key, err := parsePublicKey(der) + if err != nil { + return nil, err + } + + pk, err := parseX509PublicKey(x509key) + if err != nil { + return nil, err + } + + pk.Raw = der + + return pk, nil +} + +func parseX509PublicKey(pk any) (*PublicKey, error) { + switch key := pk.(type) { + case *rsa.PublicKey: + return convertRSAPublicKey(*key), nil + + case *ecdsa.PublicKey: + return convertECDSPublicKey(*key), nil + } + + return nil, errors.New("unsupported public key") +} + +func parsePublicKey(der []byte) (crypto.PublicKey, error) { + if key, err := x509.ParsePKCS1PublicKey(der); err == nil { + return key, nil + } + + if key, err := x509.ParsePKIXPublicKey(der); err == nil { + return key, nil + } + + return nil, errors.New("tls: failed to parse public key") +} + +func convertRSAPublicKey(key rsa.PublicKey) *PublicKey { + return &PublicKey{ + Type: "RSA", + BitSize: int32(key.Size() * 8), + + RSA_N: key.N.Bytes(), + RSA_E: big.NewInt(0).SetInt64(int64(key.E)).Bytes(), + + Key: key, + } +} + +func convertECDSPublicKey(key ecdsa.PublicKey) *PublicKey { + return &PublicKey{ + Type: "ECDSA", + BitSize: int32(key.Params().BitSize), + + Curve: key.Params().Name, + EC_Q: append([]byte{0x04}, append(key.X.Bytes(), key.Y.Bytes()...)...), + + Key: key, + } +} + +func ParseX509PrivateKey(der []byte) (*PrivateKey, error) { + x509key, err := parsePrivateKey(der) + if err != nil { + return nil, err + } + + pk, err := parseX509PrivateKey(x509key) + if err != nil { + return nil, err + } + + return pk, nil +} + +func parseX509PrivateKey(pk any) (*PrivateKey, error) { + switch key := pk.(type) { + case *rsa.PrivateKey: + return convertRSAPrivateKey(*key), nil + case *ecdsa.PrivateKey: + return convertECDSPrivateKey(*key), nil + } + + return nil, errors.New("unsupported private key") +} + +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, errors.New("found unknown private key type in PKCS#8 wrapping") + } + } + + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("unsupported private key") +} + +func convertRSAPrivateKey(key rsa.PrivateKey) *PrivateKey { + return &PrivateKey{ + Type: "RSA", + Size: key.Size() * 8, + RSA_P: key.Primes[0].Bytes(), + RSA_Q: key.Primes[1].Bytes(), + RSA_DP: key.Precomputed.Dp.Bytes(), + RSA_DQ: key.Precomputed.Dq.Bytes(), + RSA_IQ: key.Precomputed.Qinv.Bytes(), + + PublicKey: convertRSAPublicKey(key.PublicKey), + + Key: key, + } +} + +func convertECDSPrivateKey(key ecdsa.PrivateKey) *PrivateKey { + return &PrivateKey{ + Type: "ECDSA", + Curve: key.Params().Name, + EC_D: key.D.Bytes(), + + PublicKey: convertECDSPublicKey(key.PublicKey), + + Key: key, + } +} diff --git a/pkg/tls/signer.go b/pkg/tls/signer.go new file mode 100644 index 00000000..fe9d0843 --- /dev/null +++ b/pkg/tls/signer.go @@ -0,0 +1,100 @@ +/* + * The MIT License (MIT) + * Copyright (c) 2023 Cisco and/or its affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software + * and associated documentation files (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial + * portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED + * TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package tls + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "math/big" + "time" + + "emperror.dev/errors" +) + +type SignerCA struct { + PrivateKey *PrivateKey + Certificate *X509Certificate +} + +func NewSignerCA(caPEMFileName string) (*SignerCA, error) { + signer := &SignerCA{} + + containers, err := ParsePEMFromFile(caPEMFileName) + if err != nil { + return nil, err + } + + if len(containers) != 2 { + return nil, errors.New("invalid PEM content") + } + + for _, v := range containers { + switch v.Type { + case PrivateKeyContainerType: + signer.PrivateKey = v.GetPrivateKey() + case X509CertificateContainerType: + signer.Certificate = v.GetX509Certificate() + } + } + + if signer.PrivateKey == nil { + return nil, errors.New("missing CA private key") + } + + if signer.Certificate == nil { + return nil, errors.New("missing CA certificate") + } + + return signer, nil +} + +func (s *SignerCA) GetCaCertificate() *X509Certificate { + return s.Certificate +} + +func (s *SignerCA) SignCertificateRequest(req *x509.CertificateRequest) (*X509Certificate, error) { + serial, err := rand.Int(rand.Reader, (&big.Int{}).Exp(big.NewInt(2), big.NewInt(159), nil)) + if err != nil { + return nil, err + } + + pkey := s.PrivateKey.Key.(rsa.PrivateKey) + + certByte, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{ + SerialNumber: serial, + Subject: req.Subject, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + IsCA: false, + DNSNames: req.DNSNames, + IPAddresses: req.IPAddresses, + EmailAddresses: req.EmailAddresses, + URIs: req.URIs, + ExtraExtensions: req.ExtraExtensions, + }, s.Certificate.Certificate, req.PublicKey, &pkey) + if err != nil { + return nil, err + } + + return ParseX509CertificateFromDER(certByte) +}