Skip to content

Commit

Permalink
cast null_value
Browse files Browse the repository at this point in the history
  • Loading branch information
utahta committed Sep 6, 2024
1 parent 153ea3f commit 63c6611
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 0 deletions.
153 changes: 153 additions & 0 deletions grpc/federation/cel.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ import (
"sync"

"github.com/google/cel-go/cel"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
celtypes "github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/parser"
"golang.org/x/sync/singleflight"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/genproto/googleapis/rpc/errdetails"
grpccodes "google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
Expand Down Expand Up @@ -731,6 +735,25 @@ func createCELAst(req *EvalCELRequest, env *cel.Env) (*cel.Ast, error) {
if iss.Err() != nil {
return nil, iss.Err()
}

checkedExpr, err := cel.AstToCheckedExpr(ast)
if err != nil {
return nil, err
}
if newNullValueFuncReplacer().Replace(checkedExpr) {
ca, err := celast.ToAST(checkedExpr)
if err != nil {
return nil, err
}
expr, err = parser.Unparse(ca.Expr(), ca.SourceInfo())
if err != nil {
return nil, err
}
ast, iss = env.Compile(expr)
if iss.Err() != nil {
return nil, iss.Err()
}
}
return ast, nil
}

Expand Down Expand Up @@ -832,3 +855,133 @@ func SetCELValue[T any](ctx context.Context, param *SetCELValueParam[T]) error {
}
return nil
}

type nullValueFuncReplacer struct {
checkedExpr *exprpb.CheckedExpr
lastID int64
replaced bool
unsupported bool
}

func newNullValueFuncReplacer() *nullValueFuncReplacer {
return &nullValueFuncReplacer{}
}

func (r *nullValueFuncReplacer) init(checkedExpr *exprpb.CheckedExpr) {
var lastID int64
for k := range checkedExpr.GetReferenceMap() {
if lastID < k {
lastID = k
}
}
for k := range checkedExpr.GetTypeMap() {
if lastID < k {
lastID = k
}
}
r.checkedExpr = checkedExpr
r.lastID = lastID
r.replaced = false
r.unsupported = false
}

func (r *nullValueFuncReplacer) nextID() int64 {
r.lastID++
return r.lastID
}

func (r *nullValueFuncReplacer) Replace(checkedExpr *exprpb.CheckedExpr) bool {
r.init(checkedExpr)
r.replace(checkedExpr.GetExpr())
return r.replaced && !r.unsupported
}

func (r *nullValueFuncReplacer) replace(e *exprpb.Expr) {
if e == nil {
return
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
r.convertSelect(e)
case *exprpb.Expr_CallExpr:
r.convertCall(e)
case *exprpb.Expr_ListExpr:
r.convertList(e)
case *exprpb.Expr_StructExpr:
r.convertStruct(e)
case *exprpb.Expr_ComprehensionExpr:
r.convertComprehension(e)
}
}

func (r *nullValueFuncReplacer) convertSelect(e *exprpb.Expr) {
sel := e.GetSelectExpr()
r.replace(sel.GetOperand())
}

func (r *nullValueFuncReplacer) convertCall(e *exprpb.Expr) {
call := e.GetCallExpr()
fnName := call.GetFunction()
if fnName == operators.Equals || fnName == operators.NotEquals {
lhs := call.GetArgs()[0]
rhs := call.GetArgs()[1]
var target *exprpb.Expr
if _, ok := lhs.GetConstExpr().GetConstantKind().(*exprpb.Constant_NullValue); ok {
target = rhs
}
if _, ok := rhs.GetConstExpr().GetConstantKind().(*exprpb.Constant_NullValue); ok {
if target != nil {
// maybe null == null
return
}
target = lhs
}
if target == nil {
return
}
newID := r.nextID()
target.ExprKind = &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: grpcfedcel.CastNullValueFunc,
Args: []*exprpb.Expr{
{
Id: target.GetId(),
ExprKind: target.GetExprKind(),
},
},
},
}
target.Id = newID
r.checkedExpr.GetReferenceMap()[newID] = &exprpb.Reference{
OverloadId: []string{grpcfedcel.CastNullValueFunc},
}
r.checkedExpr.GetTypeMap()[newID] = &exprpb.Type{
TypeKind: &exprpb.Type_Dyn{},
}
r.replaced = true
} else {
for _, arg := range call.GetArgs() {
r.replace(arg)
}
r.replace(call.GetTarget())
}
}

func (r *nullValueFuncReplacer) convertList(e *exprpb.Expr) {
l := e.GetListExpr()
for _, elem := range l.GetElements() {
r.replace(elem)
}
}

func (r *nullValueFuncReplacer) convertStruct(e *exprpb.Expr) {
msg := e.GetStructExpr()
for _, ent := range msg.GetEntries() {
r.replace(ent.GetValue())
}
}

func (r *nullValueFuncReplacer) convertComprehension(_ *exprpb.Expr) {
// Comprehension is not supported by parser.Unparse.
r.unsupported = true
}
46 changes: 46 additions & 0 deletions grpc/federation/cel/cast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package cel

import (
"reflect"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)

const CastPackageName = "cast"

const (
CastNullValueFunc = "grpc.federation.cast.null_value"
)

type CastLibrary struct{}

func (lib *CastLibrary) LibraryName() string {
return CastPackageName
}

func (lib *CastLibrary) CompileOptions() []cel.EnvOption {
opts := []cel.EnvOption{
cel.OptionalTypes(),
cel.Function(CastNullValueFunc,
cel.Overload("grpc_federation_cast_null_value",
[]*cel.Type{cel.DynType}, cel.DynType,
cel.UnaryBinding(lib.nullValue),
),
),
}
return opts
}

func (lib *CastLibrary) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}

func (lib *CastLibrary) nullValue(val ref.Val) ref.Val {
rv := reflect.ValueOf(val.Value())
if val.Value() == nil || val.Equal(types.NullValue) == types.True || rv.Kind() == reflect.Ptr && rv.IsNil() {
return types.NullValue
}
return val
}
1 change: 1 addition & 0 deletions grpc/federation/cel/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func NewLibrary(typeAdapter types.Adapter) *Library {
new(EnumLibrary),
mdLib,
logLib,
new(CastLibrary),
},
}
}
Expand Down

0 comments on commit 63c6611

Please sign in to comment.