Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cluster pool assignment validation #5778

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,34 @@ func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admi
return workflowExecConfig, nil
}

func (m *ExecutionManager) getClusterAssignment(ctx context.Context, request *admin.ExecutionCreateRequest) (
*admin.ClusterAssignment, error) {
if request.Spec.ClusterAssignment != nil {
return request.Spec.ClusterAssignment, nil
func (m *ExecutionManager) getClusterAssignment(ctx context.Context, req *admin.ExecutionCreateRequest) (*admin.ClusterAssignment, error) {
storedAssignment, err := m.fetchClusterAssignment(ctx, req.Project, req.Domain)
if err != nil {
return nil, err
}

reqAssignment := req.GetSpec().GetClusterAssignment()
reqPool := reqAssignment.GetClusterPoolName()
storedPool := storedAssignment.GetClusterPoolName()
if reqPool == "" {
return storedAssignment, nil
}

if storedPool == "" {
return reqAssignment, nil
}

if reqPool != storedPool {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "execution with project %q and domain %q cannot run on cluster pool %q, because its configured to run on pool %q", req.Project, req.Domain, reqPool, storedPool)
}

return storedAssignment, nil
}

func (m *ExecutionManager) fetchClusterAssignment(ctx context.Context, project, domain string) (*admin.ClusterAssignment, error) {
resource, err := m.resourceManager.GetResource(ctx, interfaces.ResourceRequest{
Project: request.Project,
Domain: request.Domain,
Project: project,
Domain: domain,
ResourceType: admin.MatchableResource_CLUSTER_ASSIGNMENT,
})
if err != nil && !errors.IsDoesNotExistError(err) {
Expand All @@ -421,11 +440,13 @@ func (m *ExecutionManager) getClusterAssignment(ctx context.Context, request *ad
if resource != nil && resource.Attributes.GetClusterAssignment() != nil {
return resource.Attributes.GetClusterAssignment(), nil
}
clusterPoolAssignment := m.config.ClusterPoolAssignmentConfiguration().GetClusterPoolAssignments()[request.GetDomain()]

return &admin.ClusterAssignment{
ClusterPoolName: clusterPoolAssignment.Pool,
}, nil
var clusterAssignment *admin.ClusterAssignment
domainAssignment := m.config.ClusterPoolAssignmentConfiguration().GetClusterPoolAssignments()[domain]
if domainAssignment.Pool != "" {
clusterAssignment = &admin.ClusterAssignment{ClusterPoolName: domainAssignment.Pool}
}
return clusterAssignment, nil
}

func (m *ExecutionManager) launchSingleTaskExecution(
Expand Down
112 changes: 97 additions & 15 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ func TestCreateExecution(t *testing.T) {
}}
repository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func(
ctx context.Context, projectID string) (models.Project, error) {
return transformers.CreateProjectModel(&admin.Project{
Labels: &labels}), nil
return transformers.CreateProjectModel(&admin.Project{Labels: &labels}), nil
}

clusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
Expand Down Expand Up @@ -382,8 +381,6 @@ func TestCreateExecution(t *testing.T) {

mockConfig := getMockExecutionsConfigProvider()
mockConfig.(*runtimeMocks.MockConfigurationProvider).AddQualityOfServiceConfiguration(qosProvider)

execManager := NewExecutionManager(repository, r, mockConfig, getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, &mockPublisher, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})
request := testutils.GetExecutionRequest()
request.Spec.Metadata = &admin.ExecutionMetadata{
Principal: "unused - populated from authenticated context",
Expand All @@ -392,16 +389,18 @@ func TestCreateExecution(t *testing.T) {
request.Spec.ClusterAssignment = &clusterAssignment
request.Spec.ExecutionClusterLabel = &admin.ExecutionClusterLabel{Value: executionClusterLabel}

execManager := NewExecutionManager(repository, r, mockConfig, getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, &mockPublisher, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})

identity, err := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
assert.NoError(t, err)
ctx := identity.WithContext(context.Background())
response, err := execManager.CreateExecution(ctx, request, requestedAt)
assert.Nil(t, err)
assert.NoError(t, err)

expectedResponse := &admin.ExecutionCreateResponse{
Id: &executionIdentifier,
}
assert.Nil(t, err)
assert.NoError(t, err)
assert.True(t, proto.Equal(expectedResponse.Id, response.Id))

// TODO: Check for offloaded inputs
Expand Down Expand Up @@ -632,7 +631,6 @@ func TestCreateExecutionInCompatibleInputs(t *testing.T) {
}

func TestCreateExecutionPropellerFailure(t *testing.T) {
clusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
repository := getMockRepositoryForExecTest()
setDefaultLpCallbackForExecTest(repository)
expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "ABC")
Expand Down Expand Up @@ -666,7 +664,6 @@ func TestCreateExecutionPropellerFailure(t *testing.T) {
Principal: "unused - populated from authenticated context",
}
request.Spec.RawOutputDataConfig = &admin.RawOutputDataConfig{OutputLocationPrefix: rawOutput}
request.Spec.ClusterAssignment = &clusterAssignment

identity, err := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
assert.NoError(t, err)
Expand Down Expand Up @@ -5467,8 +5464,32 @@ func TestGetClusterAssignment(t *testing.T) {
assert.NoError(t, err)
assert.True(t, proto.Equal(ca, &clusterAssignment))
})
t.Run("value from request", func(t *testing.T) {
reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "swimming-pool"}
t.Run("value from config", func(t *testing.T) {
customCP := "my_cp"
clusterPoolAsstProvider := &runtimeIFaceMocks.ClusterPoolAssignmentConfiguration{}
clusterPoolAsstProvider.OnGetClusterPoolAssignments().Return(runtimeInterfaces.ClusterPoolAssignments{
workflowIdentifier.GetDomain(): runtimeInterfaces.ClusterPoolAssignment{
Pool: customCP,
},
})
mockConfig := getMockExecutionsConfigProvider()
mockConfig.(*runtimeMocks.MockConfigurationProvider).AddClusterPoolAssignmentConfiguration(clusterPoolAsstProvider)

executionManager := ExecutionManager{
resourceManager: &managerMocks.MockResourceManager{},
config: mockConfig,
}

ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{},
})
assert.NoError(t, err)
assert.Equal(t, customCP, ca.GetClusterPoolName())
})
t.Run("value from request matches value from config", func(t *testing.T) {
reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Expand All @@ -5479,12 +5500,30 @@ func TestGetClusterAssignment(t *testing.T) {
assert.NoError(t, err)
assert.True(t, proto.Equal(ca, &reqClusterAssignment))
})
t.Run("value from config", func(t *testing.T) {
customCP := "my_cp"
t.Run("no value in DB nor in config, takes value from request", func(t *testing.T) {
mockConfig := getMockExecutionsConfigProvider()

executionManager := ExecutionManager{
resourceManager: &managerMocks.MockResourceManager{},
config: mockConfig,
}

reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{
ClusterAssignment: &reqClusterAssignment,
},
})
assert.NoError(t, err)
assert.True(t, proto.Equal(ca, &reqClusterAssignment))
})
t.Run("empty value in DB, takes value from request", func(t *testing.T) {
clusterPoolAsstProvider := &runtimeIFaceMocks.ClusterPoolAssignmentConfiguration{}
clusterPoolAsstProvider.OnGetClusterPoolAssignments().Return(runtimeInterfaces.ClusterPoolAssignments{
workflowIdentifier.GetDomain(): runtimeInterfaces.ClusterPoolAssignment{
Pool: customCP,
Pool: "",
},
})
mockConfig := getMockExecutionsConfigProvider()
Expand All @@ -5495,13 +5534,56 @@ func TestGetClusterAssignment(t *testing.T) {
config: mockConfig,
}

reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{},
Spec: &admin.ExecutionSpec{
ClusterAssignment: &reqClusterAssignment,
},
})
assert.NoError(t, err)
assert.Equal(t, customCP, ca.GetClusterPoolName())
assert.True(t, proto.Equal(ca, &reqClusterAssignment))
})
t.Run("value from request doesn't match value from config", func(t *testing.T) {
reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "swimming-pool"}
_, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{
ClusterAssignment: &reqClusterAssignment,
},
})
st, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, codes.InvalidArgument, st.Code())
assert.Equal(t, `execution with project "project" and domain "domain" cannot run on cluster pool "swimming-pool", because its configured to run on pool "gpu"`, st.Message())
})
t.Run("db error", func(t *testing.T) {
expected := errors.New("fail db")
resourceManager.GetResourceFunc = func(ctx context.Context,
request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) {
assert.EqualValues(t, request, managerInterfaces.ResourceRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
ResourceType: admin.MatchableResource_CLUSTER_ASSIGNMENT,
})
return &managerInterfaces.ResourceResponse{
Attributes: &admin.MatchingAttributes{
Target: &admin.MatchingAttributes_ClusterAssignment{
ClusterAssignment: &clusterAssignment,
},
},
}, expected
}

_, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{},
})

assert.Equal(t, expected, err)
})
}

Expand Down
2 changes: 2 additions & 0 deletions flyteadmin/pkg/manager/interfaces/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
)

//go:generate mockery -name ResourceInterface -output=../mocks -case=underscore

// ResourceInterface manages project, domain and workflow -specific attributes.
type ResourceInterface interface {
ListAll(ctx context.Context, request *admin.ListMatchableAttributesRequest) (
Expand Down
Loading
Loading