Skip to content

Commit 9aedbf3

Browse files
committed
Simplify choice handling
1 parent 99dc222 commit 9aedbf3

File tree

12 files changed

+56
-206
lines changed

12 files changed

+56
-206
lines changed

.php-cs-fixer.dist.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
->setRiskyAllowed(true)
4444
->setFinder(
4545
(new PhpCsFixer\Finder())
46-
->in([__DIR__.'/demo', __DIR__.'/examples', __DIR__.'/fixtures', __DIR__.'/src'])
46+
->in(__DIR__.'/{demo,examples,fixtures,src}')
4747
->append([__FILE__])
4848
->exclude('var')
4949
)

src/agent/phpstan.dist.neon

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ parameters:
1313
-
1414
identifier: 'symfonyAi.forbidNativeException'
1515
path: tests/*
16+
-
17+
message: "#^Method .*::test.*\\(\\) has no return type specified\\.$#"
18+
path: tests/*

src/agent/tests/StructuredOutput/AgentProcessorTest.php

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
use Symfony\AI\Platform\Capability;
2525
use Symfony\AI\Platform\Message\MessageBag;
2626
use Symfony\AI\Platform\Model;
27-
use Symfony\AI\Platform\Result\Choice;
2827
use Symfony\AI\Platform\Result\Metadata\Metadata;
2928
use Symfony\AI\Platform\Result\ObjectResult;
3029
use Symfony\AI\Platform\Result\TextResult;
@@ -34,14 +33,13 @@
3433
#[UsesClass(Input::class)]
3534
#[UsesClass(Output::class)]
3635
#[UsesClass(MessageBag::class)]
37-
#[UsesClass(Choice::class)]
3836
#[UsesClass(MissingModelSupportException::class)]
3937
#[UsesClass(TextResult::class)]
4038
#[UsesClass(ObjectResult::class)]
4139
#[UsesClass(Model::class)]
4240
final class AgentProcessorTest extends TestCase
4341
{
44-
public function testProcessInputWithOutputStructure(): void
42+
public function testProcessInputWithOutputStructure()
4543
{
4644
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format']));
4745

@@ -53,7 +51,7 @@ public function testProcessInputWithOutputStructure(): void
5351
$this->assertSame(['response_format' => ['some' => 'format']], $input->getOptions());
5452
}
5553

56-
public function testProcessInputWithoutOutputStructure(): void
54+
public function testProcessInputWithoutOutputStructure()
5755
{
5856
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory());
5957

@@ -65,7 +63,7 @@ public function testProcessInputWithoutOutputStructure(): void
6563
$this->assertSame([], $input->getOptions());
6664
}
6765

68-
public function testProcessInputThrowsExceptionWhenLlmDoesNotSupportStructuredOutput(): void
66+
public function testProcessInputThrowsExceptionWhenLlmDoesNotSupportStructuredOutput()
6967
{
7068
self::expectException(MissingModelSupportException::class);
7169

@@ -77,7 +75,7 @@ public function testProcessInputThrowsExceptionWhenLlmDoesNotSupportStructuredOu
7775
$processor->processInput($input);
7876
}
7977

80-
public function testProcessOutputWithResponseFormat(): void
78+
public function testProcessOutputWithResponseFormat()
8179
{
8280
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format']));
8381

@@ -99,7 +97,7 @@ public function testProcessOutputWithResponseFormat(): void
9997
$this->assertSame('data', $output->result->getContent()->some);
10098
}
10199

102-
public function testProcessOutputWithComplexResponseFormat(): void
100+
public function testProcessOutputWithComplexResponseFormat()
103101
{
104102
$processor = new AgentProcessor(new ConfigurableResponseFormatFactory(['some' => 'format']));
105103

@@ -153,7 +151,7 @@ public function testProcessOutputWithComplexResponseFormat(): void
153151
$this->assertSame('x = -3.75', $structure->finalAnswer);
154152
}
155153

156-
public function testProcessOutputWithoutResponseFormat(): void
154+
public function testProcessOutputWithoutResponseFormat()
157155
{
158156
$resultFormatFactory = new ConfigurableResponseFormatFactory();
159157
$serializer = self::createMock(SerializerInterface::class);

src/platform/phpstan.dist.neon

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ parameters:
1919
-
2020
identifier: missingType.iterableValue
2121
path: tests/*
22+
-
23+
message: "#^Method .*::test.*\\(\\) has no return type specified\\.$#"
24+
path: tests/*

src/platform/src/Bridge/Gemini/Gemini/ResultConverter.php

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
use Symfony\AI\Platform\Bridge\Gemini\Gemini;
1515
use Symfony\AI\Platform\Exception\RuntimeException;
1616
use Symfony\AI\Platform\Model;
17-
use Symfony\AI\Platform\Result\Choice;
1817
use Symfony\AI\Platform\Result\ChoiceResult;
1918
use Symfony\AI\Platform\Result\RawHttpResult;
2019
use Symfony\AI\Platform\Result\RawResultInterface;
@@ -46,21 +45,12 @@ public function convert(RawResultInterface|RawHttpResult $result, array $options
4645
$data = $result->getData();
4746

4847
if (!isset($data['candidates'][0]['content']['parts'][0])) {
49-
throw new RuntimeException('Response does not contain any content');
48+
throw new RuntimeException('Response does not contain any content.');
5049
}
5150

52-
/** @var Choice[] $choices */
5351
$choices = array_map($this->convertChoice(...), $data['candidates']);
5452

