forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbasic_hf_server.py
94 lines (75 loc) · 2.56 KB
/
basic_hf_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# a basic fastapi server to run generation on HF models
import signal
import sys
import fastapi
import interface
import torch
import transformers
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from prometheus_fastapi_instrumentator import Instrumentator
from settings import settings
app = fastapi.FastAPI()
# Allow CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def log_exceptions(request: fastapi.Request, call_next):
try:
response = await call_next(request)
except Exception:
logger.exception("Exception in request")
raise
return response
# add prometheus metrics at /metrics
@app.on_event("startup")
async def enable_prom_metrics():
Instrumentator().instrument(app).expose(app)
def terminate_server(signum, frame):
logger.warning(f"Signal {signum}. Terminating server...")
sys.exit(0)
model: transformers.PreTrainedModel
tokenizer: transformers.PreTrainedTokenizer
use_gpu: bool = False
@app.on_event("startup")
async def load_models():
global model, tokenizer, use_gpu
signal.signal(signal.SIGINT, terminate_server)
logger.warning(f"Loading model {settings.model_id}...")
if "llama" in settings.model_id:
tokenizer = transformers.LlamaTokenizer.from_pretrained(settings.model_id)
model = transformers.LlamaForCausalLM.from_pretrained(settings.model_id)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(settings.model_id)
model = transformers.AutoModelForCausalLM.from_pretrained(settings.model_id)
if torch.cuda.is_available():
logger.warning("Using GPU")
use_gpu = True
model = model.cuda()
logger.warning("Model loaded")
signal.signal(signal.SIGINT, signal.SIG_DFL)
@app.on_event("startup")
async def welcome_message():
logger.warning("Server started")
logger.warning("To stop the server, press Ctrl+C")
@app.post("/generate")
async def generate(request: interface.GenerateStreamRequest):
global model, tokenizer, use_gpu
prompt = request.inputs
params = request.parameters.dict()
params.pop("seed")
params.pop("stop")
params.pop("details")
with torch.no_grad():
ids = tokenizer.encode(prompt, return_tensors="pt")
if use_gpu:
ids = ids.cuda()
output = model.generate(ids, **params)
output = output.cpu()
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return {"text": decoded}