Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/openai/toolcall-stream-with-usage.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

use Symfony\AI\Agent\Agent;
use Symfony\AI\Agent\Bridge\Clock\Clock;
use Symfony\AI\Agent\Bridge\OpenMeteo\OpenMeteo;
use Symfony\AI\Agent\Toolbox\AgentProcessor;
use Symfony\AI\Agent\Toolbox\Toolbox;
use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory;
use Symfony\AI\Platform\Message\Message;
use Symfony\AI\Platform\Message\MessageBag;
use Symfony\AI\Platform\Result\TextChunk;

require_once dirname(__DIR__).'/bootstrap.php';

$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client());

$clock = new Clock();
$openMeteo = new OpenMeteo(http_client());
$toolbox = new Toolbox([$clock, $openMeteo], logger: logger());
$processor = new AgentProcessor($toolbox);

$agent = new Agent($platform, 'gpt-4o-mini', [$processor], [$processor]);
$messages = new MessageBag(Message::ofUser('Tell me the time and the weather in Dublin.'));

$result = $agent->call($messages, [
'stream' => true, // enable streaming of response text
'stream_options' => [
'include_usage' => true, // include usage in the response
],
]);

/** @var TextChunk $textChunk */
foreach ($result->getContent() as $textChunk) {
echo $textChunk->getContent();
}

foreach ($result->getMetadata()->get('calls', []) as $call) {
echo \PHP_EOL.sprintf(
'%s: %d tokens - Finish reason: [%s]',
$call['id'],
$call['usage']['total_tokens'],
$call['finish_reason']
);
}

echo \PHP_EOL;
2 changes: 1 addition & 1 deletion src/agent/src/Toolbox/AgentProcessor.php
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public function processOutput(Output $output): void

