1
+ from unittest import mock
1
2
from uuid import uuid4
2
3
4
+ import aiohttp
3
5
import pytest
4
6
from oasst_shared .api_client import OasstApiClient
7
+ from oasst_shared .exceptions import OasstError , OasstErrorCode
5
8
from oasst_shared .schemas import protocol as protocol_schema
6
9
7
10
8
11
@pytest .fixture
9
12
def oasst_api_client_mocked ():
13
+ """
14
+ A an oasst_api_client pointed at the mocked backend.
15
+ Relies on ./scripts/backend-development/start-mock-server.sh
16
+ being run.
17
+ """
10
18
client = OasstApiClient (backend_url = "http://localhost:8080" , api_key = "123" )
11
19
yield client
12
20
# TODO The fixture should close this connection, but there seems to be a bug
@@ -15,6 +23,20 @@ def oasst_api_client_mocked():
15
23
# await client.close()
16
24
17
25
26
+ @pytest .fixture
27
+ def mock_http_session ():
28
+ yield mock .AsyncMock (spec = aiohttp .ClientSession )
29
+
30
+
31
+ @pytest .fixture
32
+ def oasst_api_client_fake_http (mock_http_session ):
33
+ """
34
+ An oasst_api_client that uses a mocked http session. No real requests are made.
35
+ """
36
+ client = OasstApiClient (backend_url = "http://localhost:8080" , api_key = "123" , session = mock_http_session )
37
+ yield client
38
+
39
+
18
40
@pytest .mark .asyncio
19
41
@pytest .mark .parametrize ("task_type" , protocol_schema .TaskRequestType )
20
42
async def test_can_fetch_task (task_type : protocol_schema .TaskRequestType , oasst_api_client_mocked : OasstApiClient ):
@@ -49,3 +71,22 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
49
71
)
50
72
is not None
51
73
)
74
+
75
+
76
+ @pytest .mark .asyncio
77
+ async def test_can_handle_oasst_error_from_api (
78
+ oasst_api_client_fake_http : OasstApiClient ,
79
+ mock_http_session : mock .AsyncMock ,
80
+ ):
81
+ # Return a 400 response with an OasstErrorResponse body
82
+ response_body = protocol_schema .OasstErrorResponse (
83
+ error_code = OasstErrorCode .GENERIC_ERROR ,
84
+ message = "Some error" ,
85
+ ).json ()
86
+ status_code = 400
87
+
88
+ mock_http_session .post .return_value .__aenter__ .return_value .json .return_value = response_body
89
+ mock_http_session .post .return_value .__aenter__ .return_value .status = status_code
90
+
91
+ with pytest .raises (OasstError ):
92
+ await oasst_api_client_fake_http .post ("/some-path" , data = {})
0 commit comments