Skip to content

Commit b6a514d

Browse files
feat(providers): add BFL Images API package (withceleste#81)
* feat(providers): add BFL Images API package Add standalone provider package for BFL Images API with mixin pattern for capability-agnostic reuse. ## Client (BFLImagesClient mixin) - Async polling workflow for image generation - Phase 1: POST to /v1/{model_id} to submit job - Phase 2: Poll GET polling_url until Ready/Failed - _parse_content() extracts result from response - _parse_finish_reason() maps BFL status to FinishReason - Merges _submit_metadata into final response for usage tracking - Configurable polling interval (0.5s) and timeout (120s) ## Parameters - WidthMapper: image width - HeightMapper: image height - SeedMapper: reproducible generation seed - StepsMapper: inference steps - GuidanceMapper: guidance scale - PromptUpsamplingMapper: prompt enhancement - SafetyToleranceMapper: content safety level - OutputFormatMapper: output image format ## Config - API base URL: https://api.bfl.ai - Endpoint: /v1/{model_id} - Polling interval: 0.5 seconds - Polling timeout: 120 seconds ## Other - py.typed marker for typed package support * refactor(bfl): migrate usage parsing to API layer pattern Move usage field mapping from capability layer to API layer mixin, following the established pattern from Anthropic and Google providers. ## API Layer Changes (celeste_bfl) - Add map_usage_fields() static method with UsageField enum - Add _parse_usage() that extracts _submit_metadata and maps fields - Maps: cost → BILLED_UNITS, input_mp → INPUT_MP, output_mp → OUTPUT_MP ## Capability Layer Changes (image-generation) - Simplify _parse_usage() to call super()._parse_usage() and wrap result - Add celeste-bfl workspace dependency to pyproject.toml * chore(image-generation): remove redundant BFL config from capability layer * fix(bfl): use Int constraint for width/height validation - Enhanced Int constraint to accept int, str, float (whole numbers only) - Updated BFL WidthMapper and HeightMapper to use Int() directly - Fixes AttributeError when AspectRatioMapper calls internal mappers
1 parent f38d736 commit b6a514d

File tree

16 files changed

+470
-228
lines changed

16 files changed

+470
-228
lines changed

packages/capabilities/image-generation/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Issues = "https://github.com/withceleste/celeste-python/issues"
2727

2828
[tool.uv.sources]
2929
celeste-ai = { workspace = true }
30+
celeste-bfl = { workspace = true }
3031
celeste-google = { workspace = true }
3132

3233
[project.entry-points."celeste.packages"]

packages/capabilities/image-generation/src/celeste_image_generation/io.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class ImageGenerationUsage(Usage):
3232
reasoning_tokens: int | None = None
3333
num_images: int | None = None
3434
billed_units: float | None = None
35+
input_mp: float | None = None
36+
output_mp: float | None = None
3537

3638

3739
class ImageGenerationOutput(Output[ImageArtifact | list[ImageArtifact]]):
Lines changed: 10 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
"""BFL (Black Forest Labs) client implementation for FLUX.2 image generation."""
1+
"""BFL client implementation for image generation."""
22

3-
import asyncio
4-
import json
5-
import time
63
from typing import Any, Unpack
74

8-
import httpx
5+
from celeste_bfl.images.client import BFLImagesClient
96

107
from celeste.artifacts import ImageArtifact
11-
from celeste.mime_types import ApplicationMimeType
8+
from celeste.exceptions import ValidationError
129
from celeste.parameters import ParameterMapper
1310
from celeste_image_generation.client import ImageGenerationClient
1411
from celeste_image_generation.io import (
@@ -18,12 +15,11 @@
1815
)
1916
from celeste_image_generation.parameters import ImageGenerationParameters
2017

21-
from . import config
2218
from .parameters import BFL_PARAMETER_MAPPERS
2319

2420

25-
class BFLImageGenerationClient(ImageGenerationClient):
26-
"""Black Forest Labs client for image generation."""
21+
class BFLImageGenerationClient(BFLImagesClient, ImageGenerationClient):
22+
"""BFL client for image generation."""
2723

2824
@classmethod
2925
def parameter_mappers(cls) -> list[ParameterMapper]:
@@ -37,12 +33,8 @@ def _init_request(self, inputs: ImageGenerationInput) -> dict[str, Any]:
3733

3834
def _parse_usage(self, response_data: dict[str, Any]) -> ImageGenerationUsage:
3935
"""Parse usage from response."""
40-
submit_metadata = response_data.get("_submit_metadata", {})
41-
cost = submit_metadata.get("cost")
42-
43-
return ImageGenerationUsage(
44-
billed_units=float(cost) if cost is not None else None,
45-
)
36+
usage = super()._parse_usage(response_data)
37+
return ImageGenerationUsage(**usage)
4638

4739
def _parse_content(
4840
self,
@@ -55,95 +47,21 @@ def _parse_content(
5547

5648
if not sample_url:
5749
msg = f"No image URL in {self.provider} response"
58-
raise ValueError(msg)
50+
raise ValidationError(msg)
5951

6052
return ImageArtifact(url=sample_url)
6153

6254
def _parse_finish_reason(
6355
self, response_data: dict[str, Any]
64-
) -> ImageGenerationFinishReason | None:
56+
) -> ImageGenerationFinishReason:
6557
"""Parse finish reason from response."""
6658
status = response_data.get("status")
6759
if status == "Ready":
6860
return ImageGenerationFinishReason(reason="COMPLETE")
6961
elif status in ("Error", "Failed"):
7062
error_msg = response_data.get("error", "Generation failed")
7163
return ImageGenerationFinishReason(reason="ERROR", message=error_msg)
72-
return None
73-
74-
async def _make_request(
75-
self,
76-
request_body: dict[str, Any],
77-
**parameters: Unpack[ImageGenerationParameters],
78-
) -> httpx.Response:
79-
"""Make HTTP request(s) and return response object."""
80-
headers = {
81-
**self.auth.get_headers(),
82-
"Content-Type": ApplicationMimeType.JSON,
83-
"Accept": ApplicationMimeType.JSON,
84-
}
85-
86-
endpoint = config.ENDPOINT.format(model_id=self.model.id)
87-
88-
submit_response = await self.http_client.post(
89-
f"{config.BASE_URL}{endpoint}",
90-
headers=headers,
91-
json_body=request_body,
92-
)
93-
94-
if submit_response.status_code != 200:
95-
return submit_response
96-
97-
submit_data = submit_response.json()
98-
polling_url = submit_data.get("polling_url")
99-
100-
if not polling_url:
101-
msg = f"No polling_url in {self.provider} response"
102-
raise ValueError(msg)
103-
104-
start_time = time.monotonic()
105-
poll_headers = {
106-
**self.auth.get_headers(),
107-
"Accept": ApplicationMimeType.JSON,
108-
}
109-
110-
while True:
111-
elapsed = time.monotonic() - start_time
112-
if elapsed >= config.POLLING_TIMEOUT:
113-
msg = f"{self.provider} polling timed out after {config.POLLING_TIMEOUT} seconds"
114-
raise TimeoutError(msg)
115-
116-
poll_response = await self.http_client.get(
117-
polling_url,
118-
headers=poll_headers,
119-
)
120-
121-
if poll_response.status_code != 200:
122-
return poll_response
123-
124-
poll_data = poll_response.json()
125-
status = poll_data.get("status")
126-
127-
if status == "Ready":
128-
final_data = {
129-
**poll_data,
130-
"_submit_metadata": submit_data,
131-
}
132-
return httpx.Response(
133-
status_code=200,
134-
content=json.dumps(final_data).encode("utf-8"),
135-
headers={"content-type": "application/json"},
136-
request=httpx.Request("GET", polling_url),
137-
)
138-
elif status in ("Error", "Failed"):
139-
return httpx.Response(
140-
status_code=400,
141-
content=json.dumps(poll_data).encode("utf-8"),
142-
headers={"content-type": "application/json"},
143-
request=httpx.Request("GET", polling_url),
144-
)
145-
146-
await asyncio.sleep(config.POLLING_INTERVAL)
64+
return ImageGenerationFinishReason(reason=None)
14765

14866

14967
__all__ = ["BFLImageGenerationClient"]

packages/capabilities/image-generation/src/celeste_image_generation/providers/bfl/config.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

packages/capabilities/image-generation/src/celeste_image_generation/providers/bfl/parameters.py

Lines changed: 45 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,47 @@
1-
"""BFL parameter mappers for image generation."""
1+
"""BFL Images parameter mappers for image generation.
2+
3+
AspectRatioMapper is defined locally (handles "WxH" → width/height transformation).
4+
Other mappers subclass from provider and add capability-specific `name` attribute.
5+
"""
26

37
from typing import Any
48

5-
from celeste import Model
9+
from celeste_bfl.images.parameters import (
10+
GuidanceMapper as _GuidanceMapper,
11+
)
12+
from celeste_bfl.images.parameters import (
13+
HeightMapper as _HeightMapper,
14+
)
15+
from celeste_bfl.images.parameters import (
16+
OutputFormatMapper as _OutputFormatMapper,
17+
)
18+
from celeste_bfl.images.parameters import (
19+
PromptUpsamplingMapper as _PromptUpsamplingMapper,
20+
)
21+
from celeste_bfl.images.parameters import (
22+
SafetyToleranceMapper as _SafetyToleranceMapper,
23+
)
24+
from celeste_bfl.images.parameters import (
25+
SeedMapper as _SeedMapper,
26+
)
27+
from celeste_bfl.images.parameters import (
28+
StepsMapper as _StepsMapper,
29+
)
30+
from celeste_bfl.images.parameters import (
31+
WidthMapper as _WidthMapper,
32+
)
33+
34+
from celeste.models import Model
635
from celeste.parameters import ParameterMapper
736
from celeste_image_generation.parameters import ImageGenerationParameter
837

938

1039
class AspectRatioMapper(ParameterMapper):
11-
"""Map aspect_ratio to BFL width/height parameters."""
40+
"""Map aspect_ratio to BFL width/height parameters.
41+
42+
Converts 'WxH' string to width/height, rounded to nearest multiple of 16.
43+
Delegates to provider's WidthMapper and HeightMapper for the actual mapping.
44+
"""
1245

1346
name = ImageGenerationParameter.ASPECT_RATIO
1447

@@ -31,130 +64,35 @@ def map(
3164
width = ((width + 8) // 16) * 16
3265
height = ((height + 8) // 16) * 16
3366

34-
request["width"] = width
35-
request["height"] = height
67+
# Delegate to provider mappers
68+
request = _WidthMapper().map(request, width, model)
69+
request = _HeightMapper().map(request, height, model)
3670
return request
3771

3872

39-
class PromptUpsamplingMapper(ParameterMapper):
40-
"""Map prompt_upsampling parameter to BFL request format."""
41-
73+
class PromptUpsamplingMapper(_PromptUpsamplingMapper):
4274
name = ImageGenerationParameter.PROMPT_UPSAMPLING
4375

44-
def map(
45-
self,
46-
request: dict[str, Any],
47-
value: object,
48-
model: Model,
49-
) -> dict[str, Any]:
50-
"""Transform prompt_upsampling into provider request."""
51-
validated_value = self._validate_value(value, model)
52-
if validated_value is None:
53-
return request
54-
55-
request["prompt_upsampling"] = validated_value
56-
return request
57-
58-
59-
class SeedMapper(ParameterMapper):
60-
"""Map seed parameter to BFL request format."""
6176

77+
class SeedMapper(_SeedMapper):
6278
name = ImageGenerationParameter.SEED
6379

64-
def map(
65-
self,
66-
request: dict[str, Any],
67-
value: object,
68-
model: Model,
69-
) -> dict[str, Any]:
70-
"""Transform seed into provider request."""
71-
validated_value = self._validate_value(value, model)
72-
if validated_value is None:
73-
return request
74-
75-
request["seed"] = validated_value
76-
return request
77-
78-
79-
class SafetyToleranceMapper(ParameterMapper):
80-
"""Map safety_tolerance parameter to BFL request format."""
8180

81+
class SafetyToleranceMapper(_SafetyToleranceMapper):
8282
name = ImageGenerationParameter.SAFETY_TOLERANCE
8383

84-
def map(
85-
self,
86-
request: dict[str, Any],
87-
value: object,
88-
model: Model,
89-
) -> dict[str, Any]:
90-
"""Transform safety_tolerance into provider request."""
91-
validated_value = self._validate_value(value, model)
92-
if validated_value is None:
93-
return request
94-
95-
request["safety_tolerance"] = validated_value
96-
return request
97-
98-
99-
class OutputFormatMapper(ParameterMapper):
100-
"""Map output_format parameter to BFL request format."""
10184

85+
class OutputFormatMapper(_OutputFormatMapper):
10286
name = ImageGenerationParameter.OUTPUT_FORMAT
10387

104-
def map(
105-
self,
106-
request: dict[str, Any],
107-
value: object,
108-
model: Model,
109-
) -> dict[str, Any]:
110-
"""Transform output_format into provider request."""
111-
validated_value = self._validate_value(value, model)
112-
if validated_value is None:
113-
return request
114-
115-
request["output_format"] = validated_value
116-
return request
117-
118-
119-
class StepsMapper(ParameterMapper):
120-
"""Map steps parameter to BFL request format (flex only)."""
12188

89+
class StepsMapper(_StepsMapper):
12290
name = ImageGenerationParameter.STEPS
12391

124-
def map(
125-
self,
126-
request: dict[str, Any],
127-
value: object,
128-
model: Model,
129-
) -> dict[str, Any]:
130-
"""Transform steps into provider request."""
131-
validated_value = self._validate_value(value, model)
132-
if validated_value is None:
133-
return request
134-
135-
request["steps"] = validated_value
136-
return request
137-
138-
139-
class GuidanceMapper(ParameterMapper):
140-
"""Map guidance parameter to BFL request format (flex only)."""
14192

93+
class GuidanceMapper(_GuidanceMapper):
14294
name = ImageGenerationParameter.GUIDANCE
14395

144-
def map(
145-
self,
146-
request: dict[str, Any],
147-
value: object,
148-
model: Model,
149-
) -> dict[str, Any]:
150-
"""Transform guidance into provider request."""
151-
validated_value = self._validate_value(value, model)
152-
if validated_value is None:
153-
return request
154-
155-
request["guidance"] = validated_value
156-
return request
157-
15896

15997
BFL_PARAMETER_MAPPERS: list[ParameterMapper] = [
16098
AspectRatioMapper(),

packages/capabilities/image-generation/tests/integration_tests/test_image_generation/test_generate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ async def test_generate(provider: Provider, model: str, parameters: dict) -> Non
3535
client = create_client(
3636
capability=Capability.IMAGE_GENERATION,
3737
provider=provider,
38+
model=model,
3839
)
3940
prompt = "A red apple on a white background"
4041

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[project]
2+
name = "celeste-bfl"
3+
version = "0.3.0"
4+
description = "BFL (Black Forest Labs) provider package for Celeste AI"
5+
authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}]
6+
license = {text = "Apache-2.0"}
7+
requires-python = ">=3.12"
8+
9+
[tool.uv.sources]
10+
celeste-ai = { workspace = true }
11+
12+
[build-system]
13+
requires = ["hatchling"]
14+
build-backend = "hatchling.build"
15+
16+
[tool.hatch.build.targets.wheel]
17+
packages = ["src/celeste_bfl"]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""BFL (Black Forest Labs) provider package for Celeste AI."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""BFL Images API provider package."""

0 commit comments

Comments
 (0)