forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhugging_face.py
68 lines (48 loc) · 2.21 KB
/
hugging_face.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from enum import Enum
from typing import Any, Dict
import aiohttp
from loguru import logger
from oasst_backend.config import settings
from oasst_shared.exceptions import OasstError, OasstErrorCode
class HfUrl(str, Enum):
HUGGINGFACE_TOXIC_CLASSIFICATION = "https://api-inference.huggingface.co/models"
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
class HfClassificationModel(str, Enum):
TOXIC_ROBERTA = "unitary/multilingual-toxic-xlm-roberta"
class HfEmbeddingModel(str, Enum):
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
class HuggingFaceAPI:
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
def __init__(
self,
api_url: str,
):
# The API endpoint we want to access
self.api_url: str = api_url
# Access token for the api
self.api_key: str = settings.HUGGING_FACE_API_KEY
# Headers going to be used
self.headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"}
async def post(self, input: str, wait_for_model: bool = True) -> Any:
"""Post request to the endpoint to get an inference
Args:
input (str): the input that we will pass to the model
Raises:
OasstError: in the case we get a bad response
Returns:
inference: the inference we obtain from the model in HF
"""
async with aiohttp.ClientSession() as session:
payload: Dict[str, str] = {"inputs": input, "wait_for_model": wait_for_model}
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
# If we get a bad response
if not response.ok:
logger.error(response)
logger.info(self.headers)
raise OasstError(
f"Response Error HuggingFace API (Status: {response.status})",
error_code=OasstErrorCode.HUGGINGFACE_API_ERROR,
)
# Get the response from the API call
inference = await response.json()
return inference