Skip to content

Commit 46ce218

Browse files
Merge pull request #217 from geo-engine/pydantic-2
Pydantic-2
2 parents b35bad6 + f4fbbfd commit 46ce218

File tree

11 files changed

+96
-51
lines changed

11 files changed

+96
-51
lines changed

geoengine/colorizer.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,13 @@ def from_response(response: geoengine_openapi_client.Colorizer) -> Colorizer:
230230
raise TypeError("Unknown colorizer type")
231231

232232

233+
def rgba_from_list(values: list[int]) -> Rgba:
234+
"""Convert a list of integers to an RGBA tuple."""
235+
if len(values) != 4:
236+
raise ValueError(f"Expected a list of 4 integers, got {len(values)} instead.")
237+
return (values[0], values[1], values[2], values[3])
238+
239+
233240
@dataclass
234241
class LinearGradientColorizer(Colorizer):
235242
'''A linear gradient colorizer.'''
@@ -242,10 +249,10 @@ def from_response_linear(response: geoengine_openapi_client.LinearGradient) -> L
242249
"""Create a colorizer from a response."""
243250
breakpoints = [ColorBreakpoint.from_response(breakpoint) for breakpoint in response.breakpoints]
244251
return LinearGradientColorizer(
245-
no_data_color=response.no_data_color,
252+
no_data_color=rgba_from_list(response.no_data_color),
246253
breakpoints=breakpoints,
247-
over_color=response.over_color,
248-
under_color=response.under_color,
254+
over_color=rgba_from_list(response.over_color),
255+
under_color=rgba_from_list(response.under_color),
249256
)
250257

