Skip to content

Commit 185e8c7

Browse files
authored
feat: Add ml model input and output shape to allow models run on entire tiles (#205)
* add ml model shape * update ml_pipeline example * ml model validierung also allow x,y shaped 1d output * rename MlModel 3DTensorShape attrbutes to bands * update openapi branch * change ml model verification * add more cases to model shape converer * update backend ref * fix onnx to 0.17 * update backend ref * more tests * update backend ref and openapi version * update backend ref * update backend ref
1 parent 4c0e4ee commit 185e8c7

File tree

6 files changed

+125
-25
lines changed

6 files changed

+125
-25
lines changed

.github/.backend_git_ref

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
af126cb150c974cf47a52d2fac5b4a96a81d2c77
1+
7aadfa383e6eee63442e366890dfb1160114caed

examples/expression.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@
397397
"name": "python",
398398
"nbconvert_exporter": "python",
399399
"pygments_lexer": "ipython3",
400-
"version": "3.10.12"
400+
"version": "3.12.3"
401401
}
402402
},
403403
"nbformat": 4,

examples/ml_pipeline.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
},
1010
{
1111
"cell_type": "code",
12-
"execution_count": 1,
12+
"execution_count": null,
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
1616
"import geoengine as ge\n",
1717
"from geoengine.ml import MlModelConfig\n",
1818
"\n",
19-
"from geoengine_openapi_client.models import MlModelMetadata, RasterDataType\n",
19+
"from geoengine_openapi_client.models import MlModelMetadata, RasterDataType, MlTensorShape3D as TensorShape3D\n",
2020
"\n",
2121
"from sklearn.tree import DecisionTreeClassifier\n",
2222
"import numpy as np\n",
@@ -88,8 +88,9 @@
8888
"metadata = MlModelMetadata(\n",
8989
" file_name=\"model.onnx\",\n",
9090
" input_type=RasterDataType.F32,\n",
91-
" num_input_bands=2,\n",
9291
" output_type=RasterDataType.I64,\n",
92+
" input_shape=TensorShape3D(y=1, x=1, bands=2),\n",
93+
" output_shape=TensorShape3D(y=1, x=1, bands=1)\n",
9394
")\n",
9495
"\n",
9596
"model_config = MlModelConfig(\n",
@@ -179,7 +180,7 @@
179180
"name": "python",
180181
"nbconvert_exporter": "python",
181182
"pygments_lexer": "ipython3",
182-
"version": "3.10.12"
183+
"version": "3.12.3"
183184
}
184185
},
185186
"nbformat": 4,

geoengine/ml.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import geoengine_openapi_client.models
1010
from onnx import TypeProto, TensorProto, ModelProto
1111
from onnx.helper import tensor_dtype_to_string
12-
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType
12+
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType, MlTensorShape3D
1313
import geoengine_openapi_client
1414
from geoengine.auth import get_session
1515
from geoengine.resource_identifier import UploadId, MlModelName
@@ -35,8 +35,10 @@ def register_ml_model(onnx_model: ModelProto,
3535
onnx_model,
3636
input_type=model_config.metadata.input_type,
3737
output_type=model_config.metadata.output_type,
38-
num_input_bands=model_config.metadata.num_input_bands,
38+
input_shape=model_config.metadata.input_shape,
39+
out_shape=model_config.metadata.output_shape
3940
)
41+
check_backend_constraints(model_config.metadata.input_shape, model_config.metadata.output_shape)
4042

4143
session = get_session()
4244

