diff --git a/Taskfile.yml b/Taskfile.yml index 06e54550..c30e0036 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -167,7 +167,7 @@ tasks: cmds: - docker rm -f adbd - board:install-arduino-app-cli: + board:install: desc: Install arduino-app-cli on the board interactive: true cmds: diff --git a/cmd/arduino-app-cli/app/restart.go b/cmd/arduino-app-cli/app/restart.go index 3462a9a5..892f108f 100644 --- a/cmd/arduino-app-cli/app/restart.go +++ b/cmd/arduino-app-cli/app/restart.go @@ -16,10 +16,18 @@ package app import ( + "context" + "fmt" + "github.com/spf13/cobra" + "golang.org/x/text/cases" + "golang.org/x/text/language" "github.com/arduino/arduino-app-cli/cmd/arduino-app-cli/completion" + "github.com/arduino/arduino-app-cli/cmd/arduino-app-cli/internal/servicelocator" "github.com/arduino/arduino-app-cli/cmd/feedback" + "github.com/arduino/arduino-app-cli/internal/orchestrator" + "github.com/arduino/arduino-app-cli/internal/orchestrator/app" "github.com/arduino/arduino-app-cli/internal/orchestrator/config" ) @@ -32,17 +40,63 @@ func newRestartCmd(cfg config.Configuration) *cobra.Command { if len(args) == 0 { return cmd.Help() } - app, err := Load(args[0]) + appToStart, err := Load(args[0]) if err != nil { feedback.Fatal(err.Error(), feedback.ErrBadArgument) - return nil - } - if err := stopHandler(cmd.Context(), app); err != nil { - feedback.Warnf("failed to stop app: %s", err.Error()) } - return startHandler(cmd.Context(), cfg, app) + return restartHandler(cmd.Context(), cfg, appToStart) }, ValidArgsFunction: completion.ApplicationNames(cfg), } return cmd } + +func restartHandler(ctx context.Context, cfg config.Configuration, app app.ArduinoApp) error { + out, _, getResult := feedback.OutputStreams() + + stream := orchestrator.RestartApp( + ctx, + servicelocator.GetDockerClient(), + servicelocator.GetProvisioner(), + servicelocator.GetModelsIndex(), + servicelocator.GetBricksIndex(), + app, + cfg, + servicelocator.GetStaticStore(), + ) + for message := range stream { + switch message.GetType() { + case orchestrator.ProgressType: + fmt.Fprintf(out, "Progress[%s]: %.0f%%\n", message.GetProgress().Name, message.GetProgress().Progress) + case orchestrator.InfoType: + fmt.Fprintln(out, "[INFO]", message.GetData()) + case orchestrator.ErrorType: + errMesg := cases.Title(language.AmericanEnglish).String(message.GetError().Error()) + feedback.Fatal(fmt.Sprintf("[ERROR] %s", errMesg), feedback.ErrGeneric) + return nil + } + } + + outputResult := getResult() + feedback.PrintResult(restartAppResult{ + AppName: app.Name, + Status: "restarted", + Output: outputResult, + }) + + return nil +} + +type restartAppResult struct { + AppName string `json:"app_name"` + Status string `json:"status"` + Output *feedback.OutputStreamsResult `json:"output,omitempty"` +} + +func (r restartAppResult) String() string { + return fmt.Sprintf("✓ App %q restarted successfully", r.AppName) +} + +func (r restartAppResult) Data() interface{} { + return r +} diff --git a/cmd/arduino-app-cli/brick/bricks.go b/cmd/arduino-app-cli/brick/bricks.go index 692552cc..74dd3a3a 100644 --- a/cmd/arduino-app-cli/brick/bricks.go +++ b/cmd/arduino-app-cli/brick/bricks.go @@ -17,16 +17,18 @@ package brick import ( "github.com/spf13/cobra" + + "github.com/arduino/arduino-app-cli/internal/orchestrator/config" ) -func NewBrickCmd() *cobra.Command { +func NewBrickCmd(cfg config.Configuration) *cobra.Command { appCmd := &cobra.Command{ Use: "brick", Short: "Manage Arduino Bricks", } appCmd.AddCommand(newBricksListCmd()) - appCmd.AddCommand(newBricksDetailsCmd()) + appCmd.AddCommand(newBricksDetailsCmd(cfg)) return appCmd } diff --git a/cmd/arduino-app-cli/brick/details.go b/cmd/arduino-app-cli/brick/details.go index fe025078..a1ba72f1 100644 --- a/cmd/arduino-app-cli/brick/details.go +++ b/cmd/arduino-app-cli/brick/details.go @@ -25,21 +25,23 @@ import ( "github.com/arduino/arduino-app-cli/cmd/arduino-app-cli/internal/servicelocator" "github.com/arduino/arduino-app-cli/cmd/feedback" "github.com/arduino/arduino-app-cli/internal/orchestrator/bricks" + "github.com/arduino/arduino-app-cli/internal/orchestrator/config" ) -func newBricksDetailsCmd() *cobra.Command { +func newBricksDetailsCmd(cfg config.Configuration) *cobra.Command { return &cobra.Command{ Use: "details", Short: "Details of a specific brick", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - bricksDetailsHandler(args[0]) + bricksDetailsHandler(args[0], cfg) }, } } -func bricksDetailsHandler(id string) { - res, err := servicelocator.GetBrickService().BricksDetails(id) +func bricksDetailsHandler(id string, cfg config.Configuration) { + res, err := servicelocator.GetBrickService().BricksDetails(id, servicelocator.GetAppIDProvider(), + cfg) if err != nil { if errors.Is(err, bricks.ErrBrickNotFound) { feedback.Fatal(err.Error(), feedback.ErrBadArgument) diff --git a/cmd/arduino-app-cli/main.go b/cmd/arduino-app-cli/main.go index 765ccd69..e859aae2 100644 --- a/cmd/arduino-app-cli/main.go +++ b/cmd/arduino-app-cli/main.go @@ -71,7 +71,7 @@ func run(configuration cfg.Configuration) error { rootCmd.AddCommand( app.NewAppCmd(configuration), - brick.NewBrickCmd(), + brick.NewBrickCmd(configuration), completion.NewCompletionCommand(), daemon.NewDaemonCmd(configuration, Version), properties.NewPropertiesCmd(configuration), diff --git a/debian/arduino-app-cli/etc/systemd/system/arduino-app-cli.service b/debian/arduino-app-cli/etc/systemd/system/arduino-app-cli.service index 601348de..93ed0367 100644 --- a/debian/arduino-app-cli/etc/systemd/system/arduino-app-cli.service +++ b/debian/arduino-app-cli/etc/systemd/system/arduino-app-cli.service @@ -2,7 +2,6 @@ Description=Arduino App CLI daemon Service After=network-online.target docker.service arduino-router.service Wants=network-online.target docker.service arduino-router.service -Requires=docker.service arduino-router.service [Service] ExecStart=/usr/bin/arduino-app-cli daemon --port 8800 --log-level error diff --git a/internal/api/api.go b/internal/api/api.go index 1d825317..08d31d84 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -56,7 +56,7 @@ func NewHTTPRouter( mux.Handle("GET /v1/version", handlers.HandlerVersion(version)) mux.Handle("GET /v1/config", handlers.HandleConfig(cfg)) mux.Handle("GET /v1/bricks", handlers.HandleBrickList(brickService)) - mux.Handle("GET /v1/bricks/{brickID}", handlers.HandleBrickDetails(brickService)) + mux.Handle("GET /v1/bricks/{brickID}", handlers.HandleBrickDetails(brickService, idProvider, cfg)) mux.Handle("GET /v1/properties", handlers.HandlePropertyKeys(cfg)) mux.Handle("GET /v1/properties/{key}", handlers.HandlePropertyGet(cfg)) diff --git a/internal/api/handlers/bricks.go b/internal/api/handlers/bricks.go index d21c6b3a..7e95753a 100644 --- a/internal/api/handlers/bricks.go +++ b/internal/api/handlers/bricks.go @@ -26,6 +26,7 @@ import ( "github.com/arduino/arduino-app-cli/internal/api/models" "github.com/arduino/arduino-app-cli/internal/orchestrator/app" "github.com/arduino/arduino-app-cli/internal/orchestrator/bricks" + "github.com/arduino/arduino-app-cli/internal/orchestrator/config" "github.com/arduino/arduino-app-cli/internal/render" ) @@ -153,14 +154,15 @@ func HandleBrickCreate( } } -func HandleBrickDetails(brickService *bricks.Service) http.HandlerFunc { +func HandleBrickDetails(brickService *bricks.Service, idProvider *app.IDProvider, + cfg config.Configuration) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := r.PathValue("brickID") if id == "" { render.EncodeResponse(w, http.StatusBadRequest, models.ErrorResponse{Details: "id must be set"}) return } - res, err := brickService.BricksDetails(id) + res, err := brickService.BricksDetails(id, idProvider, cfg) if err != nil { if errors.Is(err, bricks.ErrBrickNotFound) { details := fmt.Sprintf("brick with id %q not found", id) diff --git a/internal/api/handlers/monitor.go b/internal/api/handlers/monitor.go index e421283f..5aaf8f4d 100644 --- a/internal/api/handlers/monitor.go +++ b/internal/api/handlers/monitor.go @@ -83,24 +83,53 @@ func monitorStream(mon net.Conn, ws *websocket.Conn) { }() } +func splitOrigin(origin string) (scheme, host, port string, err error) { + parts := strings.SplitN(origin, "://", 2) + if len(parts) != 2 { + return "", "", "", fmt.Errorf("invalid origin format: %s", origin) + } + scheme = parts[0] + hostPort := parts[1] + hostParts := strings.SplitN(hostPort, ":", 2) + host = hostParts[0] + if len(hostParts) == 2 { + port = hostParts[1] + } else { + port = "*" + } + return scheme, host, port, nil +} + func checkOrigin(origin string, allowedOrigins []string) bool { + scheme, host, port, err := splitOrigin(origin) + if err != nil { + slog.Error("WebSocket origin check failed", slog.String("origin", origin), slog.String("error", err.Error())) + return false + } for _, allowed := range allowedOrigins { - if strings.HasSuffix(allowed, "*") { - // String ends with *, match the prefix - if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) { - return true - } - } else { - // Exact match - if allowed == origin { - return true - } + allowedScheme, allowedHost, allowedPort, err := splitOrigin(allowed) + if err != nil { + panic(err) + } + if allowedScheme != scheme { + continue } + if allowedHost != host && allowedHost != "*" { + continue + } + if allowedPort != port && allowedPort != "*" { + continue + } + return true } + slog.Error("WebSocket origin check failed", slog.String("origin", origin)) return false } func HandleMonitorWS(allowedOrigins []string) http.HandlerFunc { + // Do a dry-run of checkorigin, so it can panic if misconfigured now, not on first request + _ = checkOrigin("http://localhost", allowedOrigins) + upgrader := websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, diff --git a/internal/api/handlers/monitor_test.go b/internal/api/handlers/monitor_test.go new file mode 100644 index 00000000..e54c7f24 --- /dev/null +++ b/internal/api/handlers/monitor_test.go @@ -0,0 +1,52 @@ +// This file is part of arduino-app-cli. +// +// Copyright 2025 ARDUINO SA (http://www.arduino.cc/) +// +// This software is released under the GNU General Public License version 3, +// which covers the main part of arduino-app-cli. +// The terms of this license can be found at: +// https://www.gnu.org/licenses/gpl-3.0.en.html +// +// You can be released from the requirements of the above licenses by purchasing +// a commercial license. Buying such a license is mandatory if you want to +// modify or otherwise use the software for commercial activities involving the +// Arduino software without disclosing the source code of your own applications. +// To purchase a commercial license, send an email to license@arduino.cc. + +package handlers + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckOrigin(t *testing.T) { + origins := []string{ + "wails://wails", + "wails://wails.localhost:*", + "http://wails.localhost:*", + "http://localhost:*", + "https://localhost:*", + "http://example.com:7000", + "https://*:443", + } + + allow := func(origin string) { + require.True(t, checkOrigin(origin, origins), "Expected origin %s to be allowed", origin) + } + deny := func(origin string) { + require.False(t, checkOrigin(origin, origins), "Expected origin %s to be denied", origin) + } + allow("wails://wails") + allow("wails://wails:8000") + allow("http://wails.localhost") + allow("http://localhost") + allow("http://example.com:7000") + allow("https://blah.com:443") + deny("wails://evil.com") + deny("https://wails.localhost:8000") + deny("http://example.com:8000") + deny("http://blah.com:443") + deny("https://blah.com:8080") +} diff --git a/internal/e2e/daemon/brick_test.go b/internal/e2e/daemon/brick_test.go index ab04859f..fa1cab40 100644 --- a/internal/e2e/daemon/brick_test.go +++ b/internal/e2e/daemon/brick_test.go @@ -24,13 +24,44 @@ import ( "github.com/arduino/go-paths-helper" "github.com/stretchr/testify/require" + "go.bug.st/f" "github.com/arduino/arduino-app-cli/internal/api/models" + "github.com/arduino/arduino-app-cli/internal/e2e/client" "github.com/arduino/arduino-app-cli/internal/orchestrator/bricksindex" "github.com/arduino/arduino-app-cli/internal/orchestrator/config" "github.com/arduino/arduino-app-cli/internal/store" ) +func setupTestBrick(t *testing.T) (*client.CreateAppResp, *client.ClientWithResponses) { + httpClient := GetHttpclient(t) + createResp, err := httpClient.CreateAppWithResponse( + t.Context(), + &client.CreateAppParams{SkipSketch: f.Ptr(true)}, + client.CreateAppRequest{ + Icon: f.Ptr("💻"), + Name: "test-app", + Description: f.Ptr("My app description"), + }, + func(ctx context.Context, req *http.Request) error { return nil }, + ) + require.NoError(t, err) + require.Equal(t, http.StatusCreated, createResp.StatusCode()) + require.NotNil(t, createResp.JSON201) + + resp, err := httpClient.UpsertAppBrickInstanceWithResponse( + t.Context(), + *createResp.JSON201.Id, + ImageClassifactionBrickID, + client.BrickCreateUpdateRequest{Model: f.Ptr("mobilenet-image-classification")}, + func(ctx context.Context, req *http.Request) error { return nil }, + ) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode()) + + return createResp, httpClient +} + func TestBricksList(t *testing.T) { httpClient := GetHttpclient(t) @@ -56,8 +87,8 @@ func TestBricksList(t *testing.T) { } func TestBricksDetails(t *testing.T) { + _, httpClient := setupTestBrick(t) - httpClient := GetHttpclient(t) t.Run("should return 404 Not Found for an invalid brick ID", func(t *testing.T) { invalidBrickID := "notvalidBrickId" var actualBody models.ErrorResponse @@ -76,6 +107,14 @@ func TestBricksDetails(t *testing.T) { t.Run("should return 200 OK with full details for a valid brick ID", func(t *testing.T) { validBrickID := "arduino:image_classification" + expectedUsedByApps := []client.AppReference{ + { + Id: f.Ptr("dXNlcjp0ZXN0LWFwcA"), + Name: f.Ptr("test-app"), + Icon: f.Ptr("💻"), + }, + } + response, err := httpClient.GetBrickDetailsWithResponse(t.Context(), validBrickID, func(ctx context.Context, req *http.Request) error { return nil }) require.NoError(t, err) require.Equal(t, http.StatusOK, response.StatusCode(), "status code should be 200 ok") @@ -92,6 +131,7 @@ func TestBricksDetails(t *testing.T) { require.Equal(t, "path to the model file", *(*response.JSON200.Variables)["EI_CLASSIFICATION_MODEL"].Description) require.Equal(t, false, *(*response.JSON200.Variables)["EI_CLASSIFICATION_MODEL"].Required) require.NotEmpty(t, *response.JSON200.Readme) - require.Nil(t, response.JSON200.UsedByApps) + require.NotNil(t, response.JSON200.UsedByApps, "UsedByApps should not be nil") + require.Equal(t, expectedUsedByApps, *(response.JSON200.UsedByApps)) }) } diff --git a/internal/orchestrator/app/app.go b/internal/orchestrator/app/app.go index 165b0062..c40c9009 100644 --- a/internal/orchestrator/app/app.go +++ b/internal/orchestrator/app/app.go @@ -48,7 +48,7 @@ func Load(appPath string) (ArduinoApp, error) { return ArduinoApp{}, fmt.Errorf("app path is not valid: %w", err) } if !exist { - return ArduinoApp{}, fmt.Errorf("no such file or directory: %s", path) + return ArduinoApp{}, fmt.Errorf("app path must be a directory: %s", path) } path, err = path.Abs() if err != nil { diff --git a/internal/orchestrator/app/app_test.go b/internal/orchestrator/app/app_test.go index 1256c648..47e3f53f 100644 --- a/internal/orchestrator/app/app_test.go +++ b/internal/orchestrator/app/app_test.go @@ -25,13 +25,26 @@ import ( ) func TestLoad(t *testing.T) { - t.Run("empty", func(t *testing.T) { + t.Run("it fails if the app path is empty", func(t *testing.T) { app, err := Load("") assert.Error(t, err) assert.Empty(t, app) + assert.Contains(t, err.Error(), "empty app path") }) - t.Run("AppSimple", func(t *testing.T) { + t.Run("it fails if the app path exist but it's a file", func(t *testing.T) { + _, err := Load("testdata/app.yaml") + assert.Error(t, err) + assert.Contains(t, err.Error(), "app path must be a directory") + }) + + t.Run("it fails if the app path does not exist", func(t *testing.T) { + _, err := Load("testdata/this-folder-does-not-exist") + assert.Error(t, err) + assert.Contains(t, err.Error(), "app path is not valid") + }) + + t.Run("it loads an app correctly", func(t *testing.T) { app, err := Load("testdata/AppSimple") assert.NoError(t, err) assert.NotEmpty(t, app) diff --git a/internal/orchestrator/bricks/bricks.go b/internal/orchestrator/bricks/bricks.go index 4e08b2c2..759b5cc6 100644 --- a/internal/orchestrator/bricks/bricks.go +++ b/internal/orchestrator/bricks/bricks.go @@ -18,6 +18,7 @@ package bricks import ( "errors" "fmt" + "log/slog" "maps" "slices" @@ -26,6 +27,7 @@ import ( "github.com/arduino/arduino-app-cli/internal/orchestrator/app" "github.com/arduino/arduino-app-cli/internal/orchestrator/bricksindex" + "github.com/arduino/arduino-app-cli/internal/orchestrator/config" "github.com/arduino/arduino-app-cli/internal/orchestrator/modelsindex" "github.com/arduino/arduino-app-cli/internal/store" ) @@ -125,7 +127,8 @@ func (s *Service) AppBrickInstanceDetails(a *app.ArduinoApp, brickID string) (Br }, nil } -func (s *Service) BricksDetails(id string) (BrickDetailsResult, error) { +func (s *Service) BricksDetails(id string, idProvider *app.IDProvider, + cfg config.Configuration) (BrickDetailsResult, error) { brick, found := s.bricksIndex.FindBrickByID(id) if !found { return BrickDetailsResult{}, ErrBrickNotFound @@ -160,6 +163,11 @@ func (s *Service) BricksDetails(id string) (BrickDetailsResult, error) { } }) + usedByApps, err := getUsedByApps(cfg, brick.ID, idProvider) + if err != nil { + return BrickDetailsResult{}, fmt.Errorf("unable to get used by apps: %w", err) + } + return BrickDetailsResult{ ID: id, Name: brick.Name, @@ -171,9 +179,63 @@ func (s *Service) BricksDetails(id string) (BrickDetailsResult, error) { Readme: readme, ApiDocsPath: apiDocsPath, CodeExamples: codeExamples, + UsedByApps: usedByApps, }, nil } +func getUsedByApps( + cfg config.Configuration, brickId string, idProvider *app.IDProvider) ([]AppReference, error) { + var ( + pathsToExplore paths.PathList + appPaths paths.PathList + ) + pathsToExplore.Add(cfg.ExamplesDir()) + pathsToExplore.Add(cfg.AppsDir()) + usedByApps := []AppReference{} + + for _, p := range pathsToExplore { + res, err := p.ReadDirRecursiveFiltered(func(file *paths.Path) bool { + if file.Base() == ".cache" { + return false + } + if file.Join("app.yaml").NotExist() && file.Join("app.yml").NotExist() { + return true + } + return false + }, paths.FilterDirectories(), paths.FilterOutNames("python", "sketch", ".cache")) + if err != nil { + slog.Error("unable to list apps", slog.String("error", err.Error())) + return usedByApps, err + } + appPaths.AddAllMissing(res) + } + + for _, file := range appPaths { + app, err := app.Load(file.String()) + if err != nil { + // we are not considering the broken apps + slog.Warn("unable to parse app.yaml, skipping", "path", file.String(), "error", err.Error()) + continue + } + + for _, b := range app.Descriptor.Bricks { + if b.ID == brickId { + id, err := idProvider.IDFromPath(app.FullPath) + if err != nil { + return usedByApps, fmt.Errorf("failed to get app ID for %s: %w", app.FullPath, err) + } + usedByApps = append(usedByApps, AppReference{ + Name: app.Name, + ID: id.String(), + Icon: app.Descriptor.Icon, + }) + break + } + } + } + return usedByApps, nil +} + type BrickCreateUpdateRequest struct { ID string `json:"-"` Model *string `json:"model"` diff --git a/internal/orchestrator/modelsindex/models_index.go b/internal/orchestrator/modelsindex/models_index.go index a966a678..e18797f1 100644 --- a/internal/orchestrator/modelsindex/models_index.go +++ b/internal/orchestrator/modelsindex/models_index.go @@ -48,6 +48,7 @@ type AIModel struct { ModuleDescription string `yaml:"description"` Runner string `yaml:"runner"` Bricks []string `yaml:"bricks,omitempty"` + ModelLabels []string `yaml:"model_labels,omitempty"` Metadata map[string]string `yaml:"metadata,omitempty"` ModelConfiguration map[string]string `yaml:"model_configuration,omitempty"` } diff --git a/internal/orchestrator/modelsindex/modelsindex_test.go b/internal/orchestrator/modelsindex/modelsindex_test.go new file mode 100644 index 00000000..53ffb585 --- /dev/null +++ b/internal/orchestrator/modelsindex/modelsindex_test.go @@ -0,0 +1,72 @@ +package modelsindex + +import ( + "testing" + + "github.com/arduino/go-paths-helper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestModelsIndex(t *testing.T) { + modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata")) + require.NoError(t, err) + require.NotNil(t, modelsIndex) + + t.Run("it parses a valid model-list.yaml", func(t *testing.T) { + models := modelsIndex.GetModels() + assert.Len(t, models, 2, "Expected 2 models to be parsed") + }) + + t.Run("it gets a model by ID", func(t *testing.T) { + model, found := modelsIndex.GetModelByID("not-existing-model") + assert.False(t, found) + assert.Nil(t, model) + + model, found = modelsIndex.GetModelByID("face-detection") + assert.Equal(t, "brick", model.Runner) + require.True(t, found, "face-detection should be found") + assert.Equal(t, "face-detection", model.ID) + assert.Equal(t, "Lightweight-Face-Detection", model.Name) + assert.Equal(t, "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images.", model.ModuleDescription) + assert.Equal(t, []string{"face"}, model.ModelLabels) + assert.Equal(t, "/models/ootb/ei/lw-face-det.eim", model.ModelConfiguration["EI_OBJ_DETECTION_MODEL"]) + assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model.Bricks) + assert.Equal(t, "qualcomm-ai-hub", model.Metadata["source"]) + assert.Equal(t, "false", model.Metadata["ei-gpu-mode"]) + assert.Equal(t, "face-det-lite", model.Metadata["source-model-id"]) + assert.Equal(t, "https://aihub.qualcomm.com/models/face_det_lite", model.Metadata["source-model-url"]) + }) + + t.Run("it fails if model-list.yaml does not exist", func(t *testing.T) { + nonExistentPath := paths.New("nonexistentdir") + modelsIndex, err := GenerateModelsIndexFromFile(nonExistentPath) + assert.Error(t, err) + assert.Nil(t, modelsIndex) + }) + + t.Run("it gets models by a brick", func(t *testing.T) { + model := modelsIndex.GetModelsByBrick("not-existing-brick") + assert.Nil(t, model) + + model = modelsIndex.GetModelsByBrick("arduino:object_detection") + assert.Len(t, model, 1) + assert.Equal(t, "face-detection", model[0].ID) + }) + + t.Run("it gets models by bricks", func(t *testing.T) { + models := modelsIndex.GetModelsByBricks([]string{"arduino:non_existing"}) + assert.Len(t, models, 0) + assert.Nil(t, models) + + models = modelsIndex.GetModelsByBricks([]string{"arduino:video_object_detection"}) + assert.Len(t, models, 2) + assert.Equal(t, "face-detection", models[0].ID) + assert.Equal(t, "yolox-object-detection", models[1].ID) + + models = modelsIndex.GetModelsByBricks([]string{"arduino:object_detection", "arduino:video_object_detection"}) + assert.Len(t, models, 2) + assert.Equal(t, "face-detection", models[0].ID) + assert.Equal(t, "yolox-object-detection", models[1].ID) + }) +} diff --git a/internal/orchestrator/modelsindex/testdata/models-list.yaml b/internal/orchestrator/modelsindex/testdata/models-list.yaml new file mode 100644 index 00000000..7d0aefb5 --- /dev/null +++ b/internal/orchestrator/modelsindex/testdata/models-list.yaml @@ -0,0 +1,111 @@ +models: + - face-detection: + runner: brick + name : "Lightweight-Face-Detection" + description: "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images." + model_configuration: + "EI_OBJ_DETECTION_MODEL": "/models/ootb/ei/lw-face-det.eim" + model_labels: + - face + bricks: + - arduino:object_detection + - arduino:video_object_detection + metadata: + source: "qualcomm-ai-hub" + ei-gpu-mode: false + source-model-id: "face-det-lite" + source-model-url: "https://aihub.qualcomm.com/models/face_det_lite" + - yolox-object-detection: + runner: brick + name : "General purpose object detection - YoloX" + description: "General purpose object detection model based on YoloX Nano. This model is trained on the COCO dataset and can detect 80 different object classes." + model_configuration: + "EI_OBJ_DETECTION_MODEL": "/models/ootb/ei/yolo-x-nano.eim" + model_labels: + - airplane + - apple + - backpack + - banana + - baseball bat + - baseball glove + - bear + - bed + - bench + - bicycle + - bird + - boat + - book + - bottle + - bowl + - broccoli + - bus + - cake + - car + - carrot + - cat + - cell phone + - chair + - clock + - couch + - cow + - cup + - dining table + - dog + - donut + - elephant + - fire hydrant + - fork + - frisbee + - giraffe + - hair drier + - handbag + - hot dog + - horse + - keyboard + - kite + - knife + - laptop + - microwave + - motorcycle + - mouse + - orange + - oven + - parking meter + - person + - pizza + - potted plant + - refrigerator + - remote + - sandwich + - scissors + - sheep + - sink + - skateboard + - skis + - snowboard + - spoon + - sports ball + - stop sign + - suitcase + - surfboard + - teddy bear + - tennis racket + - tie + - toaster + - toilet + - toothbrush + - traffic light + - train + - truck + - tv + - umbrella + - vase + - wine glass + - zebra + metadata: + source: "edgeimpulse" + ei-project-id: 717280 + source-model-id: "YOLOX-Nano" + source-model-url: "https://github.com/Megvii-BaseDetection/YOLOX" + bricks: + - arduino:video_object_detection diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index 41890fa4..884515be 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -131,6 +131,9 @@ func StartApp( yield(StreamMessage{error: fmt.Errorf("app %q is running", running.Name)}) return } + if !yield(StreamMessage{data: fmt.Sprintf("Starting app %q", app.Name)}) { + return + } if err := setStatusLeds(LedTriggerNone); err != nil { slog.Debug("unable to set status leds", slog.String("error", err.Error())) @@ -379,6 +382,9 @@ func stopAppWithCmd(ctx context.Context, app app.ArduinoApp, cmd string) iter.Se ctx, cancel := context.WithCancel(ctx) defer cancel() + if !yield(StreamMessage{data: fmt.Sprintf("Stopping app %q", app.Name)}) { + return + } if err := setStatusLeds(LedTriggerDefault); err != nil { slog.Debug("unable to set status leds", slog.String("error", err.Error())) } @@ -427,6 +433,46 @@ func StopAndDestroyApp(ctx context.Context, app app.ArduinoApp) iter.Seq[StreamM return stopAppWithCmd(ctx, app, "down") } +func RestartApp( + ctx context.Context, + docker command.Cli, + provisioner *Provision, + modelsIndex *modelsindex.ModelsIndex, + bricksIndex *bricksindex.BricksIndex, + appToStart app.ArduinoApp, + cfg config.Configuration, + staticStore *store.StaticStore, +) iter.Seq[StreamMessage] { + return func(yield func(StreamMessage) bool) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + runningApp, err := getRunningApp(ctx, docker.Client()) + if err != nil { + yield(StreamMessage{error: err}) + return + } + + if runningApp != nil { + if runningApp.FullPath.String() != appToStart.FullPath.String() { + yield(StreamMessage{error: fmt.Errorf("another app %q is running", runningApp.Name)}) + return + } + + stopStream := StopApp(ctx, *runningApp) + for msg := range stopStream { + if !yield(msg) { + return + } + if msg.error != nil { + return + } + } + } + startStream := StartApp(ctx, docker, provisioner, modelsIndex, bricksIndex, appToStart, cfg, staticStore) + startStream(yield) + } +} + func StartDefaultApp( ctx context.Context, docker command.Cli, diff --git a/pkg/board/remote/adb/adb.go b/pkg/board/remote/adb/adb.go index 745ff0af..6efdb1f4 100644 --- a/pkg/board/remote/adb/adb.go +++ b/pkg/board/remote/adb/adb.go @@ -82,12 +82,13 @@ func (a *ADBConnection) Forward(ctx context.Context, localPort int, remotePort i if err != nil { return err } - if err := cmd.RunWithinContext(ctx); err != nil { + if out, err := cmd.RunAndCaptureCombinedOutput(ctx); err != nil { return fmt.Errorf( - "failed to forward ADB port %s to %s: %w", + "failed to forward ADB port %s to %s: %w: %s", local, remote, err, + out, ) } @@ -99,8 +100,8 @@ func (a *ADBConnection) ForwardKillAll(ctx context.Context) error { if err != nil { return err } - if err := cmd.RunWithinContext(ctx); err != nil { - return fmt.Errorf("failed to kill all ADB forwarded ports: %w", err) + if out, err := cmd.RunAndCaptureCombinedOutput(ctx); err != nil { + return fmt.Errorf("failed to kill all ADB forwarded ports: %w: %s", err, out) } return nil }