forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoasst_api_client.ts
221 lines (192 loc) · 6.28 KB
/
oasst_api_client.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import { JWT } from "next-auth/jwt";
import type { Message } from "src/types/Conversation";
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
import type { BackendUser } from "src/types/Users";
export class OasstError {
message: string;
errorCode: number;
httpStatusCode: number;
constructor(message: string, errorCode: number, httpStatusCode?: number) {
this.message = message;
this.errorCode = errorCode;
this.httpStatusCode = httpStatusCode;
}
}
export class OasstApiClient {
oasstApiUrl: string;
oasstApiKey: string;
constructor(oasstApiUrl: string, oasstApiKey: string) {
this.oasstApiUrl = oasstApiUrl;
this.oasstApiKey = oasstApiKey;
}
private async post(path: string, body: any): Promise<any> {
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
method: "POST",
headers: {
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
},
body: JSON.stringify(body),
});
if (resp.status === 204) {
return null;
}
if (resp.status >= 300) {
const errorText = await resp.text();
let error: any;
try {
error = JSON.parse(errorText);
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
}
private async put(path: string): Promise<any> {
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
method: "PUT",
headers: {
"X-API-Key": this.oasstApiKey,
},
});
if (resp.status === 204) {
return null;
}
if (resp.status >= 300) {
const errorText = await resp.text();
let error: any;
try {
error = JSON.parse(errorText);
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
}
private async get(path: string): Promise<any> {
const resp = await fetch(`${this.oasstApiUrl}${path}`, {
method: "GET",
headers: {
"X-API-Key": this.oasstApiKey,
"Content-Type": "application/json",
},
});
if (resp.status === 204) {
return null;
}
if (resp.status >= 300) {
const errorText = await resp.text();
let error: any;
try {
error = JSON.parse(errorText);
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
}
// TODO return a strongly typed Task?
// This method is used to store a task in RegisteredTask.task.
// This is a raw Json type, so we can't use it to strongly type the task.
async fetchTask(taskType: string, userToken: JWT): Promise<any> {
return this.post("/api/v1/tasks/", {
type: taskType,
user: {
id: userToken.sub,
display_name: userToken.name || userToken.email,
auth_method: "local",
},
});
}
async ackTask(taskId: string, messageId: string): Promise<void> {
return this.post(`/api/v1/tasks/${taskId}/ack`, {
message_id: messageId,
});
}
async nackTask(taskId: string, reason: string): Promise<void> {
return this.post(`/api/v1/tasks/${taskId}/nack`, {
reason,
});
}
// TODO return a strongly typed Task?
// This method is used to record interaction with task while fetching next task.
// This is a raw Json type, so we can't use it to strongly type the task.
async interactTask(
updateType: string,
taskId: string,
messageId: string,
userMessageId: string,
content: object,
userToken: JWT
): Promise<any> {
return this.post("/api/v1/tasks/interaction", {
type: updateType,
user: {
id: userToken.sub,
display_name: userToken.name || userToken.email,
auth_method: "local",
},
task_id: taskId,
message_id: messageId,
user_message_id: userMessageId,
...content,
});
}
/**
* Returns the `BackendUser` associated with `user_id`
*/
async fetch_user(user_id: string): Promise<BackendUser> {
return this.get(`/api/v1/users/users/${user_id}`);
}
/**
* Returns the set of `BackendUser`s stored by the backend.
*
* @param {number} max_count - The maximum number of users to fetch.
* @param {string} cursor - The user's `display_name` to use when paginating.
* @param {boolean} isForward - If true and `cursor` is not empty, pages
* forward. If false and `cursor` is not empty, pages backwards.
* @returns {Promise<BackendUser[]>} A Promise that returns an array of `BackendUser` objects.
*/
async fetch_users(max_count: number, cursor: string, isForward: boolean): Promise<BackendUser[]> {
const params = new URLSearchParams();
params.append("max_count", max_count.toString());
// The backend API uses different query parameters depending on the
// pagination direction but they both take the same cursor value.
// Depending on direction, pick the right query param.
if (cursor !== "") {
params.append(isForward ? "gt" : "lt", cursor);
}
const BASE_URL = `/api/v1/frontend_users`;
const url = `${BASE_URL}/?${params.toString()}`;
return this.get(url);
}
/**
* Returns the `Message`s associated with `user_id` in the backend.
*/
async fetch_user_messages(user_id: string): Promise<Message[]> {
return this.get(`/api/v1/users/${user_id}/messages`);
}
/**
* Updates the backend's knowledge about the `user_id`.
*/
async set_user_status(user_id: string, is_enabled: boolean, notes): Promise<void> {
return this.put(`/api/v1/users/users/${user_id}?enabled=${is_enabled}¬es=${notes}`);
}
/**
* Returns the valid labels for messages.
*/
async fetch_valid_text(): Promise<any> {
return this.get(`/api/v1/text_labels/valid_labels`);
}
/**
* Returns the current leaderboard ranking.
*/
async fetch_leaderboard(time_frame: LeaderboardTimeFrame): Promise<LeaderboardReply> {
return this.get(`/api/v1/leaderboards/${time_frame}`);
}
}
const oasstApiClient = new OasstApiClient(process.env.FASTAPI_URL, process.env.FASTAPI_KEY);
export { oasstApiClient };