-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_api.py
More file actions
89 lines (78 loc) · 3.39 KB
/
inference_api.py
File metadata and controls
89 lines (78 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pandas as pd
import numpy as np
import uvicorn
from fastapi import FastAPI
import joblib
from sklearn.preprocessing import MinMaxScaler
from sklearn.svm import SVC
from classify import FeatureWrapper, ASR
from io import StringIO
from pydantic import BaseModel
import os
import mne
from dotenv import load_dotenv
load_dotenv()
app = FastAPI()
model: SVC = None
scaler: MinMaxScaler = None
class CSVInput(BaseModel):
csv_data: str
@app.on_event("startup")
def load_model():
global model, scaler
model_path = os.getenv("MODEL_PATH", "three_class_model.joblib")
scaler_path = os.getenv("SCALER_PATH", "three_class_scaler.joblib")
model = joblib.load(model_path)
scaler = joblib.load(scaler_path)
def preprocess_and_epoch_full_eeg_array(eeg_data, sfreq=125, epoch_length_sec=10,
l_freq=0.5, h_freq=40.0):
num_channels, num_timesteps = eeg_data.shape
epoch_length_samples = int(epoch_length_sec * sfreq)
ch_names = ['Fp1', 'Fp2', 'AF3', 'AF4', 'F3', 'F4', 'F7', 'F8',
'C3', 'C4', 'T5', 'T6', 'PO3', 'PO4', 'O1', 'O2']
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
raw = mne.io.RawArray(eeg_data, info)
montage = mne.channels.make_standard_montage("standard_1020")
raw.set_montage(montage)
raw.filter(l_freq=l_freq, h_freq=h_freq)
total_epochs = num_timesteps // epoch_length_samples
events = np.array([[i * epoch_length_samples, 0, 1] for i in range(total_epochs)])
asr = ASR(sfreq=raw.info["sfreq"], cutoff=50)
asr.fit(raw)
raw = asr.transform(raw)
epochs = mne.Epochs(raw, events=events, event_id={'Rest': 1},
tmin=0, tmax=(epoch_length_sec - 1 / sfreq), baseline=None,
detrend=1, preload=True)
return epochs.get_data()
### Given the raw EEG csv loaded as a pandas dataframe, perform basic preprocessing and return epochs in shape (10, 16, 1000) ###
def preprocess(block, sfreq = 125):
eeg_data = block.iloc[:, 1:17].to_numpy().T * 1e-6
eeg_data = eeg_data[:,625:13125]
filtered = preprocess_and_epoch_full_eeg_array(eeg_data)
return filtered
def predict(samples):
global model, scaler
selected_channels = [i for i in range(16)]
desired_features = ["delta_bandpower","alpha_bandpower", "beta_bandpower", "theta_bandpower","clustering_pli","clustering_plv","betweenness_centrality"]
processed_samples = []
wrapper = FeatureWrapper()
for i, sample in enumerate(samples):
features = wrapper.compute_features(sample,i,125,selected_channels,desired_features=desired_features)
processed_samples.append(features)
processed_samples = np.array(processed_samples)
processed_samples = np.reshape(processed_samples,(processed_samples.shape[0],-1))
processed_samples = scaler.transform(processed_samples)
probs = model.predict_proba(processed_samples)
probs = np.mean(probs,axis=0)
class_labels = np.array([i for i in range(len(model.classes_))])
prediction = np.round((probs * class_labels).sum(axis=0)/(class_labels[-1]),2)
return float(prediction)
@app.post("/inference")
def inference(data: CSVInput):
try:
df = pd.read_csv(StringIO(data.csv_data))
preprocessed_samples = preprocess(df)
prediction = predict(preprocessed_samples)
return {"prediction": prediction}
except Exception as e:
return {"error": str(e)}