diff --git a/pkg/pm/installer/installer.go b/pkg/pm/installer/installer.go index 851e832..ce0f06c 100644 --- a/pkg/pm/installer/installer.go +++ b/pkg/pm/installer/installer.go @@ -21,7 +21,6 @@ import ( "go.wpm.so/cli/pkg/archive" "go.wpm.so/cli/pkg/pm/registry" - "go.wpm.so/cli/pkg/pm/signatures" "go.wpm.so/cli/pkg/pm/wpmjson/types" "go.wpm.so/cli/pkg/pm/wpmjson/validator" ) @@ -40,7 +39,6 @@ type Installer struct { client registry.Client extractSem chan struct{} - keysJson signatures.KeysJson logger func(format string, args ...any) } @@ -106,12 +104,6 @@ func sweepStaleRunDirs(tmpDir string) { } func (i *Installer) InstallAll(ctx context.Context, plan []Action, progressFn func(Action)) error { - keys, err := i.client.GetKeysJson(ctx) - if err != nil { - return fmt.Errorf("failed to fetch public keys for signature verification: %w", err) - } - i.keysJson = keys - g, ctx := errgroup.WithContext(ctx) g.SetLimit(i.concurrency) @@ -151,21 +143,6 @@ func (i *Installer) install(ctx context.Context, action Action) error { } func (i *Installer) installOrUpdate(ctx context.Context, action Action, targetDir string) error { - sigs := action.Signatures - if len(sigs) == 0 { - return fmt.Errorf("no signatures found for package %s@%s", action.Name, action.Version) - } - - err := signatures.Verify( - i.keysJson, - sigs[0].KeyID, - sigs[0].Sig, - fmt.Appendf(nil, "%s:%s:%s", action.Name, action.Version, action.Digest), - ) - if err != nil { - return fmt.Errorf("signature verification failed for package %s@%s: %w", action.Name, action.Version, err) - } - path := tarballPath(action.Name, action.Version) resp, err := i.client.DownloadTarball(ctx, path) if err != nil { diff --git a/pkg/pm/installer/plan.go b/pkg/pm/installer/plan.go index 899d677..bc0b6f2 100644 --- a/pkg/pm/installer/plan.go +++ b/pkg/pm/installer/plan.go @@ -6,7 +6,6 @@ import ( "go.wpm.so/cli/pkg/pm/resolution" "go.wpm.so/cli/pkg/pm/wpmjson" - "go.wpm.so/cli/pkg/pm/wpmjson/manifest" "go.wpm.so/cli/pkg/pm/wpmjson/types" "go.wpm.so/cli/pkg/pm/wpmlock" ) @@ -21,12 +20,11 @@ const ( // Action represents a single operation to be performed on the filesystem type Action struct { - Type ActionType - Name string - Version string - Signatures []manifest.Signature - Digest string // Sha256 digest - PkgType types.PackageType + Type ActionType + Name string + Version string + Digest string // Sha256 digest + PkgType types.PackageType } // CalculatePlan determines filesystem operations based on lockfile, resolved tree, and flags. @@ -109,34 +107,31 @@ func resolveAction(name string, node resolution.Node, lock *wpmlock.Lockfile, ex oldPkg, inLock := lock.Packages[name] if !inLock { return Action{ - Type: ActionInstall, - Name: name, - Version: node.Version, - Signatures: node.Signatures, - Digest: node.Digest, - PkgType: node.Type, + Type: ActionInstall, + Name: name, + Version: node.Version, + Digest: node.Digest, + PkgType: node.Type, }, true } if oldPkg.Version != node.Version || oldPkg.Digest != node.Digest { return Action{ - Type: ActionUpdate, - Name: name, - Version: node.Version, - Signatures: node.Signatures, - Digest: node.Digest, - PkgType: node.Type, + Type: ActionUpdate, + Name: name, + Version: node.Version, + Digest: node.Digest, + PkgType: node.Type, }, true } if !exists { return Action{ - Type: ActionInstall, - Name: name, - Version: node.Version, - Signatures: node.Signatures, - Digest: node.Digest, - PkgType: node.Type, + Type: ActionInstall, + Name: name, + Version: node.Version, + Digest: node.Digest, + PkgType: node.Type, }, true } diff --git a/pkg/pm/registry/client.go b/pkg/pm/registry/client.go index 3593177..8acc554 100644 --- a/pkg/pm/registry/client.go +++ b/pkg/pm/registry/client.go @@ -28,7 +28,7 @@ type client struct { // registry type Client interface { Whoami(ctx context.Context, token string) (string, error) - GetKeysJson(ctx context.Context) (signatures.KeysJson, error) + GetKeysJson(ctx context.Context) (signatures.Keys, error) DownloadTarball(ctx context.Context, url string) (io.ReadCloser, error) PutPackage(ctx context.Context, data *manifest.Package, tarball io.Reader) error GetPackageManifest(ctx context.Context, packageName, versionOrTag string, force bool) (*manifest.Package, error) @@ -149,8 +149,8 @@ func (c *client) Whoami(ctx context.Context, token string) (string, error) { } // GetKeysJson retrieves the public keys from the registry -func (c *client) GetKeysJson(ctx context.Context) (signatures.KeysJson, error) { - var keys signatures.KeysJson +func (c *client) GetKeysJson(ctx context.Context) (signatures.Keys, error) { + var keys signatures.Keys err := c.restClient.DoWithContext( ctx, diff --git a/pkg/pm/resolution/resolver.go b/pkg/pm/resolution/resolver.go index 793f3f1..44d4f2d 100644 --- a/pkg/pm/resolution/resolver.go +++ b/pkg/pm/resolution/resolver.go @@ -11,6 +11,7 @@ import ( "golang.org/x/sync/errgroup" "go.wpm.so/cli/pkg/pm/registry" + "go.wpm.so/cli/pkg/pm/signatures" "go.wpm.so/cli/pkg/pm/wpmjson" "go.wpm.so/cli/pkg/pm/wpmjson/manifest" "go.wpm.so/cli/pkg/pm/wpmjson/types" @@ -37,6 +38,7 @@ type Resolver struct { rootConfig *wpmjson.Config lockfile *wpmlock.Lockfile client registry.Client + verifier *signatures.Verifier } func New(rootConfig *wpmjson.Config, lockfile *wpmlock.Lockfile, client registry.Client) *Resolver { @@ -59,6 +61,12 @@ type fetchResult struct { } func (r *Resolver) Resolve(ctx context.Context, progress ProgressReporter, w io.Writer) (map[string]Node, error) { + keys, err := r.client.GetKeysJson(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch signing keys: %w", err) + } + r.verifier = signatures.New(keys) + resolved := make(map[string]Node) queue := r.seedQueue() @@ -94,7 +102,15 @@ func (r *Resolver) Resolve(ctx context.Context, progress ProgressReporter, w io. } func (r *Resolver) seedQueue() []dependencyRequest { - var queue []dependencyRequest + n := 0 + if r.rootConfig.Dependencies != nil { + n += len(*r.rootConfig.Dependencies) + } + if r.rootConfig.DevDependencies != nil { + n += len(*r.rootConfig.DevDependencies) + } + + queue := make([]dependencyRequest, 0, n) if r.rootConfig.Dependencies != nil { for name, version := range *r.rootConfig.Dependencies { queue = append(queue, dependencyRequest{name: name, version: version, requestor: ""}) @@ -108,21 +124,26 @@ func (r *Resolver) seedQueue() []dependencyRequest { return queue } +type requestKey struct { + name string + version string +} + // dedupeRequests drops requests already satisfied at the same version // and folds identical name@version pairs in this iteration into one entry. -func dedupeRequests(queue []dependencyRequest, resolved map[string]Node) map[string]dependencyRequest { - uniqueRequests := make(map[string]dependencyRequest) +func dedupeRequests(queue []dependencyRequest, resolved map[string]Node) map[requestKey]dependencyRequest { + unique := make(map[requestKey]dependencyRequest, len(queue)) for _, req := range queue { if exists, ok := resolved[req.name]; ok && exists.Version == req.version { continue } - uniqueRequests[req.name+"@"+req.version] = req + unique[requestKey{req.name, req.version}] = req } - return uniqueRequests + return unique } // fetchAll fetches metadata for every request concurrently and returns the collected results. -func (r *Resolver) fetchAll(ctx context.Context, requests map[string]dependencyRequest, progress ProgressReporter, w io.Writer) ([]fetchResult, error) { +func (r *Resolver) fetchAll(ctx context.Context, requests map[requestKey]dependencyRequest, progress ProgressReporter, w io.Writer) ([]fetchResult, error) { results := make(chan fetchResult, len(requests)) g, gtx := errgroup.WithContext(ctx) g.SetLimit(16) @@ -137,6 +158,9 @@ func (r *Resolver) fetchAll(ctx context.Context, requests map[string]dependencyR if err != nil { return fmt.Errorf("failed to fetch metadata for %s@%s required by %s: %w", req.name, req.version, req.requestor, err) } + if err := r.verifier.Verify(manifest); err != nil { + return fmt.Errorf("signature verification failed for %s@%s required by %s: %w", req.name, req.version, req.requestor, err) + } results <- fetchResult{req: req, manifest: manifest} return nil }) diff --git a/pkg/pm/signatures/signatures.go b/pkg/pm/signatures/signatures.go index b752ec5..f4cd7af 100644 --- a/pkg/pm/signatures/signatures.go +++ b/pkg/pm/signatures/signatures.go @@ -9,72 +9,185 @@ import ( "errors" "fmt" "math/big" + "sort" + "unicode/utf8" + + "go.wpm.so/cli/pkg/pm/wpmjson/manifest" ) -const signingAlgorithm = "ECDSA_SHA_256" +const ( + maxPayloadBytes = 4096 + signingAlgorithm = "ECDSA_SHA_256" +) -type sig struct { +type ecdsaSignature struct { R, S *big.Int } -type keyJson struct { +type key struct { Expires string `json:"expires"` Type string `json:"type"` KeyID string `json:"keyid"` PubKey string `json:"pubkey"` } -type KeysJson []keyJson +// Keys is the set of trusted public keys from keys.json. +type Keys []key + +type Verifier struct { + keys map[string]*ecdsa.PublicKey +} -// Verify verifies a Base64 encoded ASN.1 DER signature against a message using a PEM encoded Public Key. -func Verify(keys KeysJson, keyId, signatureBase64 string, originalMessage []byte) error { - var rawPublicKeyBase64, keyType string - for _, key := range keys { - if key.KeyID == keyId { - keyType = key.Type - rawPublicKeyBase64 = key.PubKey - break +// New returns a Verifier backed by keys. +func New(keys Keys) *Verifier { + parsed := make(map[string]*ecdsa.PublicKey, len(keys)) + for _, k := range keys { + if k.Type != signingAlgorithm { + continue } + pub, err := parseECDSAKey(k.PubKey) + if err != nil { + continue + } + parsed[k.KeyID] = pub + } + return &Verifier{keys: parsed} +} + +// Verify checks the manifest's signature against the trusted keys. +func (v *Verifier) Verify(m *manifest.Package) error { + sigs := m.Dist.Signatures + if len(sigs) == 0 { + return errors.New("no signatures found") } - if rawPublicKeyBase64 == "" { - return fmt.Errorf("public key with KeyID %s not found", keyId) + pub, ok := v.keys[sigs[0].KeyID] + if !ok { + return fmt.Errorf("no trusted key for KeyID %s", sigs[0].KeyID) + } + + var deps map[string]string + if m.Dependencies != nil { + deps = *m.Dependencies } - if keyType != signingAlgorithm { - return fmt.Errorf("unsupported signing algorithm: %s", keyType) + msg, err := payload(m.Name, m.Version, m.Dist.Digest, deps) + if err != nil { + return err } - keyBytes, err := base64.StdEncoding.DecodeString(rawPublicKeyBase64) + return verifyECDSA(pub, sigs[0].Sig, msg) +} + +func parseECDSAKey(pubKeyBase64 string) (*ecdsa.PublicKey, error) { + der, err := base64.StdEncoding.DecodeString(pubKeyBase64) if err != nil { - return fmt.Errorf("failed to decode base64 public key: %w", err) + return nil, err } - genericPublicKey, err := x509.ParsePKIXPublicKey(keyBytes) + pub, err := x509.ParsePKIXPublicKey(der) if err != nil { - return fmt.Errorf("failed to parse PKIX public key: %v", err) + return nil, err } - publicKey, ok := genericPublicKey.(*ecdsa.PublicKey) + ec, ok := pub.(*ecdsa.PublicKey) if !ok { - return errors.New("public key is not of type ECDSA") + return nil, errors.New("public key is not ECDSA") } + return ec, nil +} +func verifyECDSA(pub *ecdsa.PublicKey, signatureBase64 string, msg []byte) error { sigBytes, err := base64.StdEncoding.DecodeString(signatureBase64) if err != nil { - return fmt.Errorf("failed to decode base64 signature: %v", err) + return fmt.Errorf("failed to decode base64 signature: %w", err) } - var s sig + var s ecdsaSignature if _, err := asn1.Unmarshal(sigBytes, &s); err != nil { - return fmt.Errorf("failed to unmarshal ASN.1 signature: %v", err) + return fmt.Errorf("failed to unmarshal ASN.1 signature: %w", err) } - hash := sha256.Sum256(originalMessage) - valid := ecdsa.Verify(publicKey, hash[:], s.R, s.S) - if !valid { + hash := sha256.Sum256(msg) + if !ecdsa.Verify(pub, hash[:], s.R, s.S) { return errors.New("signature verification failed: invalid signature") } - return nil } + +// payload builds the message the registry signs: +// +// name:version:digest (no dependencies) +// name:version:digest:deps_digest (with dependencies) +// +// deps_digest is base64(sha256(canonicalDependencies(deps))). digest is used +// verbatim, including its "sha256:" prefix. +func payload(name, version, digest string, deps map[string]string) ([]byte, error) { + msg := name + ":" + version + ":" + digest + if len(deps) > 0 { + sum := sha256.Sum256(canonicalDependencies(deps)) + msg += ":" + base64.StdEncoding.EncodeToString(sum[:]) + } + + if len(msg) >= maxPayloadBytes { + return nil, fmt.Errorf("signature payload must be < %d bytes", maxPayloadBytes) + } + + return []byte(msg), nil +} + +// canonicalDependencies serializes deps into the canonical form used in the +// signing payload: sorted keys, no spaces, JSON-escaped. For the package-name +// and semver charset this equals RFC 8785 (JCS). +func canonicalDependencies(deps map[string]string) []byte { + keys := make([]string, 0, len(deps)) + for k := range deps { + keys = append(keys, k) + } + sort.Strings(keys) + + b := make([]byte, 0, 2+len(keys)*16) + b = append(b, '{') + for i, k := range keys { + if i > 0 { + b = append(b, ',') + } + b = appendJSONString(b, k) + b = append(b, ':') + b = appendJSONString(b, deps[k]) + } + return append(b, '}') +} + +// appendJSONString writes s as a JSON string with the minimal escaping RFC 8785 +// requires. Go's encoding/json also escapes <, >, &, and U+2028/U+2029; those +// must stay literal here to keep the output canonical. +func appendJSONString(dst []byte, s string) []byte { + const hex = "0123456789abcdef" + dst = append(dst, '"') + for _, r := range s { + switch r { + case '"': + dst = append(dst, '\\', '"') + case '\\': + dst = append(dst, '\\', '\\') + case '\b': + dst = append(dst, '\\', 'b') + case '\f': + dst = append(dst, '\\', 'f') + case '\n': + dst = append(dst, '\\', 'n') + case '\r': + dst = append(dst, '\\', 'r') + case '\t': + dst = append(dst, '\\', 't') + default: + if r < 0x20 { + dst = append(dst, '\\', 'u', '0', '0', hex[r>>4], hex[r&0xf]) + } else { + dst = utf8.AppendRune(dst, r) + } + } + } + return append(dst, '"') +} diff --git a/pkg/pm/signatures/signatures_test.go b/pkg/pm/signatures/signatures_test.go new file mode 100644 index 0000000..4ad65ad --- /dev/null +++ b/pkg/pm/signatures/signatures_test.go @@ -0,0 +1,65 @@ +package signatures + +import "testing" + +func TestCanonicalDependencies(t *testing.T) { + cases := []struct { + name string + deps map[string]string + want string + }{ + {"empty", map[string]string{}, "{}"}, + {"single", map[string]string{"elementor": "4.1.0"}, `{"elementor":"4.1.0"}`}, + {"sorted_keys", map[string]string{"elementor": "4.1.0", "akismet": "5.0.0"}, `{"akismet":"5.0.0","elementor":"4.1.0"}`}, + {"ascii_case_order", map[string]string{"b": "1.0.0", "a": "1.0.0", "B": "1.0.0"}, `{"B":"1.0.0","a":"1.0.0","b":"1.0.0"}`}, + // Out of the package-name domain, but pins the escaping/ordering contract: + {"html_chars_unescaped", map[string]string{"x": "<3&>"}, `{"x":"<3&>"}`}, + {"non_ascii_key_order", map[string]string{"é": "1.0.0", "z": "1.0.0"}, `{"z":"1.0.0","é":"1.0.0"}`}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := string(canonicalDependencies(tc.deps)); got != tc.want { + t.Fatalf("canonicalDependencies = %s, want %s", got, tc.want) + } + }) + } +} + +func TestPayload(t *testing.T) { + const digest = "sha256:hmBCVLbYU0UkrrgEE3xDhAd9jmVq57tMgICXB0XZrGA=" + const want3 = "elementor:4.1.0:" + digest + + t.Run("nil_deps_is_three_fields", func(t *testing.T) { + got, err := payload("elementor", "4.1.0", digest, nil) + if err != nil { + t.Fatal(err) + } + if string(got) != want3 { + t.Fatalf("Payload = %q, want %q", got, want3) + } + }) + + t.Run("empty_deps_is_three_fields", func(t *testing.T) { + got, err := payload("elementor", "4.1.0", digest, map[string]string{}) + if err != nil { + t.Fatal(err) + } + if string(got) != want3 { + t.Fatalf("Payload = %q, want %q", got, want3) + } + }) + + t.Run("deps_append_base64_sha256", func(t *testing.T) { + got, err := payload("elementor", "4.1.0", digest, map[string]string{ + "elementor": "4.1.0", "akismet": "5.0.0", + }) + if err != nil { + t.Fatal(err) + } + // deps_digest golden value from the registry JS. + want := want3 + ":F334q11N6Ds7005RwbHApqUBXjUrfjpI4NSo9hOPznQ=" + if string(got) != want { + t.Fatalf("Payload = %q, want %q", got, want) + } + }) +}