Skip to content

Commit caea478

Browse files
committed
Add streaming for code completion between a prompt and a suffix
1 parent f7b232f commit caea478

File tree

2 files changed

+105
-4
lines changed

2 files changed

+105
-4
lines changed

v2/code.go

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ func (oc *Config) GetBetweenResponse(prompt, suffix string) (OutputResponse, err
5858
return res, nil
5959
}
6060
}
61-
var reqBody GenerateRequest
62-
reqBody = GenerateRequest{
61+
reqBody := GenerateRequest{
6362
Model: oc.ModelName,
6463
Prompt: prompt,
6564
Suffix: suffix,
@@ -125,3 +124,71 @@ func (oc *Config) Complete(codeStart, codeEnd string) (string, error) {
125124
}
126125
return response.Response, nil
127126
}
127+
128+
// StreamBetween sends a request to the Ollama API and returns the generated output via a callback function.
129+
// The callback function is given a string and "true" when the streaming is done (or if an error occurred).
130+
func (oc *Config) StreamBetween(callbackFunction func(string, bool), prompt, suffix string) error {
131+
defer callbackFunction("", true)
132+
var (
133+
temperature float64
134+
seed = oc.SeedOrNegative
135+
)
136+
if prompt == "" {
137+
return errors.New("the prompt can not be empty")
138+
}
139+
if seed < 0 {
140+
temperature = oc.TemperatureIfNegativeSeed
141+
}
142+
reqBody := GenerateRequest{
143+
Model: oc.ModelName,
144+
System: oc.SystemPrompt,
145+
Prompt: prompt,
146+
Suffix: suffix,
147+
Stream: true,
148+
Options: RequestOptions{
149+
Seed: seed, // set to -1 to make it random
150+
Temperature: temperature, // set to 0 together with a specific seed to make output reproducible
151+
},
152+
}
153+
if oc.ContextLength != 0 {
154+
reqBody.Options.ContextLength = oc.ContextLength
155+
}
156+
reqBytes, err := json.Marshal(reqBody)
157+
if err != nil {
158+
return err
159+
}
160+
if oc.Verbose {
161+
fmt.Printf("Sending request to %s/api/generate: %s\n", oc.ServerAddr, string(reqBytes))
162+
}
163+
HTTPClient := &http.Client{
164+
Timeout: oc.HTTPTimeout,
165+
}
166+
resp, err := HTTPClient.Post(oc.ServerAddr+"/api/generate", mimeJSON, bytes.NewBuffer(reqBytes))
167+
if err != nil {
168+
return err
169+
}
170+
defer resp.Body.Close()
171+
var (
172+
decoder = json.NewDecoder(resp.Body)
173+
first = true
174+
answer string
175+
)
176+
for {
177+
var genResp GenerateResponse
178+
if err := decoder.Decode(&genResp); err != nil {
179+
break
180+
}
181+
answer = genResp.Response
182+
if first {
183+
if len(answer) > 0 && answer[0] == ' ' {
184+
answer = answer[1:]
185+
}
186+
first = false
187+
}
188+
callbackFunction(answer, false)
189+
if genResp.Done {
190+
break
191+
}
192+
}
193+
return nil
194+
}

v2/code_test.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ func TestBetween(t *testing.T) {
1111
const (
1212
codeStart = "def compute_gcd(a, b):"
1313
codeEnd = " return result"
14+
verbose = true
1415
)
16+
1517
oc := New(codeCompleteModel)
16-
oc.Verbose = true
18+
oc.Verbose = verbose
1719
if err := oc.PullIfNeeded(true); err != nil {
1820
t.Fatalf("Failed to pull model: %v", err)
1921
}
22+
2023
response, err := oc.GetBetweenResponse(codeStart, codeEnd)
2124
if err != nil {
2225
t.Fatalf("Failed to get code completion: %v", err)
@@ -30,11 +33,42 @@ func TestCodeCompletion(t *testing.T) {
3033
codeEnd = " return result"
3134
verbose = true
3235
)
36+
3337
oc := New(codeCompleteModel)
34-
oc.Verbose = true
38+
oc.Verbose = verbose
39+
3540
codeBetween, err := oc.Complete(codeStart, codeEnd)
3641
if err != nil {
3742
t.Fatal(err)
3843
}
3944
fmt.Printf("%s\n%s\n%s\n", codeStart, codeBetween, codeEnd)
4045
}
46+
47+
func TestStreamBetween(t *testing.T) {
48+
const (
49+
codeStart = "def compute_gcd(a, b):"
50+
codeEnd = " return result"
51+
verbose = false
52+
)
53+
54+
oc := New(codeCompleteModel)
55+
oc.Verbose = verbose
56+
57+
err := oc.PullIfNeeded(true)
58+
if err != nil {
59+
t.Fatalf("Failed to pull model: %v", err)
60+
}
61+
62+
callbackFunction := func(partialResult string, streamingDone bool) {
63+
if !streamingDone {
64+
fmt.Printf("%s", partialResult)
65+
} else {
66+
fmt.Println(codeEnd)
67+
}
68+
}
69+
70+
fmt.Print(codeStart)
71+
if err := oc.StreamBetween(callbackFunction, codeStart, codeEnd); err != nil {
72+
t.Fatalf("Failed to get streamed : %v", err)
73+
}
74+
}

0 commit comments

Comments
 (0)