Skip to content

Commit e2f2184

Browse files
authored
Fixes auth TokenPair change bug (also fix for null logprob from inference server) (LAION-AI#2155)
1 parent f1e41d1 commit e2f2184

File tree

5 files changed

+11
-10
lines changed

5 files changed

+11
-10
lines changed

CODEOWNERS

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
/safety @SummerSigh @shahules786
1111
/inference/ @yk @andreaskoepf @olliestanley @AbdBarho
1212
/backend/ @andreaskoepf @melvinebenezer @yk
13+
/oasst-shared/ @andreaskoepf @melvinebenezer @yk @olliestanley @AbdBarho

inference/server/oasst_inference_server/routes/auth.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ async def callback_discord(
8686
refresh_token = auth.create_refresh_token(user.id)
8787

8888
token_pair = protocol.TokenPair(
89-
protocol.Token(access_token=access_token, token_type="bearer"),
90-
protocol.Token(access_token=refresh_token, token_type="refresh"),
89+
access_token=protocol.Token(access_token=access_token, token_type="bearer"),
90+
refresh_token=protocol.Token(access_token=refresh_token, token_type="refresh"),
9191
)
9292

9393
return token_pair
@@ -154,8 +154,8 @@ async def callback_github(
154154
refresh_token = auth.create_refresh_token(user.id)
155155

156156
token_pair = protocol.TokenPair(
157-
protocol.Token(access_token=access_token, token_type="bearer"),
158-
protocol.Token(access_token=refresh_token, token_type="refresh"),
157+
access_token=protocol.Token(access_token=access_token, token_type="bearer"),
158+
refresh_token=protocol.Token(access_token=refresh_token, token_type="refresh"),
159159
)
160160

161161
return token_pair
@@ -212,8 +212,8 @@ async def login_debug(username: str, db: database.AsyncSession = Depends(deps.cr
212212
refresh_token = auth.create_refresh_token(user.id)
213213

214214
token_pair = protocol.TokenPair(
215-
protocol.Token(access_token=access_token, token_type="bearer"),
216-
protocol.Token(access_token=refresh_token, token_type="refresh"),
215+
access_token=protocol.Token(access_token=access_token, token_type="bearer"),
216+
refresh_token=protocol.Token(access_token=refresh_token, token_type="refresh"),
217217
)
218218

219219
return token_pair

inference/text-client/text_client_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def __init__(self, backend_url, http_client=requests):
1212

1313
def login(self, username):
1414
auth_data = self.http_client.get(f"{self.backend_url}/auth/login/debug", params={"username": username}).json()
15-
assert auth_data["token_type"] == "bearer"
16-
bearer_token = auth_data["access_token"]
15+
assert auth_data["access_token"]["token_type"] == "bearer"
16+
bearer_token = auth_data["access_token"]["access_token"]
1717
logger.debug(f"Logged in as {username} with token {bearer_token}")
1818
self.auth_headers = {"Authorization": f"Bearer {bearer_token}"}
1919

inference/worker/interface.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class GenerateStreamRequest(pydantic.BaseModel):
4141

4242
class Token(pydantic.BaseModel):
4343
text: str
44-
logprob: float
44+
logprob: float | None
4545
id: int
4646

4747
def __len__(self) -> int:

oasst-shared/oasst_shared/schemas/inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class PongResponse(WorkerResponseBase):
207207
class TokenResponse(WorkerResponseBase):
208208
response_type: Literal["token"] = "token"
209209
text: str
210-
log_prob: float
210+
log_prob: float | None
211211
token_id: int
212212

213213

0 commit comments

Comments
 (0)