|
38 | 38 | }, |
39 | 39 | "id": "49abde692940b09e" |
40 | 40 | }, |
| 41 | + { |
| 42 | + "cell_type": "markdown", |
| 43 | + "source": [ |
| 44 | + "## Single-round chat completion" |
| 45 | + ], |
| 46 | + "metadata": { |
| 47 | + "collapsed": false |
| 48 | + }, |
| 49 | + "id": "84b663418e3e3b19" |
| 50 | + }, |
41 | 51 | { |
42 | 52 | "cell_type": "code", |
43 | 53 | "execution_count": null, |
44 | 54 | "outputs": [], |
45 | 55 | "source": [ |
46 | | - "# normal \n", |
47 | 56 | "chat_completion_result = taskingai.inference.chat_completion(\n", |
48 | 57 | " model_id=model_id,\n", |
49 | 58 | " messages=[\n", |
|
58 | 67 | }, |
59 | 68 | "id": "43dcc632665f0de4" |
60 | 69 | }, |
| 70 | + { |
| 71 | + "cell_type": "markdown", |
| 72 | + "source": [ |
| 73 | + "## Multi-round chat completion" |
| 74 | + ], |
| 75 | + "metadata": { |
| 76 | + "collapsed": false |
| 77 | + }, |
| 78 | + "id": "9f84e86d19409580" |
| 79 | + }, |
61 | 80 | { |
62 | 81 | "cell_type": "code", |
63 | 82 | "execution_count": null, |
64 | 83 | "outputs": [], |
65 | 84 | "source": [ |
66 | | - "# multi round chat completion\n", |
67 | 85 | "chat_completion_result = taskingai.inference.chat_completion(\n", |
68 | 86 | " model_id=model_id,\n", |
69 | 87 | " messages=[\n", |
|
87 | 105 | "execution_count": null, |
88 | 106 | "outputs": [], |
89 | 107 | "source": [ |
90 | | - "# config max tokens\n", |
| 108 | + "# Add max tokens configs\n", |
91 | 109 | "chat_completion_result = taskingai.inference.chat_completion(\n", |
92 | 110 | " model_id=model_id,\n", |
93 | 111 | " messages=[\n", |
|
109 | 127 | }, |
110 | 128 | "id": "f7c1b8be2579d9e0" |
111 | 129 | }, |
| 130 | + { |
| 131 | + "cell_type": "markdown", |
| 132 | + "source": [ |
| 133 | + "## Function call" |
| 134 | + ], |
| 135 | + "metadata": { |
| 136 | + "collapsed": false |
| 137 | + }, |
| 138 | + "id": "c615ece16c777029" |
| 139 | + }, |
112 | 140 | { |
113 | 141 | "cell_type": "code", |
114 | 142 | "execution_count": null, |
115 | 143 | "outputs": [], |
116 | 144 | "source": [ |
117 | | - "# function call\n", |
| 145 | + "# function definition\n", |
118 | 146 | "function = Function(\n", |
119 | 147 | " name=\"plus_a_and_b\",\n", |
120 | 148 | " description=\"Sum up a and b and return the result\",\n", |
|
132 | 160 | " },\n", |
133 | 161 | " \"required\": [\"a\", \"b\"]\n", |
134 | 162 | " },\n", |
135 | | - ")\n", |
| 163 | + ")" |
| 164 | + ], |
| 165 | + "metadata": { |
| 166 | + "collapsed": false |
| 167 | + }, |
| 168 | + "id": "2645bdc3df011e7d" |
| 169 | + }, |
| 170 | + { |
| 171 | + "cell_type": "code", |
| 172 | + "execution_count": null, |
| 173 | + "outputs": [], |
| 174 | + "source": [ |
| 175 | + "# chat completion with the function call\n", |
136 | 176 | "chat_completion_result = taskingai.inference.chat_completion(\n", |
137 | 177 | " model_id=model_id,\n", |
138 | 178 | " messages=[\n", |
139 | | - " SystemMessage(\"You are a professional assistant.\"),\n", |
140 | 179 | " UserMessage(\"What is the result of 112 plus 22?\"),\n", |
141 | 180 | " ],\n", |
142 | 181 | " functions=[function]\n", |
143 | 182 | ")\n", |
144 | | - "print(f\"chat_completion_result = {chat_completion_result}\")\n", |
145 | | - "\n", |
146 | | - "assistant_function_call_message = chat_completion_result.message\n", |
147 | | - "fucntion_name = assistant_function_call_message.function_call.name\n", |
148 | | - "argument_content = json.dumps(assistant_function_call_message.function_call.arguments)\n", |
149 | | - "print(f\"function name: {fucntion_name}, argument content: {argument_content}\")" |
| 183 | + "function_call_assistant_message = chat_completion_result.message\n", |
| 184 | + "print(f\"function_call_assistant_message = {function_call_assistant_message}\")" |
150 | 185 | ], |
151 | 186 | "metadata": { |
152 | 187 | "collapsed": false |
153 | 188 | }, |
154 | | - "id": "2645bdc3df011e7d" |
| 189 | + "id": "850adc819aa228fc" |
155 | 190 | }, |
156 | 191 | { |
157 | | - "cell_type": "markdown", |
158 | | - "source": [], |
| 192 | + "cell_type": "code", |
| 193 | + "execution_count": null, |
| 194 | + "outputs": [], |
| 195 | + "source": [ |
| 196 | + "# get the function call result\n", |
| 197 | + "def plus_a_and_b(a, b):\n", |
| 198 | + " return a + b\n", |
| 199 | + "\n", |
| 200 | + "arguments = function_call_assistant_message.function_call.arguments\n", |
| 201 | + "function_call_result = plus_a_and_b(**arguments)\n", |
| 202 | + "print(f\"function_call_result = {function_call_result}\")" |
| 203 | + ], |
159 | 204 | "metadata": { |
160 | 205 | "collapsed": false |
161 | 206 | }, |
162 | | - "id": "ed6957f0c380ba9f" |
| 207 | + "id": "45787662d2148352" |
163 | 208 | }, |
164 | 209 | { |
165 | 210 | "cell_type": "code", |
166 | 211 | "execution_count": null, |
167 | 212 | "outputs": [], |
168 | 213 | "source": [ |
169 | | - "# add function message\n", |
| 214 | + "# chat completion with the function result\n", |
170 | 215 | "chat_completion_result = taskingai.inference.chat_completion(\n", |
171 | 216 | " model_id=model_id,\n", |
172 | 217 | " messages=[\n", |
173 | | - " SystemMessage(\"You are a professional assistant.\"),\n", |
174 | 218 | " UserMessage(\"What is the result of 112 plus 22?\"),\n", |
175 | | - " assistant_function_call_message,\n", |
176 | | - " FunctionMessage(name=fucntion_name, content=\"144\")\n", |
| 219 | + " function_call_assistant_message,\n", |
| 220 | + " FunctionMessage(name=\"plus_a_and_b\", content=str(function_call_result))\n", |
177 | 221 | " ],\n", |
178 | 222 | " functions=[function]\n", |
179 | 223 | ")\n", |
|
184 | 228 | }, |
185 | 229 | "id": "9df9a8b9eafa17d9" |
186 | 230 | }, |
| 231 | + { |
| 232 | + "cell_type": "markdown", |
| 233 | + "source": [ |
| 234 | + "## Stream mode" |
| 235 | + ], |
| 236 | + "metadata": { |
| 237 | + "collapsed": false |
| 238 | + }, |
| 239 | + "id": "a64da98251c5d3c5" |
| 240 | + }, |
187 | 241 | { |
188 | 242 | "cell_type": "code", |
189 | 243 | "execution_count": null, |
|
0 commit comments