Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ exclude =

# Allow assigning lambdas in tests
per-file-ignores =
tests/*:E731
tests/*:E731,F841
10 changes: 6 additions & 4 deletions src/spellbind/float_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,22 @@ def _get_float(value: float | Value[int] | Value[float]) -> float:
class CombinedFloatValues(DerivedValueBase[_U], Generic[_U], ABC):
def __init__(self, *values: float | Value[int] | Value[float]):
super().__init__(*[v for v in values if isinstance(v, Value)])
self.gotten_values = [_get_float(v) for v in values]
self._gotten_values = [_get_float(v) for v in values]
self._callbacks: list[Callable] = []
for i, v in enumerate(values):
if isinstance(v, Value):
v.observe(self._create_on_n_changed(i))
v.weak_observe(self._create_on_n_changed(i))
self._value = self._calculate_value()

def _create_on_n_changed(self, index: int) -> Callable[[float], None]:
def on_change(new_value: float) -> None:
self.gotten_values[index] = new_value
self._gotten_values[index] = new_value
self._on_result_change(self._calculate_value())
self._callbacks.append(on_change) # keep strong reference to callback so it won't be garbage collected
return on_change

def _calculate_value(self) -> _U:
return self.transform(self.gotten_values)
return self.transform(self._gotten_values)

def _on_result_change(self, new_value: _U) -> None:
if new_value != self._value:
Expand Down
10 changes: 6 additions & 4 deletions src/spellbind/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class DerivedValue(DerivedValueBase[_T], Generic[_S, _T], ABC):
def __init__(self, of: Value[_S]):
super().__init__(of)
self._value = self.transform(of.value)
of.observe(self._on_source_change)
of.weak_observe(self._on_source_change)

@abstractmethod
def transform(self, value: _S) -> _T:
Expand Down Expand Up @@ -257,9 +257,9 @@ def __init__(self, left: Value[_S] | _S, right: Value[_T] | _T):
self._left_getter = _create_value_getter(left)
self._right_getter = _create_value_getter(right)
if isinstance(left, Value):
left.observe(self._on_left_change)
left.weak_observe(self._on_left_change)
if isinstance(right, Value):
right.observe(self._on_right_change)
right.weak_observe(self._on_right_change)
self._value = self.transform(self._left_getter(), self._right_getter())

def _on_left_change(self, new_left_value: _S) -> None:
Expand Down Expand Up @@ -295,15 +295,17 @@ class CombinedMixedValues(DerivedValueBase[_T], Generic[_S, _T], ABC):
def __init__(self, *sources: Value[_S] | _S):
super().__init__(*[v for v in sources if isinstance(v, Value)])
self.gotten_values = [_get_value(v) for v in sources]
self._callbacks: list[Callable] = []
for i, v in enumerate(sources):
if isinstance(v, Value):
v.observe(self._create_on_n_changed(i))
v.weak_observe(self._create_on_n_changed(i))
self._value = self._calculate_value()

def _create_on_n_changed(self, index: int) -> Callable[[_S], None]:
def on_change(new_value: _S) -> None:
self.gotten_values[index] = new_value
self._on_result_change(self._calculate_value())
self._callbacks.append(on_change) # keep strong reference to callback so it won't be garbage collected
return on_change

def _calculate_value(self) -> _T:
Expand Down
29 changes: 28 additions & 1 deletion tests/test_values/test_float_values/test_float_values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from spellbind.float_values import FloatConstant, MaxFloatValues, MinFloatValues
import gc

from spellbind.float_values import FloatConstant, MaxFloatValues, MinFloatValues, FloatVariable
from spellbind.values import SimpleVariable


Expand Down Expand Up @@ -49,3 +51,28 @@ def test_min_float_values_with_literals():

a.value = 5.1
assert min_val.value == 5.1


def test_add_float_values_keeps_reference():
v0 = FloatVariable(1.5)
v1 = FloatVariable(2.5)
v2 = v0 + v1
assert len(v0._on_change._subscriptions) == 1
gc.collect()

v0.value = 3.5
assert len(v0._on_change._subscriptions) == 1


def test_add_int_values_garbage_collected():
v0 = FloatVariable(1.5)
v1 = FloatVariable(2.5)
v2 = v0 + v1
assert len(v0._on_change._subscriptions) == 1
assert len(v1._on_change._subscriptions) == 1
v2 = None
gc.collect()
v0.value = 3.5 # trigger removal of weak references
v1.value = 4.5 # trigger removal of weak references
assert len(v0._on_change._subscriptions) == 0
assert len(v1._on_change._subscriptions) == 0
28 changes: 28 additions & 0 deletions tests/test_values/test_int_values/test_add_int_values.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc

from spellbind.float_values import FloatVariable
from spellbind.int_values import IntVariable

Expand Down Expand Up @@ -56,3 +58,29 @@ def test_add_float_plus_int_value():

v1.value = 4
assert v2.value == 7.5


def test_add_int_values_keeps_reference():
v0 = IntVariable(1)
v1 = IntVariable(2)
v2 = v0 + v1
assert len(v0._on_change._subscriptions) == 1
gc.collect()

v0.value = 3
v1.value = 4
assert len(v0._on_change._subscriptions) == 1


def test_add_int_values_garbage_collected():
v0 = IntVariable(1)
v1 = IntVariable(2)
v2 = v0 + v1
assert len(v0._on_change._subscriptions) == 1
assert len(v1._on_change._subscriptions) == 1
v2 = None
gc.collect()
v0.value = 3 # trigger removal of weak references
v1.value = 4 # trigger removal of weak references
assert len(v0._on_change._subscriptions) == 0
assert len(v1._on_change._subscriptions) == 0