55-
if (1 !== \count($choices)) {
56-
return new ChoiceResult(...$choices);
57-
}
58-
59-
if ($choices[0]->hasToolCall()) {
60-
return new ToolCallResult(...$choices[0]->getToolCalls());
61-
}
62-
63-
return new TextResult($choices[0]->getContent());
53+
return 1 === \count($choices) ? $choices[0] : new ChoiceResult(...$choices);
6454
}
6555

6656
private function convertStream(HttpResponse $result): \Generator
@@ -91,10 +81,9 @@ private function convertStream(HttpResponse $result): \Generator
9181
try {
9282
$data = json_decode($delta, true, 512, \JSON_THROW_ON_ERROR);
9383
} catch (\JsonException $e) {
94-
throw new RuntimeException('Failed to decode JSON response', 0, $e);
84+
throw new RuntimeException('Failed to decode JSON response.', 0, $e);
9585
}
9686

97-
/** @var Choice[] $choices */
9887
$choices = array_map($this->convertChoice(...), $data['candidates'] ?? []);
9988

10089
if (!$choices) {
@@ -106,13 +95,7 @@ private function convertStream(HttpResponse $result): \Generator
10695
continue;
10796
}
10897

109-
if ($choices[0]->hasToolCall()) {
110-
yield new ToolCallResult(...$choices[0]->getToolCalls());
111-
}
112-
113-
if ($choices[0]->hasContent()) {
114-
yield $choices[0]->getContent();
115-
}
98+
yield $choices[0]->getContent();
11699
}
117100
}
118101
}
@@ -132,16 +115,16 @@ private function convertStream(HttpResponse $result): \Generator
132115
* }
133116
* } $choice
134117
*/
135-
private function convertChoice(array $choice): Choice
118+
private function convertChoice(array $choice): ToolCallResult|TextResult
136119
{
137120
$contentPart = $choice['content']['parts'][0] ?? [];
138121

139122
if (isset($contentPart['functionCall'])) {
140-
return new Choice(toolCalls: [$this->convertToolCall($contentPart['functionCall'])]);
123+
return new ToolCallResult($this->convertToolCall($contentPart['functionCall']));
141124
}
142125

143126
if (isset($contentPart['text'])) {
144-
return new Choice($contentPart['text']);
127+
return new TextResult($contentPart['text']);
145128
}
146129

147130
throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finishReason']));

src/platform/src/Bridge/Mistral/Llm/ResultConverter.php

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
use Symfony\AI\Platform\Bridge\Mistral\Mistral;
1515
use Symfony\AI\Platform\Exception\RuntimeException;
1616
use Symfony\AI\Platform\Model;
17-
use Symfony\AI\Platform\Result\Choice;
1817
use Symfony\AI\Platform\Result\ChoiceResult;
1918
use Symfony\AI\Platform\Result\RawHttpResult;
2019
use Symfony\AI\Platform\Result\RawResultInterface;
@@ -51,27 +50,18 @@ public function convert(RawResultInterface|RawHttpResult $result, array $options
5150
}
5251

5352
if (200 !== $code = $httpResponse->getStatusCode()) {
54-
throw new RuntimeException(\sprintf('Unexpected response code %d: %s', $code, $httpResponse->getContent(false)));
53+
throw new RuntimeException(\sprintf('Unexpected response code %d: "%s"', $code, $httpResponse->getContent(false)));
5554
}
5655

5756
$data = $result->getData();
5857

5958
if (!isset($data['choices'])) {
60-
throw new RuntimeException('Response does not contain choices');
59+
throw new RuntimeException('Response does not contain choices.');
6160
}
6261

63-
/** @var Choice[] $choices */
6462
$choices = array_map($this->convertChoice(...), $data['choices']);
6563

66-
if (1 !== \count($choices)) {
67-
return new ChoiceResult(...$choices);
68-
}
69-
70-
if ($choices[0]->hasToolCall()) {
71-
return new ToolCallResult(...$choices[0]->getToolCalls());
72-
}
73-
74-
return new TextResult($choices[0]->getContent());
64+
return 1 === \count($choices) ? $choices[0] : new ChoiceResult(...$choices);
7565
}
7666

7767
private function convertStream(HttpResponse $result): \Generator
@@ -152,7 +142,7 @@ private function isToolCallsStreamFinished(array $data): bool
152142

