Skip to content

Commit 60410a7

Browse files
committed
Fix CI failures and add TF/Keras + register_model roundtrip tests
- Fix: guard pandas/numpy imports with pytest.importorskip in test_datasets.py and test_features.py (caused CI collection errors) - Fix: add numpy/pandas to SDK unit test CI install step - Fix: add fix/** branch trigger to test-python.yml - Add TF/Keras detection tests to test_detect_framework.py - Add TF serialization roundtrip test to test_serialize_model.py - Add TF/Keras model registration test to test_models.py - Add register_model roundtrip integration tests: sklearn, pytorch, tensorflow — verifies full pipeline (detect → serialize → codegen → exec train/infer) through the SDK abstraction layer - Remove unnecessary _serialize_model mock from pytorch registration test (SimpleNet is now at module level, picklable by torch.save) 190 tests passing, 8 skipped.
1 parent 444db41 commit 60410a7

7 files changed

Lines changed: 249 additions & 14 deletions

File tree

.github/workflows/test-python.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Python SDK & Model Runner Tests
22

33
on:
44
push:
5-
branches: [main, "dev/**"]
5+
branches: [main, "dev/**", "fix/**"]
66
paths:
77
- "sdk/python/**"
88
- "model-runner/python/**"
@@ -44,7 +44,7 @@ jobs:
4444
- name: Install dependencies
4545
run: |
4646
pip install --upgrade pip
47-
pip install pytest pytest-cov pytest-mock pytest-timeout responses
47+
pip install pytest pytest-cov pytest-mock pytest-timeout responses numpy pandas
4848
pip install ./sdk/python
4949
5050
- name: Run SDK unit tests

tests/python/frameworks/test_detect_framework.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ def test_detect_pytorch(self):
3131
def test_detect_pytorch_custom_module(self, pytorch_model):
3232
assert _detect_framework(pytorch_model) == "pytorch"
3333

34+
def test_detect_tensorflow(self, tf_model):
35+
"""TensorFlow/Keras model detected as 'tensorflow'."""
36+
assert _detect_framework(tf_model) == "tensorflow"
37+
38+
def test_detect_keras_standalone(self):
39+
"""Standalone Keras 3+ model detected as 'tensorflow'."""
40+
keras = pytest.importorskip("keras")
41+
model = keras.Sequential([
42+
keras.layers.Dense(4, input_shape=(2,)),
43+
])
44+
assert _detect_framework(model) == "tensorflow"
45+
3446
@pytest.mark.parametrize("obj", ["hello", {"a": 1}, 42])
3547
def test_detect_unknown_raises(self, obj):
3648
with pytest.raises(TypeError, match="Cannot auto-detect framework"):

tests/python/frameworks/test_serialize_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@ def test_serialize_pytorch(self, pytorch_model):
2828
restored = torch.load(buf, map_location="cpu", weights_only=False)
2929
assert isinstance(restored, torch.nn.Module)
3030

31+
def test_serialize_tensorflow(self, tf_model):
32+
"""TF/Keras model serialized to .keras bytes and loadable."""
33+
keras = pytest.importorskip("keras")
34+
import tempfile, os
35+
36+
data = _serialize_model(tf_model, "tensorflow")
37+
assert isinstance(data, bytes)
38+
assert len(data) > 0
39+
# Roundtrip: write to temp file and load back
40+
tmpfile = tempfile.mktemp(suffix=".keras")
41+
try:
42+
with open(tmpfile, "wb") as f:
43+
f.write(data)
44+
restored = keras.models.load_model(tmpfile)
45+
assert restored is not None
46+
finally:
47+
if os.path.exists(tmpfile):
48+
os.unlink(tmpfile)
49+
3150
def test_serialize_unsupported_raises(self):
3251
with pytest.raises(ValueError, match="Unsupported framework"):
3352
_serialize_model(object(), "unknown_framework")
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""Integration test: register_model() with real model objects → verify generated code executes.
2+
3+
This is the critical test that verifies the SDK model abstraction works end-to-end:
4+
1. Pass a real model object to register_model()
5+
2. SDK auto-detects framework, serializes model, generates source code
6+
3. The generated source code is POSTed to the API
7+
4. We extract that source code and exec() it with MockModelContext
8+
5. Verify train(ctx) logs metrics and infer(ctx) produces predictions
9+
"""
10+
11+
import json
12+
import pytest
13+
import responses
14+
15+
from conftest import TEST_API_URL, TEST_PROJECT_ID, MockModelContext
16+
17+
18+
REGISTER_RESPONSE = {
19+
"model_id": "model-roundtrip-001",
20+
"name": "roundtrip-test",
21+
"version": 1,
22+
}
23+
24+
25+
class TestRegisterModelSklearnRoundtrip:
26+
"""register_model(model=sklearn_obj) → extract source → exec train/infer."""
27+
28+
def test_sklearn_register_and_train(self, client, mock_api, sklearn_model):
29+
"""Full pipeline: register sklearn model → exec generated train()."""
30+
mock_api.add(
31+
responses.POST,
32+
f"{TEST_API_URL}/sdk/register-model",
33+
json=REGISTER_RESPONSE,
34+
status=200,
35+
)
36+
client.register_model("test-sklearn", model=sklearn_model)
37+
38+
# Extract the source code that was POSTed
39+
body = json.loads(mock_api.calls[0].request.body)
40+
source_code = body["source_code"]
41+
assert body["framework"] == "sklearn"
42+
43+
# Execute the generated train() with MockModelContext
44+
ctx = MockModelContext(hyperparameters={"n_samples": 50, "n_features": 4})
45+
ns = {}
46+
exec(source_code, ns)
47+
ns["train"](ctx)
48+
49+
metric_names = [m[0] for m in ctx._logged_metrics]
50+
assert "accuracy" in metric_names
51+
assert "loss" in metric_names
52+
assert "progress" in metric_names
53+
# Progress should reach 100
54+
progress_values = [m[1] for m in ctx._logged_metrics if m[0] == "progress"]
55+
assert 100 in progress_values
56+
57+
def test_sklearn_register_and_infer(self, client, mock_api, sklearn_model):
58+
"""Full pipeline: register sklearn model → exec generated infer()."""
59+
mock_api.add(
60+
responses.POST,
61+
f"{TEST_API_URL}/sdk/register-model",
62+
json=REGISTER_RESPONSE,
63+
status=200,
64+
)
65+
client.register_model("test-sklearn", model=sklearn_model)
66+
67+
body = json.loads(mock_api.calls[0].request.body)
68+
source_code = body["source_code"]
69+
70+
ctx = MockModelContext(
71+
hyperparameters={"input_data": {"features": [[1.0, 2.0, 3.0, 4.0]]}}
72+
)
73+
ns = {}
74+
exec(source_code, ns)
75+
ns["infer"](ctx)
76+
77+
assert ctx._output is not None
78+
assert "predictions" in ctx._output
79+
assert isinstance(ctx._output["predictions"], list)
80+
81+
82+
class TestRegisterModelPytorchRoundtrip:
83+
"""register_model(model=pytorch_obj) → extract source → exec train/infer."""
84+
85+
pytestmark = pytest.mark.skipif(
86+
not pytest.importorskip("torch", reason="torch not installed"),
87+
reason="torch not installed",
88+
)
89+
90+
def test_pytorch_register_and_train(self, client, mock_api, pytorch_model):
91+
"""Full pipeline: register PyTorch model → exec generated train()."""
92+
mock_api.add(
93+
responses.POST,
94+
f"{TEST_API_URL}/sdk/register-model",
95+
json=REGISTER_RESPONSE,
96+
status=200,
97+
)
98+
client.register_model("test-pytorch", model=pytorch_model)
99+
100+
body = json.loads(mock_api.calls[0].request.body)
101+
source_code = body["source_code"]
102+
assert body["framework"] == "pytorch"
103+
104+
ctx = MockModelContext(hyperparameters={"epochs": 2, "batch_size": 4})
105+
ns = {}
106+
exec(source_code, ns)
107+
ns["train"](ctx)
108+
109+
metric_names = [m[0] for m in ctx._logged_metrics]
110+
assert "loss" in metric_names
111+
assert "accuracy" in metric_names
112+
assert "progress" in metric_names
113+
114+
def test_pytorch_register_and_infer(self, client, mock_api, pytorch_model):
115+
"""Full pipeline: register PyTorch model → exec generated infer()."""
116+
mock_api.add(
117+
responses.POST,
118+
f"{TEST_API_URL}/sdk/register-model",
119+
json=REGISTER_RESPONSE,
120+
status=200,
121+
)
122+
client.register_model("test-pytorch", model=pytorch_model)
123+
124+
body = json.loads(mock_api.calls[0].request.body)
125+
source_code = body["source_code"]
126+
127+
# SimpleNet has input_size=4
128+
ctx = MockModelContext(
129+
hyperparameters={"input_data": {"features": [[1.0, 2.0, 3.0, 4.0]]}}
130+
)
131+
ns = {}
132+
exec(source_code, ns)
133+
ns["infer"](ctx)
134+
135+
assert ctx._output is not None
136+
assert "predictions" in ctx._output
137+
preds = ctx._output["predictions"]
138+
assert isinstance(preds, list)
139+
assert len(preds) == 1 # one sample
140+
141+
142+
class TestRegisterModelTensorflowRoundtrip:
143+
"""register_model(model=keras_obj) → extract source → exec train/infer."""
144+
145+
def test_tensorflow_register_and_train(self, client, mock_api, tf_model):
146+
"""Full pipeline: register TF/Keras model → exec generated train()."""
147+
mock_api.add(
148+
responses.POST,
149+
f"{TEST_API_URL}/sdk/register-model",
150+
json=REGISTER_RESPONSE,
151+
status=200,
152+
)
153+
client.register_model("test-keras", model=tf_model)
154+
155+
body = json.loads(mock_api.calls[0].request.body)
156+
source_code = body["source_code"]
157+
assert body["framework"] == "tensorflow"
158+
159+
ctx = MockModelContext(hyperparameters={"epochs": 1, "n_samples": 20, "batch_size": 8})
160+
ns = {}
161+
exec(source_code, ns)
162+
ns["train"](ctx)
163+
164+
metric_names = [m[0] for m in ctx._logged_metrics]
165+
assert "loss" in metric_names
166+
assert "progress" in metric_names
167+
168+
def test_tensorflow_register_and_infer(self, client, mock_api, tf_model):
169+
"""Full pipeline: register TF/Keras model → exec generated infer()."""
170+
mock_api.add(
171+
responses.POST,
172+
f"{TEST_API_URL}/sdk/register-model",
173+
json=REGISTER_RESPONSE,
174+
status=200,
175+
)
176+
client.register_model("test-keras", model=tf_model)
177+
178+
body = json.loads(mock_api.calls[0].request.body)
179+
source_code = body["source_code"]
180+
181+
# tf_model has input_shape=(4,)
182+
ctx = MockModelContext(
183+
hyperparameters={"input_data": {"features": [[1.0, 2.0, 3.0, 4.0]]}}
184+
)
185+
ns = {}
186+
exec(source_code, ns)
187+
ns["infer"](ctx)
188+
189+
assert ctx._output is not None
190+
assert "predictions" in ctx._output

tests/python/sdk/test_datasets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import os
77
import pytest
88
import responses
9-
import pandas as pd
9+
10+
pd = pytest.importorskip("pandas", reason="pandas not installed")
1011

1112
from openmodelstudio.client import Client
1213

tests/python/sdk/test_features.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Tests for Client.create_features() and Client.load_features()."""
22

33
import json
4-
import numpy as np
5-
import pandas as pd
64
import pytest
75
import responses
86

7+
np = pytest.importorskip("numpy", reason="numpy not installed")
8+
pd = pytest.importorskip("pandas", reason="pandas not installed")
9+
910
from openmodelstudio.client import Client
1011

1112
from conftest import TEST_API_URL, TEST_PROJECT_ID

tests/python/sdk/test_models.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,26 +124,38 @@ def test_register_model_with_sklearn_object(self, client, mock_api, sklearn_mode
124124
reason="torch not available",
125125
)
126126
def test_register_model_with_pytorch_object(self, client, mock_api, pytorch_model):
127-
"""Auto-detects pytorch, serializes nn.Module, generates source with embedded blob.
128-
129-
Note: The pytorch_model fixture defines SimpleNet in a local scope,
130-
which torch.save cannot pickle. We mock _serialize_model to return
131-
dummy bytes so the rest of the registration flow can be tested.
132-
"""
127+
"""Auto-detects pytorch, serializes nn.Module, generates source with embedded blob."""
133128
mock_api.add(
134129
responses.POST,
135130
f"{TEST_API_URL}/sdk/register-model",
136131
json=REGISTER_RESPONSE,
137132
status=200,
138133
)
139-
dummy_bytes = b"fake-pytorch-model-bytes"
140-
with patch("openmodelstudio.client._serialize_model", return_value=dummy_bytes):
141-
handle = client.register_model("pytorch-net", model=pytorch_model)
134+
handle = client.register_model("pytorch-net", model=pytorch_model)
142135

143136
body = json.loads(mock_api.calls[0].request.body)
144137
assert body["framework"] == "pytorch"
145138
assert "_MODEL_B64" in body["source_code"]
146139
assert "import torch" in body["source_code"]
140+
assert "def train(ctx):" in body["source_code"]
141+
assert "def infer(ctx):" in body["source_code"]
142+
assert handle.model_id == REGISTER_RESPONSE["model_id"]
143+
144+
def test_register_model_with_tensorflow_object(self, client, mock_api, tf_model):
145+
"""Auto-detects tensorflow, serializes Keras model, generates source with embedded blob."""
146+
mock_api.add(
147+
responses.POST,
148+
f"{TEST_API_URL}/sdk/register-model",
149+
json=REGISTER_RESPONSE,
150+
status=200,
151+
)
152+
handle = client.register_model("keras-net", model=tf_model)
153+
154+
body = json.loads(mock_api.calls[0].request.body)
155+
assert body["framework"] == "tensorflow"
156+
assert "_MODEL_B64" in body["source_code"]
157+
assert "def train(ctx):" in body["source_code"]
158+
assert "def infer(ctx):" in body["source_code"]
147159
assert handle.model_id == REGISTER_RESPONSE["model_id"]
148160

149161

0 commit comments

Comments
 (0)