diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 23db5152..f972d14c 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -217,6 +217,7 @@ var tools = map[string]types.Tool{ "message", "The message to display to the user", "fields", "A comma-separated list of fields to prompt for", "sensitive", "(true or false) Whether the input should be hidden", + "metadata", "(optional) A JSON object of metadata to attach to the prompt", ), }, BuiltinFunc: prompt.SysPrompt, diff --git a/pkg/loader/openapi.go b/pkg/loader/openapi.go index cf3c3f34..e62fc5ef 100644 --- a/pkg/loader/openapi.go +++ b/pkg/loader/openapi.go @@ -209,6 +209,11 @@ func getOpenAPITools(t *openapi3.T, defaultHost, source, targetToolName string) } bodyMIME = mime + // requestBody content mime without schema + if content == nil || content.Schema == nil { + continue + } + arg := content.Schema.Value if arg.Description == "" { arg.Description = content.Schema.Value.Description @@ -300,7 +305,7 @@ func getOpenAPITools(t *openapi3.T, defaultHost, source, targetToolName string) if err != nil { return nil, fmt.Errorf("failed to parse operation server URL: %w", err) } - tool.Credentials = info.GetCredentialToolStrings(operationServerURL.Hostname()) + tool.Credentials = append(tool.Credentials, info.GetCredentialToolStrings(operationServerURL.Hostname())...) } } diff --git a/pkg/openapi/run.go b/pkg/openapi/run.go index 2efc2309..6c7e4ca7 100644 --- a/pkg/openapi/run.go +++ b/pkg/openapi/run.go @@ -42,7 +42,8 @@ func Run(operationID, defaultHost, args string, t *openapi3.T, envs []string) (s } if !validationResult.Valid() { - return "", false, fmt.Errorf("invalid arguments for operation %s: %s", operationID, validationResult.Errors()) + // We don't return an error here because we want the LLM to be able to maintain control and try again. + return fmt.Sprintf("invalid arguments for operation %s: %s", operationID, validationResult.Errors()), true, nil } // Construct and execute the HTTP request. diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index 44cb20f1..f91a04b6 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -51,25 +51,29 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types. func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) { var params struct { - Message string `json:"message,omitempty"` - Fields string `json:"fields,omitempty"` - Sensitive string `json:"sensitive,omitempty"` + Message string `json:"message,omitempty"` + Fields string `json:"fields,omitempty"` + Sensitive string `json:"sensitive,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` } if err := json.Unmarshal([]byte(input), ¶ms); err != nil { return "", err } + var fields []string for _, env := range envs { if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok { - var fields []string if params.Fields != "" { fields = strings.Split(params.Fields, ",") } + httpPrompt := types.Prompt{ Message: params.Message, Fields: fields, Sensitive: params.Sensitive == "true", + Metadata: params.Metadata, } + return sysPromptHTTP(ctx, envs, url, httpPrompt) } } diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index baa54677..fa1d40c2 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -25,7 +25,6 @@ type Client struct { clientsLock sync.Mutex cache *cache.Client clients map[string]clientInfo - modelToProvider map[string]string runner *runner.Runner envs []string credStore credentials.CredentialStore @@ -39,17 +38,13 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent envs: envs, credStore: credStore, defaultProvider: defaultProvider, - modelToProvider: make(map[string]string), clients: make(map[string]clientInfo), } } func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { - c.clientsLock.Lock() - provider, ok := c.modelToProvider[messageRequest.Model] - c.clientsLock.Unlock() - - if !ok { + _, provider := c.parseModel(messageRequest.Model) + if provider == "" { return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model) } @@ -108,10 +103,6 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error) return false, err } - c.clientsLock.Lock() - defer c.clientsLock.Unlock() - - c.modelToProvider[modelString] = providerName return true, nil } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index f92b0705..93e40670 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -628,11 +628,16 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s } } + var content string + if state.ResumeInput != nil { + content = *state.ResumeInput + } monitor.Event(Event{ Time: time.Now(), CallContext: callCtx.GetCallContext(), Type: EventTypeCallContinue, ToolResults: len(callResults), + Content: content, }) e := engine.Engine{ diff --git a/pkg/sdkserver/monitor.go b/pkg/sdkserver/monitor.go index a5b0236b..bdd88c67 100644 --- a/pkg/sdkserver/monitor.go +++ b/pkg/sdkserver/monitor.go @@ -33,6 +33,7 @@ func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []str Time: time.Now(), Type: runner.EventTypeRunStart, }, + Input: input, RunID: id, Program: prg, }, @@ -43,7 +44,6 @@ func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []str id: id, prj: prg, env: env, - input: input, events: s.events, }, nil } @@ -56,7 +56,6 @@ type Session struct { id string prj *types.Program env []string - input string events *broadcaster.Broadcaster[event] runLock sync.Mutex } @@ -68,7 +67,6 @@ func (s *Session) Event(e runner.Event) { Event: gserver.Event{ Event: e, RunID: s.id, - Input: s.input, }, } } @@ -87,7 +85,6 @@ func (s *Session) Stop(ctx context.Context, output string, err error) { Type: runner.EventTypeRunFinish, }, RunID: s.id, - Input: s.input, Output: output, }, } diff --git a/pkg/sdkserver/prompt.go b/pkg/sdkserver/prompt.go index 8d34fc53..a519f7b2 100644 --- a/pkg/sdkserver/prompt.go +++ b/pkg/sdkserver/prompt.go @@ -76,11 +76,7 @@ func (s *server) prompt(w http.ResponseWriter, r *http.Request) { }(id) s.events.C <- event{ - Prompt: types.Prompt{ - Message: prompt.Message, - Fields: prompt.Fields, - Sensitive: prompt.Sensitive, - }, + Prompt: prompt, Event: gserver.Event{ RunID: id, Event: runner.Event{ diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 4309bc28..6cb1e620 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -73,39 +73,13 @@ func (s *server) version(w http.ResponseWriter, r *http.Request) { // listTools will return the output of `gptscript --list-tools` func (s *server) listTools(w http.ResponseWriter, r *http.Request) { logger := gcontext.GetLogger(r.Context()) - var prg types.Program - if r.ContentLength != 0 { - reqObject := new(toolOrFileRequest) - err := json.NewDecoder(r.Body).Decode(reqObject) - if err != nil { - writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err)) - return - } - - if reqObject.Content != "" { - prg, err = loader.ProgramFromSource(r.Context(), reqObject.Content, reqObject.SubTool, loader.Options{Cache: s.client.Cache}) - } else if reqObject.File != "" { - prg, err = loader.Program(r.Context(), reqObject.File, reqObject.SubTool, loader.Options{Cache: s.client.Cache}) - } else { - prg, err = loader.ProgramFromSource(r.Context(), reqObject.ToolDefs.String(), reqObject.SubTool, loader.Options{Cache: s.client.Cache}) - } - if err != nil { - writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) - return - } - } - - tools := s.client.ListTools(r.Context(), prg) + tools := s.client.ListTools(r.Context(), types.Program{}) sort.Slice(tools, func(i, j int) bool { return tools[i].Name < tools[j].Name }) lines := make([]string, 0, len(tools)) for _, tool := range tools { - if tool.Name == "" { - tool.Name = prg.Name - } - // Don't print instructions tool.Instructions = "" @@ -118,22 +92,31 @@ func (s *server) listTools(w http.ResponseWriter, r *http.Request) { // listModels will return the output of `gptscript --list-models` func (s *server) listModels(w http.ResponseWriter, r *http.Request) { logger := gcontext.GetLogger(r.Context()) + client := s.client + var providers []string if r.ContentLength != 0 { reqObject := new(modelsRequest) - if err := json.NewDecoder(r.Body).Decode(reqObject); err != nil { + err := json.NewDecoder(r.Body).Decode(reqObject) + if err != nil { writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err)) return } providers = reqObject.Providers + + client, err = gptscript.New(r.Context(), s.gptscriptOpts, gptscript.Options{Env: reqObject.Env, Runner: runner.Options{CredentialOverrides: reqObject.CredentialOverrides}}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to create client: %w", err)) + return + } } if s.gptscriptOpts.DefaultModelProvider != "" { providers = append(providers, s.gptscriptOpts.DefaultModelProvider) } - out, err := s.client.ListModels(r.Context(), providers...) + out, err := client.ListModels(r.Context(), providers...) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to list models: %w", err)) return diff --git a/pkg/sdkserver/types.go b/pkg/sdkserver/types.go index b24ca645..2889626b 100644 --- a/pkg/sdkserver/types.go +++ b/pkg/sdkserver/types.go @@ -100,7 +100,9 @@ type parseRequest struct { } type modelsRequest struct { - Providers []string `json:"providers"` + Providers []string `json:"providers"` + Env []string `json:"env"` + CredentialOverrides []string `json:"credentialOverrides"` } type runInfo struct { @@ -142,6 +144,7 @@ func (r *runInfo) process(e event) map[string]any { r.Start = e.Time r.Program = *e.Program r.State = Running + r.Input = e.Input case runner.EventTypeRunFinish: r.End = e.Time r.Output = e.Output @@ -165,9 +168,11 @@ func (r *runInfo) process(e event) map[string]any { call.Type = e.Type switch e.Type { - case runner.EventTypeCallStart: + case runner.EventTypeCallStart, runner.EventTypeCallContinue: call.Start = e.Time - call.Input = e.Content + if e.Content != "" { + call.Input = e.Content + } case runner.EventTypeCallSubCalls: call.setSubCalls(e.ToolSubCalls) diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go index ea17c11c..653ad066 100644 --- a/pkg/types/prompt.go +++ b/pkg/types/prompt.go @@ -6,7 +6,8 @@ const ( ) type Prompt struct { - Message string `json:"message,omitempty"` - Fields []string `json:"fields,omitempty"` - Sensitive bool `json:"sensitive,omitempty"` + Message string `json:"message,omitempty"` + Fields []string `json:"fields,omitempty"` + Sensitive bool `json:"sensitive,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` } diff --git a/pkg/types/toolstring.go b/pkg/types/toolstring.go index 2be6d0fc..086ad043 100644 --- a/pkg/types/toolstring.go +++ b/pkg/types/toolstring.go @@ -3,6 +3,7 @@ package types import ( "encoding/json" "fmt" + "os" "path/filepath" "strings" ) @@ -76,6 +77,20 @@ func ToSysDisplayString(id string, args map[string]string) (string, error) { return fmt.Sprintf("Writing `%s`", args["filename"]), nil case "sys.context", "sys.stat", "sys.getenv", "sys.abort", "sys.chat.current", "sys.chat.finish", "sys.chat.history", "sys.echo", "sys.prompt", "sys.time.now", "sys.model.provider.credential": return "", nil + case "sys.openapi": + if os.Getenv("GPTSCRIPT_OPENAPI_REVAMP") == "true" && args["operation"] != "" { + // Pretty print the JSON by unmarshaling and marshaling it + var jsonArgs map[string]any + if err := json.Unmarshal([]byte(args["args"]), &jsonArgs); err != nil { + return "", err + } + jsonPretty, err := json.MarshalIndent(jsonArgs, "", " ") + if err != nil { + return "", err + } + return fmt.Sprintf("Running API operation `%s` with arguments %s", args["operation"], string(jsonPretty)), nil + } + fallthrough default: return "", fmt.Errorf("unknown tool for display string: %s", id) }