Skip to content

Commit a7842e6

Browse files
authored
Introduces ruff as linter and formatter (#38)
* Adds ruff and resolves linting errors on standard rule set * Resolves automatically fixable linting errors on advanced rule set * Resolves most manually fixable linting errors on advanced rule set * Switches code to use double quotes * Resolves automatically fixable linting errors on full rule set * Applies ruff formatter * Finalises rule selection and tidies up code * Adds "requirements-dev.txt" * Adds "requirements-dev.txt" * Adds new github action for linter and formatter * Addresses review comments
1 parent a14835e commit a7842e6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+9475
-4484
lines changed

.github/workflows/run_ruff.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name: Ruff
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
ruff:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
- uses: chartboost/ruff-action@v1
15+
- uses: chartboost/ruff-action@v1
16+
with:
17+
args: 'format --check'
18+

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ docs/*.inv
2828
build/*
2929
dist/*
3030
*.whl
31+
32+
# Local pre-commit hooks
33+
.pre-commit-config.yaml

RAT/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import os
2+
3+
from RAT import models
24
from RAT.classlist import ClassList
3-
from RAT.project import Project
45
from RAT.controls import set_controls
6+
from RAT.project import Project
57
from RAT.run import run
6-
import RAT.models
8+
9+
__all__ = ["ClassList", "Project", "run", "set_controls", "models"]
710

811
dir_path = os.path.dirname(os.path.realpath(__file__))
9-
os.environ["RAT_PATH"] = os.path.join(dir_path, '')
12+
os.environ["RAT_PATH"] = os.path.join(dir_path, "")

RAT/classlist.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
"""The classlist module. Contains the ClassList class, which defines a list containing instances of a particular class.
1+
"""The classlist module. Contains the ClassList class, which defines a list containing instances of a particular
2+
class.
23
"""
34

45
import collections
5-
from collections.abc import Iterable, Sequence
66
import contextlib
7-
import prettytable
8-
from typing import Any, Union
97
import warnings
8+
from collections.abc import Iterable, Sequence
9+
from typing import Any, Union
10+
11+
import prettytable
1012

1113

1214
class ClassList(collections.UserList):
@@ -31,7 +33,9 @@ class ClassList(collections.UserList):
3133
An instance, or list of instance(s), of the class to be used in this ClassList.
3234
name_field : str, optional
3335
The field used to define unique objects in the ClassList (default is "name").
36+
3437
"""
38+
3539
def __init__(self, init_list: Union[Sequence[object], object] = None, name_field: str = "name") -> None:
3640
self.name_field = name_field
3741

@@ -56,7 +60,7 @@ def __repr__(self):
5660
else:
5761
if any(model.__dict__ for model in self.data):
5862
table = prettytable.PrettyTable()
59-
table.field_names = ['index'] + [key.replace('_', ' ') for key in self.data[0].__dict__.keys()]
63+
table.field_names = ["index"] + [key.replace("_", " ") for key in self.data[0].__dict__]
6064
table.add_rows([[index] + list(model.__dict__.values()) for index, model in enumerate(self.data)])
6165
output = table.get_string()
6266
else:
@@ -81,15 +85,15 @@ def _delitem(self, index: int) -> None:
8185
"""Auxiliary routine of "__delitem__" used to enable wrapping."""
8286
del self.data[index]
8387

84-
def __iadd__(self, other: Sequence[object]) -> 'ClassList':
88+
def __iadd__(self, other: Sequence[object]) -> "ClassList":
8589
"""Define in-place addition using the "+=" operator."""
8690
return self._iadd(other)
8791

88-
def _iadd(self, other: Sequence[object]) -> 'ClassList':
92+
def _iadd(self, other: Sequence[object]) -> "ClassList":
8993
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
9094
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
9195
other = [other]
92-
if not hasattr(self, '_class_handle'):
96+
if not hasattr(self, "_class_handle"):
9397
self._class_handle = self._determine_class_handle(self + other)
9498
self._check_classes(self + other)
9599
self._check_unique_name_fields(self + other)
@@ -129,20 +133,27 @@ def append(self, obj: object = None, **kwargs) -> None:
129133
SyntaxWarning
130134
Raised if the input arguments contain BOTH an object and keyword arguments. In this situation the object is
131135
appended to the ClassList and the keyword arguments are discarded.
136+
132137
"""
133138
if obj and kwargs:
134-
warnings.warn('ClassList.append() called with both an object and keyword arguments. '
135-
'The keyword arguments will be ignored.', SyntaxWarning)
139+
warnings.warn(
140+
"ClassList.append() called with both an object and keyword arguments. "
141+
"The keyword arguments will be ignored.",
142+
SyntaxWarning,
143+
stacklevel=2,
144+
)
136145
if obj:
137-
if not hasattr(self, '_class_handle'):
146+
if not hasattr(self, "_class_handle"):
138147
self._class_handle = type(obj)
139148
self._check_classes(self + [obj])
140149
self._check_unique_name_fields(self + [obj])
141150
self.data.append(obj)
142151
else:
143-
if not hasattr(self, '_class_handle'):
144-
raise TypeError('ClassList.append() called with keyword arguments for a ClassList without a class '
145-
'defined. Call ClassList.append() with an object to define the class.')
152+
if not hasattr(self, "_class_handle"):
153+
raise TypeError(
154+
"ClassList.append() called with keyword arguments for a ClassList without a class "
155+
"defined. Call ClassList.append() with an object to define the class.",
156+
)
146157
self._validate_name_field(kwargs)
147158
self.data.append(self._class_handle(**kwargs))
148159

@@ -169,20 +180,27 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
169180
SyntaxWarning
170181
Raised if the input arguments contain both an object and keyword arguments. In this situation the object is
171182
inserted into the ClassList and the keyword arguments are discarded.
183+
172184
"""
173185
if obj and kwargs:
174-
warnings.warn('ClassList.insert() called with both an object and keyword arguments. '
175-
'The keyword arguments will be ignored.', SyntaxWarning)
186+
warnings.warn(
187+
"ClassList.insert() called with both an object and keyword arguments. "
188+
"The keyword arguments will be ignored.",
189+
SyntaxWarning,
190+
stacklevel=2,
191+
)
176192
if obj:
177-
if not hasattr(self, '_class_handle'):
193+
if not hasattr(self, "_class_handle"):
178194
self._class_handle = type(obj)
179195
self._check_classes(self + [obj])
180196
self._check_unique_name_fields(self + [obj])
181197
self.data.insert(index, obj)
182198
else:
183-
if not hasattr(self, '_class_handle'):
184-
raise TypeError('ClassList.insert() called with keyword arguments for a ClassList without a class '
185-
'defined. Call ClassList.insert() with an object to define the class.')
199+
if not hasattr(self, "_class_handle"):
200+
raise TypeError(
201+
"ClassList.insert() called with keyword arguments for a ClassList without a class "
202+
"defined. Call ClassList.insert() with an object to define the class.",
203+
)
186204
self._validate_name_field(kwargs)
187205
self.data.insert(index, self._class_handle(**kwargs))
188206

@@ -209,7 +227,7 @@ def extend(self, other: Sequence[object]) -> None:
209227
"""Extend the ClassList by adding another sequence."""
210228
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
211229
other = [other]
212-
if not hasattr(self, '_class_handle'):
230+
if not hasattr(self, "_class_handle"):
213231
self._class_handle = self._determine_class_handle(self + other)
214232
self._check_classes(self + other)
215233
self._check_unique_name_fields(self + other)
@@ -229,6 +247,7 @@ def get_names(self) -> list[str]:
229247
-------
230248
names : list [str]
231249
The value of the name_field attribute of each object in the ClassList.
250+
232251
"""
233252
return [getattr(model, self.name_field) for model in self.data if hasattr(model, self.name_field)]
234253

@@ -244,9 +263,14 @@ def get_all_matches(self, value: Any) -> list[tuple]:
244263
-------
245264
: list [tuple]
246265
A list of (index, field) tuples matching the given value.
266+
247267
"""
248-
return [(index, field) for index, element in enumerate(self.data) for field in vars(element)
249-
if getattr(element, field) == value]
268+
return [
269+
(index, field)
270+
for index, element in enumerate(self.data)
271+
for field in vars(element)
272+
if getattr(element, field) == value
273+
]
250274

251275
def _validate_name_field(self, input_args: dict[str, Any]) -> None:
252276
"""Raise a ValueError if the name_field attribute is passed as an object parameter, and its value is already
@@ -261,12 +285,15 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
261285
------
262286
ValueError
263287
Raised if the input arguments contain a name_field value already defined in the ClassList.
288+
264289
"""
265290
names = self.get_names()
266291
with contextlib.suppress(KeyError):
267292
if input_args[self.name_field] in names:
268-
raise ValueError(f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
269-
f"which is already specified in the ClassList")
293+
raise ValueError(
294+
f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
295+
f"which is already specified in the ClassList",
296+
)
270297

271298
def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
272299
"""Raise a ValueError if any value of the name_field attribute is used more than once in a list of class
@@ -281,6 +308,7 @@ def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
281308
------
282309
ValueError
283310
Raised if the input list defines more than one object with the same value of name_field.
311+
284312
"""
285313
names = [getattr(model, self.name_field) for model in input_list if hasattr(model, self.name_field)]
286314
if len(set(names)) != len(names):
@@ -298,6 +326,7 @@ def _check_classes(self, input_list: Iterable[object]) -> None:
298326
------
299327
ValueError
300328
Raised if the input list defines objects of different types.
329+
301330
"""
302331
if not (all(isinstance(element, self._class_handle) for element in input_list)):
303332
raise ValueError(f"Input list contains elements of type other than '{self._class_handle.__name__}'")
@@ -315,6 +344,7 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
315344
instance : object or str
316345
Either the object with the value of the name_field attribute given by value, or the input value if an
317346
object with that value of the name_field attribute cannot be found.
347+
318348
"""
319349
return next((model for model in self.data if getattr(model, self.name_field) == value), value)
320350

@@ -333,6 +363,7 @@ def _determine_class_handle(input_list: Sequence[object]):
333363
class_handle : type
334364
The type object of the element fulfilling the condition of satisfying "issubclass" for all of the other
335365
elements.
366+
336367
"""
337368
for this_element in input_list:
338369
if all([issubclass(type(instance), type(this_element)) for instance in input_list]):

RAT/controls.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from dataclasses import dataclass, field
2-
import prettytable
3-
from pydantic import BaseModel, Field, field_validator, ValidationError
42
from typing import Literal, Union
53

6-
from RAT.utils.enums import Parallel, Procedures, Display, BoundHandling, Strategies
4+
import prettytable
5+
from pydantic import BaseModel, Field, ValidationError, field_validator
6+
77
from RAT.utils.custom_errors import custom_pydantic_validation_error
8+
from RAT.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies
89

910

1011
@dataclass(frozen=True)
1112
class Controls:
1213
"""The full set of controls parameters required for the compiled RAT code."""
14+
1315
# All Procedures
1416
procedure: Procedures = Procedures.Calculate
1517
parallel: Parallel = Parallel.Single
@@ -44,8 +46,9 @@ class Controls:
4446
adaptPCR: bool = False
4547

4648

47-
class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
49+
class Calculate(BaseModel, validate_assignment=True, extra="forbid"):
4850
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
51+
4952
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
5053
parallel: Parallel = Parallel.Single
5154
calcSldDuringFit: bool = False
@@ -56,20 +59,21 @@ class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
5659
@classmethod
5760
def check_resample_params(cls, resampleParams):
5861
if not 0 < resampleParams[0] < 1:
59-
raise ValueError('resampleParams[0] must be between 0 and 1')
62+
raise ValueError("resampleParams[0] must be between 0 and 1")
6063
if resampleParams[1] < 0:
61-
raise ValueError('resampleParams[1] must be greater than or equal to 0')
64+
raise ValueError("resampleParams[1] must be greater than or equal to 0")
6265
return resampleParams
6366

6467
def __repr__(self) -> str:
6568
table = prettytable.PrettyTable()
66-
table.field_names = ['Property', 'Value']
69+
table.field_names = ["Property", "Value"]
6770
table.add_rows([[k, v] for k, v in self.__dict__.items()])
6871
return table.get_string()
6972

7073

7174
class Simplex(Calculate):
7275
"""Defines the additional fields for the simplex procedure."""
76+
7377
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
7478
xTolerance: float = Field(1.0e-6, gt=0.0)
7579
funcTolerance: float = Field(1.0e-6, gt=0.0)
@@ -81,6 +85,7 @@ class Simplex(Calculate):
8185

8286
class DE(Calculate):
8387
"""Defines the additional fields for the Differential Evolution procedure."""
88+
8489
procedure: Literal[Procedures.DE] = Procedures.DE
8590
populationSize: int = Field(20, ge=1)
8691
fWeight: float = 0.5
@@ -92,6 +97,7 @@ class DE(Calculate):
9297

9398
class NS(Calculate):
9499
"""Defines the additional fields for the Nested Sampler procedure."""
100+
95101
procedure: Literal[Procedures.NS] = Procedures.NS
96102
nLive: int = Field(150, ge=1)
97103
nMCMC: float = Field(0.0, ge=0.0)
@@ -101,6 +107,7 @@ class NS(Calculate):
101107

102108
class Dream(Calculate):
103109
"""Defines the additional fields for the Dream procedure."""
110+
104111
procedure: Literal[Procedures.Dream] = Procedures.Dream
105112
nSamples: int = Field(50000, ge=0)
106113
nChains: int = Field(10, gt=0)
@@ -110,28 +117,31 @@ class Dream(Calculate):
110117
adaptPCR: bool = False
111118

112119

113-
def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
114-
-> Union[Calculate, Simplex, DE, NS, Dream]:
120+
def set_controls(
121+
procedure: Procedures = Procedures.Calculate,
122+
**properties,
123+
) -> Union[Calculate, Simplex, DE, NS, Dream]:
115124
"""Returns the appropriate controls model given the specified procedure."""
116125
controls = {
117126
Procedures.Calculate: Calculate,
118127
Procedures.Simplex: Simplex,
119128
Procedures.DE: DE,
120129
Procedures.NS: NS,
121-
Procedures.Dream: Dream
130+
Procedures.Dream: Dream,
122131
}
123132

124133
try:
125134
model = controls[procedure](**properties)
126135
except KeyError:
127136
members = list(Procedures.__members__.values())
128137
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
129-
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
138+
raise ValueError(f"The controls procedure must be one of: {allowed_values}") from None
130139
except ValidationError as exc:
131-
custom_error_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure}'
132-
f' controls procedure are:\n '
133-
f'{", ".join(controls[procedure].model_fields.keys())}\n'
134-
}
140+
custom_error_msgs = {
141+
"extra_forbidden": f'Extra inputs are not permitted. The fields for the {procedure}'
142+
f' controls procedure are:\n '
143+
f'{", ".join(controls[procedure].model_fields.keys())}\n',
144+
}
135145
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
136146
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
137147

0 commit comments

Comments
 (0)