Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into graph-part13
Browse files Browse the repository at this point in the history
  • Loading branch information
miparnisari committed Sep 23, 2024
2 parents 900e50e + 0e33e96 commit 8d0ba01
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 86 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
* @openfga/dx @openfga/contractors-ides
README.md @openfga/product @openfga/community @openfga/dx
pkg/go/graph/* @openfga/backend
38 changes: 36 additions & 2 deletions pkg/go/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,46 @@ import (
var (
ErrBuildingGraph = errors.New("cannot build graph of model")
ErrInvalidModel = errors.New("model is invalid")
ErrQueryingGraph = errors.New("cannot query graph")
)

type DrawingDirection bool

const (
// DrawingDirectionListObjects is when terminal types have outgoing edges and no incoming edges.
DrawingDirectionListObjects DrawingDirection = true
DrawingDirectionCheck DrawingDirection = false
// DrawingDirectionCheck is when terminal types have incoming edges and no outgoing edges.
DrawingDirectionCheck DrawingDirection = false
)

type AuthorizationModelGraph struct {
*multi.DirectedGraph
drawingDirection DrawingDirection
ids NodeLabelsToIDs
}

func (g *AuthorizationModelGraph) GetDrawingDirection() DrawingDirection {
return g.drawingDirection
}

// GetNodeByLabel provides O(1) access to a node.
func (g *AuthorizationModelGraph) GetNodeByLabel(label string) (*AuthorizationModelNode, error) {
id, ok := g.ids[label]
if !ok {
return nil, fmt.Errorf("%w: node with label %s not found", ErrQueryingGraph, label)
}

node := g.Node(id)
if node == nil {
return nil, fmt.Errorf("%w: node with id %d not found", ErrQueryingGraph, id)
}

casted, ok := node.(*AuthorizationModelNode)
if !ok {
return nil, fmt.Errorf("%w: could not cast to AuthorizationModelNode", ErrQueryingGraph)
}

return casted, nil
}

// Reversed returns a full copy of the graph, but with the direction of the arrows flipped.
Expand Down Expand Up @@ -60,9 +88,15 @@ func (g *AuthorizationModelGraph) Reversed() (*AuthorizationModelGraph, error) {
}
}

// Make a brand new copy of the map.
copyIDs := make(NodeLabelsToIDs, len(g.ids))
for k, v := range g.ids {
copyIDs[k] = v
}

multigraph, ok := graphBuilder.DirectedMultigraphBuilder.(*multi.DirectedGraph)
if ok {
return &AuthorizationModelGraph{multigraph, !g.drawingDirection}, nil
return &AuthorizationModelGraph{multigraph, !g.drawingDirection, copyIDs}, nil
}

return nil, fmt.Errorf("%w: could not cast to directed graph", ErrBuildingGraph)
Expand Down
30 changes: 14 additions & 16 deletions pkg/go/graph/graph_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,28 @@ import (
"gonum.org/v1/gonum/graph/multi"
)

type NodeLabelsToIDs map[string]int64

type AuthorizationModelGraphBuilder struct {
graph.DirectedMultigraphBuilder

ids map[string]int64 // nodes: unique labels to ids. Used to find nodes by label.
ids NodeLabelsToIDs // nodes: unique labels to ids. Used to find nodes by label.
}

// NewAuthorizationModelGraph builds an authorization model in graph form.
// For example, types such as `group`, usersets such as `group#member` and wildcards `group:*` are encoded as nodes.
//
// The edges are defined by the assignments, e.g.
// `define viewer: [group]` defines an edge from group to document#viewer.
// Conditions are not encoded in the graph,
// and the two edges in an exclusion are not distinguished.
// By default, the graph is drawn from bottom to top (i.e. terminal types have outgoing edges and no incoming edges).
// Conditions are not encoded in the graph.
func NewAuthorizationModelGraph(model *openfgav1.AuthorizationModel) (*AuthorizationModelGraph, error) {
res, err := parseModel(model)
res, ids, err := parseModel(model)
if err != nil {
return nil, err
}

return &AuthorizationModelGraph{res, DrawingDirectionListObjects}, nil
return &AuthorizationModelGraph{res, DrawingDirectionListObjects, ids}, nil
}

func parseModel(model *openfgav1.AuthorizationModel) (*multi.DirectedGraph, error) {
func parseModel(model *openfgav1.AuthorizationModel) (*multi.DirectedGraph, NodeLabelsToIDs, error) {
graphBuilder := &AuthorizationModelGraphBuilder{
multi.NewDirectedGraph(), map[string]int64{},
}
Expand Down Expand Up @@ -67,10 +66,10 @@ func parseModel(model *openfgav1.AuthorizationModel) (*multi.DirectedGraph, erro

multigraph, ok := graphBuilder.DirectedMultigraphBuilder.(*multi.DirectedGraph)
if ok {
return multigraph, nil
return multigraph, graphBuilder.ids, nil
}

return nil, fmt.Errorf("%w: could not cast to directed graph", ErrBuildingGraph)
return nil, nil, fmt.Errorf("%w: could not cast to directed graph", ErrBuildingGraph)
}

func checkRewrite(graphBuilder *AuthorizationModelGraphBuilder, parentNode *AuthorizationModelNode, model *openfgav1.AuthorizationModel, rewrite *openfgav1.Userset, typeDef *openfgav1.TypeDefinition, relation string) {
Expand Down Expand Up @@ -136,7 +135,7 @@ func parseThis(graphBuilder *AuthorizationModelGraphBuilder, parentNode graph.No
if directlyRelatedDef.GetWildcard() != nil {
// direct assignment to wildcard
assignableWildcard := directlyRelatedDef.GetType() + ":*"
curNode = graphBuilder.GetOrAddNode(assignableWildcard, assignableWildcard, SpecificType)
curNode = graphBuilder.GetOrAddNode(assignableWildcard, assignableWildcard, SpecificTypeWildcard)
}

if directlyRelatedDef.GetRelation() != "" {
Expand Down Expand Up @@ -248,10 +247,7 @@ func (g *AuthorizationModelGraphBuilder) HasEdge(from, to graph.Node, edgeType E
}

iter := g.Lines(from.ID(), to.ID())
for {
if !iter.Next() {
return false
}
for iter.Next() {
l := iter.Line()
edge, ok := l.(*AuthorizationModelEdge)
if !ok {
Expand All @@ -261,6 +257,8 @@ func (g *AuthorizationModelGraphBuilder) HasEdge(from, to graph.Node, edgeType E
return true
}
}

return false
}

func typeAndRelationExists(model *openfgav1.AuthorizationModel, typeName, relation string) bool {
Expand Down
53 changes: 30 additions & 23 deletions pkg/go/graph/graph_edge.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,41 @@ type AuthorizationModelEdge struct {

var _ encoding.Attributer = (*AuthorizationModelEdge)(nil)

func (n *AuthorizationModelEdge) Attributes() []encoding.Attribute {
var attrs []encoding.Attribute

if n.edgeType == DirectEdge {
attrs = append(attrs, encoding.Attribute{
Key: "label",
Value: "direct",
})
}

if n.edgeType == ComputedEdge {
attrs = append(attrs, encoding.Attribute{
Key: "style",
Value: "dashed",
})
}
func (n *AuthorizationModelEdge) EdgeType() EdgeType {
return n.edgeType
}

if n.edgeType == TTUEdge {
func (n *AuthorizationModelEdge) Attributes() []encoding.Attribute {
switch n.edgeType {
case DirectEdge:
return []encoding.Attribute{
{
Key: "label",
Value: "direct",
},
}
case ComputedEdge:
return []encoding.Attribute{
{
Key: "style",
Value: "dashed",
},
}
case TTUEdge:
headLabelAttrValue := n.conditionedOn
if headLabelAttrValue == "" {
headLabelAttrValue = "missing"
}

attrs = append(attrs, encoding.Attribute{
Key: "headlabel",
Value: headLabelAttrValue,
})
return []encoding.Attribute{
{
Key: "headlabel",
Value: headLabelAttrValue,
},
}
case RewriteEdge:
return []encoding.Attribute{}
default:
return []encoding.Attribute{}
}

return attrs
}
1 change: 1 addition & 0 deletions pkg/go/graph/graph_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const (
SpecificType NodeType = 0 // e.g. `group`
SpecificTypeAndRelation NodeType = 1 // e.g. `group#viewer`
OperatorNode NodeType = 2 // e.g. union
SpecificTypeWildcard NodeType = 3 // e.g. `group:*`
)

type AuthorizationModelNode struct {
Expand Down
150 changes: 148 additions & 2 deletions pkg/go/graph/graph_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graph

import (
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -71,10 +72,8 @@ rankdir=TB
model := language.MustTransformDSLToProto(testCase.model)
graph, err := NewAuthorizationModelGraph(model)
require.NoError(t, err)
require.Equal(t, DrawingDirectionListObjects, graph.drawingDirection)
reversedGraph, err := graph.Reversed()
require.NoError(t, err)
require.Equal(t, DrawingDirectionCheck, reversedGraph.drawingDirection)
actualDOT := reversedGraph.GetDOT()
actualSorted := getSorted(actualDOT)
expectedSorted := getSorted(testCase.expectedOutput)
Expand All @@ -85,3 +84,150 @@ rankdir=TB
})
}
}

func TestGetDrawingDirection(t *testing.T) {
t.Parallel()
model := language.MustTransformDSLToProto(`
model
schema 1.1
type user
type company
relations
define member: [user]`)
graph, err := NewAuthorizationModelGraph(model)
require.NoError(t, err)
require.Equal(t, DrawingDirectionListObjects, graph.GetDrawingDirection())
reversedGraph, err := graph.Reversed()
require.NoError(t, err)
require.Equal(t, DrawingDirectionCheck, reversedGraph.GetDrawingDirection())
}

func TestGetNodeByLabel(t *testing.T) {
t.Parallel()
model := language.MustTransformDSLToProto(`
model
schema 1.1
type user
type company
relations
define member: [user with cond, user:* with cond]
define owner: [user]
define approved_member: member or owner
type group
relations
define approved_member: [user]
type license
relations
define active_member: approved_member from owner
define owner: [company, group]`)
graph, err := NewAuthorizationModelGraph(model)
require.NoError(t, err)

testCases := []struct {
label string
expectedFound bool
}{
// found
{"user", true},
{"user:*", true},
{"company", true},
{"company#member", true},
{"company#owner", true},
{"company#approved_member", true},
{"group", true},
{"group#approved_member", true},
{"license", true},
{"license#active_member", true},
{"license#owner", true},
// not found
{"unknown", false},
{"unknown#unknown", false},
{"user with cond", false},
{"user:* with cond", false},
}
for i, testCase := range testCases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Parallel()
node, err := graph.GetNodeByLabel(testCase.label)
if testCase.expectedFound {
require.NoError(t, err)
require.NotNil(t, node)
} else {
require.ErrorIs(t, err, ErrQueryingGraph)
require.Nil(t, node)
}
})
}
}

func TestGetNodeTypes(t *testing.T) {
t.Parallel()
model := language.MustTransformDSLToProto(`
model
schema 1.1
type user
type group
relations
define member: [user]
type company
relations
define wildcard: [user:*]
define direct: [user]
define userset: [group#member]
define intersectionRelation: wildcard and direct
define unionRelation: wildcard or direct
define differenceRelation: wildcard but not direct`)
graph, err := NewAuthorizationModelGraph(model)
require.NoError(t, err)

testCases := []struct {
label string
expectedNodeType NodeType
}{
{"user", SpecificType},
{"user:*", SpecificTypeWildcard},
{"group", SpecificType},
{"group#member", SpecificTypeAndRelation},
{"company", SpecificType},
{"company#wildcard", SpecificTypeAndRelation},
{"company#direct", SpecificTypeAndRelation},
{"company#userset", SpecificTypeAndRelation},
{"company#intersectionRelation", SpecificTypeAndRelation},
{"company#unionRelation", SpecificTypeAndRelation},
{"company#differenceRelation", SpecificTypeAndRelation},
}
for _, testCase := range testCases {
t.Run(testCase.label, func(t *testing.T) {
t.Parallel()
node, err := graph.GetNodeByLabel(testCase.label)
require.NoError(t, err)
require.NotNil(t, node)
require.Equal(t, testCase.expectedNodeType, node.NodeType(), "expected node type %d but got %d", testCase.expectedNodeType, node.NodeType())
})
}

// testing the operator nodes is not so straightforward...
var unionNodes, differenceNodes, intersectionNodes []*AuthorizationModelNode

iterNodes := graph.Nodes()
for iterNodes.Next() {
node, ok := iterNodes.Node().(*AuthorizationModelNode)
require.True(t, ok)
if node.nodeType != OperatorNode {
continue
}

switch node.label {
case "union":
unionNodes = append(unionNodes, node)
case "intersection":
intersectionNodes = append(intersectionNodes, node)
case "exclusion":
differenceNodes = append(differenceNodes, node)
}
}

require.Len(t, unionNodes, 1)
require.Len(t, differenceNodes, 1)
require.Len(t, intersectionNodes, 1)
}
Loading

0 comments on commit 8d0ba01

Please sign in to comment.