-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
71 lines (61 loc) · 2 KB
/
main.py
File metadata and controls
71 lines (61 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Put the code for your API here.
import os
import pickle as pkl
import pandas as pd
import uvicorn
from fastapi.responses import JSONResponse
from fastapi import FastAPI
from data_model import BasicInputData
import starter.config as config
from starter.ml.data import process_data
from starter.ml.model import inference
if "DYNO" in os.environ and os.path.isdir(".dvc"):
os.system("dvc config core.no_scm true")
if os.system("dvc pull") != 0:
exit("dvc pull failed")
os.system("rm -r .dvc .apt/usr/lib/dvc")
app = FastAPI(
title="API for Salary Model",
description="Return prediction for salary",
version="0.0.1",
)
with open(config.MODEL_PATH, 'rb') as f:
encoder, lb, model = pkl.load(f)
@app.get("/")
async def welcome():
"""
Example function for returning home directory.
Args:
Returns:
example_message (Dict) : Example message response for home directory
GET request.
"""
return {'message': 'Hello'}
@app.post("/model")
async def prediction(input_data: BasicInputData):
"""
Example function for returning model output from POST request.
The function take in a single web form entry and converts it to a single
row of input data conforming to the constraints of the features used in the model.
Args:
input_data (BasicInputData) : Instance of a BasicInputData object. Collected data from
web form submission.
Returns:
json_res (JSONResponse) : A JSON serialized response dictionary containing
model classification of input data.
"""
# Formatting input_data
input_df = pd.DataFrame(
{k: v for k, v in input_data.dict(by_alias=True).items()}, index=[0]
)
x_data, _, _, _ = process_data(
X=input_df,
categorical_features=config.cat_features,
label=None,
training=False,
encoder=encoder,
lb=lb
)
# get predictions and return
pred = inference(model, x_data)
return {"Result": "<=50K" if int(pred[0]) == 0 else ">50K"}