Skip to content

Commit e978dc5

Browse files
MiriamScharnkeccreutzi
authored andcommitted
Allow ToolChoice="required"
* Add "required" ToolChoice * ToolChoice not actually supported for Ollama (yet) Setting the ToolChoice argument does not have any effect on Ollama models. Ollama lists `tool_choice` as not yet supported (though on the plan): https://github.com/ollama/ollama/blob/main/docs/openai.md
1 parent 53dd267 commit e978dc5

File tree

11 files changed

+57
-38
lines changed

11 files changed

+57
-38
lines changed

+llms/+internal/callOllamaChatAPI.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
parameters.tools = functions;
9595
end
9696

97-
if ~isempty(nvp.ToolChoice)
97+
if isfield(nvp,"ToolChoice") && ~isempty(nvp.ToolChoice)
9898
parameters.tool_choice = nvp.ToolChoice;
9999
end
100100

+llms/+internal/hasTools.m

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function mustBeValidFunctionCall(this, functionCall)
2020
if isempty(this.FunctionNames)
2121
error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall"));
2222
end
23-
mustBeMember(functionCall, ["none","auto", this.FunctionNames]);
23+
mustBeMember(functionCall, ["none","auto","required", this.FunctionNames]);
2424
end
2525
end
2626

@@ -31,9 +31,10 @@ function mustBeValidFunctionCall(this, functionCall)
3131
if ~isempty(this.Tools)
3232
toolChoice = "auto";
3333
end
34-
elseif ~ismember(toolChoice,["auto","none"])
35-
% if toolChoice is not empty, then it must be "auto", "none" or in the format
36-
% {"type": "function", "function": {"name": "my_function"}}
34+
elseif ~ismember(toolChoice,["auto","none","required"])
35+
% if toolChoice is not empty, then it must be "auto", "none",
36+
% "required", or in the format {"type": "function", "function":
37+
% {"name": "my_function"}}
3738
toolChoice = struct("type","function","function",struct("name",toolChoice));
3839
end
3940

azureChat.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@
173173
% MaxNumTokens - Maximum number of tokens in the generated response.
174174
% Default value is inf.
175175
%
176-
% ToolChoice - Function to execute. 'none', 'auto',
176+
% ToolChoice - Function to execute. 'none', 'auto', 'required',
177177
% or specify the function to call.
178178
%
179179
% Seed - An integer value to use to obtain

doc/functions/generate.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ The supported name\-value arguments depend on the chat completion API.
9999
| `PresencePenalty` | Supported | Supported | |
100100
| `FrequencyPenalty` | Supported | Supported | |
101101
| `NumCompletions` | Supported | Supported | |
102-
| `ToolChoice` | Supported | Supported | Supported |
102+
| `ToolChoice` | Supported | Supported | |
103103
| `MinP` | | | Supported |
104104
| `TopK` | | | Supported |
105105
| `TailFreeSamplingZ` | | | Supported |
@@ -261,25 +261,26 @@ This option is only supported for these chat completion APIs:
261261
- [`azureChat`](azureChat.md) objects
262262
### `ToolChoice` — Tool choice
263263

264-
`model.ToolChoice` (default) | `"auto"` | `"none"` | `openAIFunction` object | array of `openAIFunction` objects
264+
`"auto"` (default) | `"none"` | `"required"` | string scalar
265265

266266

267-
OpenAI functions to call during output generation. For more information on OpenAI function calling, see [`openAIFunction`](openAIFunction.md).
267+
Tools that a model is allowed to call during output generation, specified as `"auto"`, `"none"`, `"required"`, or as a tool name. For more information on OpenAI function calling, see [`openAIFunction`](openAIFunction.md).
268268

269269

270-
If the tool choice is `"auto"`, then any function calls specified in `chat` are executed during generation. To see whether any function calls are specified, check the `FunctionNames` property of `chat`.
270+
If the tool choice is set to `"auto"`, then any tools available to the model can be called during output generation. To find out which tools are available to the model, see the `FunctionNames` property of the `model` input argument.
271271

272272

273-
If the tool choice is `"none"`, then no function call is executed during generation.
273+
If the tool choice is set to `"none"`, then no tools are called during output generation.
274274

275+
If the tool choice is set to `"required"`, then one or more tools are called during output generation.
275276

276-
You can also specify one or more [`openAIFunction`](openAIFunction.md) objects directly.
277-
277+
You can also require that the model uses a specific tool by setting `ToolChoice` to the name of that tool. The name must be part of `model.FunctionNames`.
278278

279279
This option is only supported for these chat completion APIs:
280280

281281
- [`openAIChat`](openAIChat.md) objects
282-
- [`azureChat`](azureChat.md) objects
282+
- [`azureChat`](azureChat.md) objects
283+
283284
### `MinP` — Minimum probability ratio
284285

285286
`model.MinP` (default) | numeric scalar between `0` and `1`

