Skip to content

Commit

Permalink
Refactor some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SandersAaronD committed Oct 1, 2024
1 parent 2ef0aca commit 0f7d9c8
Showing 1 changed file with 86 additions and 80 deletions.
166 changes: 86 additions & 80 deletions ai-training-api/app/model_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,86 +273,92 @@ func TestGetModelMetrics(t *testing.T) {
App: App{_db: db},
}

// Test case 1: Basic case with multiple metrics for a single process
t.Run("Basic case", func(t *testing.T) {
processID := uuid.New()
metrics := []model.ModelMetrics{
{ProcessID: processID, MetricName: "accuracy", StepName: "train", Step: 1, MetricValue: "0.75"},
{ProcessID: processID, MetricName: "accuracy", StepName: "train", Step: 2, MetricValue: "0.80"},
{ProcessID: processID, MetricName: "loss", StepName: "train", Step: 1, MetricValue: "0.5"},
{ProcessID: processID, MetricName: "loss", StepName: "train", Step: 2, MetricValue: "0.4"},
}
insertTestMetrics(t, db, metrics)

req := setupTestRequest(processID.String())
result, err := app.getModelMetrics("0", req)
require.NoError(t, err)
response, ok := result.(GetModelMetricsResponse)
require.True(t, ok)

require.Len(t, response, 2) // Two DataFrameWrappers: one for accuracy, one for loss

// Print out the entire response for debugging
t.Logf("Response: %+v", response)

// Check accuracy metrics
require.Equal(t, "accuracy", response[0].MetricName)
require.Equal(t, "train", response[0].StepName)
require.Len(t, response[0].Fields, 2)

// Print out the Values slices for debugging
t.Logf("Step Values: %+v", response[0].Fields[0].Values)
t.Logf("Metric Values: %+v", response[0].Fields[1].Values)

require.Equal(t, []interface{}{uint32(1), uint32(2)}, response[0].Fields[0].Values)
require.Equal(t, []interface{}{"0.75", "0.80"}, response[0].Fields[1].Values)

// Check loss metrics
require.Equal(t, "loss", response[1].MetricName)
require.Equal(t, "train", response[1].StepName)
require.Len(t, response[1].Fields, 2)
require.Equal(t, []interface{}{uint32(1), uint32(2)}, response[1].Fields[0].Values)
require.Equal(t, []interface{}{"0.5", "0.4"}, response[1].Fields[1].Values)
})

// Test case 2: No metrics in the database
t.Run("No metrics", func(t *testing.T) {
// Clear the database
db.Exec("DELETE FROM model_metrics")

processID := uuid.New()
req := setupTestRequest(processID.String())
result, err := app.getModelMetrics("0", req)
require.NoError(t, err)
response, ok := result.(GetModelMetricsResponse)
require.True(t, ok)
require.Len(t, response, 0)
})

// Test case 3: Single metric for a process
t.Run("Single metric", func(t *testing.T) {
// Clear the database
db.Exec("DELETE FROM model_metrics")

processID := uuid.New()
metrics := []model.ModelMetrics{
{ProcessID: processID, MetricName: "accuracy", StepName: "train", Step: 1, MetricValue: "0.75"},
}
insertTestMetrics(t, db, metrics)

req := setupTestRequest(processID.String())
result, err := app.getModelMetrics("0", req)
require.NoError(t, err)
response, ok := result.(GetModelMetricsResponse)
require.True(t, ok)

require.Len(t, response, 1)
require.Equal(t, "accuracy", response[0].MetricName)
require.Equal(t, "train", response[0].StepName)
require.Len(t, response[0].Fields, 2)
require.Equal(t, []interface{}{uint32(1)}, response[0].Fields[0].Values)
require.Equal(t, []interface{}{"0.75"}, response[0].Fields[1].Values)
})
type testCase struct {
name string
metrics []model.ModelMetrics
check func(*testing.T, GetModelMetricsResponse)
}

testCases := []testCase{
{
name: "Basic case",
metrics: []model.ModelMetrics{
{MetricName: "accuracy", StepName: "train", Step: 1, MetricValue: "0.75"},
{MetricName: "accuracy", StepName: "train", Step: 2, MetricValue: "0.80"},
{MetricName: "loss", StepName: "train", Step: 1, MetricValue: "0.5"},
{MetricName: "loss", StepName: "train", Step: 2, MetricValue: "0.4"},
},
check: func(t *testing.T, response GetModelMetricsResponse) {
require.Len(t, response, 2) // Two DataFrameWrappers: one for accuracy, one for loss

// Check accuracy metrics
require.Equal(t, "accuracy", response[0].MetricName)
require.Equal(t, "train", response[0].StepName)
require.Len(t, response[0].Fields, 2)
require.Equal(t, []interface{}{uint32(1), uint32(2)}, response[0].Fields[0].Values)
require.Equal(t, []interface{}{"0.75", "0.80"}, response[0].Fields[1].Values)

// Check loss metrics
require.Equal(t, "loss", response[1].MetricName)
require.Equal(t, "train", response[1].StepName)
require.Len(t, response[1].Fields, 2)
require.Equal(t, []interface{}{uint32(1), uint32(2)}, response[1].Fields[0].Values)
require.Equal(t, []interface{}{"0.5", "0.4"}, response[1].Fields[1].Values)
},
},
{
name: "No metrics",
metrics: []model.ModelMetrics{},
check: func(t *testing.T, response GetModelMetricsResponse) {
require.Len(t, response, 0)
},
},
{
name: "Single metric",
metrics: []model.ModelMetrics{
{MetricName: "accuracy", StepName: "train", Step: 1, MetricValue: "0.75"},
},
check: func(t *testing.T, response GetModelMetricsResponse) {
require.Len(t, response, 1)
require.Equal(t, "accuracy", response[0].MetricName)
require.Equal(t, "train", response[0].StepName)
require.Len(t, response[0].Fields, 2)
require.Equal(t, []interface{}{uint32(1)}, response[0].Fields[0].Values)
require.Equal(t, []interface{}{"0.75"}, response[0].Fields[1].Values)
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Clear the database
db.Exec("DELETE FROM model_metrics")

processID := uuid.New()
for i := range tc.metrics {
tc.metrics[i].ProcessID = processID
}
insertTestMetrics(t, db, tc.metrics)

req := setupTestRequest(processID.String())
result, err := app.getModelMetrics("0", req)
require.NoError(t, err)
response, ok := result.(GetModelMetricsResponse)
require.True(t, ok)

// Print out the entire response for debugging
t.Logf("Response: %+v", response)

if tc.name == "Basic case" {
// Print out the Values slices for debugging
t.Logf("Step Values: %+v", response[0].Fields[0].Values)
t.Logf("Metric Values: %+v", response[0].Fields[1].Values)
}

// Run the check function
tc.check(t, response)
})
}
}

func insertTestMetrics(t *testing.T, db *gorm.DB, metrics []model.ModelMetrics) {
Expand Down

0 comments on commit 0f7d9c8

Please sign in to comment.