Skip to content

Commit eb2ec43

Browse files
authored
Adds new features to ClassList (#11)
* Adds ability to set item in ClassList using an object as well as a dict * Adds routine "_determine_class_handle" to "classList.py" alongside appropriate tests * Amends test "test_determine_class_handle" * Adds routine "set_fields" to "classlist.py"
1 parent 5355150 commit eb2ec43

File tree

5 files changed

+135
-29
lines changed

5 files changed

+135
-29
lines changed

RAT/classlist.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, init_list: Union[Sequence[object], object] = None, name_field
4242
# Set class to be used for this instance of the ClassList, checking that all elements of the input list are of
4343
# the same type and have unique values of the specified name_field
4444
if init_list:
45-
self._class_handle = type(init_list[0])
45+
self._class_handle = self._determine_class_handle(init_list)
4646
self._check_classes(init_list)
4747
self._check_unique_name_fields(init_list)
4848

@@ -61,15 +61,15 @@ def __repr__(self):
6161
output = repr(self.data)
6262
return output
6363

64-
def __setitem__(self, index: int, set_dict: dict[str, Any]) -> None:
65-
"""Assign the values of an existing object's attributes using a dictionary."""
66-
self._setitem(index, set_dict)
64+
def __setitem__(self, index: int, item: 'RAT.models') -> None:
65+
"""Replace the object at an existing index of the ClassList."""
66+
self._setitem(index, item)
6767

68-
def _setitem(self, index: int, set_dict: dict[str, Any]) -> None:
68+
def _setitem(self, index: int, item: 'RAT.models') -> None:
6969
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
70-
self._validate_name_field(set_dict)
71-
for key, value in set_dict.items():
72-
setattr(self.data[index], key, value)
70+
self._check_classes(self + [item])
71+
self._check_unique_name_fields(self + [item])
72+
self.data[index] = item
7373

7474
def __delitem__(self, index: int) -> None:
7575
"""Delete an object from the list by index."""
@@ -85,8 +85,10 @@ def __iadd__(self, other: Sequence[object]) -> 'ClassList':
8585

8686
def _iadd(self, other: Sequence[object]) -> 'ClassList':
8787
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
88+
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
89+
other = [other]
8890
if not hasattr(self, '_class_handle'):
89-
self._class_handle = type(other[0])
91+
self._class_handle = self._determine_class_handle(self + other)
9092
self._check_classes(self + other)
9193
self._check_unique_name_fields(self + other)
9294
super().__iadd__(other)
@@ -201,12 +203,20 @@ def index(self, item: Union[object, str], *args) -> int:
201203

202204
def extend(self, other: Sequence[object]) -> None:
203205
"""Extend the ClassList by adding another sequence."""
206+
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
207+
other = [other]
204208
if not hasattr(self, '_class_handle'):
205-
self._class_handle = type(other[0])
209+
self._class_handle = self._determine_class_handle(self + other)
206210
self._check_classes(self + other)
207211
self._check_unique_name_fields(self + other)
208212
self.data.extend(other)
209213

214+
def set_fields(self, index: int, **kwargs) -> None:
215+
"""Assign the values of an existing object's attributes using keyword arguments."""
216+
self._validate_name_field(kwargs)
217+
for key, value in kwargs.items():
218+
setattr(self.data[index], key, value)
219+
210220
def get_names(self) -> list[str]:
211221
"""Return a list of the values of the name_field attribute of each class object in the list.
212222
@@ -302,3 +312,28 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
302312
object with that value of the name_field attribute cannot be found.
303313
"""
304314
return next((model for model in self.data if getattr(model, self.name_field) == value), value)
315+
316+
@staticmethod
317+
def _determine_class_handle(input_list: Sequence[object]):
318+
"""When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the
319+
element which satisfies "issubclass" for all of the other elements.
320+
321+
Parameters
322+
----------
323+
input_list : Sequence [object]
324+
A list of instances to populate the ClassList.
325+
326+
Returns
327+
-------
328+
class_handle : type
329+
The type object of the element fulfilling the condition of satisfying "issubclass" for all of the other
330+
elements.
331+
"""
332+
for this_element in input_list:
333+
if all([issubclass(type(instance), type(this_element)) for instance in input_list]):
334+
class_handle = type(this_element)
335+
break
336+
else:
337+
class_handle = type(input_list[0])
338+
339+
return class_handle

RAT/project.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def model_post_init(self, __context: Any) -> None:
170170

171171
# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
172172
# model, handle errors and reset previous values if necessary.
173-
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend']
173+
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend',
174+
'set_fields']
174175
for class_list in class_lists:
175176
attribute = getattr(self, class_list)
176177
for methodName in methods_to_wrap:

tests/test_classlist.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88

99
from RAT.classlist import ClassList
10-
from tests.utils import InputAttributes
10+
from tests.utils import InputAttributes, SubInputAttributes
1111

1212

