Skip to content

Commit fa0bf3c

Browse files
authored
Update index.py
1 parent 9a4930a commit fa0bf3c

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/index.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
1-
from fastapi import FastAPI
1+
# -*- coding: utf-8 -*-
22

3-
from src.dtos.ISayHelloDto import ISayHelloDto
3+
import pandas as pd
4+
from pycaret.classification import load_model, predict_model
5+
from fastapi import FastAPI
6+
import uvicorn
7+
from pydantic import create_model
48

9+
# Create the app
510
app = FastAPI()
611

12+
# Load trained Pipeline
13+
model = load_model("rf_api")
714

8-
@app.get("/")
9-
async def root():
10-
return {"message": "Hello World"}
15+
# Create input/output pydantic models
16+
input_model = create_model("rf_api_input", **{'age': 70.0, 'anaemia': 0.0, 'creatinine_phosphokinase': 582.0, 'diabetes': 1.0, 'ejection_fraction': 38.0, 'high_blood_pressure': 0.0, 'platelets': 25100.0, 'serum_creatinine': 1.100000023841858, 'serum_sodium': 140.0, 'sex': 1.0, 'smoking': 0.0, 'time': 246.0})
17+
output_model = create_model("rf_api_output", prediction=0)
1118

1219

13-
@app.get("/hello/{name}")
14-
async def say_hello(name: str):
15-
return {"message": f"Hello {name}"}
20+
# Define predict function
21+
@app.post("/predict", response_model=output_model)
22+
def predict(data: input_model):
23+
data = pd.DataFrame([data.dict()])
24+
predictions = predict_model(model, data=data)
25+
return {"prediction": predictions["prediction_label"].iloc[0]}
1626

1727

18-
@app.post("/hello")
19-
async def hello_message(dto: ISayHelloDto):
20-
return {"message": f"Hello {dto.message}"}
28+
if __name__ == "__main__":
29+
uvicorn.run(app, host="127.0.0.1", port=8000)

0 commit comments

Comments
 (0)