diff --git a/data/encoder.go b/data/encoder.go index b7e90fc..47b657d 100644 --- a/data/encoder.go +++ b/data/encoder.go @@ -124,12 +124,9 @@ func writeRaw(w io.Writer, value interface{}, ignoreSigningFields bool) error { func encode(w io.Writer, value interface{}, ignoreSigningFields bool) error { v := reflect.Indirect(reflect.ValueOf(value)) - fields := getFields(&v, 0) + fields := getFields(&v, 0, ignoreSigningFields) // fmt.Println(fields.String()) return fields.Each(func(e enc, v interface{}) error { - if ignoreSigningFields && e.SigningField() { - return nil - } if err := writeEncoding(w, e); err != nil { return err } @@ -164,7 +161,7 @@ func (s *fieldSlice) Append(e enc, v interface{}, children fieldSlice) { *s = append(*s, field{e, v, children}) } -func getFields(v *reflect.Value, depth int) fieldSlice { +func getFields(v *reflect.Value, depth int, ignoreSigningFields bool) fieldSlice { // fmt.Println(v, v.Kind(), v.Type().Name()) length := v.NumField() fields := make(fieldSlice, 0, length) @@ -179,6 +176,9 @@ func getFields(v *reflect.Value, depth int) fieldSlice { continue } encoding := reverseEncodings[fieldName] + if ignoreSigningFields && encoding.SigningField() { + continue + } f := v.Field(i) // fmt.Println(fieldName, encoding, f, f.Kind()) if f.Kind() == reflect.Interface { @@ -199,16 +199,16 @@ func getFields(v *reflect.Value, depth int) fieldSlice { var children fieldSlice for i := 0; i < f.Len(); i++ { f2 := f.Index(i) - children = append(children, getFields(&f2, depth+1)...) + children = append(children, getFields(&f2, depth+1, ignoreSigningFields)...) } children.Append(reverseEncodings["EndOfArray"], nil, nil) fields.Append(encoding, nil, children) case ST_OBJECT: - children := getFields(&f, depth+1) + children := getFields(&f, depth+1, ignoreSigningFields) children.Append(reverseEncodings["EndOfObject"], nil, nil) fields.Append(encoding, nil, children) default: - fields = append(fields, getFields(&f, depth+1)...) + fields = append(fields, getFields(&f, depth+1, ignoreSigningFields)...) } } fields.Sort() diff --git a/data/format.go b/data/format.go index 9f48124..e00c758 100644 --- a/data/format.go +++ b/data/format.go @@ -295,7 +295,7 @@ func init() { signingFields = make(map[enc]struct{}) for e, name := range encodings { reverseEncodings[name] = e - if strings.Contains(name, "Signature") { + if strings.Contains(name, "Signature") || name == "Signers" { signingFields[e] = struct{}{} } } diff --git a/data/interface.go b/data/interface.go index 3d85545..a2d749f 100644 --- a/data/interface.go +++ b/data/interface.go @@ -25,6 +25,7 @@ type MultiSignable interface { GetPublicKey() *PublicKey GetSignature() *VariableLength SetSigners([]Signer) + GetSigners() []Signer } type Router interface { diff --git a/data/signing.go b/data/signing.go index eed2f89..a849385 100644 --- a/data/signing.go +++ b/data/signing.go @@ -1,6 +1,7 @@ package data import ( + "fmt" "sort" "github.com/rubblelabs/ripple/crypto" @@ -40,7 +41,6 @@ func MultiSign(s MultiSignable, key crypto.Key, sequence *uint32, account Accoun if err != nil { return err } - // msg = append(s.MultiSigningPrefix().Bytes(), msg...) msg = append(msg, account.Bytes()...) @@ -68,3 +68,33 @@ func SetSigners(s MultiSignable, signers ...Signer) error { copy(s.GetHash().Bytes(), hash.Bytes()) return nil } + +func CheckMultiSignature(s MultiSignable) (bool, []Signer, error) { + if len(s.GetSigners()) == 0 { + return false, nil, fmt.Errorf("no signers in the multi-signable transaction") + } + signers := s.GetSigners() + invalidSigners := make([]Signer, 0) + for _, signer := range signers { + account := signer.Signer.Account + pubKey := signer.Signer.SigningPubKey + signature := signer.Signer.TxnSignature + + hash, msg, err := MultiSigningHash(s, account) + if err != nil { + return false, nil, err + } + msg = append(s.MultiSigningPrefix().Bytes(), msg...) + msg = append(msg, account.Bytes()...) + + valid, err := crypto.Verify(pubKey.Bytes(), hash.Bytes(), msg, signature.Bytes()) + if err != nil { + return false, nil, err + } + if !valid { + invalidSigners = append(invalidSigners, signer) + } + } + + return len(invalidSigners) == 0, invalidSigners, nil +} diff --git a/data/signing_test.go b/data/signing_test.go new file mode 100644 index 0000000..3368b1b --- /dev/null +++ b/data/signing_test.go @@ -0,0 +1,111 @@ +package data + +import ( + "reflect" + "testing" + + "github.com/rubblelabs/ripple/crypto" +) + +func TestMultiSignWithVerification(t *testing.T) { + // generate signers seeds + seed1 := genSeedFromPassword(t, "password1") + seed2 := genSeedFromPassword(t, "password2") + seq := uint32(0) + key1 := seed1.Key(ECDSA) + account1 := seed1.AccountId(ECDSA, &seq) + key2 := seed2.Key(ECDSA) + account2 := seed2.AccountId(ECDSA, &seq) + + // prepare first signature + tx := buildPaymentTxForTheMultiSigning(t) + if err := MultiSign(tx, key1, &seq, account1); err != nil { + t.Fatal(err) + } + signer1 := Signer{ + Signer: SignerItem{ + Account: account1, + TxnSignature: tx.TxnSignature, + SigningPubKey: tx.SigningPubKey, + }, + } + // prepare second signature + tx = buildPaymentTxForTheMultiSigning(t) + if err := MultiSign(tx, key2, &seq, account2); err != nil { + t.Fatal(err) + } + signer2 := Signer{ + Signer: SignerItem{ + Account: account2, + TxnSignature: tx.TxnSignature, + SigningPubKey: tx.SigningPubKey, + }, + } + // rebuild the tx and set signers + tx = buildPaymentTxForTheMultiSigning(t) + if err := SetSigners(tx, signer1, signer2); err != nil { + t.Fatal(err) + } + // check that signature is valid + valid, invalidSigners, err := CheckMultiSignature(tx) + if !valid { + t.Fatal(err) + } + if len(invalidSigners) != 0 { + t.Fatal(err) + } + + // update one signature + tx.Signers[0].Signer.TxnSignature = tx.Signers[1].Signer.TxnSignature + valid, invalidSigners, err = CheckMultiSignature(tx) + if valid { + t.Fatal(err) + } + if len(invalidSigners) != 1 { + t.Fatal(err) + } + if !reflect.DeepEqual(invalidSigners[0], tx.Signers[0]) { + t.Fatal(err) + } + + // update tx data to check that both signers are invalid now + tx.Sequence = 123 + valid, invalidSigners, err = CheckMultiSignature(tx) + if valid { + t.Fatal(err) + } + if len(invalidSigners) != 2 { + t.Fatal(err) + } +} + +func buildPaymentTxForTheMultiSigning(t *testing.T) *Payment { + amount, err := NewAmount("1") + if err != nil { + t.Fatal(err) + } + tx := Payment{ + Amount: *amount, + TxBase: TxBase{ + Account: zeroAccount, + Sequence: 1, + TransactionType: PAYMENT, + }, + } + // important for the multi-signing + tx.TxBase.SigningPubKey = &PublicKey{} + return &tx +} + +func genSeedFromPassword(t *testing.T, password string) *Seed { + seedFromPass, err := crypto.GenerateFamilySeed(password) + if err != nil { + t.Fatal(err) + } + seed, err := NewSeedFromAddress(seedFromPass.String()) + if err != nil { + t.Fatal(err) + } + + return seed +} diff --git a/data/transaction.go b/data/transaction.go index ca9fdf5..4d17241 100644 --- a/data/transaction.go +++ b/data/transaction.go @@ -264,6 +264,7 @@ func (t *TxBase) GetSignature() *VariableLength { return t.TxnSignature } func (t *TxBase) SigningPrefix() HashPrefix { return HP_TRANSACTION_SIGN } func (t *TxBase) MultiSigningPrefix() HashPrefix { return HP_TRANSACTION_MULTISIGN } func (t *TxBase) SetSigners(signers []Signer) { t.Signers = signers } +func (t *TxBase) GetSigners() []Signer { return t.Signers } func (t *TxBase) PathSet() PathSet { return PathSet(nil) } func (t *TxBase) GetHash() *Hash256 { return &t.Hash }