diff --git a/CHANGELOG.md b/CHANGELOG.md index b37778d..04aafa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,10 +44,16 @@ Ref: https://keepachangelog.com/en/1.0.0/ ### Improvements +* [\#48](https://github.com/bianjieai/nft-transfer/pull/48) apply audit suggestion. +* [\#17](https://github.com/bianjieai/nft-transfer/pull/17) replace param proposal with MsgUpdateParams. +* [\#15](https://github.com/bianjieai/nft-transfer/pull/15) solve the problem of "/" parsing error in classID. + ### Features ### Bug Fixes +* [\#29](https://github.com/bianjieai/nft-transfer/pull/29) bump up Cosmos-SDK and IBC-Go. + ## [v1.1.2] ### API Breaking diff --git a/ibc_module.go b/ibc_module.go index 39ceff4..f496c65 100644 --- a/ibc_module.go +++ b/ibc_module.go @@ -1,7 +1,6 @@ package nfttransfer import ( - "fmt" "math" "strings" @@ -185,23 +184,25 @@ func (im IBCModule) OnRecvPacket( var ( ack = channeltypes.NewResultAcknowledgement([]byte{byte(1)}) data types.NonFungibleTokenPacketData - err error + ackErr error ) - if err = types.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil { + if err := types.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil { ack = channeltypes.NewErrorAcknowledgement( errorsmod.Wrapf(sdkerrors.ErrInvalidType, "cannot unmarshal ICS-721 nft-transfer packet data"), ) + ackErr = err } // only attempt the application logic if the packet data // was successfully decoded if ack.Success() { - if err = im.keeper.OnRecvPacket(ctx, packet, data); err != nil { + if err := im.keeper.OnRecvPacket(ctx, packet, data); err != nil { ack = channeltypes.NewErrorAcknowledgement(err) + ackErr = err } } - keeper.EmitAcknowledgementEvent(ctx, data, ack, err) + keeper.EmitAcknowledgementEvent(ctx, data, ack, ackErr) // NOTE: acknowledgement will be written synchronously during IBC handler execution. return ack } @@ -237,7 +238,7 @@ func (im IBCModule) OnAcknowledgementPacket( sdk.NewAttribute(types.AttributeKeyReceiver, data.Receiver), sdk.NewAttribute(types.AttributeKeyClassID, data.ClassId), sdk.NewAttribute(types.AttributeKeyTokenIDs, strings.Join(data.TokenIds, ",")), - sdk.NewAttribute(types.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack.Success())), + sdk.NewAttribute(types.AttributeKeyAck, ack.String()), ), ) diff --git a/keeper/grpc_query.go b/keeper/grpc_query.go index 6de5a2c..59e46eb 100644 --- a/keeper/grpc_query.go +++ b/keeper/grpc_query.go @@ -27,7 +27,7 @@ func (k Keeper) ClassTrace(c context.Context, hash, err := types.ParseHexHash(strings.TrimPrefix(req.Hash, "ibc/")) if err != nil { - return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid denom trace hash: %s, error: %s", hash.String(), err)) + return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid class trace hash: %s, error: %s", req.Hash, err)) } ctx := sdk.UnwrapSDKContext(c) diff --git a/keeper/relay.go b/keeper/relay.go index a973641..9fea092 100644 --- a/keeper/relay.go +++ b/keeper/relay.go @@ -76,14 +76,10 @@ func (k Keeper) SendTransfer( packet, err := k.createOutgoingPacket(ctx, sourcePort, sourceChannel, - destinationPort, - destinationChannel, classID, tokenIDs, sender, receiver, - timeoutHeight, - timeoutTimestamp, memo, ) if err != nil { @@ -172,8 +168,8 @@ func (k Keeper) refundPacketToken(ctx sdk.Context, packet channeltypes.Packet, d return err } if types.IsAwayFromOrigin(packet.GetSourcePort(), packet.GetSourceChannel(), data.ClassId) { - for _, tokenID := range data.TokenIds { - if err := k.nftKeeper.Transfer(ctx, voucherClassID, tokenID, "", sender); err != nil { + for i, tokenID := range data.TokenIds { + if err := k.nftKeeper.Transfer(ctx, voucherClassID, tokenID, types.GetIfExist(i, data.TokenData), sender); err != nil { return err } } @@ -201,14 +197,10 @@ func (k Keeper) refundPacketToken(ctx sdk.Context, packet channeltypes.Packet, d func (k Keeper) createOutgoingPacket(ctx sdk.Context, sourcePort, sourceChannel, - destinationPort, - destinationChannel, classID string, tokenIDs []string, sender sdk.AccAddress, receiver string, - timeoutHeight clienttypes.Height, - timeoutTimestamp uint64, memo string, ) (types.NonFungibleTokenPacketData, error) { class, exist := k.nftKeeper.GetClass(ctx, classID) @@ -220,8 +212,8 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context, // NOTE: class and hex hash correctness checked during msg.ValidateBasic fullClassPath = classID err error - tokenURIs []string - tokenData []string + tokenURIs = make([]string, len(tokenIDs)) + tokenData = make([]string, len(tokenIDs)) ) // deconstruct the token denomination into the denomination trace info @@ -235,7 +227,7 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context, isAwayFromOrigin := types.IsAwayFromOrigin(sourcePort, sourceChannel, fullClassPath) - for _, tokenID := range tokenIDs { + for i, tokenID := range tokenIDs { nft, exist := k.nftKeeper.GetNFT(ctx, classID, tokenID) if !exist { return types.NonFungibleTokenPacketData{}, errorsmod.Wrap(types.ErrInvalidTokenID, "tokenId not exist") @@ -246,13 +238,13 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context, return types.NonFungibleTokenPacketData{}, errorsmod.Wrap(sdkerrors.ErrUnauthorized, "not token owner") } - tokenURIs = append(tokenURIs, nft.GetURI()) - tokenData = append(tokenData, nft.GetData()) + tokenURIs[i] = nft.GetURI() + tokenData[i] = nft.GetData() if isAwayFromOrigin { // create the escrow address for the tokens escrowAddress := types.GetEscrowAddress(sourcePort, sourceChannel) - if err := k.nftKeeper.Transfer(ctx, classID, tokenID, "", escrowAddress); err != nil { + if err := k.nftKeeper.Transfer(ctx, classID, tokenID, nft.GetData(), escrowAddress); err != nil { return types.NonFungibleTokenPacketData{}, err } } else { @@ -273,12 +265,7 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context, tokenData, memo, ) - - // check packet - if err := packetData.ValidateBasic(); err != nil { - return types.NonFungibleTokenPacketData{}, err - } - return packetData, nil + return packetData, packetData.ValidateBasic() } // processReceivedPacket will mint the tokens to receiver account diff --git a/keeper/trace.go b/keeper/trace.go index 57e4c6c..0e94394 100644 --- a/keeper/trace.go +++ b/keeper/trace.go @@ -18,8 +18,8 @@ func (k Keeper) GetClassTrace(ctx sdk.Context, classTraceHash tmbytes.HexBytes) return types.ClassTrace{}, false } - denomTrace := k.MustUnmarshalClassTrace(bz) - return denomTrace, true + classTrace := k.MustUnmarshalClassTrace(bz) + return classTrace, true } // GetAllClassTraces returns the trace information for all the class. @@ -67,16 +67,16 @@ func (k Keeper) ClassPathFromHash(ctx sdk.Context, classID string) (string, erro } // HasClassTrace checks if a the key with the given denomination trace hash exists on the store. -func (k Keeper) HasClassTrace(ctx sdk.Context, denomTraceHash tmbytes.HexBytes) bool { +func (k Keeper) HasClassTrace(ctx sdk.Context, classTraceHash tmbytes.HexBytes) bool { store := prefix.NewStore(ctx.KVStore(k.storeKey), types.ClassTraceKey) - return store.Has(denomTraceHash) + return store.Has(classTraceHash) } // SetClassTrace sets a new {trace hash -> class trace} pair to the store. -func (k Keeper) SetClassTrace(ctx sdk.Context, denomTrace types.ClassTrace) { +func (k Keeper) SetClassTrace(ctx sdk.Context, classTrace types.ClassTrace) { store := prefix.NewStore(ctx.KVStore(k.storeKey), types.ClassTraceKey) - bz := k.MustMarshalClassTrace(denomTrace) - store.Set(denomTrace.Hash(), bz) + bz := k.MustMarshalClassTrace(classTrace) + store.Set(classTrace.Hash(), bz) } // MustUnmarshalClassTrace attempts to decode and return an ClassTrace object from diff --git a/types/codec.go b/types/codec.go index 6385fdb..6d6f0c5 100644 --- a/types/codec.go +++ b/types/codec.go @@ -15,12 +15,16 @@ import ( // on the provided LegacyAmino codec. These types are used for Amino JSON serialization. func RegisterLegacyAminoCodec(cdc *codec.LegacyAmino) { cdc.RegisterConcrete(&MsgTransfer{}, "cosmos-sdk/MsgTransferNFT", nil) + cdc.RegisterConcrete(&MsgUpdateParams{}, "cosmos-sdk/MsgUpdateParams", nil) } // RegisterInterfaces register the ibc nft-transfer module interfaces to protobuf // Any. func RegisterInterfaces(registry codectypes.InterfaceRegistry) { - registry.RegisterImplementations((*sdk.Msg)(nil), &MsgTransfer{}) + registry.RegisterImplementations((*sdk.Msg)(nil), + &MsgTransfer{}, + &MsgUpdateParams{}, + ) msgservice.RegisterMsgServiceDesc(registry, &_Msg_serviceDesc) } diff --git a/types/events.go b/types/events.go index 146adaa..5d413ab 100644 --- a/types/events.go +++ b/types/events.go @@ -12,6 +12,7 @@ const ( AttributeKeyReceiver = "receiver" AttributeKeyClassID = "classID" AttributeKeyTokenIDs = "tokenIDs" + AttributeKeyAck = "acknowledgement" AttributeKeyAckSuccess = "success" AttributeKeyAckError = "error" AttributeKeyTraceHash = "trace_hash" diff --git a/types/msgs.go b/types/msgs.go index d026b09..ec63eea 100644 --- a/types/msgs.go +++ b/types/msgs.go @@ -67,10 +67,15 @@ func (msg MsgTransfer) ValidateBasic() error { return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank") } - for _, tokenID := range msg.TokenIds { - if strings.TrimSpace(tokenID) == "" { + seen := make(map[string]int64) + for i, id := range msg.TokenIds { + if strings.TrimSpace(id) == "" { return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank") } + if j, exist := seen[id]; exist { + return errorsmod.Wrapf(ErrInvalidTokenID, "the tokenId at positions %d and %d in the array are repeated", i, j) + } + seen[id] = int64(i) } // NOTE: sender format must be validated as it is required by the GetSigners function. @@ -116,6 +121,9 @@ func (msg MsgUpdateParams) GetSignBytes() []byte { // GetSigners returns the expected signers for a MsgUpdateParams. func (msg MsgUpdateParams) GetSigners() []sdk.AccAddress { - authority, _ := sdk.AccAddressFromBech32(msg.Authority) + authority, err := sdk.AccAddressFromBech32(msg.Authority) + if err != nil { + panic(err) + } return []sdk.AccAddress{authority} } diff --git a/types/packet.go b/types/packet.go index 21720e0..e33c8bb 100644 --- a/types/packet.go +++ b/types/packet.go @@ -55,10 +55,15 @@ func (nftpd NonFungibleTokenPacketData) ValidateBasic() error { return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be empty") } - for _, id := range nftpd.TokenIds { + seen := make(map[string]int64) + for i, id := range nftpd.TokenIds { if strings.TrimSpace(id) == "" { return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank") } + if j, exist := seen[id]; exist { + return errorsmod.Wrapf(ErrInvalidTokenID, "the tokenId at positions %d and %d in the array are repeated", i, j) + } + seen[id] = int64(i) } if (len(nftpd.TokenUris) != 0) && len(nftpd.TokenIds) != len(nftpd.TokenUris) { diff --git a/types/packet_test.go b/types/packet_test.go index b29886a..aefa612 100644 --- a/types/packet_test.go +++ b/types/packet_test.go @@ -27,6 +27,11 @@ func TestNonFungibleTokenPacketData_ValidateBasic(t *testing.T) { packet: NonFungibleTokenPacketData{"cryptoCat", "uri", "", []string{}, []string{"kitty_uri"}, tokenData, sender, receiver, "memo"}, wantErr: true, }, + { + name: "invalid packet with repeated tokenIds", + packet: NonFungibleTokenPacketData{"cryptoCat", "uri", "", []string{"kitty","kitty"}, []string{"kitty_uri","kitty_uri"}, tokenData, sender, receiver, "memo"}, + wantErr: true, + }, { name: "valid packet with empty tokenUris", packet: NonFungibleTokenPacketData{"cryptoCat", "uri", "", []string{"kitty"}, []string{}, tokenData, sender, receiver, "memo"}, diff --git a/types/trace.go b/types/trace.go index 1768fc0..163573d 100644 --- a/types/trace.go +++ b/types/trace.go @@ -11,21 +11,20 @@ import ( tmbytes "github.com/cometbft/cometbft/libs/bytes" tmtypes "github.com/cometbft/cometbft/types" + channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" host "github.com/cosmos/ibc-go/v7/modules/core/24-host" ) // ParseHexHash parses a hex hash in string format to bytes and validates its correctness. func ParseHexHash(hexHash string) (tmbytes.HexBytes, error) { + if strings.TrimSpace(hexHash) == "" { + return nil, fmt.Errorf("empty hex hash") + } hash, err := hex.DecodeString(hexHash) if err != nil { return nil, err } - - if err := tmtypes.ValidateHash(hash); err != nil { - return nil, err - } - - return hash, nil + return hash, tmtypes.ValidateHash(hash) } // GetClassPrefix returns the receiving class prefix @@ -63,6 +62,7 @@ func IsAwayFromOrigin(sourcePort, sourceChannel, fullClassPath string) bool { // Examples: // // - "port-1/channel-1/class-1" => ClassTrace{Path: "port-1/channel-1", BaseClassId: "class-1"} +// - "port-1/channel-1/class/1" => ClassTrace{Path: "port-1/channel-1", BaseClassId: "class/1"} // - "class-1" => ClassTrace{Path: "", BaseClassId: "class-1"} func ParseClassTrace(rawClassID string) ClassTrace { classSplit := strings.Split(rawClassID, "/") @@ -74,9 +74,10 @@ func ParseClassTrace(rawClassID string) ClassTrace { } } + path, baseClassId := extractPathAndBaseFromFullClassID(classSplit) return ClassTrace{ - Path: strings.Join(classSplit[:len(classSplit)-1], "/"), - BaseClassId: classSplit[len(classSplit)-1], + Path: path, + BaseClassId: baseClassId, } } @@ -139,7 +140,7 @@ func validateTraceIdentifiers(identifiers []string) error { return errorsmod.Wrapf(err, "invalid port ID at position %d", i) } if err := host.ChannelIdentifierValidator(identifiers[i+1]); err != nil { - return errorsmod.Wrapf(err, "invalid channel ID at position %d", i) + return errorsmod.Wrapf(err, "invalid channel ID at position %d", i+1) } } return nil @@ -181,3 +182,32 @@ func (t Traces) Sort() Traces { sort.Sort(t) return t } + +// extractPathAndBaseFromFullClassID returns the trace path and the base classID from +// the elements that constitute the complete classID. +func extractPathAndBaseFromFullClassID(fullClassIdItems []string) (string, string) { + var ( + pathSlice []string + baseClassIdSlice []string + ) + + length := len(fullClassIdItems) + for i := 0; i < length; i += 2 { + // The IBC specification does not guarantee the expected format of the + // destination port or destination channel identifier. A short term solution + // to determine base classID is to expect the channel identifier to be the + // one ibc-go specifies. A longer term solution is to separate the path and base + // denomination in the ICS721 packet. + if i < length-1 && length > 2 && channeltypes.IsValidChannelID(fullClassIdItems[i+1]) { + pathSlice = append(pathSlice, fullClassIdItems[i], fullClassIdItems[i+1]) + } else { + baseClassIdSlice = fullClassIdItems[i:] + break + } + } + + path := strings.Join(pathSlice, "/") + baseClassID := strings.Join(baseClassIdSlice, "/") + + return path, baseClassID +} diff --git a/types/trace_test.go b/types/trace_test.go index c668526..f2b1a7d 100644 --- a/types/trace_test.go +++ b/types/trace_test.go @@ -16,12 +16,12 @@ func TestIsAwayFromOrigin(t *testing.T) { args args want bool }{ - {"transfer forward by origin chain", args{"p1", "c1", "kitty"}, true}, - {"transfer forward by relay chain", args{"p3", "c3", "p2/c2/kitty"}, true}, - {"transfer forward by relay chain", args{"p5", "c5", "p4/c4/p2/c2/kitty"}, true}, - {"transfer back by relay chain", args{"p6", "c6", "p6/c6/p4/c4/p2/c2/kitty"}, false}, - {"transfer back by relay chain", args{"p4", "c4", "p4/c4/p2/c2/kitty"}, false}, - {"transfer back by relay chain", args{"p2", "c2", "p2/c2/kitty"}, false}, + {"transfer forward by origin chain", args{"port-1", "channel-1", "kitty"}, true}, + {"transfer forward by relay chain", args{"port-3", "channel-3", "port-2/channel-2/kitty"}, true}, + {"transfer forward by relay chain", args{"port-5", "channel-5", "port-4/channel-4/port-2/channel-2/kitty"}, true}, + {"transfer back by relay chain", args{"port-6", "channel-6", "port-6/channel-6/port-4/channel-4/port-2/channel-2/kitty"}, false}, + {"transfer back by relay chain", args{"port-4", "channel-4", "port-4/channel-4/port-2/channel-2/kitty"}, false}, + {"transfer back by relay chain", args{"port-2", "channel-2", "port-2/channel-2/kitty"}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -42,14 +42,18 @@ func TestParseClassTrace(t *testing.T) { want ClassTrace }{ {"native class", args{"kitty"}, ClassTrace{Path: "", BaseClassId: "kitty"}}, - {"transfer to (p2,c2)", args{"p2/c2/kitty"}, ClassTrace{Path: "p2/c2", BaseClassId: "kitty"}}, - {"transfer to (p4,c4)", args{"p4/c4/p2/c2/kitty"}, ClassTrace{Path: "p4/c4/p2/c2", BaseClassId: "kitty"}}, - {"transfer to (p6,c6)", args{"p6/c6/p4/c4/p2/c2/kitty"}, ClassTrace{Path: "p6/c6/p4/c4/p2/c2", BaseClassId: "kitty"}}, + {"transfer to (port-2,channel-2)", args{"port-2/channel-2/kitty"}, ClassTrace{Path: "port-2/channel-2", BaseClassId: "kitty"}}, + {"transfer to (port-4,channel-4)", args{"port-4/channel-4/port-2/channel-2/kitty"}, ClassTrace{Path: "port-4/channel-4/port-2/channel-2", BaseClassId: "kitty"}}, + {"transfer to (port-6,channel-6)", args{"port-6/channel-6/port-4/channel-4/port-2/channel-2/kitty"}, ClassTrace{Path: "port-6/channel-6/port-4/channel-4/port-2/channel-2", BaseClassId: "kitty"}}, + {"native class with /", args{"cat/kitty"}, ClassTrace{Path: "", BaseClassId: "cat/kitty"}}, + {"transfer to (port-2,channel-2) with /", args{"port-2/channel-2/cat/kitty"}, ClassTrace{Path: "port-2/channel-2", BaseClassId: "cat/kitty"}}, + {"transfer to (port-4,channel-4) with /", args{"port-4/channel-4/port-2/channel-2/cat/kitty"}, ClassTrace{Path: "port-4/channel-4/port-2/channel-2", BaseClassId: "cat/kitty"}}, + {"transfer to (port-6,channel-6) with /", args{"port-6/channel-6/port-4/channel-4/port-2/channel-2/cat/kitty"}, ClassTrace{Path: "port-6/channel-6/port-4/channel-4/port-2/channel-2", BaseClassId: "cat/kitty"}}, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := ParseClassTrace(tt.args.rawClassID); !reflect.DeepEqual(got, tt.want) { - t.Errorf("ParseClassTrace() = %v, want %v", got, tt.want) + for i := range tests { + t.Run(tests[i].name, func(t *testing.T) { + if got := ParseClassTrace(tests[i].args.rawClassID); !reflect.DeepEqual(got, tests[i].want) { + t.Errorf("ParseClassTrace() = %v, want %v", got, tests[i].want) } }) } @@ -62,9 +66,9 @@ func TestClassTrace_GetFullClassPath(t *testing.T) { want string }{ {"native class", ClassTrace{Path: "", BaseClassId: "kitty"}, "kitty"}, - {"first tranfer", ClassTrace{Path: "p2/c2", BaseClassId: "kitty"}, "p2/c2/kitty"}, - {"second tranfer", ClassTrace{Path: "p4/c4/p2/c2", BaseClassId: "kitty"}, "p4/c4/p2/c2/kitty"}, - {"third tranfer", ClassTrace{Path: "p6/c6/p4/c4/p2/c2", BaseClassId: "kitty"}, "p6/c6/p4/c4/p2/c2/kitty"}, + {"first tranfer", ClassTrace{Path: "port-2/channel-2", BaseClassId: "kitty"}, "port-2/channel-2/kitty"}, + {"second tranfer", ClassTrace{Path: "port-4/channel-4/port-2/channel-2", BaseClassId: "kitty"}, "port-4/channel-4/port-2/channel-2/kitty"}, + {"third tranfer", ClassTrace{Path: "port-6/channel-6/port-4/channel-4/port-2/channel-2", BaseClassId: "kitty"}, "port-6/channel-6/port-4/channel-4/port-2/channel-2/kitty"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {