Skip to content

Commit e398d64

Browse files
Merge pull request #5 from fullstack-ml-academy/dig-5-train-model-github-action
Train Model Github Action
2 parents c40410d + fa3090f commit e398d64

File tree

4 files changed

+59
-0
lines changed

4 files changed

+59
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name: model-training
2+
on: [push]
3+
jobs:
4+
train-model:
5+
runs-on: ubuntu-latest
6+
steps:
7+
- name: Checkout Repository
8+
uses: actions/checkout@v3
9+
- name: Setup Python und install requirements
10+
uses: actions/setup-python@v4
11+
with:
12+
python-version: "3.8"
13+
- run: pip install -r requirements.txt
14+
- name: Train model
15+
run: python src/train.py
16+
- name: Upload trained model
17+
uses: actions/upload-artifact@v4
18+
with:
19+
name: regressor_mpg.pickle
20+
path: data/models/regressor_mpg.pickle

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ scikit-learn==1.0.2
77
scipy==1.8.0
88
six==1.16.0
99
threadpoolctl==3.1.0
10+
flask==3.1.1
11+
flask_cors==6.0.1

src/api.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from flask import Flask, Response, request
2+
from flask_cors import CORS
3+
import os
4+
import pandas as pd
5+
import pickle
6+
7+
app = Flask(__name__)
8+
9+
CORS(app)
10+
11+
# load training data
12+
training_data = pd.read_csv(os.path.join("data", "auto-mpg.csv"))
13+
14+
# load model
15+
file_to_open = open(os.path.join("data", "models", "baummethoden_lr.pickle"), "rb")
16+
trained_model = pickle.load(file_to_open)
17+
file_to_open.close()
18+
19+
20+
@app.route("/", methods=["GET"])
21+
def index():
22+
return {"hello": "world"}
23+
24+
25+
@app.route("/hello_world", methods=["GET"])
26+
def hello_world():
27+
return "<p>Hello World!</p>"
28+
29+
30+
@app.route("/training_data", methods=["GET"])
31+
def get_training_data():
32+
return Response(training_data.to_json(), mimetype="application/json")

wsgi.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from src.api import app
2+
3+
4+
if __name__ == "__main__":
5+
app.run(debug=True)

0 commit comments

Comments
 (0)