Skip to content

Commit 8e94de5

Browse files
committed
add shape property
1 parent a693a49 commit 8e94de5

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

src/graphnet/models/data_representation/images/mappings/pixel_mappings.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ def _set_image_feature_names(self, input_feature_names: List[str]) -> None:
4747
"""Set the final image feature names."""
4848
raise NotImplementedError
4949

50+
@property
51+
@abstractmethod
52+
def shape(
53+
self,
54+
) -> List[List[int]]:
55+
"""Return the shape of the output images as a list of tuples.
56+
57+
In the dimensions (F,D,H,W) where F is the number of features
58+
per pixel. And D,H,W are the dimension of the image
59+
"""
60+
pass
61+
5062

5163
class IC86PixelMapping(PixelMapping):
5264
"""Mapping for the IceCube86.
@@ -230,6 +242,20 @@ def _set_image_feature_names(self, input_feature_names: List[str]) -> None:
230242
if infeature not in [self._string_label, self._dom_number_label]
231243
]
232244

245+
@property
246+
def shape(
247+
self,
248+
) -> List[List[int]]:
249+
"""Return the shape of the output images as a list of tuples."""
250+
ret = []
251+
if self._include_main_array:
252+
ret.append([self._nb_cnn_features, 10, 10, 60])
253+
if self._include_upper_dc:
254+
ret.append([self._nb_cnn_features, 1, 8, 10])
255+
if self._include_lower_dc:
256+
ret.append([self._nb_cnn_features, 1, 8, 50])
257+
return ret
258+
233259

234260
class ExamplePrometheusMapping(PixelMapping):
235261
"""Mapping for the Prometheus detector.
@@ -363,3 +389,10 @@ def _set_image_feature_names(self, input_feature_names: List[str]) -> None:
363389
for infeature in input_feature_names
364390
if infeature not in [self._string_label, self._sensor_number_label]
365391
]
392+
393+
@property
394+
def shape(
395+
self,
396+
) -> List[List[int]]:
397+
"""Return the shape of the output images as a list of tuples."""
398+
return [[self._nb_cnn_features, 8, 9, 22]]

tests/models/test_pixel_mapping.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,24 @@ def test_pixel_mappings() -> None:
6868
# Apply node definition to torch tensor with raw pulses
6969
picture = pixel_mapping(dummy_data, pixel_feature_names)
7070
new_features = pixel_mapping.image_feature_names
71+
n_features = len(new_features)
7172

7273
# Check the output
7374
basic_checks_picture(picture, dtype)
7475

7576
# More checks
77+
assert (
78+
len(pixel_mapping.shape) == 3
79+
), f"Expected shape to be 3 got {len(pixel_mapping.shape)}"
80+
assert pixel_mapping.shape == [
81+
(n_features, 10, 10, 60),
82+
(n_features, 1, 8, 10),
83+
(n_features, 1, 8, 50),
84+
], (
85+
f"Expected shape to be [({n_features},10,10,60), "
86+
f"({n_features},1,8,10), ({n_features},1,8,50)] got "
87+
f"{pixel_mapping.shape}"
88+
)
7689
assert isinstance(
7790
new_features, list
7891
), f"Output should be a list of feature names got {type(new_features)}"

0 commit comments

Comments
 (0)