if ($result instanceof GenericStreamResponse) {
$output->setResult(
new ToolboxStreamResponse($result->getContent(), $this->handleToolCallsCallback($output))
new ToolboxStreamResponse($result, $this->handleToolCallsCallback($output))
);

return;
Expand Down
25 changes: 22 additions & 3 deletions src/agent/src/Toolbox/StreamResult.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

use Symfony\AI\Platform\Message\Message;
use Symfony\AI\Platform\Result\BaseResult;
use Symfony\AI\Platform\Result\StreamResult as PlatformStreamResult;
use Symfony\AI\Platform\Result\ToolCallResult;

/**
Expand All @@ -21,15 +22,15 @@
final class StreamResult extends BaseResult
{
public function __construct(
private readonly \Generator $generator,
private readonly PlatformStreamResult $sourceStreamResult,
private readonly \Closure $handleToolCallsCallback,
) {
}

public function getContent(): \Generator
{
$streamedResult = '';
foreach ($this->generator as $value) {
foreach ($this->sourceStreamResult->getContent() as $value) {
if ($value instanceof ToolCallResult) {
$innerResult = ($this->handleToolCallsCallback)($value, Message::ofAssistant($streamedResult));

Expand All @@ -48,12 +49,30 @@ public function getContent(): \Generator
yield from $content;
}

break;
if ($innerResult->getMetadata()->has('calls')) {
$innerCalls = $innerResult->getMetadata()->get('calls');
$previousCalls = $this->getMetadata()->get('calls', []);
$calls = array_merge($previousCalls, $innerCalls);
} else {
$calls[] = $innerResult->getMetadata()->all();
}

if ($calls !== ['calls' => []]) {
$this->getMetadata()->add('calls', $calls);
}

continue;
}

$streamedResult .= $value;

yield $value;
}

// Attach the metadata from the platform stream to the agent after the stream has been fully processed
// and the post-result metadata, such as usage, has been received.
$calls = $this->getMetadata()->get('calls', []);
$calls[] = $this->sourceStreamResult->getMetadata()->all();
$this->getMetadata()->add('calls', $calls);
}
}
13 changes: 8 additions & 5 deletions src/ai-bundle/src/Profiler/TraceablePlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ public function invoke(string $model, array|string|object $input, array $options
}

if ($options['stream'] ?? false) {
$originalStream = $deferredResult->asStream();
$deferredResult = new DeferredResult(new PlainConverter($this->createTraceableStreamResult($originalStream)), $deferredResult->getRawResult(), $options);
$deferredResult = new DeferredResult(new PlainConverter($this->createTraceableStreamResult($deferredResult)), $deferredResult->getRawResult(), $options);
}

$this->calls[] = [
Expand All @@ -75,16 +74,20 @@ public function getModelCatalog(): ModelCatalogInterface
return $this->platform->getModelCatalog();
}

private function createTraceableStreamResult(\Generator $originalStream): StreamResult
private function createTraceableStreamResult(DeferredResult $sourceResult): StreamResult
{
return $result = new StreamResult((function () use (&$result, $originalStream) {
return $result = new StreamResult((function () use (&$result, $sourceResult) {
$this->resultCache[$result] = '';
foreach ($originalStream as $chunk) {
foreach ($sourceResult->asStream() as $chunk) {
yield $chunk;
if (\is_string($chunk)) {
$this->resultCache[$result] .= $chunk;
}
}

foreach ($sourceResult->getResult()->getMetadata() as $key => $value) {
$result->getMetadata()->add($key, $value);
}
})());
}
}
28 changes: 26 additions & 2 deletions src/platform/src/Bridge/OpenAi/Gpt/ResultConverter.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
use Symfony\AI\Platform\Exception\ContentFilterException;
use Symfony\AI\Platform\Exception\RateLimitExceededException;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Metadata\Metadata;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\Result\ChoiceResult;
use Symfony\AI\Platform\Result\RawHttpResult;
use Symfony\AI\Platform\Result\RawResultInterface;
use Symfony\AI\Platform\Result\ResultInterface;
use Symfony\AI\Platform\Result\StreamResult;
use Symfony\AI\Platform\Result\TextChunk;
use Symfony\AI\Platform\Result\TextResult;
use Symfony\AI\Platform\Result\ToolCall;
use Symfony\AI\Platform\Result\ToolCallResult;
Expand Down Expand Up @@ -88,21 +90,43 @@ public function convert(RawResultInterface|RawHttpResult $result, array $options
private function convertStream(RawResultInterface|RawHttpResult $result): \Generator
{
$toolCalls = [];
$metadata = [];
foreach ($result->getDataStream() as $data) {
if (!$metadata && isset($data['id'])) {
$metadata['id'] = $data['id'];
}

if (isset($data['usage'])) {
$metadata['usage'] = $data['usage'];
}

if (isset($data['choices'][0]['finish_reason'])) {
$metadata['finish_reason'] = $data['choices'][0]['finish_reason'];
}

if ($this->streamIsToolCall($data)) {
$toolCalls = $this->convertStreamToToolCalls($toolCalls, $data);
}

if ([] !== $toolCalls && $this->isToolCallsStreamFinished($data)) {
yield new ToolCallResult(...array_map($this->convertToolCall(...), $toolCalls));
$toolCallResult = new ToolCallResult(...array_map($this->convertToolCall(...), $toolCalls));
$metadata['tool_calls'] = $toolCalls;
$toolCallResult->getMetadata()->set($metadata);
yield $toolCallResult;
}

if (!isset($data['choices'][0]['delta']['content'])) {
continue;
}

yield $data['choices'][0]['delta']['content'];
$textChunk = new TextChunk($data['choices'][0]['delta']['content']);
$textChunk->getMetadata()->set($metadata);
$textChunk->setRawResult($result);

yield $textChunk;
}

yield new Metadata($metadata);
}

/**
Expand Down
8 changes: 7 additions & 1 deletion src/platform/src/Result/DeferredResult.php
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ public function asVectors(): array
*/
public function asStream(): \Generator
{
yield from $this->as(StreamResult::class)->getContent();
$streamResult = $this->as(StreamResult::class);

yield from $streamResult->getContent();

foreach ($streamResult->getMetadata() as $key => $value) {
$this->getResult()->getMetadata()->add($key, $value);
}
}

/**
Expand Down
13 changes: 12 additions & 1 deletion src/platform/src/Result/StreamResult.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

namespace Symfony\AI\Platform\Result;

use Symfony\AI\Platform\Metadata\Metadata;

/**
* @author Christopher Hertel <mail@christopher-hertel.de>
*/
Expand All @@ -23,6 +25,15 @@ public function __construct(

public function getContent(): \Generator
{
yield from $this->generator;
foreach ($this->generator as $content) {
if ($content instanceof Metadata) {
foreach ($content as $key => $value) {
$this->getMetadata()->add($key, $value);
}
continue;
}

yield $content;
}
}
}
33 changes: 33 additions & 0 deletions src/platform/src/Result/TextChunk.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <fabien@symfony.com>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform\Result;

/**
* @author Oscar Esteve <oscarsdt@gmail.com>
*/
final class TextChunk extends BaseResult implements \Stringable
{
public function __construct(
private readonly string $content,
) {
}

public function __toString(): string
{
return $this->content;
}

public function getContent(): string
{
return $this->content;
}
}
Loading