functionSignatures.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
{"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]},
3131
{"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]},
3232
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
33-
{"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"},
33+
{"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",\"required\",this.FunctionNames]"},
3434
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]},
3535
{"name":"ModelName","kind":"namevalue","type":"choices=llms.openai.models"},
3636
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
@@ -82,7 +82,7 @@
8282
{"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]},
8383
{"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]},
8484
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
85-
{"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"},
85+
{"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",\"required\",this.FunctionNames]"},
8686
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]},
8787
{"name":"APIKey","kind":"namevalue","type":["string","scalar"]},
8888
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
@@ -108,6 +108,7 @@
108108
{"name":"modelName","kind":"positional","type":"choices=ollamaChat.models"},
109109
{"name":"systemPrompt","kind":"ordered","type":["string","scalar"]},
110110
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
111+
{"name":"Tools","kind":"namevalue","type":"openAIFunction"},
111112
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
112113
{"name":"MinP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
113114
{"name":"TopK","kind":"namevalue","type":["numeric","scalar","integer",">=1"]},

ollamaChat.m

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,6 @@
160160
% values reduce it. Setting Temperature=0 removes
161161
% randomness from the output altogether.
162162
%
163-
% ToolChoice - Function to execute. 'none', 'auto',
164-
% or specify the function to call.
165-
%
166163
% TopP - Top probability mass value for controlling the
167164
% diversity of the output. Default value is CHAT.TopP;
168165
% lower values imply that only the more likely
@@ -218,7 +215,6 @@
218215
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
219216
nvp.Endpoint (1,1) string = this.Endpoint
220217
nvp.MaxNumTokens (1,1) {mustBeNumeric,mustBePositive} = inf
221-
nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = []
222218
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
223219
end
224220

@@ -233,8 +229,6 @@
233229
messagesStruct = horzcat(this.SystemPrompt, messagesStruct);
234230
end
235231

236-
toolChoice = convertToolChoice(this, nvp.ToolChoice);
237-
238232
if isfield(nvp,"StreamFun")
239233
streamFun = nvp.StreamFun;
240234
else
@@ -244,7 +238,7 @@
244238
try
245239
[text, message, response] = llms.internal.callOllamaChatAPI(...
246240
nvp.ModelName, messagesStruct, this.FunctionsStruct, ...
247-
Temperature=nvp.Temperature, ToolChoice=toolChoice, ...
241+
Temperature=nvp.Temperature, ...
248242
TopP=nvp.TopP, MinP=nvp.MinP, TopK=nvp.TopK,...
249243
TailFreeSamplingZ=nvp.TailFreeSamplingZ,...
250244
StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...

openAIChat.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@
161161
% MaxNumTokens - Maximum number of tokens in the generated response.
162162
% Default value is inf.
163163
%
164-
% ToolChoice - Function to execute. 'none', 'auto',
164+
% ToolChoice - Function to execute. 'none', 'auto', 'required',
165165
% or specify the function to call.
166166
%
167167
% Seed - An integer value to use to obtain

tests/hopenAIChat.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
classdef (Abstract) hopenAIChat < hstructuredOutput & htoolCalls
1+
classdef (Abstract) hopenAIChat < hstructuredOutput & htoolCalls & htoolChoice
22
% Tests for OpenAI-based chats (openAIChat, azureChat)
33

44
% Copyright 2023-2025 The MathWorks, Inc.

tests/htoolCalls.m

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,7 @@
77
defaultModel
88
end
99

10-
methods (Test) % not calling the server
11-
function errorsWhenPassingToolChoiceWithEmptyTools(testCase)
12-
testCase.verifyError(@()generate(testCase.defaultModel,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall");
13-
end
14-
end
15-
1610
methods (Test) % calling the server, end-to-end tests
17-
function settingToolChoiceWithNone(testCase)
18-
functions = openAIFunction("funName");
19-
chat = testCase.constructor(Tools=functions);
20-
21-
testCase.verifyWarningFree(@()generate(chat,"This is okay","ToolChoice","none"));
22-
end
23-
2411
function generateWithToolsAndStreamFunc(testCase)
2512
import matlab.unittest.constraints.HasField
2613

tests/htoolChoice.m

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
classdef (Abstract) htoolChoice < matlab.mock.TestCase
2+
% Tests for backends with ToolChoice support
3+
4+
% Copyright 2023-2025 The MathWorks, Inc.
5+
properties(Abstract)
6+
constructor
7+
defaultModel
8+
end
9+
10+
methods (Test) % not calling the server
11+
function errorsWhenPassingToolChoiceWithEmptyTools(testCase)
12+
testCase.verifyError(@()generate(testCase.defaultModel,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall");
13+
end
14+
end
15+
16+
methods (Test) % calling the server, end-to-end tests
17+
function settingToolChoiceWithNone(testCase)
18+
functions = openAIFunction("funName");
19+
chat = testCase.constructor(Tools=functions);
20+
21+
testCase.verifyWarningFree(@()generate(chat,"This is okay","ToolChoice","none"));
22+
end
23+
24+
function settingToolChoiceAsRequired(testCase)
25+
functions = openAIFunction("funName");
26+
chat = testCase.constructor(Tools=functions);
27+
28+
testCase.verifyWarningFree(@()generate(chat,"This is okay","ToolChoice","required"));
29+
end
30+
end
31+
end

0 commit comments

Comments
 (0)