@@ -61,10 +63,57 @@ def register_ml_model(onnx_model: ModelProto,
6163
return MlModelName.from_response(res_name)
6264

6365

66+
def model_dim_to_tensorshape(model_dims):
67+
'''Transform an ONNX dimension into a MlTensorShape3D'''
68+
69+
mts = MlTensorShape3D(x=1, y=1, bands=1)
70+
if len(model_dims) == 1 and model_dims[0].dim_value in (-1, 0):
71+
pass # in this case, the model will produce as many outs as inputs
72+
elif len(model_dims) == 1 and model_dims[0].dim_value > 0:
73+
mts.bands = model_dims[0].dim_value
74+
elif len(model_dims) == 2:
75+
if model_dims[0].dim_value in (None, -1, 0, 1):
76+
mts.bands = model_dims[1].dim_value
77+
else:
78+
mts.y = model_dims[0].dim_value
79+
mts.x = model_dims[1].dim_value
80+
elif len(model_dims) == 3:
81+
if model_dims[0].dim_value in (None, -1, 0, 1):
82+
mts.y = model_dims[1].dim_value
83+
mts.x = model_dims[2].dim_value
84+
else:
85+
mts.y = model_dims[0].dim_value
86+
mts.x = model_dims[1].dim_value
87+
mts.bands = model_dims[2].dim_value
88+
elif len(model_dims) == 4 and model_dims[0].dim_value in (None, -1, 0, 1):
89+
mts.y = model_dims[1].dim_value
90+
mts.x = model_dims[2].dim_value
91+
mts.bands = model_dims[3].dim_value
92+
else:
93+
raise InputException(f'Only 1D and 3D input tensors are supported. Got model dim {model_dims}')
94+
return mts
95+
96+
97+
def check_backend_constraints(input_shape: MlTensorShape3D, output_shape: MlTensorShape3D, ge_tile_size=(512, 512)):
98+
''' Checks that the shapes match the constraintsof the backend'''
99+
100+
if not (
101+
input_shape.x in [1, ge_tile_size[0]] and input_shape.y in [1, ge_tile_size[1]] and input_shape.bands > 0
102+
):
103+
raise InputException(f'Backend currently supports single pixel and full tile shaped input! Got {input_shape}!')
104+
105+
if not (
106+
output_shape.x in [1, ge_tile_size[0]] and output_shape.y in [1, ge_tile_size[1]] and output_shape.bands > 0
107+
):
108+
raise InputException(f'Backend currently supports single pixel and full tile shaped Output! Got {input_shape}!')
109+
110+
111+
# pylint: disable=too-many-branches,too-many-statements
64112
def validate_model_config(onnx_model: ModelProto, *,
65113
input_type: RasterDataType,
66114
output_type: RasterDataType,
67-
num_input_bands: int):
115+
input_shape: MlTensorShape3D,
116+
out_shape: MlTensorShape3D):
68117
'''Validates the model config. Raises an exception if the model config is invalid'''
69118

70119
def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: 'str'):
@@ -85,18 +134,21 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
85134
raise InputException('Models with multiple inputs are not supported')
86135
check_data_type(model_inputs[0].type, input_type, 'input')
87136

88-
dims = model_inputs[0].type.tensor_type.shape.dim
89-
if len(dims) != 2:
90-
raise InputException('Only 2D input tensors are supported')
91-
if not dims[1].dim_value:
92-
raise InputException('Dimension 1 of the input tensor must have a length')
93-
if dims[1].dim_value != num_input_bands:
94-
raise InputException(f'Model input has {dims[1].dim_value} bands, but {num_input_bands} bands are expected')
137+
dim = model_inputs[0].type.tensor_type.shape.dim
138+
139+
in_ts3d = model_dim_to_tensorshape(dim)
140+
if not in_ts3d == input_shape:
141+
raise InputException(f"Input shape {in_ts3d} and metadata {input_shape} not equal!")
95142

96143
if len(model_outputs) < 1:
97144
raise InputException('Models with no outputs are not supported')
98145
check_data_type(model_outputs[0].type, output_type, 'output')
99146

147+
dim = model_outputs[0].type.tensor_type.shape.dim
148+
out_ts3d = model_dim_to_tensorshape(dim)
149+
if not out_ts3d == out_shape:
150+
raise InputException(f"Output shape {out_ts3d} and metadata {out_shape} not equal!")
151+
100152