251258
def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
@@ -273,9 +280,9 @@ def from_response_logarithmic(
273280
breakpoints = [ColorBreakpoint.from_response(breakpoint) for breakpoint in response.breakpoints]
274281
return LogarithmicGradientColorizer(
275282
breakpoints=breakpoints,
276-
no_data_color=response.no_data_color,
277-
over_color=response.over_color,
278-
under_color=response.under_color,
283+
no_data_color=rgba_from_list(response.no_data_color),
284+
over_color=rgba_from_list(response.over_color),
285+
under_color=rgba_from_list(response.under_color),
279286
)
280287

281288
def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
@@ -300,16 +307,16 @@ def from_response_palette(response: geoengine_openapi_client.PaletteColorizer) -
300307
"""Create a colorizer from a response."""
301308

302309
return PaletteColorizer(
303-
colors={float(k): v for k, v in response.colors.items()},
304-
no_data_color=response.no_data_color,
305-
default_color=response.default_color,
310+
colors={float(k): rgba_from_list(v) for k, v in response.colors.items()},
311+
no_data_color=rgba_from_list(response.no_data_color),
312+
default_color=rgba_from_list(response.default_color),
306313
)
307314

308315
def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
309316
"""Return the colorizer as a dictionary."""
310317
return geoengine_openapi_client.Colorizer(geoengine_openapi_client.PaletteColorizer(
311318
type='palette',
312-
colors=self.colors,
319+
colors={str(k): v for k, v in self.colors.items()},
313320
default_color=self.default_color,
314321
no_data_color=self.no_data_color,
315322
))

geoengine/datasets.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,14 @@ def upload_dataframe(
466466
ints = [key for (key, value) in columns.items() if value.data_type == 'int']
467467
texts = [key for (key, value) in columns.items() if value.data_type == 'text']
468468

469+
result_descriptor = VectorResultDescriptor(
470+
data_type=vector_type,
471+
spatial_reference=df.crs.to_string(),
472+
columns=columns,
473+
).to_api_dict().actual_instance
474+
if not isinstance(result_descriptor, geoengine_openapi_client.TypedVectorResultDescriptor):
475+
raise TypeError('Expected TypedVectorResultDescriptor')
476+
469477
create = geoengine_openapi_client.CreateDataset(
470478
data_path=geoengine_openapi_client.DataPath(geoengine_openapi_client.DataPathOneOf1(
471479
upload=str(upload_id)
@@ -494,11 +502,9 @@ def upload_dataframe(
494502
),
495503
on_error=on_error.to_api_enum(),
496504
),
497-
result_descriptor=VectorResultDescriptor(
498-
data_type=vector_type,
499-
spatial_reference=df.crs.to_string(),
500-
columns=columns,
501-
).to_api_dict().actual_instance
505+
result_descriptor=geoengine_openapi_client.VectorResultDescriptor.from_dict(
506+
result_descriptor.to_dict()
507+
)
502508
)
503509
)
504510
)

geoengine/error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, response: Union[geoengine_openapi_client.ApiException, Dict[s
2121
super().__init__()
2222

2323
if isinstance(response, geoengine_openapi_client.ApiException):
24-
obj = json.loads(response.body)
24+
obj = json.loads(response.body) if response.body else {'error': 'unknown', 'message': 'unknown'}
2525
else:
2626
obj = response
2727

geoengine/ml.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
103103
if not data_type.tensor_type:
104104
raise InputException('Only tensor input types are supported')
105105
elem_type = data_type.tensor_type.elem_type
106-
if elem_type != RASTER_TYPE_TO_ONNX_TYPE[expected_type]:
106+
expected_tensor_type = RASTER_TYPE_TO_ONNX_TYPE[expected_type]
107+
if elem_type != expected_tensor_type:
107108
elem_type_str = tensor_dtype_to_string(elem_type)
109+
expected_type_str = tensor_dtype_to_string(expected_tensor_type)
108110
raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the '
109-
f'expected type `{expected_type}`')
111+
f'expected type `{expected_type_str}`')
110112

111113
model_inputs = onnx_model.graph.input
112114
model_outputs = onnx_model.graph.output

geoengine/types.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,29 @@ def __repr__(self) -> str:
660660
return f'{self.name}: {self.measurement}'
661661

662662

663+
def literal_raster_data_type(
664+
data_type: geoengine_openapi_client.RasterDataType
665+
) -> Literal['U8', 'U16', 'U32', 'U64', 'I8', 'I16', 'I32', 'I64', 'F32', 'F64']:
666+
'''Convert a `RasterDataType` to a literal'''
667+
668+
data_type_map: dict[
669+
geoengine_openapi_client.RasterDataType,
670+
Literal['U8', 'U16', 'U32', 'U64', 'I8', 'I16', 'I32', 'I64', 'F32', 'F64']
671+
] = {
672+
geoengine_openapi_client.RasterDataType.U8: 'U8',
673+
geoengine_openapi_client.RasterDataType.U16: 'U16',
674+
geoengine_openapi_client.RasterDataType.U32: 'U32',
675+
geoengine_openapi_client.RasterDataType.U64: 'U64',
676+
geoengine_openapi_client.RasterDataType.I8: 'I8',
677+
geoengine_openapi_client.RasterDataType.I16: 'I16',
678+
geoengine_openapi_client.RasterDataType.I32: 'I32',
679+
geoengine_openapi_client.RasterDataType.I64: 'I64',
680+
geoengine_openapi_client.RasterDataType.F32: 'F32',
681+
geoengine_openapi_client.RasterDataType.F64: 'F64',
682+
}
683+
return data_type_map[data_type]
684+
685+
663686
class RasterResultDescriptor(ResultDescriptor):
664687
'''
665688
A raster result descriptor
@@ -701,7 +724,7 @@ def from_response_raster(
701724
response: geoengine_openapi_client.TypedRasterResultDescriptor) -> RasterResultDescriptor:
702725
'''Parse a raster result descriptor from an http response'''
703726
spatial_ref = response.spatial_reference
704-
data_type = response.data_type.value
727+
data_type = literal_raster_data_type(response.data_type)
705728
bands = [RasterBandDescriptor.from_response(band) for band in response.bands]
706729

707730
time_bounds = None

geoengine/workflow.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,8 +984,10 @@ def data_usage(offset: int = 0, limit: int = 10) -> List[geoengine_openapi_clien
984984
response = user_api.data_usage_handler(offset=offset, limit=limit)
985985

986986
# create dataframe from response
987-
usage_dicts = [data_usage.dict(by_alias=True) for data_usage in response]
987+
usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
988988
df = pd.DataFrame(usage_dicts)
989+
if 'timestamp' in df.columns:
990+
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
989991

990992
return df
991993

@@ -1005,7 +1007,9 @@ def data_usage_summary(granularity: geoengine_openapi_client.UsageSummaryGranula
10051007
offset=offset, limit=limit)
10061008

10071009
# create dataframe from response
1008-
usage_dicts = [data_usage.dict(by_alias=True) for data_usage in response]
1010+
usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
10091011
df = pd.DataFrame(usage_dicts)
1012+
if 'timestamp' in df.columns:
1013+
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
10101014

10111015
return df

geoengine/workflow_builder/operators.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -694,11 +694,12 @@ def from_operator_dict(cls, operator_dict: Dict[str, Any]) -> 'Expression':
694694

695695
output_band = None
696696
if "outputBand" in operator_dict["params"] and operator_dict["params"]["outputBand"] is not None:
697-
output_band = RasterBandDescriptor.from_response(
698-
geoengine_openapi_client.RasterBandDescriptor.from_dict(
699-
operator_dict["params"]["outputBand"]
700-
)
697+
raster_band_descriptor = geoengine_openapi_client.RasterBandDescriptor.from_dict(
698+
operator_dict["params"]["outputBand"]
701699
)
700+
if raster_band_descriptor is None:
701+
raise ValueError("Invalid output band")
702+
output_band = RasterBandDescriptor.from_response(raster_band_descriptor)
702703

703704
return Expression(
704705
expression=operator_dict["params"]["expression"],

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package_dir =
1818
packages = find:
1919
python_requires = >=3.9
2020
install_requires =
21-
geoengine-openapi-client == 0.0.19
21+
geoengine-openapi-client == 0.0.21
2222
geopandas >=0.9,<0.15
2323
matplotlib >=3.5,<3.8
2424
numpy >=1.21,<2.1
@@ -34,7 +34,7 @@ install_requires =
3434
websockets >= 10.0,<11
3535
xarray >=0.19,<2024.12
3636
urllib3 >= 2.0, < 2.3
37-
pydantic >= 1.10.5, < 2
37+
pydantic >= 2.10.6, < 2.11
3838
skl2onnx >=1.17,<2
3939

4040
[options.extras_require]

tests/ge_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def _start(self) -> None:
201201
'GEOENGINE__POSTGRES__PASSWORD': POSTGRES_PASSWORD,
202202
'GEOENGINE__POSTGRES__SCHEMA': self.db_schema,
203203
'GEOENGINE__LOGGING__LOG_SPEC': GE_LOG_SPEC,
204+
'GEOENGINE__POSTGRES__CLEAR_DATABASE_ON_START': 'true',
204205
'PATH': os.environ['PATH'],
205206
},
206207
stderr=subprocess.PIPE,

tests/test_ml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_uploading_onnx_model(self):
106106
)
107107
self.assertEqual(
108108
str(exception.exception),
109-
'Model input type `TensorProto.FLOAT` does not match the expected type `RasterDataType.F64`'
109+
'Model input type `TensorProto.FLOAT` does not match the expected type `TensorProto.DOUBLE`'
110110
)
111111

112112
with self.assertRaises(ge.InputException) as exception:
@@ -126,5 +126,5 @@ def test_uploading_onnx_model(self):
126126
)
127127
self.assertEqual(
128128
str(exception.exception),
129-
'Model output type `TensorProto.INT64` does not match the expected type `RasterDataType.I32`'
129+
'Model output type `TensorProto.INT64` does not match the expected type `TensorProto.INT32`'
130130
)

0 commit comments

Comments
 (0)