1313
@pytest.fixture
@@ -59,7 +59,21 @@ def test_input_sequence(self, input_sequence: Sequence[object]) -> None:
5959
"""
6060
class_list = ClassList(input_sequence)
6161
assert class_list.data == list(input_sequence)
62-
assert isinstance(input_sequence[-1], class_list._class_handle)
62+
for element in input_sequence:
63+
assert isinstance(element, class_list._class_handle)
64+
65+
@pytest.mark.parametrize("input_sequence", [
66+
([InputAttributes(name='Alice'), SubInputAttributes(name='Bob')]),
67+
([SubInputAttributes(name='Alice'), InputAttributes(name='Bob')]),
68+
])
69+
def test_input_sequence_subclass(self, input_sequence: Sequence[object]) -> None:
70+
"""For an input of a sequence containing objects of a class and its subclasses, the ClassList should be a list
71+
equal to the input sequence, and _class_handle should be set to the type of the parent class.
72+
"""
73+
class_list = ClassList(input_sequence)
74+
assert class_list.data == list(input_sequence)
75+
for element in input_sequence:
76+
assert isinstance(element, class_list._class_handle)
6377

6478
@pytest.mark.parametrize("empty_input", [([]), (())])
6579
def test_empty_input(self, empty_input: Sequence[object]) -> None:
@@ -119,26 +133,33 @@ def test_repr_empty_classlist() -> None:
119133
assert repr(ClassList()) == repr([])
120134

121135

122-
@pytest.mark.parametrize(["new_values", "expected_classlist"], [
123-
({'name': 'Eve'}, ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])),
124-
({'name': 'John', 'surname': 'Luther'},
136+
@pytest.mark.parametrize(["new_item", "expected_classlist"], [
137+
(InputAttributes(name='Eve'), ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])),
138+
(InputAttributes(name='John', surname='Luther'),
125139
ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])),
126140
])
127-
def test_setitem(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList') -> None:
128-
"""We should be able to set values in an element of a ClassList using a dictionary."""
141+
def test_setitem(two_name_class_list: ClassList, new_item: InputAttributes, expected_classlist: ClassList) -> None:
142+
"""We should be able to set values in an element of a ClassList using a new object."""
129143
class_list = two_name_class_list
130-
class_list[0] = new_values
144+
class_list[0] = new_item
131145
assert class_list == expected_classlist
132146

133147

148+
@pytest.mark.parametrize("new_item", [
149+
(InputAttributes(name='Bob')),
150+
])
151+
def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_item: InputAttributes) -> None:
152+
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
153+
with pytest.raises(ValueError, match="Input list contains objects with the same value of the name attribute"):
154+
two_name_class_list[0] = new_item
155+
156+
134157
@pytest.mark.parametrize("new_values", [
135-
({'name': 'Bob'}),
158+
'Bob',
136159
])
137-
def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
160+
def test_setitem_different_classes(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
138161
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
139-
with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} "
140-
f"'{new_values[two_name_class_list.name_field]}', "
141-
f"which is already specified in the ClassList"):
162+
with pytest.raises(ValueError, match=f"Input list contains elements of type other than 'InputAttributes'"):
142163
two_name_class_list[0] = new_values
143164

144165

@@ -160,9 +181,11 @@ def test_delitem_not_present(two_name_class_list: 'ClassList') -> None:
160181
(ClassList(InputAttributes(name='Eve'))),
161182
([InputAttributes(name='Eve')]),
162183
(InputAttributes(name='Eve'),),
184+
(InputAttributes(name='Eve')),
163185
])
164186
def test_iadd(two_name_class_list: 'ClassList', added_list: Iterable, three_name_class_list: 'ClassList') -> None:
165-
"""We should be able to use the "+=" operator to add iterables to a ClassList."""
187+
"""We should be able to use the "+=" operator to add iterables to a ClassList. Individual objects should be wrapped
188+
in a list before being added."""
166189
class_list = two_name_class_list
167190
class_list += added_list
168191
assert class_list == three_name_class_list
@@ -439,9 +462,11 @@ def test_index_not_present(two_name_class_list: 'ClassList', index_value: Union[
439462
(ClassList(InputAttributes(name='Eve'))),
440463
([InputAttributes(name='Eve')]),
441464
(InputAttributes(name='Eve'),),
465+
(InputAttributes(name='Eve')),
442466
])
443467
def test_extend(two_name_class_list: 'ClassList', extended_list: Sequence, three_name_class_list: 'ClassList') -> None:
444-
"""We should be able to extend a ClassList using another ClassList or a sequence"""
468+
"""We should be able to extend a ClassList using another ClassList or a sequence. Individual objects should be
469+
wrapped in a list before being added."""
445470
class_list = two_name_class_list
446471
class_list.extend(extended_list)
447472
assert class_list == three_name_class_list
@@ -460,6 +485,30 @@ def test_extend_empty_classlist(extended_list: Sequence, one_name_class_list: 'C
460485
assert isinstance(extended_list[-1], class_list._class_handle)
461486

462487

488+
@pytest.mark.parametrize(["new_values", "expected_classlist"], [
489+
({'name': 'Eve'}, ClassList([InputAttributes(name='Eve'), InputAttributes(name='Bob')])),
490+
({'name': 'John', 'surname': 'Luther'},
491+
ClassList([InputAttributes(name='John', surname='Luther'), InputAttributes(name='Bob')])),
492+
])
493+
def test_set_fields(two_name_class_list: 'ClassList', new_values: dict[str, Any], expected_classlist: 'ClassList')\
494+
-> None:
495+
"""We should be able to set field values in an element of a ClassList using keyword arguments."""
496+
class_list = two_name_class_list
497+
class_list.set_fields(0, **new_values)
498+
assert class_list == expected_classlist
499+
500+
501+
@pytest.mark.parametrize("new_values", [
502+
({'name': 'Bob'}),
503+
])
504+
def test_set_fields_same_name_field(two_name_class_list: 'ClassList', new_values: dict[str, Any]) -> None:
505+
"""If we set the name_field of an object in the ClassList to one already defined, we should raise a ValueError."""
506+
with pytest.raises(ValueError, match=f"Input arguments contain the {two_name_class_list.name_field} "
507+
f"'{new_values[two_name_class_list.name_field]}', "
508+
f"which is already specified in the ClassList"):
509+
two_name_class_list.set_fields(0, **new_values)
510+
511+
463512
@pytest.mark.parametrize(["class_list", "expected_names"], [
464513
(ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')]), ["Alice", "Bob"]),
465514
(ClassList([InputAttributes(id='Alice'), InputAttributes(id='Bob')], name_field='id'), ["Alice", "Bob"]),
@@ -563,3 +612,19 @@ def test__get_item_from_name_field(two_name_class_list: 'ClassList',
563612
If the value is not the name_field of an object defined in the ClassList, we should return the value.
564613
"""
565614
assert two_name_class_list._get_item_from_name_field(value) == expected_output
615+
616+
617+
@pytest.mark.parametrize(["input_list", "expected_type"], [
618+
([InputAttributes(name='Alice')], InputAttributes),
619+
([InputAttributes(name='Alice'), SubInputAttributes(name='Bob')], InputAttributes),
620+
([SubInputAttributes(name='Alice'), InputAttributes(name='Bob')], InputAttributes),
621+
([SubInputAttributes(name='Alice'), SubInputAttributes(name='Bob')], SubInputAttributes),
622+
([SubInputAttributes(name='Alice'), SubInputAttributes(name='Bob'), InputAttributes(name='Eve')], InputAttributes),
623+
([InputAttributes(name='Alice'), dict(name='Bob')], InputAttributes),
624+
([dict(name='Alice'), InputAttributes(name='Bob')], dict),
625+
])
626+
def test_determine_class_handle(input_list: 'ClassList', expected_type: type) -> None:
627+
"""The _class_handle for the ClassList should be the type that satisfies the condition "isinstance(element, type)"
628+
for all elements in the ClassList.
629+
"""
630+
assert ClassList._determine_class_handle(input_list) == expected_type

