Skip to content

Commit 78292c2

Browse files
refactor(providers): standardize usage field mapping and add streaming mixins (withceleste#86)
* refactor(providers): standardize map_usage_fields as static methods Refactor all provider clients to expose map_usage_fields() as static methods, enabling streaming mixins to share usage parsing logic with clients for consistent usage field mapping across sync and streaming. Affected providers: - Anthropic Messages - BFL Images - BytePlus Images/Videos - Cohere Chat - ElevenLabs Text-to-Speech - Google GenerateContent/Imagen/Veo - Mistral Chat - OpenAI Audio/Images/Responses/Videos - xAI Responses * feat(providers): add OpenAI Images streaming mixin Add OpenAIImagesStream mixin for SSE parsing in OpenAI Images API: - Handle image_generation.partial_image and image_generation.completed - Handle image_edit.partial_image and image_edit.completed - Extract content, usage, and metadata from SSE events - Use static map_usage_fields for consistent usage parsing * feat(providers): add BytePlus Images streaming mixin Add BytePlusImagesStream mixin for SSE parsing in BytePlus Images API: - Handle image_generation.partial_succeeded with url or b64_json - Handle image_generation.partial_failed error events - Handle image_generation.completed with usage data - Use static map_usage_fields for consistent usage parsing Also fix missing dependencies in BytePlus pyproject.toml. * refactor(image-generation): migrate streaming to use provider mixins Update capability streaming implementations to use new provider mixins: - OpenAIImageGenerationStream now inherits from OpenAIImagesStream - BytePlusImageGenerationStream now inherits from BytePlusImagesStream Both now use super()._parse_chunk() to get raw parsed data, then wrap in typed ImageGenerationChunk with proper ImageArtifact handling. * fix(bfl): add missing dependencies to pyproject.toml
1 parent 78393d0 commit 78292c2

File tree

21 files changed

+360
-97
lines changed

21 files changed

+360
-97
lines changed

packages/capabilities/image-generation/src/celeste_image_generation/providers/byteplus/streaming.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
"""BytePlus streaming for image generation."""
22

3+
import base64
34
import logging
45
from collections.abc import AsyncIterator
56
from typing import Any
67

8+
from celeste_byteplus.images.streaming import BytePlusImagesStream
9+
710
from celeste.artifacts import ImageArtifact
11+
from celeste.core import UsageField
812
from celeste.mime_types import ImageMimeType
913
from celeste_image_generation.io import ImageGenerationChunk, ImageGenerationUsage
1014
from celeste_image_generation.streaming import ImageGenerationStream
1115

1216
logger = logging.getLogger(__name__)
1317

1418

15-
class BytePlusImageGenerationStream(ImageGenerationStream):
19+
class BytePlusImageGenerationStream(BytePlusImagesStream, ImageGenerationStream):
1620
"""BytePlus streaming for image generation."""
1721

1822
def __init__(self, sse_iterator: AsyncIterator[dict[str, Any]]) -> None:
@@ -21,40 +25,53 @@ def __init__(self, sse_iterator: AsyncIterator[dict[str, Any]]) -> None:
2125
self._completed_usage: ImageGenerationUsage | None = None
2226

2327
def _parse_chunk(self, chunk_data: dict[str, Any]) -> ImageGenerationChunk | None:
24-
"""Parse chunk from SSE event."""
25-
event_type = chunk_data.get("type")
26-
27-
if event_type == "image_generation.partial_succeeded":
28-
url = chunk_data.get("url")
29-
if not url:
30-
logger.warning("partial_succeeded event missing URL")
31-
return None
32-
33-
artifact = ImageArtifact(url=url, mime_type=ImageMimeType.PNG)
34-
return ImageGenerationChunk(content=artifact)
35-
36-
if event_type == "image_generation.completed":
37-
usage_data = chunk_data.get("usage")
38-
if usage_data:
39-
self._completed_usage = ImageGenerationUsage(
40-
total_tokens=usage_data.get("total_tokens"),
41-
)
28+
"""Parse chunk from SSE event.
29+
30+
Uses provider mixin to parse raw SSE event, then wraps in typed chunk.
31+
"""
32+
raw = super()._parse_chunk(chunk_data)
33+
if not raw:
4234
return None
4335

44-
if event_type == "image_generation.partial_failed":
45-
error = chunk_data.get("error", {})
36+
# Handle error events
37+
if raw.get("is_error"):
38+
error = raw.get("error", {})
4639
logger.error(
4740
"Image generation failed: %s - %s",
4841
error.get("code"),
4942
error.get("message"),
5043
)
5144
return None
5245

53-
logger.warning("Unknown event type: %s", event_type)
54-
return None
46+
# Handle completed event (usage only)
47+
usage_data = raw.get("usage")
48+
if usage_data:
49+
self._completed_usage = ImageGenerationUsage(
50+
total_tokens=usage_data.get(UsageField.TOTAL_TOKENS),
51+
output_tokens=usage_data.get(UsageField.OUTPUT_TOKENS),
52+
num_images=usage_data.get(UsageField.NUM_IMAGES),
53+
)
54+
return None
55+
56+
# Handle partial succeeded (image content)
57+
content = raw.get("content")
58+
content_type = raw.get("content_type")
59+
if not content:
60+
return None
61+
62+
if content_type == "url":
63+
artifact = ImageArtifact(url=content, mime_type=ImageMimeType.PNG)
64+
else: # b64_json
65+
image_data = base64.b64decode(content)
66+
artifact = ImageArtifact(data=image_data)
67+
68+
return ImageGenerationChunk(content=artifact)
5569

5670
def _parse_usage(self, chunks: list[ImageGenerationChunk]) -> ImageGenerationUsage:
57-
"""Parse usage from chunks."""
71+
"""Parse usage from chunks.
72+
73+
Usage is stored from the completed event.
74+
"""
5875
if self._completed_usage is not None:
5976
return self._completed_usage
6077

packages/capabilities/image-generation/src/celeste_image_generation/providers/openai/streaming.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,56 @@
44
import logging
55
from typing import Any
66

7+
from celeste_openai.images.streaming import OpenAIImagesStream
8+
79
from celeste.artifacts import ImageArtifact
10+
from celeste.core import UsageField
811
from celeste_image_generation.io import ImageGenerationChunk, ImageGenerationUsage
912
from celeste_image_generation.streaming import ImageGenerationStream
1013

1114
logger = logging.getLogger(__name__)
1215

1316

14-
class OpenAIImageGenerationStream(ImageGenerationStream):
17+
class OpenAIImageGenerationStream(OpenAIImagesStream, ImageGenerationStream):
1518
"""OpenAI streaming for image generation."""
1619

1720
def _parse_chunk(self, chunk_data: dict[str, Any]) -> ImageGenerationChunk | None:
1821
"""Parse chunk from SSE event.
1922
20-
OpenAI returns two event types:
21-
- image_generation.partial_image: Progressive image chunks
22-
- image_generation.completed: Final image with usage data
23+
Uses provider mixin to parse raw SSE event, then wraps in typed chunk.
2324
"""
24-
event_type = chunk_data.get("type")
25-
26-
if event_type == "image_generation.partial_image":
27-
# Partial image chunk
28-
b64_json = chunk_data.get("b64_json")
29-
if not b64_json:
30-
return None
31-
32-
image_data = base64.b64decode(b64_json)
33-
artifact = ImageArtifact(data=image_data)
34-
35-
return ImageGenerationChunk(content=artifact)
36-
37-
if event_type == "image_generation.completed":
38-
# Final image with usage
39-
b64_json = chunk_data.get("b64_json")
40-
if not b64_json:
41-
return None
25+
raw = super()._parse_chunk(chunk_data)
26+
if not raw:
27+
return None
4228

43-
image_data = base64.b64decode(b64_json)
44-
artifact = ImageArtifact(data=image_data)
29+
b64_json = raw.get("content")
30+
if not b64_json:
31+
return None
4532

46-
# Parse usage from completed event
47-
usage_data = chunk_data.get("usage")
48-
usage = None
49-
if usage_data:
50-
usage = ImageGenerationUsage(
51-
total_tokens=usage_data.get("total_tokens"),
52-
input_tokens=usage_data.get("input_tokens"),
53-
output_tokens=usage_data.get("output_tokens"),
54-
)
33+
image_data = base64.b64decode(b64_json)
34+
artifact = ImageArtifact(data=image_data)
5535

56-
return ImageGenerationChunk(content=artifact, usage=usage)
36+
# Parse usage from raw dict (already mapped to UsageField keys)
37+
usage = None
38+
usage_data = raw.get("usage")
39+
if usage_data:
40+
usage = ImageGenerationUsage(
41+
total_tokens=usage_data.get(UsageField.TOTAL_TOKENS),
42+
input_tokens=usage_data.get(UsageField.INPUT_TOKENS),
43+
output_tokens=usage_data.get(UsageField.OUTPUT_TOKENS),
44+
)
5745

58-
logger.warning("Unknown event type: %s", event_type)
59-
return None
46+
return ImageGenerationChunk(content=artifact, usage=usage)
6047

6148
def _parse_usage(self, chunks: list[ImageGenerationChunk]) -> ImageGenerationUsage:
6249
"""Parse usage from chunks.
6350
6451
Usage is only available in the final completed event.
6552
"""
66-
# Look for usage in final chunk (completed event)
6753
for chunk in reversed(chunks):
6854
if chunk.usage is not None:
6955
return chunk.usage
7056

71-
# No usage found
7257
return ImageGenerationUsage()
7358

7459

packages/providers/anthropic/src/celeste_anthropic/messages/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def map_usage_fields(usage_data: dict[str, Any]) -> dict[str, int | None]:
120120
def _parse_usage(self, response_data: dict[str, Any]) -> dict[str, int | None]:
121121
"""Extract usage data from Messages API response."""
122122
usage_data = response_data.get("usage", {})
123-
return self.map_usage_fields(usage_data)
123+
return AnthropicMessagesClient.map_usage_fields(usage_data)
124124

125125
def _parse_content(self, response_data: dict[str, Any]) -> Any:
126126
"""Parse content array from Messages API.

packages/providers/bfl/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ description = "BFL (Black Forest Labs) provider package for Celeste AI"
55
authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}]
66
license = {text = "Apache-2.0"}
77
requires-python = ">=3.12"
8+
dependencies = ["celeste-ai", "httpx"]
89

910
[tool.uv.sources]
1011
celeste-ai = { workspace = true }

packages/providers/bfl/src/celeste_bfl/images/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def map_usage_fields(usage_data: dict[str, Any]) -> dict[str, float | None]:
136136
def _parse_usage(self, response_data: dict[str, Any]) -> dict[str, float | None]:
137137
"""Extract usage data from BFL response."""
138138
submit_metadata = response_data.get("_submit_metadata", {})
139-
return self.map_usage_fields(submit_metadata)
139+
return BFLImagesClient.map_usage_fields(submit_metadata)
140140

141141
def _parse_content(self, response_data: dict[str, Any]) -> Any:
142142
"""Parse result from response."""

packages/providers/byteplus/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ description = "BytePlus provider package for Celeste AI"
55
authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}]
66
license = {text = "Apache-2.0"}
77
requires-python = ">=3.12"
8+
dependencies = ["celeste-ai", "httpx"]
89

910
[tool.uv.sources]
1011
celeste-ai = { workspace = true }

packages/providers/byteplus/src/celeste_byteplus/images/client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,26 @@ def _make_stream_request(
6868
json_body=request_body,
6969
)
7070

71-
def _parse_usage(self, response_data: dict[str, Any]) -> dict[str, int | None]:
72-
"""Extract usage data from Images API response.
71+
@staticmethod
72+
def map_usage_fields(usage_data: dict[str, Any]) -> dict[str, int | None]:
73+
"""Map BytePlus Images usage fields to unified names.
7374
74-
Returns dict that capability clients wrap in their specific Usage type.
75+
Shared by client and streaming across all capabilities.
7576
"""
76-
usage_data = response_data.get("usage", {})
77-
7877
return {
7978
UsageField.TOTAL_TOKENS: usage_data.get("total_tokens"),
8079
UsageField.OUTPUT_TOKENS: usage_data.get("output_tokens"),
8180
UsageField.NUM_IMAGES: usage_data.get("generated_images"),
8281
}
8382

83+
def _parse_usage(self, response_data: dict[str, Any]) -> dict[str, int | None]:
84+
"""Extract usage data from Images API response.
85+
86+
Returns dict that capability clients wrap in their specific Usage type.
87+
"""
88+
usage_data = response_data.get("usage", {})
89+
return BytePlusImagesClient.map_usage_fields(usage_data)
90+
8491
def _parse_content(self, response_data: dict[str, Any]) -> Any:
8592
"""Parse images/data array from Images API response.
8693
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""BytePlus Images SSE parsing for streaming."""
2+
3+
from typing import Any
4+
5+
from .client import BytePlusImagesClient
6+
7+
8+
class BytePlusImagesStream:
9+
"""Mixin for BytePlus Images API SSE parsing.
10+
11+
Provides shared implementation for capabilities using BytePlus Images API streaming:
12+
- _parse_chunk() - Parse SSE event into raw chunk dict
13+
14+
Handles all image streaming event types:
15+
- image_generation.partial_succeeded - Partial image with url or b64_json
16+
- image_generation.partial_failed - Error event
17+
- image_generation.completed - Final event with usage only
18+
19+
Capability streams extend via super() to wrap results in typed Chunks.
20+
21+
Usage:
22+
class BytePlusImageGenerationStream(BytePlusImagesStream, ImageGenerationStream):
23+
def _parse_chunk(self, event):
24+
raw = super()._parse_chunk(event)
25+
if not raw:
26+
return None
27+
return ImageGenerationChunk(...)
28+
"""
29+
30+
def _parse_chunk(self, event: dict[str, Any]) -> dict[str, Any] | None:
31+
"""Parse SSE event into raw chunk data.
32+
33+
Returns dict with:
34+
- content_type: "url" or "b64_json" or None
35+
- content: url string or b64_json string
36+
- is_error: True for partial_failed events
37+
- error: error dict for failed events
38+
- usage: usage dict from completed event (None otherwise)
39+
- metadata: model, created, image_index, size
40+
- raw_event: original event dict
41+
"""
42+
event_type = event.get("type")
43+
if not event_type:
44+
return None
45+
46+
# Handle successful partial image
47+
if event_type == "image_generation.partial_succeeded":
48+
url = event.get("url")
49+
b64_json = event.get("b64_json")
50+
51+
content_type = None
52+
content = None
53+
if url:
54+
content_type = "url"
55+
content = url
56+
elif b64_json:
57+
content_type = "b64_json"
58+
content = b64_json
59+
60+
if not content:
61+
return None
62+
63+
return {
64+
"content_type": content_type,
65+
"content": content,
66+
"is_error": False,
67+
"error": None,
68+
"usage": None,
69+
"metadata": {
70+
"model": event.get("model"),
71+
"created": event.get("created"),
72+
"image_index": event.get("image_index"),
73+
"size": event.get("size"),
74+
},
75+
"raw_event": event,
76+
}
77+
78+
# Handle failed partial image
79+
if event_type == "image_generation.partial_failed":
80+
return {
81+
"content_type": None,
82+
"content": None,
83+
"is_error": True,
84+
"error": event.get("error"),
85+
"usage": None,
86+
"metadata": {
87+
"model": event.get("model"),
88+
"created": event.get("created"),
89+
"image_index": event.get("image_index"),
90+
},
91+
"raw_event": event,
92+
}
93+
94+
# Handle completed event (usage only, no image)
95+
if event_type == "image_generation.completed":
96+
usage_data = event.get("usage")
97+
usage = None
98+
if usage_data:
99+
usage = BytePlusImagesClient.map_usage_fields(usage_data)
100+
return {
101+
"content_type": None,
102+
"content": None,
103+
"is_error": False,
104+
"error": None,
105+
"usage": usage,
106+
"metadata": {
107+
"model": event.get("model"),
108+
"created": event.get("created"),
109+
},
110+
"raw_event": event,
111+
}
112+
113+
return None
114+
115+
116+
__all__ = ["BytePlusImagesStream"]

0 commit comments

Comments
 (0)