Skip to content

Commit

Permalink
Merge pull request #48 from bianjieai/dreamer/apply-audit
Browse files Browse the repository at this point in the history
optimization: apply audit  suggestion
  • Loading branch information
aofengli authored Oct 23, 2023
2 parents ec385f7 + c409209 commit e1cc238
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 66 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@ Ref: https://keepachangelog.com/en/1.0.0/

### Improvements

* [\#48](https://github.com/bianjieai/nft-transfer/pull/48) apply audit suggestion.
* [\#17](https://github.com/bianjieai/nft-transfer/pull/17) replace param proposal with MsgUpdateParams.
* [\#15](https://github.com/bianjieai/nft-transfer/pull/15) solve the problem of "/" parsing error in classID.

### Features

### Bug Fixes

* [\#29](https://github.com/bianjieai/nft-transfer/pull/29) bump up Cosmos-SDK and IBC-Go.

## [v1.1.2]

### API Breaking
Expand Down
13 changes: 7 additions & 6 deletions ibc_module.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package nfttransfer

import (
"fmt"
"math"
"strings"

Expand Down Expand Up @@ -185,23 +184,25 @@ func (im IBCModule) OnRecvPacket(
var (
ack = channeltypes.NewResultAcknowledgement([]byte{byte(1)})
data types.NonFungibleTokenPacketData
err error
ackErr error
)

if err = types.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil {
if err := types.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil {
ack = channeltypes.NewErrorAcknowledgement(
errorsmod.Wrapf(sdkerrors.ErrInvalidType, "cannot unmarshal ICS-721 nft-transfer packet data"),
)
ackErr = err
}

// only attempt the application logic if the packet data
// was successfully decoded
if ack.Success() {
if err = im.keeper.OnRecvPacket(ctx, packet, data); err != nil {
if err := im.keeper.OnRecvPacket(ctx, packet, data); err != nil {
ack = channeltypes.NewErrorAcknowledgement(err)
ackErr = err
}
}
keeper.EmitAcknowledgementEvent(ctx, data, ack, err)
keeper.EmitAcknowledgementEvent(ctx, data, ack, ackErr)
// NOTE: acknowledgement will be written synchronously during IBC handler execution.
return ack
}
Expand Down Expand Up @@ -237,7 +238,7 @@ func (im IBCModule) OnAcknowledgementPacket(
sdk.NewAttribute(types.AttributeKeyReceiver, data.Receiver),
sdk.NewAttribute(types.AttributeKeyClassID, data.ClassId),
sdk.NewAttribute(types.AttributeKeyTokenIDs, strings.Join(data.TokenIds, ",")),
sdk.NewAttribute(types.AttributeKeyAckSuccess, fmt.Sprintf("%t", ack.Success())),
sdk.NewAttribute(types.AttributeKeyAck, ack.String()),
),
)

Expand Down
2 changes: 1 addition & 1 deletion keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (k Keeper) ClassTrace(c context.Context,

hash, err := types.ParseHexHash(strings.TrimPrefix(req.Hash, "ibc/"))
if err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid denom trace hash: %s, error: %s", hash.String(), err))
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid class trace hash: %s, error: %s", req.Hash, err))
}

ctx := sdk.UnwrapSDKContext(c)
Expand Down
31 changes: 9 additions & 22 deletions keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,10 @@ func (k Keeper) SendTransfer(
packet, err := k.createOutgoingPacket(ctx,
sourcePort,
sourceChannel,
destinationPort,
destinationChannel,
classID,
tokenIDs,
sender,
receiver,
timeoutHeight,
timeoutTimestamp,
memo,
)
if err != nil {
Expand Down Expand Up @@ -172,8 +168,8 @@ func (k Keeper) refundPacketToken(ctx sdk.Context, packet channeltypes.Packet, d
return err
}
if types.IsAwayFromOrigin(packet.GetSourcePort(), packet.GetSourceChannel(), data.ClassId) {
for _, tokenID := range data.TokenIds {
if err := k.nftKeeper.Transfer(ctx, voucherClassID, tokenID, "", sender); err != nil {
for i, tokenID := range data.TokenIds {
if err := k.nftKeeper.Transfer(ctx, voucherClassID, tokenID, types.GetIfExist(i, data.TokenData), sender); err != nil {
return err
}
}
Expand Down Expand Up @@ -201,14 +197,10 @@ func (k Keeper) refundPacketToken(ctx sdk.Context, packet channeltypes.Packet, d
func (k Keeper) createOutgoingPacket(ctx sdk.Context,
sourcePort,
sourceChannel,
destinationPort,
destinationChannel,
classID string,
tokenIDs []string,
sender sdk.AccAddress,
receiver string,
timeoutHeight clienttypes.Height,
timeoutTimestamp uint64,
memo string,
) (types.NonFungibleTokenPacketData, error) {
class, exist := k.nftKeeper.GetClass(ctx, classID)
Expand All @@ -220,8 +212,8 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context,
// NOTE: class and hex hash correctness checked during msg.ValidateBasic
fullClassPath = classID
err error
tokenURIs []string
tokenData []string
tokenURIs = make([]string, len(tokenIDs))
tokenData = make([]string, len(tokenIDs))
)

// deconstruct the token denomination into the denomination trace info
Expand All @@ -235,7 +227,7 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context,

isAwayFromOrigin := types.IsAwayFromOrigin(sourcePort,
sourceChannel, fullClassPath)
for _, tokenID := range tokenIDs {
for i, tokenID := range tokenIDs {
nft, exist := k.nftKeeper.GetNFT(ctx, classID, tokenID)
if !exist {
return types.NonFungibleTokenPacketData{}, errorsmod.Wrap(types.ErrInvalidTokenID, "tokenId not exist")
Expand All @@ -246,13 +238,13 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context,
return types.NonFungibleTokenPacketData{}, errorsmod.Wrap(sdkerrors.ErrUnauthorized, "not token owner")
}

tokenURIs = append(tokenURIs, nft.GetURI())
tokenData = append(tokenData, nft.GetData())
tokenURIs[i] = nft.GetURI()
tokenData[i] = nft.GetData()

if isAwayFromOrigin {
// create the escrow address for the tokens
escrowAddress := types.GetEscrowAddress(sourcePort, sourceChannel)
if err := k.nftKeeper.Transfer(ctx, classID, tokenID, "", escrowAddress); err != nil {
if err := k.nftKeeper.Transfer(ctx, classID, tokenID, nft.GetData(), escrowAddress); err != nil {
return types.NonFungibleTokenPacketData{}, err
}
} else {
Expand All @@ -273,12 +265,7 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context,
tokenData,
memo,
)

// check packet
if err := packetData.ValidateBasic(); err != nil {
return types.NonFungibleTokenPacketData{}, err
}
return packetData, nil
return packetData, packetData.ValidateBasic()
}

// processReceivedPacket will mint the tokens to receiver account
Expand Down
14 changes: 7 additions & 7 deletions keeper/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func (k Keeper) GetClassTrace(ctx sdk.Context, classTraceHash tmbytes.HexBytes)
return types.ClassTrace{}, false
}

denomTrace := k.MustUnmarshalClassTrace(bz)
return denomTrace, true
classTrace := k.MustUnmarshalClassTrace(bz)
return classTrace, true
}

// GetAllClassTraces returns the trace information for all the class.
Expand Down Expand Up @@ -67,16 +67,16 @@ func (k Keeper) ClassPathFromHash(ctx sdk.Context, classID string) (string, erro
}

// HasClassTrace checks if a the key with the given denomination trace hash exists on the store.
func (k Keeper) HasClassTrace(ctx sdk.Context, denomTraceHash tmbytes.HexBytes) bool {
func (k Keeper) HasClassTrace(ctx sdk.Context, classTraceHash tmbytes.HexBytes) bool {
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.ClassTraceKey)
return store.Has(denomTraceHash)
return store.Has(classTraceHash)
}

// SetClassTrace sets a new {trace hash -> class trace} pair to the store.
func (k Keeper) SetClassTrace(ctx sdk.Context, denomTrace types.ClassTrace) {
func (k Keeper) SetClassTrace(ctx sdk.Context, classTrace types.ClassTrace) {
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.ClassTraceKey)
bz := k.MustMarshalClassTrace(denomTrace)
store.Set(denomTrace.Hash(), bz)
bz := k.MustMarshalClassTrace(classTrace)
store.Set(classTrace.Hash(), bz)
}

// MustUnmarshalClassTrace attempts to decode and return an ClassTrace object from
Expand Down
6 changes: 5 additions & 1 deletion types/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ import (
// on the provided LegacyAmino codec. These types are used for Amino JSON serialization.
func RegisterLegacyAminoCodec(cdc *codec.LegacyAmino) {
cdc.RegisterConcrete(&MsgTransfer{}, "cosmos-sdk/MsgTransferNFT", nil)
cdc.RegisterConcrete(&MsgUpdateParams{}, "cosmos-sdk/MsgUpdateParams", nil)
}

// RegisterInterfaces register the ibc nft-transfer module interfaces to protobuf
// Any.
func RegisterInterfaces(registry codectypes.InterfaceRegistry) {
registry.RegisterImplementations((*sdk.Msg)(nil), &MsgTransfer{})
registry.RegisterImplementations((*sdk.Msg)(nil),
&MsgTransfer{},
&MsgUpdateParams{},
)
msgservice.RegisterMsgServiceDesc(registry, &_Msg_serviceDesc)
}

Expand Down
1 change: 1 addition & 0 deletions types/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const (
AttributeKeyReceiver = "receiver"
AttributeKeyClassID = "classID"
AttributeKeyTokenIDs = "tokenIDs"
AttributeKeyAck = "acknowledgement"
AttributeKeyAckSuccess = "success"
AttributeKeyAckError = "error"
AttributeKeyTraceHash = "trace_hash"
Expand Down
14 changes: 11 additions & 3 deletions types/msgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ func (msg MsgTransfer) ValidateBasic() error {
return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank")
}

for _, tokenID := range msg.TokenIds {
if strings.TrimSpace(tokenID) == "" {
seen := make(map[string]int64)
for i, id := range msg.TokenIds {
if strings.TrimSpace(id) == "" {
return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank")
}
if j, exist := seen[id]; exist {
return errorsmod.Wrapf(ErrInvalidTokenID, "the tokenId at positions %d and %d in the array are repeated", i, j)
}
seen[id] = int64(i)
}

// NOTE: sender format must be validated as it is required by the GetSigners function.
Expand Down Expand Up @@ -116,6 +121,9 @@ func (msg MsgUpdateParams) GetSignBytes() []byte {

// GetSigners returns the expected signers for a MsgUpdateParams.
func (msg MsgUpdateParams) GetSigners() []sdk.AccAddress {
authority, _ := sdk.AccAddressFromBech32(msg.Authority)
authority, err := sdk.AccAddressFromBech32(msg.Authority)
if err != nil {
panic(err)
}
return []sdk.AccAddress{authority}
}
7 changes: 6 additions & 1 deletion types/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ func (nftpd NonFungibleTokenPacketData) ValidateBasic() error {
return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be empty")
}

for _, id := range nftpd.TokenIds {
seen := make(map[string]int64)
for i, id := range nftpd.TokenIds {
if strings.TrimSpace(id) == "" {
return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank")
}
if j, exist := seen[id]; exist {
return errorsmod.Wrapf(ErrInvalidTokenID, "the tokenId at positions %d and %d in the array are repeated", i, j)
}
seen[id] = int64(i)
}

if (len(nftpd.TokenUris) != 0) && len(nftpd.TokenIds) != len(nftpd.TokenUris) {
Expand Down
5 changes: 5 additions & 0 deletions types/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ func TestNonFungibleTokenPacketData_ValidateBasic(t *testing.T) {
packet: NonFungibleTokenPacketData{"cryptoCat", "uri", "", []string{}, []string{"kitty_uri"}, tokenData, sender, receiver, "memo"},
wantErr: true,
},
{
name: "invalid packet with repeated tokenIds",
packet: NonFungibleTokenPacketData{"cryptoCat", "uri", "", []string{"kitty","kitty"}, []string{"kitty_uri","kitty_uri"}, tokenData, sender, receiver, "memo"},
wantErr: true,
},
{
name: "valid packet with empty tokenUris",
packet: NonFungibleTokenPacketData{"cryptoCat", "uri", "", []string{"kitty"}, []string{}, tokenData, sender, receiver, "memo"},
Expand Down
48 changes: 39 additions & 9 deletions types/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,20 @@ import (
tmbytes "github.com/cometbft/cometbft/libs/bytes"
tmtypes "github.com/cometbft/cometbft/types"

channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
host "github.com/cosmos/ibc-go/v7/modules/core/24-host"
)

// ParseHexHash parses a hex hash in string format to bytes and validates its correctness.
func ParseHexHash(hexHash string) (tmbytes.HexBytes, error) {
if strings.TrimSpace(hexHash) == "" {
return nil, fmt.Errorf("empty hex hash")
}
hash, err := hex.DecodeString(hexHash)
if err != nil {
return nil, err
}

if err := tmtypes.ValidateHash(hash); err != nil {
return nil, err
}

return hash, nil
return hash, tmtypes.ValidateHash(hash)
}

// GetClassPrefix returns the receiving class prefix
Expand Down Expand Up @@ -63,6 +62,7 @@ func IsAwayFromOrigin(sourcePort, sourceChannel, fullClassPath string) bool {
// Examples:
//
// - "port-1/channel-1/class-1" => ClassTrace{Path: "port-1/channel-1", BaseClassId: "class-1"}
// - "port-1/channel-1/class/1" => ClassTrace{Path: "port-1/channel-1", BaseClassId: "class/1"}
// - "class-1" => ClassTrace{Path: "", BaseClassId: "class-1"}
func ParseClassTrace(rawClassID string) ClassTrace {
classSplit := strings.Split(rawClassID, "/")
Expand All @@ -74,9 +74,10 @@ func ParseClassTrace(rawClassID string) ClassTrace {
}
}

path, baseClassId := extractPathAndBaseFromFullClassID(classSplit)
return ClassTrace{
Path: strings.Join(classSplit[:len(classSplit)-1], "/"),
BaseClassId: classSplit[len(classSplit)-1],
Path: path,
BaseClassId: baseClassId,
}
}

Expand Down Expand Up @@ -139,7 +140,7 @@ func validateTraceIdentifiers(identifiers []string) error {
return errorsmod.Wrapf(err, "invalid port ID at position %d", i)
}
if err := host.ChannelIdentifierValidator(identifiers[i+1]); err != nil {
return errorsmod.Wrapf(err, "invalid channel ID at position %d", i)
return errorsmod.Wrapf(err, "invalid channel ID at position %d", i+1)
}
}
return nil
Expand Down Expand Up @@ -181,3 +182,32 @@ func (t Traces) Sort() Traces {
sort.Sort(t)
return t
}

// extractPathAndBaseFromFullClassID returns the trace path and the base classID from
// the elements that constitute the complete classID.
func extractPathAndBaseFromFullClassID(fullClassIdItems []string) (string, string) {
var (
pathSlice []string
baseClassIdSlice []string
)

length := len(fullClassIdItems)
for i := 0; i < length; i += 2 {
// The IBC specification does not guarantee the expected format of the
// destination port or destination channel identifier. A short term solution
// to determine base classID is to expect the channel identifier to be the
// one ibc-go specifies. A longer term solution is to separate the path and base
// denomination in the ICS721 packet.
if i < length-1 && length > 2 && channeltypes.IsValidChannelID(fullClassIdItems[i+1]) {
pathSlice = append(pathSlice, fullClassIdItems[i], fullClassIdItems[i+1])
} else {
baseClassIdSlice = fullClassIdItems[i:]
break
}
}

path := strings.Join(pathSlice, "/")
baseClassID := strings.Join(baseClassIdSlice, "/")

return path, baseClassID
}
Loading

0 comments on commit e1cc238

Please sign in to comment.