forked from gptscript-ai/gptscript
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprompt.go
112 lines (95 loc) · 2.83 KB
/
prompt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package sdkserver
import (
"encoding/json"
"fmt"
"net/http"
"time"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/mvl"
"github.com/gptscript-ai/gptscript/pkg/runner"
gserver "github.com/gptscript-ai/gptscript/pkg/server"
"github.com/gptscript-ai/gptscript/pkg/types"
)
func (s *server) promptResponse(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
id := r.PathValue("id")
s.lock.RLock()
promptChan := s.waitingToPrompt[id]
s.lock.RUnlock()
if promptChan == nil {
writeError(logger, w, http.StatusNotFound, fmt.Errorf("no prompt found with id %q", id))
return
}
var promptResponse map[string]string
if err := json.NewDecoder(r.Body).Decode(&promptResponse); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
return
}
// Don't block here because, if the prompter is no longer waiting on this then it will never unblock.
select {
case promptChan <- promptResponse:
w.WriteHeader(http.StatusAccepted)
default:
w.WriteHeader(http.StatusConflict)
}
}
func (s *server) prompt(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
if r.Header.Get("Authorization") != "Bearer "+s.token {
writeError(logger, w, http.StatusUnauthorized, fmt.Errorf("invalid token"))
return
}
id := r.PathValue("id")
s.lock.RLock()
promptChan := s.waitingToPrompt[id]
s.lock.RUnlock()
if promptChan != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("prompt called multiple times for same ID: %s", id))
return
}
var prompt types.Prompt
if err := json.NewDecoder(r.Body).Decode(&prompt); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %v", err))
return
}
s.lock.Lock()
promptChan = make(chan map[string]string)
s.waitingToPrompt[id] = promptChan
s.lock.Unlock()
defer func(id string) {
s.lock.Lock()
delete(s.waitingToPrompt, id)
s.lock.Unlock()
}(id)
s.events.C <- event{
Prompt: prompt,
Event: gserver.Event{
RunID: id,
Event: runner.Event{
Type: Prompt,
Time: time.Now(),
},
},
}
// Wait for the prompt response to come through.
select {
case <-r.Context().Done():
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("context canceled: %v", r.Context().Err()))
return
case promptResponse := <-promptChan:
writePromptResponse(logger, w, http.StatusOK, promptResponse)
}
}
func writePromptResponse(logger mvl.Logger, w http.ResponseWriter, code int, resp any) {
b, err := json.Marshal(resp)
if err != nil {
logger.Errorf("failed to marshal response: %v", err)
w.WriteHeader(http.StatusInternalServerError)
} else {
w.WriteHeader(code)
}
_, err = w.Write(b)
if err != nil {
logger.Errorf("failed to write response: %v", err)
}
}