From 6d7f9d7ef0e5becad83e44b11786e58b3bc079b4 Mon Sep 17 00:00:00 2001 From: Kamil Benkirane <62942280+Kamilbenkirane@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:25:27 +0100 Subject: [PATCH 1/4] Fix Gemini response modalities to include TEXT and IMAGE (#125) * fix(images): use TEXT+IMAGE responseModalities for Gemini models The Gemini API's documented default is responseModalities: ["TEXT", "IMAGE"]. The previous hardcoded ["Image"] worked for gemini-2.5-flash-image but fails for gemini-3-pro-image-preview (a thinking model that requires both modalities). This aligns with Google's API documentation and ensures compatibility with both current and future Gemini image models. Fixes #123 https://claude.ai/code/session_01KYduqFZTvWMNMBW9b1nLXF * chore: bump version to 0.9.4 https://claude.ai/code/session_01KYduqFZTvWMNMBW9b1nLXF --------- Co-authored-by: Claude --- pyproject.toml | 2 +- src/celeste/modalities/images/providers/google/gemini.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aa6b104..00ab458 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "celeste-ai" -version = "0.9.3" +version = "0.9.4" description = "Open source, type-safe primitives for multi-modal AI. All capabilities, all providers, one interface" authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}] readme = "README.md" diff --git a/src/celeste/modalities/images/providers/google/gemini.py b/src/celeste/modalities/images/providers/google/gemini.py index 898fcb1..e714322 100644 --- a/src/celeste/modalities/images/providers/google/gemini.py +++ b/src/celeste/modalities/images/providers/google/gemini.py @@ -84,7 +84,7 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]: return { "contents": [{"parts": parts}], "generationConfig": { - "responseModalities": ["Image"], + "responseModalities": ["TEXT", "IMAGE"], "imageConfig": {}, }, } From 4381e8a0a3c6d05664c9207e1bcc919335eaadb1 Mon Sep 17 00:00:00 2001 From: Kamil Benkirane <62942280+Kamilbenkirane@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:52:55 +0100 Subject: [PATCH 2/4] feat: expose extra_body parameter on all modalities (#126) * feat: expose extra_body parameter on all modalities Add extra_body parameter to all public methods across images, audio, videos, and embeddings modalities. This allows users to pass provider-specific request fields (e.g., Google's generationConfig, imageConfig) without resorting to private methods. Updated methods: - images: generate, edit (stream, sync, sync.stream) - audio: speak (stream, sync, sync.stream) - videos: generate (sync) - embeddings: embed (async, sync) Also updated the modality client template for future modalities. Fixes #124 https://claude.ai/code/session_01KYduqFZTvWMNMBW9b1nLXF * style: format with ruff https://claude.ai/code/session_01KYduqFZTvWMNMBW9b1nLXF --------- Co-authored-by: Claude --- pyproject.toml | 2 +- src/celeste/modalities/audio/client.py | 15 ++++++++-- src/celeste/modalities/embeddings/client.py | 13 ++++++-- src/celeste/modalities/images/client.py | 30 +++++++++++++++---- src/celeste/modalities/videos/client.py | 8 +++-- .../client.py.template | 21 +++++++++---- 6 files changed, 70 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 00ab458..1ffb621 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "celeste-ai" -version = "0.9.4" +version = "0.9.5" description = "Open source, type-safe primitives for multi-modal AI. All capabilities, all providers, one interface" authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}] readme = "README.md" diff --git a/src/celeste/modalities/audio/client.py b/src/celeste/modalities/audio/client.py index 6c6f45b..eada7e7 100644 --- a/src/celeste/modalities/audio/client.py +++ b/src/celeste/modalities/audio/client.py @@ -1,6 +1,6 @@ """Audio modality client.""" -from typing import Unpack +from typing import Any, Unpack from asgiref.sync import async_to_sync @@ -45,6 +45,8 @@ def __init__(self, client: AudioClient) -> None: def speak( self, text: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[AudioParameters], ) -> AudioStream: """Stream speech generation.""" @@ -52,6 +54,7 @@ def speak( return self._client._stream( inputs, stream_class=self._client._stream_class(), + extra_body=extra_body, **parameters, ) @@ -65,11 +68,15 @@ def __init__(self, client: AudioClient) -> None: def speak( self, text: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[AudioParameters], ) -> AudioOutput: """Blocking speech generation.""" inputs = AudioInput(text=text) - return async_to_sync(self._client._predict)(inputs, **parameters) + return async_to_sync(self._client._predict)( + inputs, extra_body=extra_body, **parameters + ) @property def stream(self) -> "AudioSyncStreamNamespace": @@ -86,6 +93,8 @@ def __init__(self, client: AudioClient) -> None: def speak( self, text: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[AudioParameters], ) -> AudioStream: """Sync streaming speech generation. @@ -99,7 +108,7 @@ def speak( stream.output.content.save("output.mp3") """ # Return same stream as async version - __iter__/__next__ handle sync iteration - return self._client.stream.speak(text, **parameters) + return self._client.stream.speak(text, extra_body=extra_body, **parameters) __all__ = [ diff --git a/src/celeste/modalities/embeddings/client.py b/src/celeste/modalities/embeddings/client.py index 01126cb..1d46110 100644 --- a/src/celeste/modalities/embeddings/client.py +++ b/src/celeste/modalities/embeddings/client.py @@ -1,6 +1,6 @@ """Embeddings modality client.""" -from typing import Unpack +from typing import Any, Unpack from asgiref.sync import async_to_sync @@ -29,12 +29,15 @@ def _output_class(cls) -> type[EmbeddingsOutput]: async def embed( self, text: str | list[str], + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: """Generate embeddings from text. Args: text: Text to embed. Single string or list of strings. + extra_body: Additional provider-specific fields to merge into request. **parameters: Embedding parameters (e.g., dimensions). Returns: @@ -43,7 +46,7 @@ async def embed( - list[list[float]] if text was a list """ inputs = EmbeddingsInput(text=text) - output = await self._predict(inputs, **parameters) + output = await self._predict(inputs, extra_body=extra_body, **parameters) # If single text input, unwrap from batch format to single embedding if ( @@ -71,10 +74,14 @@ def __init__(self, client: EmbeddingsClient) -> None: def embed( self, text: str | list[str], + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: """Blocking embeddings generation.""" - return async_to_sync(self._client.embed)(text, **parameters) + return async_to_sync(self._client.embed)( + text, extra_body=extra_body, **parameters + ) __all__ = [ diff --git a/src/celeste/modalities/images/client.py b/src/celeste/modalities/images/client.py index 14cf72d..b0d7aff 100644 --- a/src/celeste/modalities/images/client.py +++ b/src/celeste/modalities/images/client.py @@ -1,6 +1,6 @@ """Images modality client.""" -from typing import Unpack +from typing import Any, Unpack from asgiref.sync import async_to_sync @@ -49,6 +49,8 @@ def __init__(self, client: ImagesClient) -> None: def generate( self, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[ImageParameters], ) -> ImagesStream: """Stream image generation.""" @@ -56,6 +58,7 @@ def generate( return self._client._stream( inputs, stream_class=self._client._stream_class(), + extra_body=extra_body, **parameters, ) @@ -63,6 +66,8 @@ def edit( self, image: ImageArtifact, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[ImageParameters], ) -> ImagesStream: """Stream image editing.""" @@ -70,6 +75,7 @@ def edit( return self._client._stream( inputs, stream_class=self._client._stream_class(), + extra_body=extra_body, **parameters, ) @@ -86,6 +92,8 @@ def __init__(self, client: ImagesClient) -> None: def generate( self, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[ImageParameters], ) -> ImageOutput: """Blocking image generation. @@ -95,12 +103,16 @@ def generate( result.content.show() """ inputs = ImageInput(prompt=prompt) - return async_to_sync(self._client._predict)(inputs, **parameters) + return async_to_sync(self._client._predict)( + inputs, extra_body=extra_body, **parameters + ) def edit( self, image: ImageArtifact, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[ImageParameters], ) -> ImageOutput: """Blocking image edit. @@ -110,7 +122,9 @@ def edit( result.content.show() """ inputs = ImageInput(prompt=prompt, image=image) - return async_to_sync(self._client._predict)(inputs, **parameters) + return async_to_sync(self._client._predict)( + inputs, extra_body=extra_body, **parameters + ) @property def stream(self) -> "ImagesSyncStreamNamespace": @@ -127,6 +141,8 @@ def __init__(self, client: ImagesClient) -> None: def generate( self, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[ImageParameters], ) -> ImagesStream: """Sync streaming image generation. @@ -140,12 +156,14 @@ def generate( print(stream.output.usage) """ # Return same stream as async version - __iter__/__next__ handle sync iteration - return self._client.stream.generate(prompt, **parameters) + return self._client.stream.generate(prompt, extra_body=extra_body, **parameters) def edit( self, image: ImageArtifact, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[ImageParameters], ) -> ImagesStream: """Sync streaming image editing. @@ -158,7 +176,9 @@ def edit( print(chunk.content) print(stream.output.usage) """ - return self._client.stream.edit(image, prompt, **parameters) + return self._client.stream.edit( + image, prompt, extra_body=extra_body, **parameters + ) __all__ = [ diff --git a/src/celeste/modalities/videos/client.py b/src/celeste/modalities/videos/client.py index 618989e..80a89be 100644 --- a/src/celeste/modalities/videos/client.py +++ b/src/celeste/modalities/videos/client.py @@ -1,6 +1,6 @@ """Videos modality client.""" -from typing import Unpack +from typing import Any, Unpack from asgiref.sync import async_to_sync @@ -42,6 +42,8 @@ def __init__(self, client: VideosClient) -> None: def generate( self, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[VideoParameters], ) -> VideoOutput: """Blocking video generation. @@ -51,7 +53,9 @@ def generate( result.content.save("video.mp4") """ inputs = VideoInput(prompt=prompt) - return async_to_sync(self._client._predict)(inputs, **parameters) + return async_to_sync(self._client._predict)( + inputs, extra_body=extra_body, **parameters + ) __all__ = [ diff --git a/templates/modalities/{modality_slug}/src/celeste_{modality_slug}/client.py.template b/templates/modalities/{modality_slug}/src/celeste_{modality_slug}/client.py.template index 150d30e..cb74625 100644 --- a/templates/modalities/{modality_slug}/src/celeste_{modality_slug}/client.py.template +++ b/templates/modalities/{modality_slug}/src/celeste_{modality_slug}/client.py.template @@ -1,6 +1,6 @@ """{Modality} modality client.""" -from typing import Unpack +from typing import Any, Unpack from asgiref.sync import async_to_sync @@ -70,6 +70,8 @@ class {Modality}StreamNamespace: def generate( self, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Stream: """Stream {modality} generation. @@ -82,6 +84,7 @@ class {Modality}StreamNamespace: return self._client._stream( inputs, stream_class=self._client._stream_class(), + extra_body=extra_body, **parameters, ) @@ -92,6 +95,7 @@ class {Modality}StreamNamespace: image: ImageContent | None = None, video: VideoContent | None = None, audio: AudioContent | None = None, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Stream: """Stream media analysis (image, video, or audio). @@ -105,6 +109,7 @@ class {Modality}StreamNamespace: return self._client._stream( inputs, stream_class=self._client._stream_class(), + extra_body=extra_body, **parameters, ) @@ -121,6 +126,8 @@ class {Modality}SyncNamespace: def generate( self, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Output: """Blocking {modality} generation. @@ -130,7 +137,7 @@ class {Modality}SyncNamespace: print(result.content) """ inputs = {Modality}Input(prompt=prompt) - return async_to_sync(self._client._predict)(inputs, **parameters) + return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, **parameters) def analyze( self, @@ -139,6 +146,7 @@ class {Modality}SyncNamespace: image: ImageContent | None = None, video: VideoContent | None = None, audio: AudioContent | None = None, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Output: """Blocking media analysis (image, video, or audio). @@ -149,7 +157,7 @@ class {Modality}SyncNamespace: """ self._client._check_media_support(image=image, video=video, audio=audio) inputs = {Modality}Input(prompt=prompt, image=image, video=video, audio=audio) - return async_to_sync(self._client._predict)(inputs, **parameters) + return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, **parameters) @property def stream(self) -> "{Modality}SyncStreamNamespace": @@ -166,6 +174,8 @@ class {Modality}SyncStreamNamespace: def generate( self, prompt: str, + *, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Stream: """Sync streaming {modality} generation. @@ -179,7 +189,7 @@ class {Modality}SyncStreamNamespace: print(stream.output.usage) """ # Return same stream as async version - __iter__/__next__ handle sync iteration - return self._client.stream.generate(prompt, **parameters) + return self._client.stream.generate(prompt, extra_body=extra_body, **parameters) def analyze( self, @@ -188,6 +198,7 @@ class {Modality}SyncStreamNamespace: image: ImageContent | None = None, video: VideoContent | None = None, audio: AudioContent | None = None, + extra_body: dict[str, Any] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Stream: """Sync streaming media analysis (image, video, or audio). @@ -202,7 +213,7 @@ class {Modality}SyncStreamNamespace: """ # Return same stream as async version - __iter__/__next__ handle sync iteration return self._client.stream.analyze( - prompt, image=image, video=video, audio=audio, **parameters + prompt, image=image, video=video, audio=audio, extra_body=extra_body, **parameters ) From 0ca70d0ae7cc4f2eff670ff35844bf5d2f5d7366 Mon Sep 17 00:00:00 2001 From: Kamil Benkirane <62942280+Kamilbenkirane@users.noreply.github.com> Date: Thu, 29 Jan 2026 22:54:15 +0100 Subject: [PATCH 3/4] feat(xai): add Grok Imagine for image and video generation (#128) * feat(xai): add grok-imagine-image and grok-imagine-video models Add xAI Grok Imagine support for image and video generation: Images (grok-imagine-image): - Generate and edit operations - Parameters: aspect_ratio, num_images, output_format - Aspect ratios: 1:1, 3:4, 4:3, 9:16, 16:9, 2:3, 3:2, and more Videos (grok-imagine-video): - Generate and edit operations - Async polling pattern (HTTP 200=ready, 202=processing) - Parameters: duration (1-15s), aspect_ratio, resolution Co-Authored-By: Claude Opus 4.5 * fix(xai): serialize image artifacts and validate video URLs - Image edit: serialize ImageArtifact to URL or base64 string instead of passing object directly (xAI API expects string) - Video edit: validate video has URL before using, raise clear error if not (xAI only supports URL, not base64/path) Co-Authored-By: Claude Opus 4.5 * fix(xai): remove client-side video URL validation Let the xAI API handle validation for video edit requests. Co-Authored-By: Claude Opus 4.5 --------- Co-authored-by: Claude Opus 4.5 --- .gitignore | 1 + notebooks/working-with-images.ipynb | 64 +++---- pyproject.toml | 2 +- src/celeste/modalities/images/models.py | 2 + .../modalities/images/providers/__init__.py | 2 + .../images/providers/xai/__init__.py | 6 + .../modalities/images/providers/xai/client.py | 101 +++++++++++ .../modalities/images/providers/xai/models.py | 38 +++++ .../images/providers/xai/parameters.py | 41 +++++ src/celeste/modalities/videos/io.py | 3 +- src/celeste/modalities/videos/models.py | 2 + .../modalities/videos/providers/__init__.py | 2 + .../videos/providers/xai/__init__.py | 6 + .../modalities/videos/providers/xai/client.py | 78 +++++++++ .../modalities/videos/providers/xai/models.py | 23 +++ .../videos/providers/xai/parameters.py | 41 +++++ src/celeste/providers/xai/images/__init__.py | 1 + src/celeste/providers/xai/images/client.py | 116 +++++++++++++ src/celeste/providers/xai/images/config.py | 13 ++ .../providers/xai/images/parameters.py | 73 ++++++++ src/celeste/providers/xai/videos/__init__.py | 1 + src/celeste/providers/xai/videos/client.py | 159 ++++++++++++++++++ src/celeste/providers/xai/videos/config.py | 20 +++ .../providers/xai/videos/parameters.py | 67 ++++++++ tests/integration_tests/images/test_edit.py | 1 + .../integration_tests/images/test_generate.py | 1 + .../integration_tests/videos/test_generate.py | 2 + 27 files changed, 832 insertions(+), 34 deletions(-) create mode 100644 src/celeste/modalities/images/providers/xai/__init__.py create mode 100644 src/celeste/modalities/images/providers/xai/client.py create mode 100644 src/celeste/modalities/images/providers/xai/models.py create mode 100644 src/celeste/modalities/images/providers/xai/parameters.py create mode 100644 src/celeste/modalities/videos/providers/xai/__init__.py create mode 100644 src/celeste/modalities/videos/providers/xai/client.py create mode 100644 src/celeste/modalities/videos/providers/xai/models.py create mode 100644 src/celeste/modalities/videos/providers/xai/parameters.py create mode 100644 src/celeste/providers/xai/images/__init__.py create mode 100644 src/celeste/providers/xai/images/client.py create mode 100644 src/celeste/providers/xai/images/config.py create mode 100644 src/celeste/providers/xai/images/parameters.py create mode 100644 src/celeste/providers/xai/videos/__init__.py create mode 100644 src/celeste/providers/xai/videos/client.py create mode 100644 src/celeste/providers/xai/videos/config.py create mode 100644 src/celeste/providers/xai/videos/parameters.py diff --git a/.gitignore b/.gitignore index c8fccfd..bd252c2 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,4 @@ uv.lock bandit-report.json scripts/ +assets/ diff --git a/notebooks/working-with-images.ipynb b/notebooks/working-with-images.ipynb index 8c7ca10..c6ae50d 100644 --- a/notebooks/working-with-images.ipynb +++ b/notebooks/working-with-images.ipynb @@ -20,13 +20,13 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "import celeste\n", "from IPython.display import Image, display" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -41,24 +41,24 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "img_gen_result = await celeste.images.generate(\n", " \"A nano banana on the beach\",\n", " model=\"gemini-2.5-flash-image\",\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "display(Image(data=img_gen_result.content.data))" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -73,25 +73,25 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "img_edit_result = await celeste.images.edit(\n", " image=img_gen_result.content,\n", " prompt=\"Make it night time\",\n", " model=\"gemini-2.5-flash-image\",\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "display(Image(data=img_edit_result.content.data))" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -106,25 +106,25 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "analyze_result = await celeste.images.analyze(\n", " prompt=\"What fruit is in this image and what color is it?\",\n", " image=img_gen_result.content,\n", " model=\"gemini-2.5-flash-lite\",\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "print(analyze_result.content)" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -159,7 +159,9 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "prompt = \"A blurry iPhone-style photograph showing the window of a moving train. Through the window, a scenic landscape appears: tall green cliffs running alongside a river, with a small European village built on the slopes. The motion blur suggests the train is moving quickly, with soft reflections on the glass, natural daylight, and a casual handheld phone-camera aesthetic. Sharp textures where possible, rich colors, and a realistic sense of depth and distance.\"\n", "\n", @@ -170,13 +172,11 @@ " steps=1,\n", ")\n", "display(Image(data=local_result.content.data))" - ], - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "---\n", "\n", @@ -187,7 +187,9 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "from tqdm.asyncio import tqdm\n", "\n", @@ -204,24 +206,22 @@ " pass\n", "\n", "display(Image(data=chunk.content.data))" - ], - "outputs": [], - "execution_count": null + ] }, { - "metadata": {}, "cell_type": "markdown", + "metadata": {}, "source": [ "---\n", "Star on GitHub 👉 [withceleste/celeste-python](https://github.com/withceleste/celeste-python)" ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "" + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 1ffb621..b48ca75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "celeste-ai" -version = "0.9.5" +version = "0.9.6" description = "Open source, type-safe primitives for multi-modal AI. All capabilities, all providers, one interface" authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}] readme = "README.md" diff --git a/src/celeste/modalities/images/models.py b/src/celeste/modalities/images/models.py index e80db1f..d7548ce 100644 --- a/src/celeste/modalities/images/models.py +++ b/src/celeste/modalities/images/models.py @@ -6,10 +6,12 @@ from .providers.byteplus.models import MODELS as BYTEPLUS_MODELS from .providers.google.models import MODELS as GOOGLE_MODELS from .providers.openai.models import MODELS as OPENAI_MODELS +from .providers.xai.models import MODELS as XAI_MODELS MODELS: list[Model] = [ *BFL_MODELS, *BYTEPLUS_MODELS, *GOOGLE_MODELS, *OPENAI_MODELS, + *XAI_MODELS, ] diff --git a/src/celeste/modalities/images/providers/__init__.py b/src/celeste/modalities/images/providers/__init__.py index 8540b99..e703e21 100644 --- a/src/celeste/modalities/images/providers/__init__.py +++ b/src/celeste/modalities/images/providers/__init__.py @@ -8,6 +8,7 @@ from .google import GoogleImagesClient from .ollama import OllamaImagesClient from .openai import OpenAIImagesClient +from .xai import XAIImagesClient PROVIDERS: dict[Provider, type[ImagesClient]] = { Provider.BFL: BFLImagesClient, @@ -15,4 +16,5 @@ Provider.GOOGLE: GoogleImagesClient, Provider.OLLAMA: OllamaImagesClient, Provider.OPENAI: OpenAIImagesClient, + Provider.XAI: XAIImagesClient, } diff --git a/src/celeste/modalities/images/providers/xai/__init__.py b/src/celeste/modalities/images/providers/xai/__init__.py new file mode 100644 index 0000000..a169fc5 --- /dev/null +++ b/src/celeste/modalities/images/providers/xai/__init__.py @@ -0,0 +1,6 @@ +"""xAI provider for images modality.""" + +from .client import XAIImagesClient +from .models import MODELS + +__all__ = ["MODELS", "XAIImagesClient"] diff --git a/src/celeste/modalities/images/providers/xai/client.py b/src/celeste/modalities/images/providers/xai/client.py new file mode 100644 index 0000000..31e1688 --- /dev/null +++ b/src/celeste/modalities/images/providers/xai/client.py @@ -0,0 +1,101 @@ +"""xAI images client.""" + +from typing import Any, Unpack + +from celeste.artifacts import ImageArtifact +from celeste.parameters import ParameterMapper +from celeste.providers.xai.images import config +from celeste.providers.xai.images.client import XAIImagesClient as XAIImagesMixin + +from ...client import ImagesClient +from ...io import ( + ImageFinishReason, + ImageInput, + ImageOutput, + ImageUsage, +) +from ...parameters import ImageParameters +from .parameters import XAI_PARAMETER_MAPPERS + + +class XAIImagesClient(XAIImagesMixin, ImagesClient): + """xAI images client.""" + + @classmethod + def parameter_mappers(cls) -> list[ParameterMapper]: + return XAI_PARAMETER_MAPPERS + + def _init_request(self, inputs: ImageInput) -> dict[str, Any]: + """Initialize request from inputs.""" + request: dict[str, Any] = {"prompt": inputs.prompt} + if inputs.image is not None: + # xAI accepts URL or base64 string + if inputs.image.url: + request["image"] = inputs.image.url + else: + request["image"] = inputs.image.get_base64() + return request + + async def generate( + self, + prompt: str, + **parameters: Unpack[ImageParameters], + ) -> ImageOutput: + """Generate images from prompt.""" + inputs = ImageInput(prompt=prompt) + return await self._predict( + inputs, + endpoint=config.XAIImagesEndpoint.CREATE_IMAGE, + **parameters, + ) + + async def edit( + self, + image: ImageArtifact, + prompt: str, + **parameters: Unpack[ImageParameters], + ) -> ImageOutput: + """Edit an image with text instructions.""" + inputs = ImageInput(image=image, prompt=prompt) + return await self._predict( + inputs, + endpoint=config.XAIImagesEndpoint.CREATE_EDIT, + **parameters, + ) + + def _parse_usage(self, response_data: dict[str, Any]) -> ImageUsage: + """Parse usage from response.""" + usage = super()._parse_usage(response_data) + return ImageUsage(**usage) + + def _parse_content( + self, + response_data: dict[str, Any], + **parameters: Unpack[ImageParameters], + ) -> ImageArtifact: + """Parse content from response.""" + data = super()._parse_content(response_data) + image_data = data[0] + + # xAI returns either b64_json or url + b64_json = image_data.get("b64_json") + if b64_json: + import base64 + + image_bytes = base64.b64decode(b64_json) + return ImageArtifact(data=image_bytes) + + url = image_data.get("url") + if url: + return ImageArtifact(url=url) + + msg = "No image URL or base64 data in response" + raise ValueError(msg) + + def _parse_finish_reason(self, response_data: dict[str, Any]) -> ImageFinishReason: + """Parse finish reason from response.""" + finish_reason = super()._parse_finish_reason(response_data) + return ImageFinishReason(reason=finish_reason.reason) + + +__all__ = ["XAIImagesClient"] diff --git a/src/celeste/modalities/images/providers/xai/models.py b/src/celeste/modalities/images/providers/xai/models.py new file mode 100644 index 0000000..85aabb2 --- /dev/null +++ b/src/celeste/modalities/images/providers/xai/models.py @@ -0,0 +1,38 @@ +"""xAI models for images modality.""" + +from celeste.constraints import Choice, Range +from celeste.core import Modality, Operation, Provider +from celeste.models import Model + +from ...parameters import ImageParameter + +MODELS: list[Model] = [ + Model( + id="grok-imagine-image", + provider=Provider.XAI, + display_name="Grok Imagine Image", + operations={Modality.IMAGES: {Operation.GENERATE, Operation.EDIT}}, + parameter_constraints={ + ImageParameter.NUM_IMAGES: Range(min=1, max=10), + ImageParameter.ASPECT_RATIO: Choice( + options=[ + "1:1", + "3:4", + "4:3", + "9:16", + "16:9", + "2:3", + "3:2", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + "auto", + ] + ), + ImageParameter.OUTPUT_FORMAT: Choice(options=["url", "b64_json"]), + }, + ), +] diff --git a/src/celeste/modalities/images/providers/xai/parameters.py b/src/celeste/modalities/images/providers/xai/parameters.py new file mode 100644 index 0000000..3078a09 --- /dev/null +++ b/src/celeste/modalities/images/providers/xai/parameters.py @@ -0,0 +1,41 @@ +"""xAI parameter mappers for images.""" + +from celeste.parameters import ParameterMapper +from celeste.providers.xai.images.parameters import ( + AspectRatioMapper as _AspectRatioMapper, +) +from celeste.providers.xai.images.parameters import ( + NumImagesMapper as _NumImagesMapper, +) +from celeste.providers.xai.images.parameters import ( + ResponseFormatMapper as _ResponseFormatMapper, +) + +from ...parameters import ImageParameter + + +class AspectRatioMapper(_AspectRatioMapper): + """Map aspect_ratio to xAI's aspect_ratio parameter.""" + + name = ImageParameter.ASPECT_RATIO + + +class NumImagesMapper(_NumImagesMapper): + """Map num_images to xAI's n parameter.""" + + name = ImageParameter.NUM_IMAGES + + +class OutputFormatMapper(_ResponseFormatMapper): + """Map output_format to xAI's response_format parameter.""" + + name = ImageParameter.OUTPUT_FORMAT + + +XAI_PARAMETER_MAPPERS: list[ParameterMapper] = [ + AspectRatioMapper(), + NumImagesMapper(), + OutputFormatMapper(), +] + +__all__ = ["XAI_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/videos/io.py b/src/celeste/modalities/videos/io.py index ec4643e..0a005c6 100644 --- a/src/celeste/modalities/videos/io.py +++ b/src/celeste/modalities/videos/io.py @@ -7,9 +7,10 @@ class VideoInput(Input): - """Input for video generation operations.""" + """Input for video generation and edit operations.""" prompt: str + video: VideoArtifact | None = None # For edit operations class VideoFinishReason(FinishReason): diff --git a/src/celeste/modalities/videos/models.py b/src/celeste/modalities/videos/models.py index f4821c9..4788b20 100644 --- a/src/celeste/modalities/videos/models.py +++ b/src/celeste/modalities/videos/models.py @@ -5,9 +5,11 @@ from .providers.byteplus.models import MODELS as BYTEPLUS_MODELS from .providers.google.models import MODELS as GOOGLE_MODELS from .providers.openai.models import MODELS as OPENAI_MODELS +from .providers.xai.models import MODELS as XAI_MODELS MODELS: list[Model] = [ *BYTEPLUS_MODELS, *GOOGLE_MODELS, *OPENAI_MODELS, + *XAI_MODELS, ] diff --git a/src/celeste/modalities/videos/providers/__init__.py b/src/celeste/modalities/videos/providers/__init__.py index 7108886..e3db955 100644 --- a/src/celeste/modalities/videos/providers/__init__.py +++ b/src/celeste/modalities/videos/providers/__init__.py @@ -6,9 +6,11 @@ from .byteplus import BytePlusVideosClient from .google import GoogleVideosClient from .openai import OpenAIVideosClient +from .xai import XAIVideosClient PROVIDERS: dict[Provider, type[VideosClient]] = { Provider.BYTEPLUS: BytePlusVideosClient, Provider.GOOGLE: GoogleVideosClient, Provider.OPENAI: OpenAIVideosClient, + Provider.XAI: XAIVideosClient, } diff --git a/src/celeste/modalities/videos/providers/xai/__init__.py b/src/celeste/modalities/videos/providers/xai/__init__.py new file mode 100644 index 0000000..6d6426e --- /dev/null +++ b/src/celeste/modalities/videos/providers/xai/__init__.py @@ -0,0 +1,6 @@ +"""xAI provider for videos modality.""" + +from .client import XAIVideosClient +from .models import MODELS + +__all__ = ["MODELS", "XAIVideosClient"] diff --git a/src/celeste/modalities/videos/providers/xai/client.py b/src/celeste/modalities/videos/providers/xai/client.py new file mode 100644 index 0000000..3225a85 --- /dev/null +++ b/src/celeste/modalities/videos/providers/xai/client.py @@ -0,0 +1,78 @@ +"""xAI videos client.""" + +from typing import Any, Unpack + +from celeste.artifacts import VideoArtifact +from celeste.parameters import ParameterMapper +from celeste.providers.xai.videos import config +from celeste.providers.xai.videos.client import XAIVideosClient as XAIVideosMixin + +from ...client import VideosClient +from ...io import VideoFinishReason, VideoInput, VideoOutput, VideoUsage +from ...parameters import VideoParameters +from .parameters import XAI_PARAMETER_MAPPERS + + +class XAIVideosClient(XAIVideosMixin, VideosClient): + """xAI client for video generation.""" + + @classmethod + def parameter_mappers(cls) -> list[ParameterMapper]: + return XAI_PARAMETER_MAPPERS + + def _init_request(self, inputs: VideoInput) -> dict[str, Any]: + """Initialize request from inputs.""" + request: dict[str, Any] = {"prompt": inputs.prompt} + if inputs.video is not None: + request["video"] = {"url": inputs.video.url} + return request + + async def generate( + self, + prompt: str, + **parameters: Unpack[VideoParameters], + ) -> VideoOutput: + """Generate videos from prompt.""" + inputs = VideoInput(prompt=prompt) + return await self._predict( + inputs, + endpoint=config.XAIVideosEndpoint.CREATE_VIDEO, + **parameters, + ) + + async def edit( + self, + video: VideoArtifact, + prompt: str, + **parameters: Unpack[VideoParameters], + ) -> VideoOutput: + """Edit a video with text instructions.""" + inputs = VideoInput(prompt=prompt, video=video) + return await self._predict( + inputs, + endpoint=config.XAIVideosEndpoint.CREATE_EDIT, + **parameters, + ) + + def _parse_usage(self, response_data: dict[str, Any]) -> VideoUsage: + """Parse usage from response.""" + usage = super()._parse_usage(response_data) + return VideoUsage(**usage) + + def _parse_content( + self, + response_data: dict[str, Any], + **parameters: Unpack[VideoParameters], + ) -> VideoArtifact: + """Parse content from response.""" + # xAI returns video URL directly + url = super()._parse_content(response_data) + return VideoArtifact(url=url) + + def _parse_finish_reason(self, response_data: dict[str, Any]) -> VideoFinishReason: + """Parse finish reason from response.""" + finish_reason = super()._parse_finish_reason(response_data) + return VideoFinishReason(reason=finish_reason.reason) + + +__all__ = ["XAIVideosClient"] diff --git a/src/celeste/modalities/videos/providers/xai/models.py b/src/celeste/modalities/videos/providers/xai/models.py new file mode 100644 index 0000000..dfb0b03 --- /dev/null +++ b/src/celeste/modalities/videos/providers/xai/models.py @@ -0,0 +1,23 @@ +"""xAI models for videos modality.""" + +from celeste.constraints import Choice, Range +from celeste.core import Modality, Operation, Provider +from celeste.models import Model + +from ...parameters import VideoParameter + +MODELS: list[Model] = [ + Model( + id="grok-imagine-video", + provider=Provider.XAI, + display_name="Grok Imagine Video", + operations={Modality.VIDEOS: {Operation.GENERATE, Operation.EDIT}}, + parameter_constraints={ + VideoParameter.DURATION: Range(min=1, max=15), + VideoParameter.ASPECT_RATIO: Choice( + options=["16:9", "4:3", "1:1", "9:16", "3:4", "3:2", "2:3"] + ), + VideoParameter.RESOLUTION: Choice(options=["720p", "480p"]), + }, + ), +] diff --git a/src/celeste/modalities/videos/providers/xai/parameters.py b/src/celeste/modalities/videos/providers/xai/parameters.py new file mode 100644 index 0000000..aa35548 --- /dev/null +++ b/src/celeste/modalities/videos/providers/xai/parameters.py @@ -0,0 +1,41 @@ +"""xAI parameter mappers for videos.""" + +from celeste.parameters import ParameterMapper +from celeste.providers.xai.videos.parameters import ( + AspectRatioMapper as _AspectRatioMapper, +) +from celeste.providers.xai.videos.parameters import ( + DurationMapper as _DurationMapper, +) +from celeste.providers.xai.videos.parameters import ( + ResolutionMapper as _ResolutionMapper, +) + +from ...parameters import VideoParameter + + +class DurationMapper(_DurationMapper): + """Map duration to xAI's duration parameter.""" + + name = VideoParameter.DURATION + + +class AspectRatioMapper(_AspectRatioMapper): + """Map aspect_ratio to xAI's aspect_ratio parameter.""" + + name = VideoParameter.ASPECT_RATIO + + +class ResolutionMapper(_ResolutionMapper): + """Map resolution to xAI's resolution parameter.""" + + name = VideoParameter.RESOLUTION + + +XAI_PARAMETER_MAPPERS: list[ParameterMapper] = [ + DurationMapper(), + AspectRatioMapper(), + ResolutionMapper(), +] + +__all__ = ["XAI_PARAMETER_MAPPERS"] diff --git a/src/celeste/providers/xai/images/__init__.py b/src/celeste/providers/xai/images/__init__.py new file mode 100644 index 0000000..a3a95d3 --- /dev/null +++ b/src/celeste/providers/xai/images/__init__.py @@ -0,0 +1 @@ +"""xAI Images API provider package.""" diff --git a/src/celeste/providers/xai/images/client.py b/src/celeste/providers/xai/images/client.py new file mode 100644 index 0000000..9a31405 --- /dev/null +++ b/src/celeste/providers/xai/images/client.py @@ -0,0 +1,116 @@ +"""xAI Images API client mixin.""" + +from collections.abc import AsyncIterator +from typing import Any + +from celeste.client import APIMixin +from celeste.core import UsageField +from celeste.exceptions import StreamingNotSupportedError +from celeste.io import FinishReason +from celeste.mime_types import ApplicationMimeType + +from . import config + + +class XAIImagesClient(APIMixin): + """Mixin for xAI Images API. + + Provides shared HTTP implementation: + - _make_request(endpoint=...) - HTTP POST to specified endpoint + - _make_stream_request() - Raises StreamingNotSupportedError (xAI Images doesn't support streaming) + - _parse_usage() - Extract usage dict from response + - _parse_content() - Extract data array from response + - _parse_finish_reason() - Returns None (no finish reasons for images) + - _build_metadata() - Filter content fields + + Modality clients pass endpoint parameter to route operations: + await self._predict(inputs, endpoint=config.XAIImagesEndpoint.CREATE_IMAGE, **parameters) + """ + + def _build_request( + self, + inputs: Any, + extra_body: dict[str, Any] | None = None, + streaming: bool = False, + **parameters: Any, + ) -> dict[str, Any]: + """Build request with model ID.""" + request_body = super()._build_request( + inputs, extra_body=extra_body, streaming=streaming, **parameters + ) + request_body["model"] = self.model.id + return request_body + + async def _make_request( + self, + request_body: dict[str, Any], + *, + endpoint: str | None = None, + **parameters: Any, + ) -> dict[str, Any]: + """Make HTTP request to xAI Images API.""" + if endpoint is None: + endpoint = config.XAIImagesEndpoint.CREATE_IMAGE + + headers = { + **self.auth.get_headers(), + "Content-Type": ApplicationMimeType.JSON, + } + + response = await self.http_client.post( + f"{config.BASE_URL}{endpoint}", + headers=headers, + json_body=request_body, + ) + self._handle_error_response(response) + data: dict[str, Any] = response.json() + return data + + def _make_stream_request( + self, + request_body: dict[str, Any], + *, + endpoint: str | None = None, + **parameters: Any, + ) -> AsyncIterator[dict[str, Any]]: + """xAI Images does not support SSE streaming.""" + raise StreamingNotSupportedError(model_id=self.model.id) + + @staticmethod + def map_usage_fields(usage_data: dict[str, Any]) -> dict[str, int | float | None]: + """Map xAI Images usage fields to unified names.""" + return { + UsageField.INPUT_TOKENS: usage_data.get("input_tokens"), + UsageField.OUTPUT_TOKENS: usage_data.get("output_tokens"), + UsageField.TOTAL_TOKENS: usage_data.get("total_tokens"), + } + + def _parse_usage( + self, response_data: dict[str, Any] + ) -> dict[str, int | float | None]: + """Extract usage data from xAI Images API response.""" + usage_data = response_data.get("usage", {}) + return XAIImagesClient.map_usage_fields(usage_data) + + def _parse_content(self, response_data: dict[str, Any]) -> Any: + """Parse data array from xAI Images API response.""" + data = response_data.get("data", []) + if not data: + msg = "No image data in response" + raise ValueError(msg) + return data + + def _parse_finish_reason(self, response_data: dict[str, Any]) -> FinishReason: + """xAI Images API doesn't provide finish reasons.""" + return FinishReason(reason=None) + + def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: + """Build metadata dictionary, filtering out content fields.""" + content_fields = {"data"} + filtered_data = { + k: v for k, v in response_data.items() if k not in content_fields + } + return super()._build_metadata(filtered_data) + + +__all__ = ["XAIImagesClient"] diff --git a/src/celeste/providers/xai/images/config.py b/src/celeste/providers/xai/images/config.py new file mode 100644 index 0000000..7da1e3f --- /dev/null +++ b/src/celeste/providers/xai/images/config.py @@ -0,0 +1,13 @@ +"""Configuration for xAI Images API.""" + +from enum import StrEnum + + +class XAIImagesEndpoint(StrEnum): + """Endpoints for xAI Images API.""" + + CREATE_IMAGE = "/v1/images/generations" + CREATE_EDIT = "/v1/images/edits" + + +BASE_URL = "https://api.x.ai" diff --git a/src/celeste/providers/xai/images/parameters.py b/src/celeste/providers/xai/images/parameters.py new file mode 100644 index 0000000..7d8670b --- /dev/null +++ b/src/celeste/providers/xai/images/parameters.py @@ -0,0 +1,73 @@ +"""xAI Images API parameter mappers. + +Naming convention: +- Mapper class name MUST match the provider's API parameter name +- Example: API param "aspect_ratio" → class AspectRatioMapper +- The request key should match the provider's expected field name exactly +""" + +from typing import Any + +from celeste.models import Model +from celeste.parameters import ParameterMapper + + +class AspectRatioMapper(ParameterMapper): + """Map aspect_ratio to xAI aspect_ratio field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform aspect_ratio into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + request["aspect_ratio"] = validated_value + return request + + +class NumImagesMapper(ParameterMapper): + """Map num_images to xAI n field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform num_images into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + request["n"] = validated_value + return request + + +class ResponseFormatMapper(ParameterMapper): + """Map response_format to xAI response_format field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform response_format into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + request["response_format"] = validated_value + return request + + +__all__ = [ + "AspectRatioMapper", + "NumImagesMapper", + "ResponseFormatMapper", +] diff --git a/src/celeste/providers/xai/videos/__init__.py b/src/celeste/providers/xai/videos/__init__.py new file mode 100644 index 0000000..e34a607 --- /dev/null +++ b/src/celeste/providers/xai/videos/__init__.py @@ -0,0 +1 @@ +"""xAI Videos API provider package.""" diff --git a/src/celeste/providers/xai/videos/client.py b/src/celeste/providers/xai/videos/client.py new file mode 100644 index 0000000..f6dcbdb --- /dev/null +++ b/src/celeste/providers/xai/videos/client.py @@ -0,0 +1,159 @@ +"""xAI Videos API client mixin.""" + +import asyncio +from collections.abc import AsyncIterator +from typing import Any + +from celeste.client import APIMixin +from celeste.core import UsageField +from celeste.exceptions import StreamingNotSupportedError +from celeste.io import FinishReason +from celeste.mime_types import ApplicationMimeType + +from . import config + + +class XAIVideosClient(APIMixin): + """Mixin for xAI Videos API video generation. + + Provides shared implementation for video generation: + - _make_request() - HTTP POST with async polling pattern + - _parse_usage() - Extract usage dict from response + - _parse_content() - Extract video URL from response + - _parse_finish_reason() - Returns None (Videos API doesn't provide finish reasons) + - _build_metadata() - Filter content fields + + The Videos API uses async polling: + 1. POST to /v1/videos/generations returns request_id + 2. Poll GET /v1/videos/{request_id} until completed/failed + 3. Response contains video URL directly + """ + + def _build_request( + self, + inputs: Any, + extra_body: dict[str, Any] | None = None, + streaming: bool = False, + **parameters: Any, + ) -> dict[str, Any]: + """Build request with model ID.""" + request_body = super()._build_request( + inputs, extra_body=extra_body, streaming=streaming, **parameters + ) + request_body["model"] = self.model.id + return request_body + + def _make_stream_request( + self, + request_body: dict[str, Any], + *, + endpoint: str | None = None, + **parameters: Any, + ) -> AsyncIterator[dict[str, Any]]: + """xAI Videos API does not support SSE streaming.""" + raise StreamingNotSupportedError(model_id=self.model.id) + + async def _make_request( + self, + request_body: dict[str, Any], + *, + endpoint: str | None = None, + **parameters: Any, + ) -> dict[str, Any]: + """Make HTTP request with async polling for xAI video generation.""" + if endpoint is None: + endpoint = config.XAIVideosEndpoint.CREATE_VIDEO + + headers = { + **self.auth.get_headers(), + "Content-Type": ApplicationMimeType.JSON, + } + + # Submit video generation request + response = await self.http_client.post( + f"{config.BASE_URL}{endpoint}", + headers=headers, + json_body=request_body, + ) + self._handle_error_response(response) + video_obj: dict[str, Any] = response.json() + + request_id = video_obj.get("request_id") + if not request_id: + # Response already has URL (e.g., cached result) + if "url" in video_obj: + return video_obj + msg = "No request_id in video generation response" + raise ValueError(msg) + + # Poll for completion + poll_endpoint = f"/v1/videos/{request_id}" + for _ in range(config.MAX_POLLS): + await asyncio.sleep(config.POLL_INTERVAL) + + status_response = await self.http_client.get( + f"{config.BASE_URL}{poll_endpoint}", + headers=headers, + ) + self._handle_error_response(status_response) + + # xAI uses HTTP status codes: 200 = ready, 202 = still processing + if status_response.status_code == 200: + return status_response.json() + + # 202 Accepted means still processing, continue polling + if status_response.status_code == 202: + continue + + # Parse response for error handling + video_obj = status_response.json() + status = video_obj.get("status", "") + if status == config.STATUS_FAILED: + error = video_obj.get("error", "Video generation failed") + raise RuntimeError(error) + + msg = f"Video generation timeout after {config.MAX_POLLS * config.POLL_INTERVAL} seconds" + raise TimeoutError(msg) + + @staticmethod + def map_usage_fields(usage_data: dict[str, Any]) -> dict[str, int | float | None]: + """Map xAI Videos usage fields to unified names.""" + return { + UsageField.INPUT_TOKENS: usage_data.get("input_tokens"), + UsageField.OUTPUT_TOKENS: usage_data.get("output_tokens"), + UsageField.TOTAL_TOKENS: usage_data.get("total_tokens"), + } + + def _parse_usage( + self, response_data: dict[str, Any] + ) -> dict[str, int | float | None]: + """Extract usage data from xAI Videos API response.""" + usage_data = response_data.get("usage", {}) + return XAIVideosClient.map_usage_fields(usage_data) + + def _parse_content(self, response_data: dict[str, Any]) -> Any: + """Parse video URL from xAI Videos API response. + + Response structure: {"video": {"url": "...", "duration": 8}, "model": "..."} + """ + video = response_data.get("video", {}) + url = video.get("url") + if not url: + msg = "No video URL in response" + raise ValueError(msg) + return url + + def _parse_finish_reason(self, response_data: dict[str, Any]) -> FinishReason: + """Videos API doesn't provide finish reasons.""" + return FinishReason(reason=None) + + def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]: + """Build metadata dictionary, filtering out content fields.""" + content_fields = {"video"} + filtered_data = { + k: v for k, v in response_data.items() if k not in content_fields + } + return super()._build_metadata(filtered_data) + + +__all__ = ["XAIVideosClient"] diff --git a/src/celeste/providers/xai/videos/config.py b/src/celeste/providers/xai/videos/config.py new file mode 100644 index 0000000..7405f01 --- /dev/null +++ b/src/celeste/providers/xai/videos/config.py @@ -0,0 +1,20 @@ +"""Configuration for xAI Videos API.""" + +from enum import StrEnum + + +class XAIVideosEndpoint(StrEnum): + """Endpoints for xAI Videos API.""" + + CREATE_VIDEO = "/v1/videos/generations" + CREATE_EDIT = "/v1/videos/edits" + + +BASE_URL = "https://api.x.ai" + +# Polling Configuration +MAX_POLLS = 60 +POLL_INTERVAL = 5 # seconds + +# Status Constants +STATUS_FAILED = "failed" diff --git a/src/celeste/providers/xai/videos/parameters.py b/src/celeste/providers/xai/videos/parameters.py new file mode 100644 index 0000000..fc0b7a2 --- /dev/null +++ b/src/celeste/providers/xai/videos/parameters.py @@ -0,0 +1,67 @@ +"""xAI Videos API parameter mappers.""" + +from typing import Any + +from celeste.models import Model +from celeste.parameters import ParameterMapper + + +class DurationMapper(ParameterMapper): + """Map duration to xAI duration field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform duration into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + request["duration"] = validated_value + return request + + +class AspectRatioMapper(ParameterMapper): + """Map aspect_ratio to xAI aspect_ratio field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform aspect_ratio into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + request["aspect_ratio"] = validated_value + return request + + +class ResolutionMapper(ParameterMapper): + """Map resolution to xAI resolution field.""" + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform resolution into provider request.""" + validated_value = self._validate_value(value, model) + if validated_value is None: + return request + + request["resolution"] = validated_value + return request + + +__all__ = [ + "AspectRatioMapper", + "DurationMapper", + "ResolutionMapper", +] diff --git a/tests/integration_tests/images/test_edit.py b/tests/integration_tests/images/test_edit.py index 5eaa0cd..d05ab1f 100644 --- a/tests/integration_tests/images/test_edit.py +++ b/tests/integration_tests/images/test_edit.py @@ -22,6 +22,7 @@ (Provider.OPENAI, "gpt-image-1-mini"), (Provider.GOOGLE, "gemini-2.5-flash-image"), (Provider.BFL, "flux-2-pro"), + (Provider.XAI, "grok-imagine-image"), ], ) @pytest.mark.integration diff --git a/tests/integration_tests/images/test_generate.py b/tests/integration_tests/images/test_generate.py index ab2371e..705b298 100644 --- a/tests/integration_tests/images/test_generate.py +++ b/tests/integration_tests/images/test_generate.py @@ -23,6 +23,7 @@ (Provider.GOOGLE, "imagen-4.0-fast-generate-001", {"num_images": 1}), (Provider.BYTEPLUS, "seedream-4-0-250828", {}), (Provider.BFL, "flux-2-pro", {}), + (Provider.XAI, "grok-imagine-image", {}), ], ) @pytest.mark.integration diff --git a/tests/integration_tests/videos/test_generate.py b/tests/integration_tests/videos/test_generate.py index 1e7edc7..9e9ef83 100644 --- a/tests/integration_tests/videos/test_generate.py +++ b/tests/integration_tests/videos/test_generate.py @@ -29,6 +29,8 @@ "seedance-1-0-lite-t2v-250428", {"duration": 2, "resolution": "480p"}, ), + # xAI Grok Imagine: duration 1-15s, 480p/720p + (Provider.XAI, "grok-imagine-video", {"duration": 2}), ], ) @pytest.mark.integration From 4902f3e2eab3e570785ed8f802f02b2947593bec Mon Sep 17 00:00:00 2001 From: Kamil Benkirane <62942280+Kamilbenkirane@users.noreply.github.com> Date: Thu, 29 Jan 2026 23:08:33 +0100 Subject: [PATCH 4/4] fix(xai): use ImageUrl struct format for image edits (#129) xAI API expects {"image": {"url": "..."}} not a raw string. Supports both URL and data URI (for base64 images). Co-authored-by: Claude Opus 4.5 --- src/celeste/modalities/images/providers/xai/client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/celeste/modalities/images/providers/xai/client.py b/src/celeste/modalities/images/providers/xai/client.py index 31e1688..66c19eb 100644 --- a/src/celeste/modalities/images/providers/xai/client.py +++ b/src/celeste/modalities/images/providers/xai/client.py @@ -29,11 +29,13 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]: """Initialize request from inputs.""" request: dict[str, Any] = {"prompt": inputs.prompt} if inputs.image is not None: - # xAI accepts URL or base64 string + # xAI expects {"image": {"url": "..."}} with URL or data URI if inputs.image.url: - request["image"] = inputs.image.url + request["image"] = {"url": inputs.image.url} else: - request["image"] = inputs.image.get_base64() + mime_type = inputs.image.mime_type + base64_data = inputs.image.get_base64() + request["image"] = {"url": f"data:{mime_type};base64,{base64_data}"} return request async def generate(