1+ from __future__ import annotations
2+ from os import PathLike
13import numpy as np
24from .units import convert_unit
35import re
4- from copy import copy
6+ from attrs import define , field
57
68from . import hdf5
79
1012_object_name_pattern = re .compile ('[a-zA-Z][a-zA-Z0-9_]*' )
1113
1214
13- class Group (object ):
15+ @define (eq = False )
16+ class Group :
1417 """ SDF Group """
1518
16- def __init__ (self , name , comment = None , attributes = dict (), groups = [], datasets = []):
17- self .name = name
18- self .comment = comment
19- self .attributes = copy (attributes )
20- self .groups = copy (groups )
21- self .datasets = copy (datasets )
19+ name : str = None
20+ comment : str = None
21+ attributes : dict [str , str ] = field (factory = dict )
22+ groups : list [Group ] = field (factory = list )
23+ datasets : list [Dataset ] = field (factory = list )
2224
2325 def __contains__ (self , key ):
2426 for obj in self .datasets + self .groups :
@@ -36,34 +38,21 @@ def __iter__(self):
3638 for obj in self .groups + self .datasets :
3739 yield obj
3840
39- def __repr__ (self ):
40- return '<SDF Group "' + self .name + '": [' + ', ' .join (map (lambda obj : obj .name , self )) + ']>'
4141
42-
43- class Dataset ( object ) :
42+ @ define ( eq = False )
43+ class Dataset :
4444 """ SDF Dataset """
4545
46- def __init__ (self , name ,
47- comment = None ,
48- attributes = dict (),
49- data = np .empty (0 ),
50- display_name = None ,
51- relative_quantity = False ,
52- unit = None ,
53- display_unit = None ,
54- is_scale = False ,
55- scales = []
56- ):
57- self .name = name
58- self .comment = comment
59- self .attributes = copy (attributes )
60- self .data = data
61- self ._display_name = display_name
62- self .relative_quantity = relative_quantity
63- self .unit = unit
64- self ._display_unit = display_unit
65- self .is_scale = is_scale
66- self .scales = scales
46+ name : str = None
47+ comment : str = None
48+ attributes : dict [str , str ] = field (factory = dict )
49+ data : np .typing .NDArray = None
50+ _display_name : str = None
51+ relative_quantity : bool = False
52+ unit : str = None
53+ _display_unit : str = None
54+ is_scale : bool = False
55+ scales : list [Dataset ] = field (factory = list )
6756
6857 @property
6958 def display_data (self ):
@@ -89,72 +78,46 @@ def display_unit(self):
8978 def display_unit (self , value ):
9079 self ._display_unit = value
9180
92- def validate (self ):
93- if self .display_unit and not self .unit :
94- return 'ERROR' , 'display_unit was set but no unit'
95-
96- return 'OK'
97-
9881 # some shorthand aliases
9982 @property
10083 def d (self ):
10184 return self .data
10285
10386 dd = display_data
10487
105- def __repr__ (self ):
106- text = '<SDF Dataset "' + self .name + '": '
107-
108- if not isinstance (self .data , np .ndarray ) or len (self .data .shape ) == 0 :
109- text += str (self .data )
110- elif len (self .data .shape ) == 1 and len (self .data ) <= 10 :
111- text += str (self .data )
112- else :
113- text += '<' + 'x' .join (map (str , self .data .shape )) + '>'
114-
115- if self .unit is not None :
116- text += ' ' + self .unit
117-
118- if any (self .scales ):
119- text += ' w.r.t. ' + ', ' .join (map (lambda s : s .name if s is not None else 'None' , self .scales ))
12088
121- text += '>'
122-
123- return text
124-
125-
126- def validate (obj ):
89+ def validate (obj : Group | Dataset ) -> list [str ]:
12790 """ Validate an sdf.Group or sdf.Dataset """
12891
129- errors = []
92+ problems = []
13093
13194 if isinstance (obj , Group ):
132- errors += _validate_group (obj , is_root = True )
95+ problems += _validate_group (obj , is_root = True )
13396 elif isinstance (obj , Dataset ):
134- errors += _validate_dataset (obj )
97+ problems += _validate_dataset (obj )
13598 else :
136- errors .append (' Unknown object type: %s' % type (obj ))
99+ problems .append (f" Unknown object type: { type (obj )} " )
137100
138- return errors
101+ return problems
139102
140103
141104def _validate_group (group , is_root = False ):
142- errors = []
105+
106+ problems = []
143107
144108 if not is_root and not _object_name_pattern .match (group .name ):
145- errors += [
146- "Object names must only contain letters, digits and underscores (\" _\" ) and must start with a letter" ]
109+ problems .append ("Object names must only contain letters, digits, and underscores (\" _\" ) and must start with a letter." )
147110
148111 for child_group in group .groups :
149- errors += _validate_dataset (child_group )
112+ problems += _validate_dataset (child_group )
150113
151114 for ds in group .datasets :
152- errors += _validate_dataset (ds )
115+ problems += _validate_dataset (ds )
153116
154- return errors
117+ return problems
155118
156119
157- def _validate_dataset (ds ) :
120+ def _validate_dataset (ds : Dataset ) -> list [ str ] :
158121
159122 if type (ds .data ) is not np .ndarray :
160123 return ['Dataset.data must be a numpy.ndarray' ]
@@ -177,10 +140,8 @@ def _validate_dataset(ds):
177140 return []
178141
179142
180- def load (filename , objectname = '/' , unit = None , scale_units = None ):
181- """ Load a dataset or group from an SDF file """
182-
183- obj = None
143+ def load (filename : str | PathLike , objectname : str = '/' , unit : str = None , scale_units : list [str ] = None ) -> Dataset | Group :
144+ """ Load a Dataset or Group from an SDF file """
184145
185146 if filename .endswith ('.mat' ):
186147 from . import dsres
@@ -212,7 +173,7 @@ def load(filename, objectname='/', unit=None, scale_units=None):
212173 return obj
213174
214175
215- def save (filename , group ):
176+ def save (filename : str | PathLike , group : Group ):
216177 """ Save an SDF group to a file """
217178
218179 hdf5 .save (filename , group )
0 commit comments