Skip to content
Merged
Prev Previous commit
Next Next commit
refactor(tests): update model index tests for clarity and accuracy
  • Loading branch information
dido18 authored and lucarin91 committed Oct 28, 2025
commit e3eec7bf4ff799f654cd0c0811428bde0aa76684
115 changes: 40 additions & 75 deletions internal/orchestrator/modelsindex/modelsindex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,102 +9,67 @@ import (
)

func TestGenerateModelsIndexFromFile(t *testing.T) {
testdataPath := paths.New("testdata")

t.Run("Valid Model list", func(t *testing.T) {
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath)
t.Run("it parses a valid model-list.yaml", func(t *testing.T) {
modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata"))
require.NoError(t, err)
require.NotNil(t, modelsIndex)

models := modelsIndex.GetModels()
assert.Len(t, models, 3, "Expected 3 models to be parsed")

// Test first model
model1, found := modelsIndex.GetModelByID("face-detection")
assert.Equal(t, "brick", model1.Runner)
require.True(t, found, "face-detection should be found")
assert.Equal(t, "face-detection:", model1.ID)
assert.Equal(t, "Lightweight-Face-Detection", model1.Name)
assert.Equal(t, "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images.", model1.ModuleDescription)
assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model1.L)
assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model1.Bricks)
assert.Equal(t, "1.0.0", model1.Metadata["version"])
assert.Equal(t, "Test Author", model1.Metadata["author"])
assert.Equal(t, "1000", model1.ModelConfiguration["max_tokens"])
assert.Equal(t, "0.7", model1.ModelConfiguration["temperature"])
assert.Len(t, models, 2, "Expected 2 models to be parsed")
})

// // Test second model
// model2, found := modelsIndex.GetModelByID("test_model_2")
// // require.True(t, found, "test_model_2 should be found")
// // assert.Equal(t, "test_model_2", model2.ID)
// // assert.Equal(t, "Test Model 2", model2.Name)
// // assert.Equal(t, "Another test AI model", model2.ModuleDescription)
// // assert.Equal(t, "another_runner", model2.Runner)
// // assert.Equal(t, []string{"brick2", "brick3"}, model2.Bricks)
// // assert.Equal(t, "2.0.0", model2.Metadata["version"])
// // assert.Equal(t, "MIT", model2.Metadata["license"])
t.Run("it gets a model by ID", func(t *testing.T) {
modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata"))
require.NoError(t, err)

// // Test minimal model
// model3, found := modelsIndex.GetModelByID("minimal_model")
// require.True(t, found, "minimal_model should be found")
// assert.Equal(t, "minimal_model", model3.ID)
// assert.Equal(t, "Minimal Model", model3.Name)
// assert.Equal(t, "Minimal model with no optional fields", model3.ModuleDescription)
// assert.Equal(t, "minimal_runner", model3.Runner)
// assert.Empty(t, model3.Bricks)
// assert.Empty(t, model3.Metadata)
// assert.Empty(t, model3.ModelConfiguration)
model, found := modelsIndex.GetModelByID("face-detection")
assert.Equal(t, "brick", model.Runner)
require.True(t, found, "face-detection should be found")
assert.Equal(t, "face-detection", model.ID)
assert.Equal(t, "Lightweight-Face-Detection", model.Name)
assert.Equal(t, "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images.", model.ModuleDescription)
assert.Equal(t, []string{"face"}, model.ModelLabels)
assert.Equal(t, "/models/ootb/ei/lw-face-det.eim", model.ModelConfiguration["EI_OBJ_DETECTION_MODEL"])
assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model.Bricks)
assert.Equal(t, "qualcomm-ai-hub", model.Metadata["source"])
assert.Equal(t, "false", model.Metadata["ei-gpu-mode"])
assert.Equal(t, "face-det-lite", model.Metadata["source-model-id"])
assert.Equal(t, "https://aihub.qualcomm.com/models/face_det_lite", model.Metadata["source-model-url"])
})

// Test file not found error
t.Run("FileNotFound", func(t *testing.T) {
nonExistentPath := paths.New("nonexistent")
t.Run("it fails if model-list.yaml does not exist", func(t *testing.T) {
nonExistentPath := paths.New("nonexistent.yaml")
modelsIndex, err := GenerateModelsIndexFromFile(nonExistentPath)
assert.Error(t, err)
assert.Nil(t, modelsIndex)
})

// Test invalid YAML parsing
t.Run("InvalidYAML", func(t *testing.T) {
// Create a temporary invalid YAML file
invalidPath := testdataPath.Join("invalid-models.yaml")
t.Run("it filters models by a single brick", func(t *testing.T) {
modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata"))
require.NoError(t, err)

brick1Models := modelsIndex.GetModelsByBrick("arduino:object_detection")
assert.Len(t, brick1Models, 1)
assert.Equal(t, "face-detection", brick1Models[0].ID)

// We expect this to either fail parsing or handle gracefully
// Since the current implementation may be lenient with missing fields
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath.Parent().Join("testdata-invalid"))
if err != nil {
// If it fails, that's expected for invalid files
assert.Error(t, err)
assert.Nil(t, modelsIndex)
}
// Note: Some invalid YAML might still parse successfully depending on the YAML library's behavior
_ = invalidPath // Avoid unused variable warning
brick1Models = modelsIndex.GetModelsByBrick("not-existing-brick")
assert.Nil(t, brick1Models)
})

// Test brick filtering functionality
t.Run("BrickFiltering", func(t *testing.T) {
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath)
t.Run("it filters models by multiple bricks", func(t *testing.T) {
modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata"))
require.NoError(t, err)

// Test GetModelsByBrick
brick1Models := modelsIndex.GetModelsByBrick("brick1")
assert.Len(t, brick1Models, 1)
assert.Equal(t, "test_model_1", brick1Models[0].ID)

brick2Models := modelsIndex.GetModelsByBrick("brick2")
brick2Models := modelsIndex.GetModelsByBrick("arduino:video_object_detection")
assert.Len(t, brick2Models, 2)
modelIDs := []string{brick2Models[0].ID, brick2Models[1].ID}
assert.Contains(t, modelIDs, "test_model_1")
assert.Contains(t, modelIDs, "test_model_2")
assert.Equal(t, "face-detection", brick2Models[0].ID)
assert.Equal(t, "yolox-object-detection", brick2Models[1].ID)

// Test GetModelsByBricks
multiModels := modelsIndex.GetModelsByBricks([]string{"brick1", "brick3"})
assert.Len(t, multiModels, 2)
multiModelIDs := []string{multiModels[0].ID, multiModels[1].ID}
assert.Contains(t, multiModelIDs, "test_model_1")
assert.Contains(t, multiModelIDs, "test_model_2")
bricks2Models := modelsIndex.GetModelsByBricks([]string{"arduino:object_detection", "arduino:video_object_detection"})
assert.Len(t, bricks2Models, 2)
assert.Equal(t, "face-detection", bricks2Models[0].ID)
assert.Equal(t, "yolox-object-detection", bricks2Models[1].ID)

// Test non-existent brick
nonExistentModels := modelsIndex.GetModelsByBrick("nonexistent_brick")
assert.Nil(t, nonExistentModels)
})
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,4 @@ models:
source-model-id: "YOLOX-Nano"
source-model-url: "https://github.com/Megvii-BaseDetection/YOLOX"
bricks:
- arduino:object_detection
- arduino:video_object_detection