From c0bc45f03db7d4856c960190c2057dd78e76137a Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 20 Sep 2024 11:58:39 +0800 Subject: [PATCH 1/6] [Flyte][3][Attribute Access] Binary IDL With MessagePack Signed-off-by: Future-Outlier --- .../controller/nodes/attr_path_resolver.go | 97 ++- .../nodes/attr_path_resolver_test.go | 557 +++++++++++++++++- go.mod | 1 + 3 files changed, 629 insertions(+), 26 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go index 42150cb887..0bcd610801 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go @@ -1,6 +1,7 @@ package nodes import ( + "github.com/shamaton/msgpack/v2" "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -37,13 +38,20 @@ func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath } } - // resolve dataclass - if currVal.GetScalar() != nil && currVal.GetScalar().GetGeneric() != nil { - st := currVal.GetScalar().GetGeneric() - // start from index "count" - currVal, err = resolveAttrPathInPbStruct(nodeID, st, bindAttrPath[count:]) - if err != nil { - return nil, err + // resolve dataclass and Pydantic BaseModel + if scalar := currVal.GetScalar(); scalar != nil { + if binary := scalar.GetBinary(); binary != nil { + // Start from index "count" + currVal, err = resolveAttrPathInBinary(nodeID, binary, bindAttrPath[count:]) + if err != nil { + return nil, err + } + } else if generic := scalar.GetGeneric(); generic != nil { + // Start from index "count" + currVal, err = resolveAttrPathInPbStruct(nodeID, generic, bindAttrPath[count:]) + if err != nil { + return nil, err + } } } @@ -84,6 +92,79 @@ func resolveAttrPathInPbStruct(nodeID string, st *structpb.Struct, bindAttrPath return literal, err } +// resolveAttrPathInBinary resolves the binary idl object (e.g. dataclass, pydantic basemodel) with attribute path +func resolveAttrPathInBinary(nodeID string, binaryIDL *core.Binary, bindAttrPath []*core.PromiseAttribute) (*core. + Literal, + error) { + + binaryBytes := binaryIDL.GetValue() + serializationFormat := binaryIDL.GetTag() + + var currVal interface{} + var tmpVal interface{} + var exist bool + + if serializationFormat == "msgpack" { + err := msgpack.Unmarshal(binaryBytes, &currVal) + if err != nil { + return nil, err + + } + } else { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, + "Unsupported format '%v' found for literal value.\n"+ + "Please ensure the serialization format is supported.", serializationFormat) + } + + // Turn the current value to a map, so it can be resolved more easily + for _, attr := range bindAttrPath { + switch resolvedVal := currVal.(type) { + // map + case map[interface{}]interface{}: + tmpVal, exist = resolvedVal[attr.GetStringValue()] + if !exist { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal) + } + currVal = tmpVal + // list + case []interface{}: + if int(attr.GetIntValue()) >= len(resolvedVal) { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal) + } + currVal = resolvedVal[attr.GetIntValue()] + } + } + + if serializationFormat == "msgpack" { + // Marshal the current value to MessagePack bytes + resolvedBinaryBytes, err := msgpack.Marshal(currVal) + if err != nil { + return nil, err + } + // Construct and return the binary-encoded literal + return constructResolvedBinary(resolvedBinaryBytes, serializationFormat), nil + } + // Unsupported serialization format + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, + "Unsupported format '%v' found for literal value.\n"+ + "Please ensure the serialization format is supported.", serializationFormat) +} + +func constructResolvedBinary(resolvedBinaryBytes []byte, serializationFormat string) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: resolvedBinaryBytes, + Tag: serializationFormat, + }, + }, + }, + }, + } +} + // convertInterfaceToLiteral converts the protobuf struct (e.g. dataclass) to literal func convertInterfaceToLiteral(nodeID string, obj interface{}) (*core.Literal, error) { @@ -128,7 +209,7 @@ func convertInterfaceToLiteral(nodeID string, obj interface{}) (*core.Literal, e return literal, nil } -// convertInterfaceToLiteralScalar converts the a single value to a literal scalar +// convertInterfaceToLiteralScalar converts a single value to a literal scalar func convertInterfaceToLiteralScalar(nodeID string, obj interface{}) (*core.Literal_Scalar, error) { value := &core.Primitive{} diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go index fb966c666e..e0486aae5e 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go @@ -3,6 +3,7 @@ package nodes import ( "testing" + "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/structpb" @@ -31,7 +32,7 @@ func NewStructFromMap(m map[string]interface{}) *structpb.Struct { return st } -func TestResolveAttrPathIn(t *testing.T) { +func TestResolveAttrPathInStruct(t *testing.T) { args := []struct { literal *core.Literal @@ -51,7 +52,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, @@ -73,7 +74,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 1, }, @@ -94,7 +95,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, @@ -119,12 +120,12 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, }, - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 1, }, @@ -149,7 +150,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, @@ -159,7 +160,7 @@ func TestResolveAttrPathIn(t *testing.T) { Value: &core.Literal_Collection{ Collection: &core.LiteralCollection{ Literals: []*core.Literal{ - &core.Literal{ + { Value: &core.Literal_Collection{ Collection: &core.LiteralCollection{ Literals: []*core.Literal{ @@ -181,11 +182,11 @@ func TestResolveAttrPathIn(t *testing.T) { Value: &core.Literal_Map{ Map: &core.LiteralMap{ Literals: map[string]*core.Literal{ - "foo": &core.Literal{ + "foo": { Value: &core.Literal_Collection{ Collection: &core.LiteralCollection{ Literals: []*core.Literal{ - &core.Literal{ + { Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Generic{ @@ -203,17 +204,17 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, }, - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 0, }, }, - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "bar", }, @@ -222,6 +223,53 @@ func TestResolveAttrPathIn(t *testing.T) { expected: NewScalarLiteral("car"), hasError: false, }, + // - nested map {"foo": {"bar": {"baz": 42}}} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap( + map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": map[string]interface{}{ + "baz": 42, + }, + }, + }, + ), + }, + }, + }, + }, + // Test accessing the entire nested map at foo.bar + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "bar", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap( + map[string]interface{}{ + "baz": 42, + }, + ), + }, + }, + }, + }, + hasError: false, + }, // - exception key error with map { literal: &core.Literal{ @@ -234,7 +282,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "random", }, @@ -256,7 +304,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 2, }, @@ -277,7 +325,7 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "random", }, @@ -302,12 +350,12 @@ func TestResolveAttrPathIn(t *testing.T) { }, }, path: []*core.PromiseAttribute{ - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, }, - &core.PromiseAttribute{ + { Value: &core.PromiseAttribute_IntValue{ IntValue: 100, }, @@ -328,3 +376,476 @@ func TestResolveAttrPathIn(t *testing.T) { } } } + +func TestResolveAttrPathInBinary(t *testing.T) { + // Helper function to convert a map to msgpack bytes and then to BinaryIDL + toMsgpackBytes := func(m interface{}) []byte { + msgpackBytes, _ := msgpack.Marshal(m) + return msgpackBytes + } + + args := []struct { + literal *core.Literal + path []*core.PromiseAttribute + expected *core.Literal + hasError bool + }{ + // - nested map {"foo": {"bar": 42, "baz": {"qux": 3.14, "quux": "str"}}} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": int64(42), + "baz": map[string]interface{}{ + "qux": 3.14, + "quux": "str", + }, + }, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing the int value at foo.bar + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "bar", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(int64(42)), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - nested map {"foo": {"bar": 42, "baz": {"qux": 3.14, "quux": "str"}}} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": int64(42), + "baz": map[string]interface{}{ + "qux": 3.14, + "quux": "str", + }, + }, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing the float value at foo.baz.qux + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "baz", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "qux", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(3.14), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - nested map {"foo": {"bar": 42, "baz": {"qux": 3.14, "quux": "str"}}} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": int64(42), + "baz": map[string]interface{}{ + "qux": 3.14, + "quux": "str", + }, + }, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing the string value at foo.baz.quux + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "baz", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "quux", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes("str"), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - nested list {"foo": [42, 3.14, "str"]} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": []interface{}{int64(-42), 3.14, "str"}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing the int value at foo[0] + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 0, + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(int64(-42)), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - nested list {"foo": [42, 3.14, "str"]} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": []interface{}{int64(42), 3.14, "str"}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing the float value at foo[1] + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 1, + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(3.14), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - nested list {"foo": [42, 3.14, "str"]} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": []interface{}{int64(42), 3.14, "str"}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing the string value at foo[2] + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 2, + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes("str"), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - test extracting a nested map as a Binary object {"foo": {"bar": {"baz": 42}}} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": map[string]interface{}{ + "baz": int64(42), + }, + }, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing the entire nested map at foo.bar + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "bar", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "baz": int64(42), + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - exception case with non-existing key in nested map + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": map[string]interface{}{ + "bar": int64(42), + "baz": map[string]interface{}{ + "qux": 3.14, + "quux": "str", + }, + }, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing a non-existing key in the nested map + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "baz", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "unknown", + }, + }, + }, + expected: &core.Literal{}, + hasError: true, + }, + // - exception case with out-of-range index in list + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": []interface{}{int64(42), 3.14, "str"}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + // Test accessing an out-of-range index in the list + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 10, + }, + }, + }, + expected: &core.Literal{}, + hasError: true, + }, + // - nested list struct {"foo": [["bar1", "bar2"]]} + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[string]interface{}{ + "foo": []interface{}{[]interface{}{"bar1", "bar2"}}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]interface{}{[]interface{}{"bar1", "bar2"}}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + } + + for i, arg := range args { + resolved, err := resolveAttrPathInPromise("", arg.literal, arg.path) + if arg.hasError { + assert.Error(t, err, i) + assert.ErrorContains(t, err, errors.PromiseAttributeResolveError, i) + } else { + assert.Equal(t, arg.expected, resolved, i) + } + } +} diff --git a/go.mod b/go.mod index 8c8053def6..2799c63d4d 100644 --- a/go.mod +++ b/go.mod @@ -166,6 +166,7 @@ require ( github.com/robfig/cron/v3 v3.0.0 // indirect github.com/sendgrid/rest v2.6.9+incompatible // indirect github.com/sendgrid/sendgrid-go v3.10.0+incompatible // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.9.2 // indirect github.com/spf13/cast v1.4.1 // indirect From c9616ca58fa6c3eebe0adaed14bc58cdc8375a34 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 20 Sep 2024 16:36:49 +0800 Subject: [PATCH 2/6] cover all cases in test Signed-off-by: Future-Outlier --- .../nodes/attr_path_resolver_test.go | 1009 ++++++++++++++--- 1 file changed, 845 insertions(+), 164 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go index e0486aae5e..2695712925 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go @@ -1,6 +1,8 @@ package nodes import ( + "fmt" + "reflect" "testing" "github.com/shamaton/msgpack/v2" @@ -11,6 +13,49 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors" ) +// FlyteFile and FlyteDirectory represented as map[interface{}]interface{} +type FlyteFile map[interface{}]interface{} +type FlyteDirectory map[interface{}]interface{} + +// InnerDC struct (equivalent to InnerDC dataclass in Python) +type InnerDC struct { + A int `json:"a"` + B float64 `json:"b"` + C string `json:"c"` + D bool `json:"d"` + E []int `json:"e"` + F []FlyteFile `json:"f"` + G [][]int `json:"g"` + H []map[int]bool `json:"h"` + I map[int]bool `json:"i"` + J map[int]FlyteFile `json:"j"` + K map[int][]int `json:"k"` + L map[int]map[int]int `json:"l"` + M map[string]string `json:"m"` + N FlyteFile `json:"n"` + O FlyteDirectory `json:"o"` +} + +// DC struct (equivalent to DC dataclass in Python) +type DC struct { + A int `json:"a"` + B float64 `json:"b"` + C string `json:"c"` + D bool `json:"d"` + E []int `json:"e"` + F []FlyteFile `json:"f"` + G [][]int `json:"g"` + H []map[int]bool `json:"h"` + I map[int]bool `json:"i"` + J map[int]FlyteFile `json:"j"` + K map[int][]int `json:"k"` + L map[int]map[int]int `json:"l"` + M map[string]string `json:"m"` + N FlyteFile `json:"n"` + O FlyteDirectory `json:"o"` + Inner InnerDC `json:"inner_dc"` +} + func NewScalarLiteral(value string) *core.Literal { return &core.Literal{ Value: &core.Literal_Scalar{ @@ -377,51 +422,141 @@ func TestResolveAttrPathInStruct(t *testing.T) { } } +func createNestedDC() DC { + flyteFile := FlyteFile{ + "path": "s3://my-s3-bucket/example.txt", + } + + flyteDirectory := FlyteDirectory{ + "path": "s3://my-s3-bucket/s3_flyte_dir", + } + + // Example of initializing InnerDC + innerDC := InnerDC{ + A: -1, + B: -2.1, + C: "Hello, Flyte", + D: false, + E: []int{0, 1, 2, -1, -2}, + F: []FlyteFile{flyteFile}, + G: [][]int{{0}, {1}, {-1}}, + H: []map[int]bool{{0: false}, {1: true}, {-1: true}}, + I: map[int]bool{0: false, 1: true, -1: false}, + J: map[int]FlyteFile{ + 0: flyteFile, + 1: flyteFile, + -1: flyteFile, + }, + K: map[int][]int{ + 0: {0, 1, -1}, + }, + L: map[int]map[int]int{ + 1: {-1: 0}, + }, + M: map[string]string{ + "key": "value", + }, + N: flyteFile, + O: flyteDirectory, + } + + // Initializing DC + dc := DC{ + A: 1, + B: 2.1, + C: "Hello, Flyte", + D: false, + E: []int{0, 1, 2, -1, -2}, + F: []FlyteFile{flyteFile}, + G: [][]int{{0}, {1}, {-1}}, + H: []map[int]bool{{0: false}, {1: true}, {-1: true}}, + I: map[int]bool{0: false, 1: true, -1: false}, + J: map[int]FlyteFile{ + 0: flyteFile, + 1: flyteFile, + -1: flyteFile, + }, + K: map[int][]int{ + 0: {0, 1, -1}, + }, + L: map[int]map[int]int{ + 1: {-1: 0}, + }, + M: map[string]string{ + "key": "value", + }, + N: flyteFile, + O: flyteDirectory, + Inner: innerDC, + } + return dc +} + func TestResolveAttrPathInBinary(t *testing.T) { // Helper function to convert a map to msgpack bytes and then to BinaryIDL toMsgpackBytes := func(m interface{}) []byte { - msgpackBytes, _ := msgpack.Marshal(m) + msgpackBytes, err := msgpack.Marshal(m) + assert.NoError(t, err) return msgpackBytes } + flyteFile := FlyteFile{ + "path": "s3://my-s3-bucket/example.txt", + } + + flyteDirectory := FlyteDirectory{ + "path": "s3://my-s3-bucket/s3_flyte_dir", + } + + nestedDC := createNestedDC() + literalNestedDC := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(nestedDC), + Tag: "msgpack", + }, + }, + }, + }, + } + args := []struct { literal *core.Literal path []*core.PromiseAttribute expected *core.Literal hasError bool }{ - // - nested map {"foo": {"bar": 42, "baz": {"qux": 3.14, "quux": "str"}}} { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "A", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": int64(42), - "baz": map[string]interface{}{ - "qux": 3.14, - "quux": "str", - }, - }, - }), - Tag: "msgpack", + Value: toMsgpackBytes(1), + Tag: "msgpack", }, }, }, }, }, - // Test accessing the int value at foo.bar + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", - }, - }, - { - Value: &core.PromiseAttribute_StringValue{ - StringValue: "bar", + StringValue: "B", }, }, }, @@ -430,7 +565,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(int64(42)), + Value: toMsgpackBytes(2.1), Tag: "msgpack", }, }, @@ -439,43 +574,35 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, hasError: false, }, - // - nested map {"foo": {"bar": 42, "baz": {"qux": 3.14, "quux": "str"}}} { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "C", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": int64(42), - "baz": map[string]interface{}{ - "qux": 3.14, - "quux": "str", - }, - }, - }), - Tag: "msgpack", + Value: toMsgpackBytes("Hello, Flyte"), + Tag: "msgpack", }, }, }, }, }, - // Test accessing the float value at foo.baz.qux + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", - }, - }, - { - Value: &core.PromiseAttribute_StringValue{ - StringValue: "baz", - }, - }, - { - Value: &core.PromiseAttribute_StringValue{ - StringValue: "qux", + StringValue: "D", }, }, }, @@ -484,7 +611,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(3.14), + Value: toMsgpackBytes(false), Tag: "msgpack", }, }, @@ -493,43 +620,58 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, hasError: false, }, - // - nested map {"foo": {"bar": 42, "baz": {"qux": 3.14, "quux": "str"}}} { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "E", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": int64(42), - "baz": map[string]interface{}{ - "qux": 3.14, - "quux": "str", - }, - }, - }), - Tag: "msgpack", + Value: toMsgpackBytes([]int{0, 1, 2, -1, -2}), + Tag: "msgpack", }, }, }, }, }, - // Test accessing the string value at foo.baz.quux + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", + StringValue: "F", }, }, - { - Value: &core.PromiseAttribute_StringValue{ - StringValue: "baz", + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]FlyteFile{{"path": "s3://my-s3-bucket/example.txt"}}), + Tag: "msgpack", + }, + }, }, }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "quux", + StringValue: "G", }, }, }, @@ -538,7 +680,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes("str"), + Value: toMsgpackBytes([][]int{{0}, {1}, {-1}}), Tag: "msgpack", }, }, @@ -547,32 +689,35 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, hasError: false, }, - // - nested list {"foo": [42, 3.14, "str"]} { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "H", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": []interface{}{int64(-42), 3.14, "str"}, - }), - Tag: "msgpack", + Value: toMsgpackBytes([]map[int]bool{{0: false}, {1: true}, {-1: true}}), + Tag: "msgpack", }, }, }, }, }, - // Test accessing the int value at foo[0] + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", - }, - }, - { - Value: &core.PromiseAttribute_IntValue{ - IntValue: 0, + StringValue: "I", }, }, }, @@ -581,7 +726,7 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(int64(-42)), + Value: toMsgpackBytes(map[int]bool{0: false, 1: true, -1: false}), Tag: "msgpack", }, }, @@ -590,32 +735,40 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, hasError: false, }, - // - nested list {"foo": [42, 3.14, "str"]} { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "J", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": []interface{}{int64(42), 3.14, "str"}, - }), + Value: toMsgpackBytes( + map[int]FlyteFile{ + 0: flyteFile, + 1: flyteFile, + -1: flyteFile, + }), Tag: "msgpack", }, }, }, }, }, - // Test accessing the float value at foo[1] + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", - }, - }, - { - Value: &core.PromiseAttribute_IntValue{ - IntValue: 1, + StringValue: "K", }, }, }, @@ -624,8 +777,11 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(3.14), - Tag: "msgpack", + Value: toMsgpackBytes( + map[int][]int{ + 0: {0, 1, -1}, + }), + Tag: "msgpack", }, }, }, @@ -633,32 +789,38 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, hasError: false, }, - // - nested list {"foo": [42, 3.14, "str"]} { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "L", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": []interface{}{int64(42), 3.14, "str"}, - }), + Value: toMsgpackBytes( + map[int]map[int]int{ + 1: {-1: 0}, + }), Tag: "msgpack", }, }, }, }, }, - // Test accessing the string value at foo[2] + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", - }, - }, - { - Value: &core.PromiseAttribute_IntValue{ - IntValue: 2, + StringValue: "M", }, }, }, @@ -667,8 +829,11 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes("str"), - Tag: "msgpack", + Value: toMsgpackBytes( + map[string]string{ + "key": "value", + }), + Tag: "msgpack", }, }, }, @@ -676,36 +841,35 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, hasError: false, }, - // - test extracting a nested map as a Binary object {"foo": {"bar": {"baz": 42}}} { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "N", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": map[string]interface{}{ - "baz": int64(42), - }, - }, - }), - Tag: "msgpack", + Value: toMsgpackBytes(flyteFile), + Tag: "msgpack", }, }, }, }, }, - // Test accessing the entire nested map at foo.bar + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", - }, - }, - { - Value: &core.PromiseAttribute_StringValue{ - StringValue: "bar", + StringValue: "O", }, }, }, @@ -714,10 +878,8 @@ func TestResolveAttrPathInBinary(t *testing.T) { Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "baz": int64(42), - }), - Tag: "msgpack", + Value: toMsgpackBytes(flyteDirectory), + Tag: "msgpack", }, }, }, @@ -725,58 +887,375 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, hasError: false, }, - // - exception case with non-existing key in nested map { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": int64(42), - "baz": map[string]interface{}{ - "qux": 3.14, - "quux": "str", - }, - }, - }), - Tag: "msgpack", + Value: toMsgpackBytes(nestedDC.Inner), + Tag: "msgpack", }, }, }, }, }, - // Test accessing a non-existing key in the nested map + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", + StringValue: "Inner", }, }, { Value: &core.PromiseAttribute_StringValue{ - StringValue: "baz", + StringValue: "A", }, }, - { + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(-1), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { Value: &core.PromiseAttribute_StringValue{ - StringValue: "unknown", + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "B", }, }, }, - expected: &core.Literal{}, - hasError: true, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(-2.1), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, }, - // - exception case with out-of-range index in list { - literal: &core.Literal{ + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "C", + }, + }, + }, + expected: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes(map[string]interface{}{ - "foo": []interface{}{int64(42), 3.14, "str"}, + Value: toMsgpackBytes("Hello, Flyte"), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "D", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(false), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "E", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]int{0, 1, 2, -1, -2}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "F", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]FlyteFile{flyteFile}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "G", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([][]int{{0}, {1}, {-1}}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "G", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 0, + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]int{0}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "G", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 2, + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 0, + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(-1), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "H", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes([]map[int]bool{{0: false}, {1: true}, {-1: true}}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "I", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[int]bool{0: false, 1: true, -1: false}), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "J", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(map[int]FlyteFile{ + 0: flyteFile, + 1: flyteFile, + -1: flyteFile, }), Tag: "msgpack", }, @@ -784,23 +1263,158 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, - // Test accessing an out-of-range index in the list + hasError: false, + }, + { + literal: literalNestedDC, path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ - StringValue: "foo", + StringValue: "Inner", }, }, { - Value: &core.PromiseAttribute_IntValue{ - IntValue: 10, + Value: &core.PromiseAttribute_StringValue{ + StringValue: "K", }, }, }, - expected: &core.Literal{}, - hasError: true, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[int][]int{ + 0: {0, 1, -1}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, }, - // - nested list struct {"foo": [["bar1", "bar2"]]} + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "L", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[int]map[int]int{ + 1: {-1: 0}, + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "M", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes( + map[string]string{ + "key": "value", + }), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "N", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(flyteFile), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: literalNestedDC, + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "Inner", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "O", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: toMsgpackBytes(flyteDirectory), + Tag: "msgpack", + }, + }, + }, + }, + }, + hasError: false, + }, + // - exception case with non-existing key in nested map { literal: &core.Literal{ Value: &core.Literal_Scalar{ @@ -808,7 +1422,13 @@ func TestResolveAttrPathInBinary(t *testing.T) { Value: &core.Scalar_Binary{ Binary: &core.Binary{ Value: toMsgpackBytes(map[string]interface{}{ - "foo": []interface{}{[]interface{}{"bar1", "bar2"}}, + "foo": map[string]interface{}{ + "bar": int64(42), + "baz": map[string]interface{}{ + "qux": 3.14, + "quux": "str", + }, + }, }), Tag: "msgpack", }, @@ -816,26 +1436,58 @@ func TestResolveAttrPathInBinary(t *testing.T) { }, }, }, + // Test accessing a non-existing key in the nested map path: []*core.PromiseAttribute{ { Value: &core.PromiseAttribute_StringValue{ StringValue: "foo", }, }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "baz", + }, + }, + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "unknown", + }, + }, }, - expected: &core.Literal{ + expected: &core.Literal{}, + hasError: true, + }, + // - exception case with out-of-range index in list + { + literal: &core.Literal{ Value: &core.Literal_Scalar{ Scalar: &core.Scalar{ Value: &core.Scalar_Binary{ Binary: &core.Binary{ - Value: toMsgpackBytes([]interface{}{[]interface{}{"bar1", "bar2"}}), - Tag: "msgpack", + Value: toMsgpackBytes(map[string]interface{}{ + "foo": []interface{}{int64(42), 3.14, "str"}, + }), + Tag: "msgpack", }, }, }, }, }, - hasError: false, + // Test accessing an out-of-range index in the list + path: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + { + Value: &core.PromiseAttribute_IntValue{ + IntValue: 10, + }, + }, + }, + expected: &core.Literal{}, + hasError: true, }, } @@ -845,7 +1497,36 @@ func TestResolveAttrPathInBinary(t *testing.T) { assert.Error(t, err, i) assert.ErrorContains(t, err, errors.PromiseAttributeResolveError, i) } else { - assert.Equal(t, arg.expected, resolved, i) + var expectedValue, actualValue interface{} + + // Helper function to unmarshal a Binary Literal into an interface{} + unmarshalBinaryLiteral := func(literal *core.Literal) (interface{}, error) { + if scalar, ok := literal.Value.(*core.Literal_Scalar); ok { + if binary, ok := scalar.Scalar.Value.(*core.Scalar_Binary); ok { + var value interface{} + err := msgpack.Unmarshal(binary.Binary.Value, &value) + return value, err + } + } + return nil, fmt.Errorf("literal is not a Binary Scalar") + } + + // Unmarshal the expected value + expectedValue, err := unmarshalBinaryLiteral(arg.expected) + if err != nil { + t.Fatalf("Failed to unmarshal expected value in test case %d: %v", i, err) + } + + // Unmarshal the resolved value + actualValue, err = unmarshalBinaryLiteral(resolved) + if err != nil { + t.Fatalf("Failed to unmarshal resolved value in test case %d: %v", i, err) + } + + // Deeply compare the expected and actual values, ignoring map ordering + if !reflect.DeepEqual(expectedValue, actualValue) { + t.Fatalf("Test case %d: Expected %+v, but got %+v", i, expectedValue, actualValue) + } } } } From 6e6d4496a3e7279350a4a942073da51eed943f99 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 25 Sep 2024 09:38:08 +0800 Subject: [PATCH 3/6] update pingsu's nit advice Signed-off-by: Future-Outlier --- flytepropeller/pkg/controller/nodes/attr_path_resolver.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go index 0bcd610801..f46db9d7bd 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go @@ -8,6 +8,8 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors" ) +const messagepack = "msgpack" + // resolveAttrPathInPromise resolves the literal with attribute path // If the promise is chained with attributes (e.g. promise.a["b"][0]), then we need to resolve the promise func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) { @@ -41,13 +43,11 @@ func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath // resolve dataclass and Pydantic BaseModel if scalar := currVal.GetScalar(); scalar != nil { if binary := scalar.GetBinary(); binary != nil { - // Start from index "count" currVal, err = resolveAttrPathInBinary(nodeID, binary, bindAttrPath[count:]) if err != nil { return nil, err } } else if generic := scalar.GetGeneric(); generic != nil { - // Start from index "count" currVal, err = resolveAttrPathInPbStruct(nodeID, generic, bindAttrPath[count:]) if err != nil { return nil, err @@ -104,7 +104,7 @@ func resolveAttrPathInBinary(nodeID string, binaryIDL *core.Binary, bindAttrPath var tmpVal interface{} var exist bool - if serializationFormat == "msgpack" { + if serializationFormat == messagepack { err := msgpack.Unmarshal(binaryBytes, &currVal) if err != nil { return nil, err @@ -135,7 +135,7 @@ func resolveAttrPathInBinary(nodeID string, binaryIDL *core.Binary, bindAttrPath } } - if serializationFormat == "msgpack" { + if serializationFormat == messagepack { // Marshal the current value to MessagePack bytes resolvedBinaryBytes, err := msgpack.Marshal(currVal) if err != nil { From 267852b7ef213f6d6d21f04dc04b10343eff9a2a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 25 Sep 2024 00:00:57 -0700 Subject: [PATCH 4/6] nit Signed-off-by: Kevin Su --- .../pkg/controller/nodes/attr_path_resolver.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go index f46db9d7bd..70f3fd2eff 100644 --- a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go @@ -17,7 +17,7 @@ func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath var tmpVal *core.Literal var err error var exist bool - count := 0 + index := 0 for _, attr := range bindAttrPath { switch currVal.GetValue().(type) { @@ -27,13 +27,13 @@ func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal.GetMap().GetLiterals()) } currVal = tmpVal - count++ + index++ case *core.Literal_Collection: if int(attr.GetIntValue()) >= len(currVal.GetCollection().GetLiterals()) { return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal.GetCollection().GetLiterals()) } currVal = currVal.GetCollection().GetLiterals()[attr.GetIntValue()] - count++ + index++ // scalar is always the leaf, so we can break here case *core.Literal_Scalar: break @@ -43,12 +43,12 @@ func resolveAttrPathInPromise(nodeID string, literal *core.Literal, bindAttrPath // resolve dataclass and Pydantic BaseModel if scalar := currVal.GetScalar(); scalar != nil { if binary := scalar.GetBinary(); binary != nil { - currVal, err = resolveAttrPathInBinary(nodeID, binary, bindAttrPath[count:]) + currVal, err = resolveAttrPathInBinary(nodeID, binary, bindAttrPath[index:]) if err != nil { return nil, err } } else if generic := scalar.GetGeneric(); generic != nil { - currVal, err = resolveAttrPathInPbStruct(nodeID, generic, bindAttrPath[count:]) + currVal, err = resolveAttrPathInPbStruct(nodeID, generic, bindAttrPath[index:]) if err != nil { return nil, err } From 626f9c307554eb2b796856f1f19858129b76e3dc Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 26 Sep 2024 14:39:25 +0800 Subject: [PATCH 5/6] Support flytectl Signed-off-by: Future-Outlier --- flyteadmin/go.mod | 1 + flytecopilot/go.mod | 1 + flytecopilot/go.sum | 2 + flytectl/go.mod | 1 + flytectl/go.sum | 2 + .../clients/go/coreutils/extract_literal.go | 2 + .../go/coreutils/extract_literal_test.go | 30 +------- flyteidl/clients/go/coreutils/literals.go | 34 ++++++++- .../clients/go/coreutils/literals_test.go | 73 +++++++++++++++++++ flyteidl/go.mod | 1 + flyteidl/go.sum | 2 + flyteplugins/go.mod | 1 + flyteplugins/go.sum | 2 + 13 files changed, 119 insertions(+), 33 deletions(-) diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index cfc2bfa010..ba98c3a62a 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -167,6 +167,7 @@ require ( github.com/prometheus/procfs v0.10.1 // indirect github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect github.com/sendgrid/rest v2.6.9+incompatible // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/flytecopilot/go.mod b/flytecopilot/go.mod index e1dbdc7683..166411654f 100644 --- a/flytecopilot/go.mod +++ b/flytecopilot/go.mod @@ -83,6 +83,7 @@ require ( github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/flytecopilot/go.sum b/flytecopilot/go.sum index 9fb93ec715..0e3773721b 100644 --- a/flytecopilot/go.sum +++ b/flytecopilot/go.sum @@ -311,6 +311,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= diff --git a/flytectl/go.mod b/flytectl/go.mod index d783ac0513..8829eb881f 100644 --- a/flytectl/go.mod +++ b/flytectl/go.mod @@ -142,6 +142,7 @@ require ( github.com/prometheus/procfs v0.10.1 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/spf13/afero v1.9.2 // indirect github.com/spf13/cast v1.4.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect diff --git a/flytectl/go.sum b/flytectl/go.sum index 1e3b5d7ef8..cb4054c995 100644 --- a/flytectl/go.sum +++ b/flytectl/go.sum @@ -420,6 +420,8 @@ github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/flyteidl/clients/go/coreutils/extract_literal.go b/flyteidl/clients/go/coreutils/extract_literal.go index 5801296dc3..64bd6efdd0 100644 --- a/flyteidl/clients/go/coreutils/extract_literal.go +++ b/flyteidl/clients/go/coreutils/extract_literal.go @@ -54,6 +54,8 @@ func ExtractFromLiteral(literal *core.Literal) (interface{}, error) { default: return nil, fmt.Errorf("unsupported literal scalar primitive type %T", scalarValue) } + case *core.Scalar_Binary: + return scalarValue.Binary, nil case *core.Scalar_Blob: return scalarValue.Blob.Uri, nil case *core.Scalar_Schema: diff --git a/flyteidl/clients/go/coreutils/extract_literal_test.go b/flyteidl/clients/go/coreutils/extract_literal_test.go index 2ce8747fd5..3542c3f225 100644 --- a/flyteidl/clients/go/coreutils/extract_literal_test.go +++ b/flyteidl/clients/go/coreutils/extract_literal_test.go @@ -113,7 +113,7 @@ func TestFetchLiteral(t *testing.T) { s := MakeBinaryLiteral([]byte{'h'}) assert.Equal(t, []byte{'h'}, s.GetScalar().GetBinary().GetValue()) _, err := ExtractFromLiteral(s) - assert.NotNil(t, err) + assert.Nil(t, err) }) t.Run("NoneType", func(t *testing.T) { @@ -124,34 +124,6 @@ func TestFetchLiteral(t *testing.T) { assert.Nil(t, err) }) - t.Run("Generic", func(t *testing.T) { - literalVal := map[string]interface{}{ - "x": 1, - "y": "ystringvalue", - } - var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} - lit, err := MakeLiteralForType(literalType, literalVal) - assert.NoError(t, err) - extractedLiteralVal, err := ExtractFromLiteral(lit) - assert.NoError(t, err) - fieldsMap := map[string]*structpb.Value{ - "x": { - Kind: &structpb.Value_NumberValue{NumberValue: 1}, - }, - "y": { - Kind: &structpb.Value_StringValue{StringValue: "ystringvalue"}, - }, - } - expectedStructVal := &structpb.Struct{ - Fields: fieldsMap, - } - extractedStructValue := extractedLiteralVal.(*structpb.Struct) - assert.Equal(t, len(expectedStructVal.Fields), len(extractedStructValue.Fields)) - for key, val := range expectedStructVal.Fields { - assert.Equal(t, val.Kind, extractedStructValue.Fields[key].Kind) - } - }) - t.Run("Generic Passed As String", func(t *testing.T) { literalVal := "{\"x\": 1,\"y\": \"ystringvalue\"}" var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} diff --git a/flyteidl/clients/go/coreutils/literals.go b/flyteidl/clients/go/coreutils/literals.go index 3527ac246b..403c5e1955 100644 --- a/flyteidl/clients/go/coreutils/literals.go +++ b/flyteidl/clients/go/coreutils/literals.go @@ -2,7 +2,6 @@ package coreutils import ( - "encoding/json" "fmt" "math" "reflect" @@ -14,11 +13,14 @@ import ( "github.com/golang/protobuf/ptypes" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "github.com/shamaton/msgpack/v2" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytestdlib/storage" ) +const messagepack = "msgpack" + func MakePrimitive(v interface{}) (*core.Primitive, error) { switch p := v.(type) { case int: @@ -144,6 +146,7 @@ func MakeBinaryLiteral(v []byte) *core.Literal { Value: &core.Scalar_Binary{ Binary: &core.Binary{ Value: v, + Tag: messagepack, }, }, }, @@ -389,7 +392,7 @@ func MakeLiteralForSimpleType(t core.SimpleType, s string) (*core.Literal, error scalar.Value = &core.Scalar_Binary{ Binary: &core.Binary{ Value: []byte(s), - // TODO Tag not supported at the moment + Tag: messagepack, }, } case core.SimpleType_ERROR: @@ -559,12 +562,35 @@ func MakeLiteralForType(t *core.LiteralType, v interface{}) (*core.Literal, erro strValue = fmt.Sprintf("%.0f", math.Trunc(f)) } if newT.Simple == core.SimpleType_STRUCT { + // If the type is a STRUCT, we expect the input to be a complex object + // like the following example: + // inputs: + // dc: + // a: 1 + // b: 3.14 + // c: "example_string" + // Instead of storing it directly as a structured value, we will serialize + // the input object using MsgPack and return it as a binary IDL object. + + // If the value is not already a string (meaning it's not already serialized), + // proceed with serialization. if _, isValueStringType := v.(string); !isValueStringType { - byteValue, err := json.Marshal(v) + byteValue, err := msgpack.Marshal(v) if err != nil { return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v) } - strValue = string(byteValue) + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: byteValue, + Tag: messagepack, + }, + }, + }, + }, + }, nil } } lv, err := MakeLiteralForSimpleType(newT.Simple, strValue) diff --git a/flyteidl/clients/go/coreutils/literals_test.go b/flyteidl/clients/go/coreutils/literals_test.go index 24a0af4865..35a0f8a45a 100644 --- a/flyteidl/clients/go/coreutils/literals_test.go +++ b/flyteidl/clients/go/coreutils/literals_test.go @@ -14,6 +14,7 @@ import ( "github.com/golang/protobuf/ptypes" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -242,6 +243,16 @@ func TestMakeDefaultLiteralForType(t *testing.T) { assert.NotNil(t, l.GetScalar().GetError()) }) + t.Run("binary", func(t *testing.T) { + l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BINARY, + }}) + assert.NoError(t, err) + assert.NotNil(t, l.GetScalar().GetBinary()) + assert.NotNil(t, l.GetScalar().GetBinary().GetValue()) + assert.NotNil(t, l.GetScalar().GetBinary().GetTag()) + }) + t.Run("struct", func(t *testing.T) { l, err := MakeDefaultLiteralForType(&core.LiteralType{Type: &core.LiteralType_Simple{ Simple: core.SimpleType_STRUCT, @@ -444,6 +455,68 @@ func TestMakeLiteralForType(t *testing.T) { assert.Equal(t, expectedVal, actualVal) }) + t.Run("SimpleBinary", func(t *testing.T) { + // We compare the deserialized values instead of the raw msgpack bytes because Go does not guarantee the order + // of map keys during serialization. This means that while the serialized bytes may differ, the deserialized + // values should be logically equivalent. + + var literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} + v := map[string]interface{}{ + "a": int64(1), + "b": 3.14, + "c": "example_string", + "d": map[string]interface{}{ + "1": int64(100), + "2": int64(200), + }, + "e": map[string]interface{}{ + "a": int64(1), + "b": 3.14, + }, + "f": []string{"a", "b", "c"}, + } + + val, err := MakeLiteralForType(literalType, v) + assert.NoError(t, err) + + msgpackBytes, err := msgpack.Marshal(v) + assert.NoError(t, err) + + literalVal := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: msgpackBytes, + Tag: messagepack, + }, + }, + }, + }, + } + + expectedLiteralVal, err := ExtractFromLiteral(literalVal) + assert.NoError(t, err) + actualLiteralVal, err := ExtractFromLiteral(val) + assert.NoError(t, err) + + // Check if the extracted value is of type *core.Binary (not []byte) + expectedBinary, ok := expectedLiteralVal.(*core.Binary) + assert.True(t, ok, "expectedLiteralVal is not of type *core.Binary") + actualBinary, ok := actualLiteralVal.(*core.Binary) + assert.True(t, ok, "actualLiteralVal is not of type *core.Binary") + + // Now check if the Binary values match + var expectedVal, actualVal map[string]interface{} + err = msgpack.Unmarshal(expectedBinary.Value, &expectedVal) + assert.NoError(t, err) + err = msgpack.Unmarshal(actualBinary.Value, &actualVal) + assert.NoError(t, err) + + // Finally, assert that the deserialized values are equal + assert.Equal(t, expectedVal, actualVal) + }) + t.Run("ArrayStrings", func(t *testing.T) { var literalType = &core.LiteralType{Type: &core.LiteralType_CollectionType{ CollectionType: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}}} diff --git a/flyteidl/go.mod b/flyteidl/go.mod index 55ec124554..037cac70cd 100644 --- a/flyteidl/go.mod +++ b/flyteidl/go.mod @@ -13,6 +13,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/pkg/errors v0.9.1 + github.com/shamaton/msgpack/v2 v2.2.2 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 golang.org/x/net v0.27.0 diff --git a/flyteidl/go.sum b/flyteidl/go.sum index 5d5cb7e9a2..e1e7d9782d 100644 --- a/flyteidl/go.sum +++ b/flyteidl/go.sum @@ -217,6 +217,8 @@ github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPH github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index c4287581bc..1d75c44ac3 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -108,6 +108,7 @@ require ( github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/flyteplugins/go.sum b/flyteplugins/go.sum index fa26e3cfda..3721c28a7a 100644 --- a/flyteplugins/go.sum +++ b/flyteplugins/go.sum @@ -342,6 +342,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= From f68f75cca2d68bf819f060fa422774963b6214bf Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 1 Oct 2024 00:04:10 +0800 Subject: [PATCH 6/6] update Yee and Ketan's advice Signed-off-by: Future-Outlier --- flytepropeller/pkg/compiler/validators/utils.go | 4 +++- flytepropeller/pkg/compiler/validators/utils_test.go | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index fb4ba04548..fa22ffe84f 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -11,6 +11,8 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" ) +const messagepack = "msgpack" + func containsBindingByVariableName(bindings []*core.Binding, name string) (found bool) { for _, b := range bindings { if b.Var == name { @@ -47,7 +49,7 @@ func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { // If the binary has a tag, treat it as a structured type (e.g., dict, dataclass, Pydantic BaseModel). // Otherwise, treat it as raw binary data. // Reference: https://github.com/flyteorg/flyte/blob/master/rfc/system/5741-binary-idl-with-message-pack.md - if len(v.Binary.Tag) > 0 { + if v.Binary.Tag == messagepack { literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} } else { literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index 26e34988c3..5f4ca2ad06 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -55,7 +55,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { Value: &core.Scalar_Binary{ Binary: &core.Binary{ Value: serializedBinaryData, - Tag: "msgpack", + Tag: messagepack, }, }, }, @@ -83,7 +83,7 @@ func TestLiteralTypeForLiterals(t *testing.T) { Value: &core.Scalar_Binary{ Binary: &core.Binary{ Value: serializedBinaryData, - Tag: "msgpack", + Tag: messagepack, }, }, },