101153
RASTER_TYPE_TO_ONNX_TYPE = {
102154
RasterDataType.F32: TensorProto.FLOAT,

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package_dir =
1818
packages = find:
1919
python_requires = >=3.10
2020
install_requires =
21-
geoengine-openapi-client == 0.0.23
21+
geoengine-openapi-client == 0.0.25
2222
geopandas >=1.0,<2.0
2323
matplotlib >=3.5,<3.11
2424
numpy >=1.21,<2.3

tests/test_ml.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,63 @@
11
'''Tests ML functionality'''
22

3+
from typing import List
34
import unittest
5+
from onnx import TensorShapeProto as TSP
46
from sklearn.ensemble import RandomForestClassifier
57
from skl2onnx import to_onnx
68
import numpy as np
7-
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType
9+
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType, MlTensorShape3D
810
import geoengine as ge
11+
from geoengine.ml import model_dim_to_tensorshape
912
from tests.ge_test import GeoEngineTestInstance
1013

1114

12-
class WorkflowStorageTests(unittest.TestCase):
13-
'''Test methods for storing workflows as datasets'''
15+
class MlModelTests(unittest.TestCase):
16+
'''Test methods for MlModels'''
1417

1518
def setUp(self) -> None:
1619
ge.reset(False)
1720

21+
def test_model_dim_to_tensorshape(self):
22+
''' Test model_dim_to_tensorshape '''
23+
24+
dim_1d: List[TSP.Dimension] = [TSP.Dimension(dim_value=7)]
25+
mts_1d = MlTensorShape3D(bands=7, y=1, x=1)
26+
self.assertEqual(model_dim_to_tensorshape(dim_1d), mts_1d)
27+
28+
dim_1d_v: List[TSP.Dimension] = [TSP.Dimension(dim_value=None), TSP.Dimension(dim_value=7)]
29+
mts_1d_v = MlTensorShape3D(bands=7, y=1, x=1)
30+
self.assertEqual(model_dim_to_tensorshape(dim_1d_v), mts_1d_v)
31+
32+
dim_2d_t: List[TSP.Dimension] = [TSP.Dimension(dim_value=512), TSP.Dimension(dim_value=512)]
33+
mts_2d_t = MlTensorShape3D(bands=1, y=512, x=512)
34+
self.assertEqual(model_dim_to_tensorshape(dim_2d_t), mts_2d_t)
35+
36+
dim_2d_1: List[TSP.Dimension] = [TSP.Dimension(dim_value=1), TSP.Dimension(dim_value=7)]
37+
mts_2d_1 = MlTensorShape3D(bands=7, y=1, x=1)
38+
self.assertEqual(model_dim_to_tensorshape(dim_2d_1), mts_2d_1)
39+
40+
dim_3d_t: List[TSP.Dimension] = [
41+
TSP.Dimension(dim_value=512), TSP.Dimension(dim_value=512), TSP.Dimension(dim_value=7)
42+
]
43+
mts_3d_t = MlTensorShape3D(bands=7, y=512, x=512)
44+
self.assertEqual(model_dim_to_tensorshape(dim_3d_t), mts_3d_t)
45+
46+
dim_3d_v: List[TSP.Dimension] = [
47+
TSP.Dimension(dim_value=None), TSP.Dimension(dim_value=512), TSP.Dimension(dim_value=512)
48+
]
49+
mts_3d_v = MlTensorShape3D(bands=1, y=512, x=512)
50+
self.assertEqual(model_dim_to_tensorshape(dim_3d_v), mts_3d_v)
51+
52+
dim_4d_v: List[TSP.Dimension] = [
53+
TSP.Dimension(dim_value=None),
54+
TSP.Dimension(dim_value=512),
55+
TSP.Dimension(dim_value=512),
56+
TSP.Dimension(dim_value=4)
57+
]
58+
mts_4d_v = MlTensorShape3D(bands=4, y=512, x=512)
59+
self.assertEqual(model_dim_to_tensorshape(dim_4d_v), mts_4d_v)
60+
1861
def test_uploading_onnx_model(self):
1962

2063
clf = RandomForestClassifier(random_state=42)
@@ -40,8 +83,9 @@ def test_uploading_onnx_model(self):
4083
metadata=MlModelMetadata(
4184
file_name="model.onnx",
4285
input_type=RasterDataType.F32,
43-
num_input_bands=2,
4486
output_type=RasterDataType.I64,
87+
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
88+
output_shape=MlTensorShape3D(y=1, x=1, bands=1)
4589
),
4690
display_name="Decision Tree",
4791
description="A simple decision tree model",
@@ -77,16 +121,17 @@ def test_uploading_onnx_model(self):
77121
metadata=MlModelMetadata(
78122
file_name="model.onnx",
79123
input_type=RasterDataType.F32,
80-
num_input_bands=4,
81124
output_type=RasterDataType.I64,
125+
input_shape=MlTensorShape3D(y=1, x=1, bands=4),
126+
output_shape=MlTensorShape3D(y=1, x=1, bands=1)
82127
),
83128
display_name="Decision Tree",
84129
description="A simple decision tree model",
85130
)
86131
)
87132
self.assertEqual(
88133
str(exception.exception),
89-
'Model input has 2 bands, but 4 bands are expected'
134+
'Input shape bands=2 x=1 y=1 and metadata bands=4 x=1 y=1 not equal!'
90135
)
91136

92137
with self.assertRaises(ge.InputException) as exception:
@@ -97,8 +142,9 @@ def test_uploading_onnx_model(self):
97142
metadata=MlModelMetadata(
98143
file_name="model.onnx",
99144
input_type=RasterDataType.F64,
100-
num_input_bands=2,
101145
output_type=RasterDataType.I64,
146+
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
147+
output_shape=MlTensorShape3D(y=1, x=1, bands=1)
102148
),
103149
display_name="Decision Tree",
104150
description="A simple decision tree model",
@@ -117,8 +163,9 @@ def test_uploading_onnx_model(self):
117163
metadata=MlModelMetadata(
118164
file_name="model.onnx",
119165
input_type=RasterDataType.F32,
120-
num_input_bands=2,
121166
output_type=RasterDataType.I32,
167+
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
168+
output_shape=MlTensorShape3D(y=1, x=1, bands=1)
122169
),
123170
display_name="Decision Tree",
124171
description="A simple decision tree model",

0 commit comments

Comments
 (0)