tests/test_project.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
def test_project():
1818
"""Add parameters to the default project, so each ClassList can be tested properly."""
1919
test_project = RAT.project.Project()
20-
test_project.data[0] = {'data': np.array([[1, 1, 1]])}
20+
test_project.data.set_fields(0, data=np.array([[1, 1, 1]]))
2121
test_project.parameters.append(name='Test SLD')
2222
test_project.custom_files.append(name='Test Custom File')
2323
test_project.layers.append(name='Test Layer', SLD='Test SLD')
@@ -161,7 +161,7 @@ def test_rename_models(test_project, model: str, field: str) -> None:
161161
"""When renaming a model in the project, the new name should be recorded when that model is referred to elsewhere
162162
in the project.
163163
"""
164-
getattr(test_project, model)[-1] = {'name': 'New Name'}
164+
getattr(test_project, model).set_fields(-1, name='New Name')
165165
attribute = RAT.project.model_names_used_in[model].attribute
166166
assert getattr(getattr(test_project, attribute)[-1], field) == 'New Name'
167167

@@ -307,7 +307,7 @@ def test_wrap_set(test_project, class_list: str, field: str) -> None:
307307
orig_class_list = copy.deepcopy(test_attribute)
308308

309309
with contextlib.redirect_stdout(io.StringIO()) as print_str:
310-
test_attribute[0] = {field: 'undefined'}
310+
test_attribute.set_fields(0, **{field: 'undefined'})
311311
assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "undefined" in '
312312
f'the "{field}" field of "{class_list}" must be defined in '
313313
f'"{RAT.project.values_defined_in[class_list+"."+field]}".\033[0m\n')

tests/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@ def __eq__(self, other: Any):
1111
if isinstance(other, InputAttributes):
1212
return self.__dict__ == other.__dict__
1313
return False
14+
15+
16+
class SubInputAttributes(InputAttributes):
17+
"""Trivial subclass of InputAttributes"""
18+
pass

0 commit comments

Comments
 (0)