Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/spellbind/bool_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import TypeVar, Generic, Callable

from spellbind.values import Value, DerivedValue, Constant
from spellbind.values import Value, DerivedValue, Constant, SimpleVariable

_S = TypeVar('_S')

Expand Down Expand Up @@ -34,5 +34,9 @@ class BoolConstant(BoolValue, Constant[bool]):
pass


class BoolVariable(SimpleVariable[bool], BoolValue):
pass


TRUE = BoolConstant(True)
FALSE = BoolConstant(False)
26 changes: 15 additions & 11 deletions src/spellbind/float_values.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
from typing_extensions import Self

import operator
from abc import ABC, abstractmethod
from typing import Generic, Callable, Sequence, TypeVar, overload

from typing_extensions import Self
from typing_extensions import TYPE_CHECKING

from spellbind.bool_values import BoolValue
Expand Down Expand Up @@ -112,6 +112,10 @@ class FloatConstant(FloatValue, Constant[float]):
pass


class FloatVariable(SimpleVariable[float], FloatValue):
pass


def _create_float_getter(value: float | Value[int] | Value[float]) -> Callable[[], float]:
if isinstance(value, Value):
return lambda: value.value
Expand Down Expand Up @@ -158,6 +162,16 @@ def value(self) -> _U:
return self._value


class MaxFloatValues(CombinedFloatValues[float], FloatValue):
def transform(self, values: Sequence[float]) -> float:
return max(values)


class MinFloatValues(CombinedFloatValues[float], FloatValue):
def transform(self, values: Sequence[float]) -> float:
return min(values)


class CombinedTwoFloatValues(CombinedFloatValues[_U], Generic[_U], ABC):
def __init__(self, left: FloatLike, right: FloatLike):
super().__init__(left, right)
Expand All @@ -171,9 +185,6 @@ def transform_two(self, left: float, right: float) -> _U:


class AddFloatValues(CombinedFloatValues[float], FloatValue):
def __init__(self, *values: FloatLike):
super().__init__(*values)

def transform(self, values: Sequence[float]) -> float:
return sum(values)

Expand All @@ -184,9 +195,6 @@ def transform_two(self, left: float, right: float) -> float:


class MultiplyFloatValues(CombinedFloatValues[float], FloatValue):
def __init__(self, *values: FloatLike):
super().__init__(*values)

def transform(self, values: Sequence[float]) -> float:
result = 1.0
for value in values:
Expand All @@ -199,10 +207,6 @@ def transform_two(self, left: float, right: float) -> float:
return left / right


class FloatVariable(SimpleVariable[float], FloatValue):
pass


class RoundFloatValue(CombinedTwoValues[float, int, float], FloatValue):
def __init__(self, value: FloatValue, ndigits: IntLike):
super().__init__(value, ndigits)
Expand Down
10 changes: 10 additions & 0 deletions src/spellbind/int_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ class IntVariable(SimpleVariable[int], IntValue):
pass


class MaxIntValues(CombinedMixedValues[int, int], IntValue):
def transform(self, *values: int) -> int:
return max(values)


class MinIntValues(CombinedMixedValues[int, int], IntValue):
def transform(self, *values: int) -> int:
return min(values)


class AddIntValues(CombinedMixedValues[int, int], IntValue):
def transform(self, *values: int) -> int:
return sum(values)
Expand Down
47 changes: 46 additions & 1 deletion tests/test_float_values.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,51 @@
from spellbind.float_values import FloatConstant
from spellbind.float_values import FloatConstant, MaxFloatValues, MinFloatValues
from spellbind.values import SimpleVariable


def test_float_constant_str():
const = FloatConstant(3.14)
assert str(const) == "3.14"


def test_max_float_values():
a = SimpleVariable(10.5)
b = SimpleVariable(20.3)
c = SimpleVariable(5.7)

max_val = MaxFloatValues(a, b, c)
assert max_val.value == 20.3

a.value = 30.1
assert max_val.value == 30.1


def test_max_float_values_with_literals():
a = SimpleVariable(10.5)

max_val = MaxFloatValues(a, 25.7, 15.2)
assert max_val.value == 25.7

a.value = 30.1
assert max_val.value == 30.1


def test_min_float_values():
a = SimpleVariable(10.5)
b = SimpleVariable(20.3)
c = SimpleVariable(5.7)

min_val = MinFloatValues(a, b, c)
assert min_val.value == 5.7

c.value = 2.1
assert min_val.value == 2.1


def test_min_float_values_with_literals():
a = SimpleVariable(10.5)

min_val = MinFloatValues(a, 25.7, 15.2)
assert min_val.value == 10.5

a.value = 5.1
assert min_val.value == 5.1
47 changes: 46 additions & 1 deletion tests/test_int_values.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,51 @@
from spellbind.int_values import IntConstant
from spellbind.int_values import IntConstant, MaxIntValues, MinIntValues
from spellbind.values import SimpleVariable


def test_int_constant_str():
const = IntConstant(42)
assert str(const) == "42"


def test_max_int_values():
a = SimpleVariable(10)
b = SimpleVariable(20)
c = SimpleVariable(5)

max_val = MaxIntValues(a, b, c)
assert max_val.value == 20

a.value = 30
assert max_val.value == 30


def test_max_int_values_with_literals():
a = SimpleVariable(10)

max_val = MaxIntValues(a, 25, 15)
assert max_val.value == 25

a.value = 30
assert max_val.value == 30


def test_min_int_values():
a = SimpleVariable(10)
b = SimpleVariable(20)
c = SimpleVariable(5)

min_val = MinIntValues(a, b, c)
assert min_val.value == 5

c.value = 2
assert min_val.value == 2


def test_min_int_values_with_literals():
a = SimpleVariable(10)

min_val = MinIntValues(a, 25, 15)
assert min_val.value == 10

a.value = 5
assert min_val.value == 5