Skip to content

Commit

Permalink
Merge pull request #34 from khalkie/main
Browse files Browse the repository at this point in the history
  • Loading branch information
adityasaky committed Feb 14, 2023
2 parents 154aa5b + 60bd7fd commit c32a0fd
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 52 deletions.
11 changes: 6 additions & 5 deletions dsse/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ https://github.com/secure-systems-lab/dsse
package dsse

import (
"context"
"encoding/base64"
"errors"
"fmt"
Expand Down Expand Up @@ -77,7 +78,7 @@ using the current algorithm, and the key used (if applicable).
For an example see EcdsaSigner in sign_test.go.
*/
type Signer interface {
Sign(data []byte) ([]byte, error)
Sign(ctx context.Context, data []byte) ([]byte, error)
KeyID() (string, error)
}

Expand Down Expand Up @@ -143,7 +144,7 @@ Returned is an envelope as defined here:
https://github.com/secure-systems-lab/dsse/blob/master/envelope.md
One signature will be added for each Signer in the EnvelopeSigner.
*/
func (es *EnvelopeSigner) SignPayload(payloadType string, body []byte) (*Envelope, error) {
func (es *EnvelopeSigner) SignPayload(ctx context.Context, payloadType string, body []byte) (*Envelope, error) {
var e = Envelope{
Payload: base64.StdEncoding.EncodeToString(body),
PayloadType: payloadType,
Expand All @@ -152,7 +153,7 @@ func (es *EnvelopeSigner) SignPayload(payloadType string, body []byte) (*Envelop
paeEnc := PAE(payloadType, body)

for _, signer := range es.providers {
sig, err := signer.Sign(paeEnc)
sig, err := signer.Sign(ctx, paeEnc)
if err != nil {
return nil, err
}
Expand All @@ -176,8 +177,8 @@ Any domain specific validation such as parsing the decoded body and
validating the payload type is left out to the caller.
Verify returns a list of accepted keys each including a keyid, public and signiture of the accepted provider keys.
*/
func (es *EnvelopeSigner) Verify(e *Envelope) ([]AcceptedKey, error) {
return es.ev.Verify(e)
func (es *EnvelopeSigner) Verify(ctx context.Context, e *Envelope) ([]AcceptedKey, error) {
return es.ev.Verify(ctx, e)
}

/*
Expand Down
81 changes: 41 additions & 40 deletions dsse/sign_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dsse

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
Expand Down Expand Up @@ -40,11 +41,11 @@ func TestPAE(t *testing.T) {

type nilsigner int

func (n nilsigner) Sign(data []byte) ([]byte, error) {
func (n nilsigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}

func (n nilsigner) Verify(data, sig []byte) error {
func (n nilsigner) Verify(ctx context.Context, data, sig []byte) error {
if len(data) != len(sig) {
return errLength
}
Expand All @@ -68,11 +69,11 @@ func (n nilsigner) Public() crypto.PublicKey {

type nullsigner int

func (n nullsigner) Sign(data []byte) ([]byte, error) {
func (n nullsigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}

func (n nullsigner) Verify(data, sig []byte) error {
func (n nullsigner) Verify(ctx context.Context, data, sig []byte) error {
if len(data) != len(sig) {
return errLength
}
Expand All @@ -96,11 +97,11 @@ func (n nullsigner) Public() crypto.PublicKey {

type errsigner int

func (n errsigner) Sign(data []byte) ([]byte, error) {
func (n errsigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
return nil, fmt.Errorf("signing error")
}

func (n errsigner) Verify(data, sig []byte) error {
func (n errsigner) Verify(ctx context.Context, data, sig []byte) error {
return errVerify
}

Expand All @@ -117,11 +118,11 @@ type errverifier int
var errVerify = fmt.Errorf("accepted signatures do not match threshold, Found: 0, Expected 1")
var errThreshold = fmt.Errorf("invalid threshold")

func (n errverifier) Sign(data []byte) ([]byte, error) {
func (n errverifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}

func (n errverifier) Verify(data, sig []byte) error {
func (n errverifier) Verify(ctx context.Context, data, sig []byte) error {
return errVerify
}

Expand All @@ -135,11 +136,11 @@ func (n errverifier) Public() crypto.PublicKey {

type badverifier int

func (n badverifier) Sign(data []byte) ([]byte, error) {
func (n badverifier) Sign(ctx context.Context, data []byte) ([]byte, error) {
return append(data, byte(0)), nil
}

func (n badverifier) Verify(data, sig []byte) error {
func (n badverifier) Verify(ctx context.Context, data, sig []byte) error {

if len(data) != len(sig) {
return errLength
Expand Down Expand Up @@ -199,7 +200,7 @@ func TestNilSign(t *testing.T) {
signer, err := NewEnvelopeSigner(ns)
assert.Nil(t, err, "unexpected error")

got, err := signer.SignPayload(payloadType, []byte(payload))
got, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")
assert.Equal(t, &want, got, "bad signature")
}
Expand All @@ -209,7 +210,7 @@ func TestSignError(t *testing.T) {
signer, err := NewEnvelopeSigner(es)
assert.Nil(t, err, "unexpected error")

got, err := signer.SignPayload("t", []byte("d"))
got, err := signer.SignPayload(context.TODO(), "t", []byte("d"))
assert.Nil(t, got, "expected nil")
assert.NotNil(t, err, "error expected")
assert.Equal(t, "signing error", err.Error(), "wrong error")
Expand Down Expand Up @@ -252,7 +253,7 @@ type EcdsaSigner struct {
verified bool
}

func (es *EcdsaSigner) Sign(data []byte) ([]byte, error) {
func (es *EcdsaSigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
// Data is complete message, hash it and sign the digest
digest := sha256.Sum256(data)
r, s, err := rfc6979.SignECDSA(es.key, digest[:], sha256.New)
Expand All @@ -268,7 +269,7 @@ func (es *EcdsaSigner) Sign(data []byte) ([]byte, error) {
return rawSig, nil
}

func (es *EcdsaSigner) Verify(data, sig []byte) error {
func (es *EcdsaSigner) Verify(ctx context.Context, data, sig []byte) error {
var r big.Int
var s big.Int
digest := sha256.Sum256(data)
Expand Down Expand Up @@ -319,12 +320,12 @@ func TestEcdsaSign(t *testing.T) {
signer, err := NewEnvelopeSigner(ecdsa)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "unexpected error")
assert.Equal(t, &want, env, "Wrong envelope generated")

// Now verify
acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "unexpected error")
assert.True(t, ecdsa.verified, "verify was not called")
assert.Len(t, acceptedKeys, 1, "unexpected keys")
Expand Down Expand Up @@ -384,10 +385,10 @@ func TestVerifyOneProvider(t *testing.T) {
signer, err := NewEnvelopeSigner(ns)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "unexpected error")
assert.Len(t, acceptedKeys, 1, "unexpected keys")
assert.Equal(t, acceptedKeys[0].KeyID, "nil", "unexpected keyid")
Expand All @@ -402,10 +403,10 @@ func TestVerifyMultipleProvider(t *testing.T) {
signer, err := NewEnvelopeSigner(ns, null)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "unexpected error")
assert.Len(t, acceptedKeys, 2, "unexpected keys")
}
Expand All @@ -418,10 +419,10 @@ func TestVerifyMultipleProviderThreshold(t *testing.T) {
var null nullsigner
signer, err := NewMultiEnvelopeSigner(2, ns, null)
assert.Nil(t, err)
env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "unexpected error")
assert.Len(t, acceptedKeys, 2, "unexpected keys")
}
Expand All @@ -443,10 +444,10 @@ func TestVerifyErr(t *testing.T) {
signer, err := NewEnvelopeSigner(errv)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

_, err = signer.Verify(env)
_, err = signer.Verify(context.TODO(), env)
assert.Equal(t, errVerify, err, "wrong error")
}

Expand All @@ -458,10 +459,10 @@ func TestBadVerifier(t *testing.T) {
signer, err := NewEnvelopeSigner(badv)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

_, err = signer.Verify(env)
_, err = signer.Verify(context.TODO(), env)
assert.NotNil(t, err, "expected error")
}

Expand All @@ -472,7 +473,7 @@ func TestVerifyNoSig(t *testing.T) {

env := &Envelope{}

_, err = signer.Verify(env)
_, err = signer.Verify(context.TODO(), env)
assert.Equal(t, ErrNoSignature, err, "wrong error")
}

Expand All @@ -489,7 +490,7 @@ func TestVerifyBadBase64(t *testing.T) {
},
}

_, err := signer.Verify(env)
_, err := signer.Verify(context.TODO(), env)
assert.IsType(t, base64.CorruptInputError(0), err, "wrong error")
})

Expand All @@ -503,7 +504,7 @@ func TestVerifyBadBase64(t *testing.T) {
},
}

_, err := signer.Verify(env)
_, err := signer.Verify(context.TODO(), env)
assert.IsType(t, base64.CorruptInputError(0), err, "wrong error")
})
}
Expand All @@ -527,7 +528,7 @@ func TestVerifyNoMatch(t *testing.T) {
},
}

_, err = signer.Verify(env)
_, err = signer.Verify(context.TODO(), env)
assert.NotNil(t, err, "expected error")
}

Expand All @@ -537,11 +538,11 @@ type interceptSigner struct {
verifyCalled bool
}

func (i *interceptSigner) Sign(data []byte) ([]byte, error) {
func (i *interceptSigner) Sign(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}

func (i *interceptSigner) Verify(data, sig []byte) error {
func (i *interceptSigner) Verify(ctx context.Context, data, sig []byte) error {
i.verifyCalled = true

if i.verifyRes {
Expand Down Expand Up @@ -573,10 +574,10 @@ func TestVerifyOneFail(t *testing.T) {
signer, err := NewEnvelopeSigner(s1, s2)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "expected error")
assert.True(t, s1.verifyCalled, "verify not called")
assert.True(t, s2.verifyCalled, "verify not called")
Expand All @@ -599,10 +600,10 @@ func TestVerifySameKeyID(t *testing.T) {
signer, err := NewEnvelopeSigner(s1, s2)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "expected error")
assert.True(t, s1.verifyCalled, "verify not called")
assert.True(t, s2.verifyCalled, "verify not called")
Expand All @@ -627,10 +628,10 @@ func TestVerifyEmptyKeyID(t *testing.T) {
signer, err := NewEnvelopeSigner(s1, s2)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "expected error")
// assert.True(t, s1.verifyCalled, "verify not called")
// assert.True(t, s2.verifyCalled, "verify not called")
Expand Down Expand Up @@ -658,10 +659,10 @@ func TestVerifyPublicKeyID(t *testing.T) {
signer, err := NewEnvelopeSigner(s1, s2)
assert.Nil(t, err, "unexpected error")

env, err := signer.SignPayload(payloadType, []byte(payload))
env, err := signer.SignPayload(context.TODO(), payloadType, []byte(payload))
assert.Nil(t, err, "sign failed")

acceptedKeys, err := signer.Verify(env)
acceptedKeys, err := signer.Verify(context.TODO(), env)
assert.Nil(t, err, "expected error")
assert.Len(t, acceptedKeys, 1, "unexpected keys")
assert.Equal(t, acceptedKeys[0].KeyID, keyID, "unexpected keyid")
Expand Down
7 changes: 4 additions & 3 deletions dsse/verify.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dsse

import (
"context"
"crypto"
"errors"
"fmt"
Expand All @@ -15,7 +16,7 @@ must perform the same steps.
If KeyID returns successfully, only signature matching the key ID will be verified.
*/
type Verifier interface {
Verify(data, sig []byte) error
Verify(ctx context.Context, data, sig []byte) error
KeyID() (string, error)
Public() crypto.PublicKey
}
Expand All @@ -31,7 +32,7 @@ type AcceptedKey struct {
Sig Signature
}

func (ev *EnvelopeVerifier) Verify(e *Envelope) ([]AcceptedKey, error) {
func (ev *EnvelopeVerifier) Verify(ctx context.Context, e *Envelope) ([]AcceptedKey, error) {
if e == nil {
return nil, errors.New("cannot verify a nil envelope")
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func (ev *EnvelopeVerifier) Verify(e *Envelope) ([]AcceptedKey, error) {
continue
}

err = v.Verify(paeEnc, sig)
err = v.Verify(ctx, paeEnc, sig)
if err != nil {
continue
}
Expand Down
Loading

0 comments on commit c32a0fd

Please sign in to comment.