Skip to content

Commit 6e650bc

Browse files
refactor(client): add APIMixin abstract base for type-safe provider mixins (withceleste#87)
Introduce APIMixin class as abstract base for all provider API mixins: - Declares model, auth, provider attributes with proper type hints - Declares abstract http_client property returning HTTPClient - Provides _build_request, _build_metadata, _handle_error_response stubs that chain to Client via MRO Updates all 15 provider client mixins to inherit from APIMixin: - anthropic/messages, bfl/images, byteplus/images, byteplus/videos - cohere/chat, elevenlabs/text_to_speech, google/generate_content - google/imagen, google/veo, mistral/chat, openai/audio - openai/images, openai/responses, openai/videos, xai/responses This eliminates ~50 type: ignore comments across provider clients while maintaining the same runtime behavior through MRO chaining.
1 parent 78292c2 commit 6e650bc

File tree

16 files changed

+187
-118
lines changed

16 files changed

+187
-118
lines changed

packages/providers/anthropic/src/celeste_anthropic/messages/client.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import httpx
77

8+
from celeste.client import APIMixin
89
from celeste.core import UsageField
910
from celeste.io import FinishReason
1011
from celeste.mime_types import ApplicationMimeType
1112

1213
from . import config
1314

1415

15-
class AnthropicMessagesClient:
16+
class AnthropicMessagesClient(APIMixin):
1617
"""Mixin for Anthropic Messages API capabilities.
1718
1819
Provides shared implementation for all capabilities using the Messages API:
@@ -39,8 +40,8 @@ def _build_request(
3940
**parameters: Any,
4041
) -> Any:
4142
"""Build request with Anthropic-specific defaults."""
42-
request = super()._build_request(inputs, **parameters) # type: ignore[misc]
43-
request["model"] = self.model.id # type: ignore[attr-defined]
43+
request = super()._build_request(inputs, **parameters)
44+
request["model"] = self.model.id
4445

4546
# Apply max_tokens default if not set (Anthropic requires it)
4647
if "max_tokens" not in request:
@@ -53,7 +54,7 @@ def _build_headers(self, request_body: dict[str, Any]) -> dict[str, str]:
5354
beta_features: list[str] = request_body.pop("_beta_features", [])
5455

5556
headers: dict[str, str] = {
56-
**self.auth.get_headers(), # type: ignore[attr-defined]
57+
**self.auth.get_headers(),
5758
config.HEADER_ANTHROPIC_VERSION: config.ANTHROPIC_VERSION,
5859
"Content-Type": ApplicationMimeType.JSON,
5960
}
@@ -75,7 +76,7 @@ async def _make_request(
7576
"""Make HTTP request to Anthropic Messages API endpoint."""
7677
headers = self._build_headers(request_body)
7778

78-
return await self.http_client.post( # type: ignore[attr-defined,no-any-return]
79+
return await self.http_client.post(
7980
f"{config.BASE_URL}{config.AnthropicMessagesEndpoint.CREATE_MESSAGE}",
8081
headers=headers,
8182
json_body=request_body,
@@ -90,7 +91,7 @@ def _make_stream_request(
9091
request_body["stream"] = True
9192
headers = self._build_headers(request_body)
9293

93-
return self.http_client.stream_post( # type: ignore[attr-defined,no-any-return]
94+
return self.http_client.stream_post(
9495
f"{config.BASE_URL}{config.AnthropicMessagesEndpoint.CREATE_MESSAGE}",
9596
headers=headers,
9697
json_body=request_body,
@@ -147,7 +148,7 @@ def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]:
147148
filtered_data = {
148149
k: v for k, v in response_data.items() if k not in content_fields
149150
}
150-
return super()._build_metadata(filtered_data) # type: ignore[misc,no-any-return]
151+
return super()._build_metadata(filtered_data)
151152

152153

153154
__all__ = ["AnthropicMessagesClient"]

packages/providers/bfl/src/celeste_bfl/images/client.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
import httpx
88

9+
from celeste.client import APIMixin
910
from celeste.core import UsageField
1011
from celeste.io import FinishReason
1112
from celeste.mime_types import ApplicationMimeType
1213

1314
from . import config
1415

1516

16-
class BFLImagesClient:
17+
class BFLImagesClient(APIMixin):
1718
"""Mixin for BFL Images API operations.
1819
1920
Provides shared implementation:
@@ -44,30 +45,30 @@ async def _make_request(
4445
2. Poll polling_url until Ready/Failed
4546
3. Return response with _submit_metadata for usage parsing
4647
"""
47-
auth_headers = self.auth.get_headers() # type: ignore[attr-defined]
48+
auth_headers = self.auth.get_headers()
4849
headers = {
4950
**auth_headers,
5051
"Content-Type": ApplicationMimeType.JSON,
5152
"Accept": ApplicationMimeType.JSON,
5253
}
5354

54-
endpoint = config.BFLImagesEndpoint.CREATE_IMAGE.format(model_id=self.model.id) # type: ignore[attr-defined]
55+
endpoint = config.BFLImagesEndpoint.CREATE_IMAGE.format(model_id=self.model.id)
5556

5657
# Phase 1: Submit job
57-
submit_response = await self.http_client.post( # type: ignore[attr-defined]
58+
submit_response = await self.http_client.post(
5859
f"{config.BASE_URL}{endpoint}",
5960
headers=headers,
6061
json_body=request_body,
6162
)
6263

6364
if submit_response.status_code != 200:
64-
return submit_response # type: ignore[no-any-return]
65+
return submit_response
6566

6667
submit_data = submit_response.json()
6768
polling_url = submit_data.get("polling_url")
6869

6970
if not polling_url:
70-
msg = f"No polling_url in {self.provider} response" # type: ignore[attr-defined]
71+
msg = f"No polling_url in {self.provider} response"
7172
raise ValueError(msg)
7273

7374
# Phase 2: Poll for completion
@@ -80,16 +81,16 @@ async def _make_request(
8081
while True:
8182
elapsed = time.monotonic() - start_time
8283
if elapsed >= config.POLLING_TIMEOUT:
83-
msg = f"{self.provider} polling timed out after {config.POLLING_TIMEOUT} seconds" # type: ignore[attr-defined]
84+
msg = f"{self.provider} polling timed out after {config.POLLING_TIMEOUT} seconds"
8485
raise TimeoutError(msg)
8586

86-
poll_response = await self.http_client.get( # type: ignore[attr-defined]
87+
poll_response = await self.http_client.get(
8788
polling_url,
8889
headers=poll_headers,
8990
)
9091

9192
if poll_response.status_code != 200:
92-
return poll_response # type: ignore[no-any-return]
93+
return poll_response
9394

9495
poll_data = poll_response.json()
9596
status = poll_data.get("status")

packages/providers/byteplus/src/celeste_byteplus/images/client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import httpx
77

8+
from celeste.client import APIMixin
89
from celeste.core import UsageField
910
from celeste.io import FinishReason
1011
from celeste.mime_types import ApplicationMimeType
1112

1213
from . import config
1314

1415

15-
class BytePlusImagesClient:
16+
class BytePlusImagesClient(APIMixin):
1617
"""Mixin for BytePlus Images API capabilities.
1718
1819
Provides shared implementation for all capabilities using the Images API:
@@ -39,11 +40,11 @@ async def _make_request(
3940
request_body["stream"] = False
4041

4142
headers = {
42-
**self.auth.get_headers(), # type: ignore[attr-defined]
43+
**self.auth.get_headers(),
4344
"Content-Type": ApplicationMimeType.JSON,
4445
}
4546

46-
return await self.http_client.post( # type: ignore[attr-defined,no-any-return]
47+
return await self.http_client.post(
4748
f"{config.BASE_URL}{config.BytePlusImagesEndpoint.CREATE_IMAGE}",
4849
headers=headers,
4950
json_body=request_body,
@@ -58,11 +59,11 @@ def _make_stream_request(
5859
request_body["stream"] = True
5960

6061
headers = {
61-
**self.auth.get_headers(), # type: ignore[attr-defined]
62+
**self.auth.get_headers(),
6263
"Content-Type": ApplicationMimeType.JSON,
6364
}
6465

65-
return self.http_client.stream_post( # type: ignore[attr-defined,no-any-return]
66+
return self.http_client.stream_post(
6667
f"{config.BASE_URL}{config.BytePlusImagesEndpoint.CREATE_IMAGE}",
6768
headers=headers,
6869
json_body=request_body,
@@ -114,7 +115,7 @@ def _build_metadata(self, response_data: dict[str, Any]) -> Any:
114115
filtered_data = {
115116
k: v for k, v in response_data.items() if k not in content_fields
116117
}
117-
metadata = super()._build_metadata(filtered_data) # type: ignore[misc]
118+
metadata = super()._build_metadata(filtered_data)
118119

119120
# Add provider-specific parsed fields
120121
seed = response_data.get("seed")

packages/providers/byteplus/src/celeste_byteplus/videos/client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import httpx
99

10+
from celeste.client import APIMixin
1011
from celeste.core import UsageField
1112
from celeste.io import FinishReason
1213
from celeste.mime_types import ApplicationMimeType
@@ -16,7 +17,7 @@
1617
logger = logging.getLogger(__name__)
1718

1819

19-
class BytePlusVideosClient:
20+
class BytePlusVideosClient(APIMixin):
2021
"""Mixin for BytePlus ModelArk Videos API with async polling.
2122
2223
Provides shared implementation:
@@ -46,22 +47,22 @@ async def _make_request(
4647
2. Poll CONTENT_STATUS endpoint until succeeded/failed/canceled
4748
3. Return response with final status data
4849
"""
49-
auth_headers = self.auth.get_headers() # type: ignore[attr-defined]
50+
auth_headers = self.auth.get_headers()
5051
headers = {
5152
**auth_headers,
5253
"Content-Type": ApplicationMimeType.JSON,
5354
}
5455

5556
# Phase 1: Submit job
5657
logger.debug("Submitting video generation task to BytePlus")
57-
submit_response = await self.http_client.post( # type: ignore[attr-defined]
58+
submit_response = await self.http_client.post(
5859
f"{config.BASE_URL}{config.BytePlusVideosEndpoint.CREATE_VIDEO}",
5960
headers=headers,
6061
json_body=request_body,
6162
)
6263

6364
if submit_response.status_code != 200:
64-
return submit_response # type: ignore[no-any-return]
65+
return submit_response
6566

6667
submit_data = submit_response.json()
6768
task_id = submit_data["id"]
@@ -82,13 +83,13 @@ async def _make_request(
8283
status_url = f"{config.BASE_URL}{config.BytePlusVideosEndpoint.GET_VIDEO_STATUS.format(task_id=task_id)}"
8384
logger.debug(f"Polling BytePlus task status: {task_id}")
8485

85-
status_response = await self.http_client.get( # type: ignore[attr-defined]
86+
status_response = await self.http_client.get(
8687
status_url,
8788
headers=headers,
8889
)
8990

9091
if status_response.status_code != 200:
91-
return status_response # type: ignore[no-any-return]
92+
return status_response
9293

9394
status_data = status_response.json()
9495
status = status_data.get("status")

packages/providers/cohere/src/celeste_cohere/chat/client.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import httpx
77

8+
from celeste.client import APIMixin
89
from celeste.core import UsageField
910
from celeste.io import FinishReason
1011
from celeste.mime_types import ApplicationMimeType
1112

1213
from . import config
1314

1415

15-
class CohereChatClient:
16+
class CohereChatClient(APIMixin):
1617
"""Mixin for Cohere Chat API capabilities.
1718
1819
Provides shared implementation for all capabilities using the Chat API:
@@ -37,14 +38,14 @@ async def _make_request(
3738
**parameters: Any,
3839
) -> httpx.Response:
3940
"""Make HTTP request to Cohere Chat API endpoint."""
40-
request_body["model"] = self.model.id # type: ignore[attr-defined]
41+
request_body["model"] = self.model.id
4142

4243
headers = {
43-
**self.auth.get_headers(), # type: ignore[attr-defined]
44+
**self.auth.get_headers(),
4445
"Content-Type": ApplicationMimeType.JSON,
4546
}
4647

47-
return await self.http_client.post( # type: ignore[attr-defined,no-any-return]
48+
return await self.http_client.post(
4849
f"{config.BASE_URL}{config.CohereChatEndpoint.CREATE_CHAT}",
4950
headers=headers,
5051
json_body=request_body,
@@ -56,15 +57,15 @@ def _make_stream_request(
5657
**parameters: Any,
5758
) -> AsyncIterator[dict[str, Any]]:
5859
"""Make streaming request to Cohere Chat API endpoint."""
59-
request_body["model"] = self.model.id # type: ignore[attr-defined]
60+
request_body["model"] = self.model.id
6061
request_body["stream"] = True
6162

6263
headers = {
63-
**self.auth.get_headers(), # type: ignore[attr-defined]
64+
**self.auth.get_headers(),
6465
"Content-Type": ApplicationMimeType.JSON,
6566
}
6667

67-
return self.http_client.stream_post( # type: ignore[attr-defined,no-any-return]
68+
return self.http_client.stream_post(
6869
f"{config.BASE_URL}{config.CohereChatEndpoint.CREATE_CHAT}",
6970
headers=headers,
7071
json_body=request_body,
@@ -113,7 +114,7 @@ def _build_metadata(self, response_data: dict[str, Any]) -> dict[str, Any]:
113114
filtered_data = {
114115
k: v for k, v in response_data.items() if k not in content_fields
115116
}
116-
return super()._build_metadata(filtered_data) # type: ignore[misc,no-any-return]
117+
return super()._build_metadata(filtered_data)
117118

118119

119120
__all__ = ["CohereChatClient"]

packages/providers/elevenlabs/src/celeste_elevenlabs/text_to_speech/client.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
import httpx
77

8+
from celeste.client import APIMixin
89
from celeste.mime_types import ApplicationMimeType, AudioMimeType
910

1011
from . import config
1112

1213

13-
class ElevenLabsTextToSpeechClient:
14+
class ElevenLabsTextToSpeechClient(APIMixin):
1415
"""Mixin for ElevenLabs Text-to-Speech API.
1516
1617
Provides shared implementation for speech generation:
@@ -44,19 +45,19 @@ async def _make_request(
4445
voice_id = parameters.get("voice", config.DEFAULT_VOICE_ID)
4546

4647
# Set model_id
47-
request_body["model_id"] = self.model.id # type: ignore[attr-defined]
48+
request_body["model_id"] = self.model.id
4849

4950
# Build URL with voice_id in path
5051
endpoint = config.ElevenLabsTextToSpeechEndpoint.CREATE_SPEECH.format(
5152
voice_id=voice_id
5253
)
5354

5455
headers = {
55-
**self.auth.get_headers(), # type: ignore[attr-defined]
56+
**self.auth.get_headers(),
5657
"Content-Type": ApplicationMimeType.JSON,
5758
}
5859

59-
return await self.http_client.post( # type: ignore[attr-defined,no-any-return]
60+
return await self.http_client.post(
6061
f"{config.BASE_URL}{endpoint}",
6162
headers=headers,
6263
json_body=request_body,
@@ -78,15 +79,15 @@ def _make_stream_request(
7879
voice_id = parameters.get("voice", config.DEFAULT_VOICE_ID)
7980

8081
# Set model_id
81-
request_body["model_id"] = self.model.id # type: ignore[attr-defined]
82+
request_body["model_id"] = self.model.id
8283

8384
# Build URL with voice_id in path
8485
endpoint = config.ElevenLabsTextToSpeechEndpoint.STREAM_SPEECH.format(
8586
voice_id=voice_id
8687
)
8788

8889
headers = {
89-
**self.auth.get_headers(), # type: ignore[attr-defined]
90+
**self.auth.get_headers(),
9091
"Content-Type": ApplicationMimeType.JSON,
9192
}
9293

@@ -106,7 +107,7 @@ async def _stream_binary_audio(
106107
107108
Wraps httpx streaming to yield dicts compatible with Stream interface.
108109
"""
109-
client = await self.http_client._get_client() # type: ignore[attr-defined]
110+
client = await self.http_client._get_client()
110111

111112
async with client.stream(
112113
"POST",

0 commit comments

Comments
 (0)