Skip to content

Commit d862073

Browse files
fixed issues with to_array function
1 parent 757af75 commit d862073

File tree

2 files changed

+75
-48
lines changed

2 files changed

+75
-48
lines changed

imas/ids_slice.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def _path(self) -> str:
7777
@property
7878
def shape(self) -> Tuple[int, ...]:
7979
"""Get the virtual multi-dimensional shape.
80-
80+
8181
Returns the shape of the data as if it were organized in a multi-dimensional
8282
array, based on the hierarchy of slicing operations performed.
83-
83+
8484
Returns:
8585
Tuple of dimensions. Use with caution for jagged arrays where sizes vary.
8686
"""
@@ -119,7 +119,7 @@ def __getitem__(self, item: Union[int, slice]) -> Union[Any, "IDSSlice"]:
119119
# Preserve structure instead of flattening
120120
sliced_elements = []
121121
sliced_sizes = []
122-
122+
123123
for array in self._matched_elements:
124124
sliced = array[item]
125125
if isinstance(sliced, IDSSlice):
@@ -134,7 +134,9 @@ def __getitem__(self, item: Union[int, slice]) -> Union[Any, "IDSSlice"]:
134134

135135
# Update shape to reflect the sliced structure
136136
# Keep first dimensions, update last dimension
137-
new_virtual_shape = self._virtual_shape[:-1] + (sliced_sizes[0] if sliced_sizes else 0,)
137+
new_virtual_shape = self._virtual_shape[:-1] + (
138+
sliced_sizes[0] if sliced_sizes else 0,
139+
)
138140
new_hierarchy = self._element_hierarchy[:-1] + [sliced_sizes]
139141

