forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi_client.py
156 lines (132 loc) · 6.55 KB
/
api_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""API Client for interacting with the OASST backend."""
import enum
import typing as t
from http import HTTPStatus
from typing import Optional, Type
from uuid import UUID
import aiohttp
from loguru import logger
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import ValidationError
# TODO: Move to `protocol`?
class TaskType(str, enum.Enum):
"""Task types."""
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
prompter_reply = "prompter_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
label_initial_prompt = "label_initial_prompt"
label_assistant_reply = "label_assistant_reply"
label_prompter_reply = "label_prompter_reply"
done = "task_done"
class OasstApiClient:
"""API Client for interacting with the OASST backend."""
def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None):
"""Create a new OasstApiClient.
Args:
----
backend_url (str): The base backend URL.
api_key (str): The API key to use for authentication.
"""
if session is None:
logger.debug("Opening OasstApiClient session")
session = aiohttp.ClientSession()
self.session = session
self.backend_url = backend_url
self.api_key = api_key
self.task_models_map: dict[TaskType, Type[protocol_schema.Task]] = {
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.label_initial_prompt: protocol_schema.LabelInitialPromptTask,
TaskType.label_prompter_reply: protocol_schema.LabelPrompterReplyTask,
TaskType.label_assistant_reply: protocol_schema.LabelAssistantReplyTask,
TaskType.done: protocol_schema.TaskDone,
}
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key})
logger.debug(f"response: {response}")
# If the response is not a 2XX, check to see
# if the json has the fields to create an
# OasstError.
if response.status >= 300:
text = await response.text()
logger.debug(f"resp text: {text}")
data = await response.json()
try:
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
raise OasstError(
error_code=oasst_error.error_code,
message=oasst_error.message,
)
except ValidationError as e:
logger.debug(f"Got error from API but could not parse: {e}")
raw_response = await response.text()
logger.debug(f"Raw response: {raw_response}")
raise OasstError(
raw_response,
OasstErrorCode.GENERIC_ERROR,
HTTPStatus(response.status),
)
if response.status == 204:
# No content
return None
return await response.json()
def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task:
if data is None:
raise Exception("Cannot parse data as a task: data is none")
task_type = TaskType(data.get("type"))
model = self.task_models_map.get(task_type)
if not model:
logger.error(f"Unsupported task type: {task_type}")
raise ValueError(f"Unsupported task type: {task_type}")
return self.task_models_map[task_type].parse_obj(data) # type: ignore
async def fetch_task(
self,
task_type: protocol_schema.TaskRequestType,
user: Optional[protocol_schema.User] = None,
collective: bool = False,
lang: Optional[str] = None,
) -> protocol_schema.Task:
"""Fetch a task from the backend."""
logger.debug(f"Fetching task {task_type} for user {user}")
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang)
resp = await self.post("/api/v1/tasks/", data=req.dict())
logger.debug(f"RESP {resp}")
return self._parse_task(resp)
async def fetch_random_task(
self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None
) -> protocol_schema.Task:
"""Fetch a random task from the backend."""
logger.debug(f"Fetching random for user {user}")
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang)
async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
"""Send an ACK for a task to the backend."""
logger.debug(f"ACK task {task_id} with post {message_id}")
req = protocol_schema.TaskAck(message_id=message_id)
await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict())
async def nack_task(self, task_id: str | UUID, reason: str) -> None:
"""Send a NACK for a task to the backend."""
logger.debug(f"NACK task {task_id} with reason {reason}")
req = protocol_schema.TaskNAck(reason=reason)
await self.post(f"/api/v1/tasks/{task_id}/nack", data=req.dict())
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
"""Send a completed task to the backend."""
logger.debug(f"Interaction: {interaction}")
resp = await self.post("/api/v1/tasks/interaction", data=interaction.dict())
return self._parse_task(resp)
async def close(self):
logger.debug("Closing OasstApiClient session")
await self.session.close()