Skip to content

Commit 7e21a4d

Browse files
committed
Use attrs and add type hints
1 parent 5d299e7 commit 7e21a4d

File tree

6 files changed

+61
-83
lines changed

6 files changed

+61
-83
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ description = "Work with Scientific Data Format files in Python"
55
readme = "README.rst"
66
requires-python = ">=3.10"
77
dependencies = [
8+
"attrs>=25.3.0",
89
"h5py>=3.13.0",
910
"matplotlib>=3.10.3",
1011
"numpy>=2.2.6",

src/sdf/__init__.py

Lines changed: 38 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
from os import PathLike
13
import numpy as np
24
from .units import convert_unit
35
import re
4-
from copy import copy
6+
from attrs import define, field
57

68
from . import hdf5
79

@@ -10,15 +12,15 @@
1012
_object_name_pattern = re.compile('[a-zA-Z][a-zA-Z0-9_]*')
1113

1214

13-
class Group(object):
15+
@define(eq=False)
16+
class Group:
1417
""" SDF Group """
1518

16-
def __init__(self, name, comment=None, attributes=dict(), groups=[], datasets=[]):
17-
self.name = name
18-
self.comment = comment
19-
self.attributes = copy(attributes)
20-
self.groups = copy(groups)
21-
self.datasets = copy(datasets)
19+
name: str = None
20+
comment: str = None
21+
attributes: dict[str, str] = field(factory=dict)
22+
groups: list[Group] = field(factory=list)
23+
datasets: list[Dataset] = field(factory=list)
2224

2325
def __contains__(self, key):
2426
for obj in self.datasets + self.groups:
@@ -36,34 +38,21 @@ def __iter__(self):
3638
for obj in self.groups + self.datasets:
3739
yield obj
3840

39-
def __repr__(self):
40-
return '<SDF Group "' + self.name + '": [' + ', '.join(map(lambda obj: obj.name, self)) + ']>'
4141

42-
43-
class Dataset(object):
42+
@define(eq=False)
43+
class Dataset:
4444
""" SDF Dataset """
4545

46-
def __init__(self, name,
47-
comment=None,
48-
attributes=dict(),
49-
data=np.empty(0),
50-
display_name=None,
51-
relative_quantity=False,
52-
unit=None,
53-
display_unit=None,
54-
is_scale=False,
55-
scales=[]
56-
):
57-
self.name = name
58-
self.comment = comment
59-
self.attributes = copy(attributes)
60-
self.data = data
61-
self._display_name = display_name
62-
self.relative_quantity = relative_quantity
63-
self.unit = unit
64-
self._display_unit = display_unit
65-
self.is_scale = is_scale
66-
self.scales = scales
46+
name: str = None
47+
comment: str = None
48+
attributes: dict[str, str] = field(factory=dict)
49+
data: np.typing.NDArray = None
50+
_display_name: str = None
51+
relative_quantity: bool = False
52+
unit: str = None
53+
_display_unit: str = None
54+
is_scale: bool = False
55+
scales: list[Dataset] = field(factory=list)
6756

6857
@property
6958
def display_data(self):
@@ -89,72 +78,46 @@ def display_unit(self):
8978
def display_unit(self, value):
9079
self._display_unit = value
9180

92-
def validate(self):
93-
if self.display_unit and not self.unit:
94-
return 'ERROR', 'display_unit was set but no unit'
95-
96-
return 'OK'
97-
9881
# some shorthand aliases
9982
@property
10083
def d(self):
10184
return self.data
10285

10386
dd = display_data
10487

105-
def __repr__(self):
106-
text = '<SDF Dataset "' + self.name + '": '
107-
108-
if not isinstance(self.data, np.ndarray) or len(self.data.shape) == 0:
109-
text += str(self.data)
110-
elif len(self.data.shape) == 1 and len(self.data) <= 10:
111-
text += str(self.data)
112-
else:
113-
text += '<' + 'x'.join(map(str, self.data.shape)) + '>'
114-
115-
if self.unit is not None:
116-
text += ' ' + self.unit
117-
118-
if any(self.scales):
119-
text += ' w.r.t. ' + ', '.join(map(lambda s: s.name if s is not None else 'None', self.scales))
12088

121-
text += '>'
122-
123-
return text
124-
125-
126-
def validate(obj):
89+
def validate(obj: Group | Dataset) -> list[str]:
12790
""" Validate an sdf.Group or sdf.Dataset """
12891

129-
errors = []
92+
problems = []
13093

13194
if isinstance(obj, Group):
132-
errors += _validate_group(obj, is_root=True)
95+
problems += _validate_group(obj, is_root=True)
13396
elif isinstance(obj, Dataset):
134-
errors += _validate_dataset(obj)
97+
problems += _validate_dataset(obj)
13598
else:
136-
errors.append('Unknown object type: %s' % type(obj))
99+
problems.append(f"Unknown object type: {type(obj)}")
137100

138-
return errors
101+
return problems
139102

140103

141104
def _validate_group(group, is_root=False):
142-
errors = []
105+
106+
problems = []
143107

144108
if not is_root and not _object_name_pattern.match(group.name):
145-
errors += [
146-
"Object names must only contain letters, digits and underscores (\"_\") and must start with a letter"]
109+
problems.append("Object names must only contain letters, digits, and underscores (\"_\") and must start with a letter.")
147110

148111
for child_group in group.groups:
149-
errors += _validate_dataset(child_group)
112+
problems += _validate_dataset(child_group)
150113

151114
for ds in group.datasets:
152-
errors += _validate_dataset(ds)
115+
problems += _validate_dataset(ds)
153116

154-
return errors
117+
return problems
155118

156119

157-
def _validate_dataset(ds):
120+
def _validate_dataset(ds: Dataset) -> list[str]:
158121

159122
if type(ds.data) is not np.ndarray:
160123
return ['Dataset.data must be a numpy.ndarray']
@@ -177,10 +140,8 @@ def _validate_dataset(ds):
177140
return []
178141

179142

180-
def load(filename, objectname='/', unit=None, scale_units=None):
181-
""" Load a dataset or group from an SDF file """
182-
183-
obj = None
143+
def load(filename: str | PathLike, objectname: str = '/', unit: str = None, scale_units: list[str] = None) -> Dataset | Group:
144+
""" Load a Dataset or Group from an SDF file """
184145

185146
if filename.endswith('.mat'):
186147
from . import dsres
@@ -212,7 +173,7 @@ def load(filename, objectname='/', unit=None, scale_units=None):
212173
return obj
213174

214175

215-
def save(filename, group):
176+
def save(filename: str | PathLike, group: Group):
216177
""" Save an SDF group to a file """
217178

218179
hdf5.save(filename, group)

src/sdf/dsres.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
from os import PathLike
2+
13
import numpy as np
24
from sdf import Group, Dataset
35
import scipy.io
46

7+
58
# extract strings from the matrix
69
def strMatNormal(a):
710
return [''.join(s).rstrip() for s in a]
11+
812
def strMatTrans(a):
913
return [''.join(s).rstrip() for s in zip(*a)]
1014

1115

12-
def _split_description(comment):
16+
def _split_description(comment: str) -> tuple[str | None, str | None, str | None, dict[str, str]]:
1317

1418
unit = None
1519
display_unit = None
@@ -35,7 +39,7 @@ def _split_description(comment):
3539
return unit, display_unit, comment, info
3640

3741

38-
def load(filename, objectname):
42+
def load(filename: str | PathLike, objectname: str) -> Dataset | Group:
3943

4044
g_root = _load_mat(filename)
4145

@@ -50,7 +54,7 @@ def load(filename, objectname):
5054
return obj
5155

5256

53-
def _load_mat(filename):
57+
def _load_mat(filename: str) -> Group:
5458

5559
mat = scipy.io.loadmat(filename, chars_as_strings=False)
5660

src/sdf/hdf5.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
import h5py
23
import sdf
34
import numpy as np
@@ -14,7 +15,7 @@ def _to_python_str(s):
1415
return s
1516

1617

17-
def load(filename, objectname):
18+
def load(filename: str | os.PathLike, objectname: str) -> sdf.Dataset | sdf.Group:
1819

1920
with h5py.File(filename, 'r') as f:
2021

@@ -43,7 +44,7 @@ def load(filename, objectname):
4344
raise Exception('Unexpected object')
4445

4546

46-
def save(filename, group):
47+
def save(filename: str | os.PathLike, group: sdf.Group) -> None:
4748

4849
with h5py.File(filename, 'w') as f:
4950

tests/test_sdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_3D_example(self):
201201
def test_validate_group(self):
202202
g = sdf.Group('8')
203203
errors = sdf._validate_group(g, is_root=False)
204-
self.assertEqual(["Object names must only contain letters, digits and underscores (\"_\") and must start with a letter"], errors)
204+
self.assertEqual(["Object names must only contain letters, digits, and underscores (\"_\") and must start with a letter."], errors)
205205

206206
g.name = 'G1'
207207
errors = sdf._validate_group(g, is_root=False)

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)