140142
return IDSSlice(
@@ -172,7 +174,9 @@ def __getitem__(self, item: Union[int, slice]) -> Union[Any, "IDSSlice"]:
172174

173175
# Update shape to reflect the slice on first dimension
174176
new_virtual_shape = (len(sliced_elements),) + self._virtual_shape[1:]
175-
new_element_hierarchy = [len(sliced_elements)] + self._element_hierarchy[1:]
177+
new_element_hierarchy = [
178+
len(sliced_elements)
179+
] + self._element_hierarchy[1:]
176180

177181
return IDSSlice(
178182
self.metadata,
@@ -232,11 +236,13 @@ def __getattr__(self, name: str) -> "IDSSlice":
232236
if isinstance(child_elements[0], IDSStructArray):
233237
# Children are IDSStructArray - track the new dimension
234238
child_sizes = [len(arr) for arr in child_elements]
235-
239+
236240
# New virtual shape: current shape + new dimension
237-
new_virtual_shape = self._virtual_shape + (child_sizes[0] if child_sizes else 0,)
241+
new_virtual_shape = self._virtual_shape + (
242+
child_sizes[0] if child_sizes else 0,
243+
)
238244
new_hierarchy = self._element_hierarchy + [child_sizes]
239-
245+
240246
return IDSSlice(
241247
child_metadata,
242248
child_elements,
@@ -249,12 +255,14 @@ def __getattr__(self, name: str) -> "IDSSlice":
249255
# Children are IDSNumericArray - track the array dimension
250256
# Each IDSNumericArray has a size (length of its data)
251257
child_sizes = [len(arr) for arr in child_elements]
252-
253-
# New virtual shape: current shape + new dimension (size of first numeric array)
258+
259+
# New virtual shape: current shape + new dimension
254260
# Jagged arrays handled by to_array() with object dtype
255-
new_virtual_shape = self._virtual_shape + (child_sizes[0] if child_sizes else 0,)
261+
new_virtual_shape = self._virtual_shape + (
262+
child_sizes[0] if child_sizes else 0,
263+
)
256264
new_hierarchy = self._element_hierarchy + [child_sizes]
257-
265+
258266
return IDSSlice(
259267
child_metadata,
260268
child_elements,
@@ -338,7 +346,11 @@ def values(self, reshape: bool = False) -> Any:
338346
>>> # Result: ndarray shape (106, 100)
339347
>>>
340348
>>> # 3D ions case - returns object array with structure
341-
>>> ion_rho = core_profiles.profiles_1d[:].ion[:].element[:].density.values(reshape=True)
349+
>>> ion_rho = (
350+
... core_profiles.profiles_1d[:].ion[:].element[:].density.values(
351+
... reshape=True
352+
... )
353+
... )
342354
>>> # Result: object array shape (106, 3, 2) with IDSNumericArray elements
343355
"""
344356
from imas.ids_primitive import IDSPrimitive, IDSNumericArray
@@ -359,7 +371,9 @@ def values(self, reshape: bool = False) -> Any:
359371
if isinstance(element, IDSPrimitive):
360372
flat_values.append(element.value)
361373
elif isinstance(element, IDSNumericArray):
362-
flat_values.append(element.data if hasattr(element, 'data') else element.value)
374+
flat_values.append(
375+
element.data if hasattr(element, "data") else element.value
376+
)
363377
else:
364378
flat_values.append(element)
365379

@@ -399,7 +413,8 @@ def to_array(self) -> np.ndarray:
399413
structure of the IMAS data.
400414
401415
Returns:
402-
numpy.ndarray with shape self.shape. For jagged arrays, dtype will be object.
416+
numpy.ndarray with shape self.shape. For jagged arrays,
417+
dtype will be object.
403418
404419
Raises:
405420
ValueError: If array cannot be converted to numpy
@@ -436,25 +451,27 @@ def to_array(self) -> np.ndarray:
436451

437452
# Multi-dimensional case
438453
# Check if matched elements are themselves arrays (IDSNumericArray)
439-
if self._matched_elements and isinstance(self._matched_elements[0], IDSNumericArray):
454+
if self._matched_elements and isinstance(
455+
self._matched_elements[0], IDSNumericArray
456+
):
440457
# Elements are numeric arrays - extract their values and stack them
441458
array_values = []
442459
for element in self._matched_elements:
443460
if isinstance(element, IDSNumericArray):
444461
array_values.append(element.value)
445462
else:
446463
array_values.append(element)
447-
464+
448465
# Try to stack into proper shape
449466
try:
450467
# Check if all arrays have the same size (regular)
451468
sizes = []
452469
for val in array_values:
453-
if hasattr(val, '__len__'):
470+
if hasattr(val, "__len__"):
454471
sizes.append(len(val))
455472
else:
456473
sizes.append(1)
457-
474+
458475
# If all sizes are the same, we can create a regular array
459476
if len(set(sizes)) == 1:
460477
# Regular array - all sub-arrays same size
@@ -478,7 +495,7 @@ def to_array(self) -> np.ndarray:
478495
for i, val in enumerate(array_values):
479496
result_arr[i] = val
480497
return result_arr
481-
except (ValueError, TypeError) as e:
498+
except (ValueError, TypeError):
482499
# Fallback: return object array
483500
result_arr = np.empty(self._virtual_shape[0], dtype=object)
484501
for i, val in enumerate(array_values):
@@ -488,11 +505,29 @@ def to_array(self) -> np.ndarray:
488505
# For non-numeric elements in multi-dimensional structure
489506
# Extract and try to build structure
490507
flat_values = []
491-
for element in self._matched_elements:
492-
if isinstance(element, IDSPrimitive):
493-
flat_values.append(element.value)
494-
else:
495-
flat_values.append(element)
508+
509+
# First check if matched_elements are IDSStructArray (which need flattening)
510+
from imas.ids_struct_array import IDSStructArray
511+
512+
has_struct_arrays = self._matched_elements and isinstance(
513+
self._matched_elements[0], IDSStructArray
514+
)
515+
516+
if has_struct_arrays:
517+
# Flatten IDSStructArray elements
518+
for struct_array in self._matched_elements:
519+
for element in struct_array:
520+
if isinstance(element, IDSPrimitive):
521+
flat_values.append(element.value)
522+
else:
523+
flat_values.append(element)
524+
else:
525+
# Regular elements
526+
for element in self._matched_elements:
527+
if isinstance(element, IDSPrimitive):
528+
flat_values.append(element.value)
529+
else:
530+
flat_values.append(element)
496531

497532
total_size = 1
498533
for dim in self._virtual_shape:

imas/test/test_multidim_slicing.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pytest
77

88
from imas.ids_factory import IDSFactory
9-
from imas.ids_slice import IDSSlice
109

1110

1211
class TestMultiDimSlicing:
@@ -137,7 +136,7 @@ def test_slice_preserves_groups(self):
137136

138137
# Get all ions, then slice
139138
result = cp.profiles_1d[:].ion[:]
140-
139+
141140
# Should still know the structure: 10 profiles, 3 ions each
142141
assert result.shape == (10, 3)
143142
assert len(result) == 30 # Flattened for iteration, but shape preserved
@@ -153,7 +152,7 @@ def test_integer_index_on_nested(self):
153152

154153
# Get first ion from all profiles
155154
result = cp.profiles_1d[:].ion[0]
156-
155+
157156
assert len(result) == 5
158157
for i, ion in enumerate(result):
159158
assert ion.label == f"ion_{i}_0"
@@ -167,7 +166,7 @@ def test_slice_on_nested_arrays(self):
167166

168167
# Get first 2 ions from each profile
169168
result = cp.profiles_1d[:].ion[:2]
170-
169+
171170
assert result.shape == (5, 2)
172171
assert len(result) == 10 # 5 profiles * 2 ions each
173172

@@ -180,7 +179,7 @@ def test_step_slicing_on_nested(self):
180179

181180
# Get every other ion
182181
result = cp.profiles_1d[:].ion[::2]
183-
182+
184183
assert result.shape == (5, 3) # 5 profiles, 3 ions each (0, 2, 4)
185184
assert len(result) == 15
186185

@@ -195,7 +194,7 @@ def test_negative_indexing_on_nested(self):
195194

196195
# Get last ion from each profile
197196
result = cp.profiles_1d[:].ion[-1]
198-
197+
199198
assert len(result) == 5
200199
for ion in result:
201200
assert ion.label == "ion_2"
@@ -227,7 +226,7 @@ def test_boolean_indexing_simple(self):
227226
p.electrons.density = np.array([float(i)] * 5)
228227

229228
result = cp.profiles_1d[:].electrons.density
230-
229+
231230
mask = np.array([True, False, True, False, True])
232231
filtered = result[mask]
233232
assert len(filtered) == 3
@@ -239,15 +238,10 @@ def test_assignment_on_slice(self):
239238
for p in cp.profiles_1d:
240239
p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0])
241240

242-
# Assign new values through slice
243-
new_values = np.array([[0.1, 0.6, 1.1],
244-
[0.2, 0.7, 1.2],
245-
[0.3, 0.8, 1.3]])
246-
247241
# This requires assignment support
248242
# cp.profiles_1d[:].grid.rho_tor_norm[:] = new_values
249243
# For now, verify slicing works for reading
250-
244+
251245
result = cp.profiles_1d[:].grid.rho_tor_norm
252246
array = result.to_array()
253247
assert array.shape == (3, 3)
@@ -257,7 +251,7 @@ def test_xarray_integration_compatible(self):
257251
cp = IDSFactory("3.39.0").core_profiles()
258252
cp.profiles_1d.resize(3)
259253
cp.time = np.array([1.0, 2.0, 3.0])
260-
254+
261255
for i, p in enumerate(cp.profiles_1d):
262256
p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0])
263257
p.electrons.temperature = np.array([1.0, 2.0, 3.0]) * (i + 1)
@@ -274,7 +268,7 @@ def test_performance_large_hierarchy(self):
274268
cp = IDSFactory("3.39.0").core_profiles()
275269
n_profiles = 50
276270
cp.profiles_1d.resize(n_profiles)
277-
271+
278272
for p in cp.profiles_1d:
279273
p.grid.rho_tor_norm = np.linspace(0, 1, 100)
280274
p.ion.resize(5)
@@ -284,7 +278,7 @@ def test_performance_large_hierarchy(self):
284278
# Should handle large data without significant slowdown
285279
result = cp.profiles_1d[:].grid.rho_tor_norm
286280
array = result.to_array()
287-
281+
288282
assert array.shape == (n_profiles, 100)
289283

290284
def test_lazy_loading_with_multidim(self):
@@ -296,12 +290,10 @@ def test_lazy_loading_with_multidim(self):
296290
p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0])
297291

298292
result = cp.profiles_1d[:].grid.rho_tor_norm
299-
300-
# Verify lazy attributes are preserved
301-
assert hasattr(result, '_lazy')
302-
assert hasattr(result, '_parent_array')
303-
304293

294+
# Verify lazy attributes are preserved
295+
assert hasattr(result, "_lazy")
296+
assert hasattr(result, "_parent_array")
305297

306298

307299
class TestEdgeCases:
@@ -339,7 +331,7 @@ def test_single_dimension_value(self):
339331
i.z_ion = 1.0
340332

341333
result = cp.profiles_1d[:].ion[0].z_ion
342-
334+
343335
# Should be 3 items (one per profile)
344336
assert len(result) == 3
345337

@@ -352,6 +344,6 @@ def test_slice_of_slice(self):
352344

353345
result1 = cp.profiles_1d[::2].ion # Every other profile's ions
354346
assert result1.shape == (5, 3)
355-
347+
356348
result2 = result1[:2] # First 2 from each
357349
assert result2.shape == (5, 2)

0 commit comments

Comments
 (0)