153143
/**
154144
* @param array{
155-
* index: integer,
145+
* index: int,
156146
* message: array{
157147
* role: 'assistant',
158148
* content: ?string,
@@ -170,14 +160,14 @@ private function isToolCallsStreamFinished(array $data): bool
170160
* finish_reason: 'stop'|'length'|'tool_calls'|'content_filter',
171161
* } $choice
172162
*/
173-
private function convertChoice(array $choice): Choice
163+
private function convertChoice(array $choice): ToolCallResult|TextResult
174164
{
175165
if ('tool_calls' === $choice['finish_reason']) {
176-
return new Choice(toolCalls: array_map([$this, 'convertToolCall'], $choice['message']['tool_calls']));
166+
return new ToolCallResult(...array_map([$this, 'convertToolCall'], $choice['message']['tool_calls']));
177167
}
178168

179169
if ('stop' === $choice['finish_reason']) {
180-
return new Choice($choice['message']['content']);
170+
return new TextResult($choice['message']['content']);
181171
}
182172

183173
throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finish_reason']));

src/platform/src/Bridge/OpenAI/GPT/ResultConverter.php

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
use Symfony\AI\Platform\Exception\ContentFilterException;
1616
use Symfony\AI\Platform\Exception\RuntimeException;
1717
use Symfony\AI\Platform\Model;
18-
use Symfony\AI\Platform\Result\Choice;
1918
use Symfony\AI\Platform\Result\ChoiceResult;
2019
use Symfony\AI\Platform\Result\RawHttpResult;
2120
use Symfony\AI\Platform\Result\RawResultInterface;
@@ -54,21 +53,12 @@ public function convert(RawResultInterface|RawHttpResult $result, array $options
5453
}
5554

5655
if (!isset($data['choices'])) {
57-
throw new RuntimeException('Response does not contain choices');
56+
throw new RuntimeException('Response does not contain choices.');
5857
}
5958

60-
/** @var Choice[] $choices */
6159
$choices = array_map($this->convertChoice(...), $data['choices']);
6260

63-
if (1 !== \count($choices)) {
64-
return new ChoiceResult(...$choices);
65-
}
66-
67-
if ($choices[0]->hasToolCall()) {
68-
return new ToolCallResult(...$choices[0]->getToolCalls());
69-
}
70-
71-
return new TextResult($choices[0]->getContent());
61+
return 1 === \count($choices) ? $choices[0] : new ChoiceResult(...$choices);
7262
}
7363

7464
private function convertStream(HttpResponse $result): \Generator
@@ -149,7 +139,7 @@ private function isToolCallsStreamFinished(array $data): bool
149139

150140
/**
151141
* @param array{
152-
* index: integer,
142+
* index: int,
153143
* message: array{
154144
* role: 'assistant',
155145
* content: ?string,
@@ -167,14 +157,14 @@ private function isToolCallsStreamFinished(array $data): bool
167157
* finish_reason: 'stop'|'length'|'tool_calls'|'content_filter',
168158
* } $choice
169159
*/
170-
private function convertChoice(array $choice): Choice
160+
private function convertChoice(array $choice): ToolCallResult|TextResult
171161
{
172162
if ('tool_calls' === $choice['finish_reason']) {
173-
return new Choice(toolCalls: array_map([$this, 'convertToolCall'], $choice['message']['tool_calls']));
163+
return new ToolCallResult(...array_map([$this, 'convertToolCall'], $choice['message']['tool_calls']));
174164
}
175165

176166
if (\in_array($choice['finish_reason'], ['stop', 'length'], true)) {
177-
return new Choice($choice['message']['content']);
167+
return new TextResult($choice['message']['content']);
178168
}
179169

180170
throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finish_reason']));

src/platform/src/Result/Choice.php

Lines changed: 0 additions & 50 deletions
This file was deleted.

src/platform/src/Result/ChoiceResult.php

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,24 @@
1919
final class ChoiceResult extends BaseResult
2020
{
2121
/**
22-
* @var Choice[]
22+
* @var ResultInterface[]
2323
*/
24-
private readonly array $choices;
24+
private readonly array $results;
2525

26-
public function __construct(Choice ...$choices)
26+
public function __construct(ResultInterface ...$results)
2727
{
28-
if ([] === $choices) {
29-
throw new InvalidArgumentException('Result must have at least one choice.');
28+
if (1 >= \count($results)) {
29+
throw new InvalidArgumentException('A choice result must contain at least two results.');
3030
}
3131

32-
$this->choices = $choices;
32+
$this->results = $results;
3333
}
3434

3535
/**
36-
* @return Choice[]
36+
* @return ResultInterface[]
3737
*/
3838
public function getContent(): array
3939
{
40-
return $this->choices;
40+
return $this->results;
4141
}
4242
}

0 commit comments

Comments
 (0)