Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 49 additions & 87 deletions src/pybind/event.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,82 @@
import inspect
from inspect import Parameter
from abc import ABC, abstractmethod
from typing import Callable, TypeVar, Generic

from pybind.emitters import Emitter, TriEmitter, BiEmitter, ValueEmitter
from pybind.functions import assert_parameter_max_count
from pybind.observables import Observable, ValueObservable, BiObservable, TriObservable, Observer, \
ValueObserver, BiObserver, TriObserver
ValueObserver, BiObserver, TriObserver, Subscription, DeadReferenceError, WeakSubscription, StrongSubscription

_S = TypeVar("_S")
_T = TypeVar("_T")
_U = TypeVar("_U")
_O = TypeVar('_O', bound=Callable)


def _is_positional_parameter(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)


def count_total_parameters(function: Callable) -> int:
parameters = inspect.signature(function).parameters
return sum(1 for parameter in parameters.values() if _is_positional_parameter(parameter))


def trim_and_call(observer: Callable, *parameters):
parameter_count = count_total_parameters(observer)
trimmed_parameters = parameters[:parameter_count]
observer(*trimmed_parameters)


def _is_required_positional_parameter(param: Parameter) -> bool:
return param.default == param.empty and _is_positional_parameter(param)
class BaseEvent(Generic[_O], ABC):
_subscriptions: list[Subscription[_O]]

def __init__(self):
self._subscriptions = []

def count_non_default_parameters(function: Callable) -> int:
parameters = inspect.signature(function).parameters
return sum(1 for param in parameters.values() if _is_required_positional_parameter(param))
@abstractmethod
def _get_parameter_count(self) -> int:
raise NotImplementedError

def observe(self, observer: _O) -> None:
assert_parameter_max_count(observer, self._get_parameter_count())
self._subscriptions.append(StrongSubscription(observer))

def assert_parameter_max_count(callable_: Callable, max_count: int) -> None:
if count_non_default_parameters(callable_) > max_count:
if hasattr(callable_, '__name__'):
callable_name = callable_.__name__
elif hasattr(callable_, '__class__'):
callable_name = callable_.__class__.__name__
else:
callable_name = str(callable_)
raise ValueError(f"Callable {callable_name} has too many non-default parameters: "
f"{count_non_default_parameters(callable_)} > {max_count}")
def weak_observe(self, observer: _O) -> None:
assert_parameter_max_count(observer, self._get_parameter_count())
self._subscriptions.append(WeakSubscription(observer))

def unobserve(self, observer: _O) -> None:
for i, sub in enumerate(self._subscriptions):
if sub.matches_observer(observer):
del self._subscriptions[i]
return
raise ValueError(f"Observer {observer} is not subscribed to this event.")

class Event(Observable, Emitter):
_observers: list[Observer]
def is_observed(self, observer: _O) -> bool:
return any(sub.matches_observer(observer) for sub in self._subscriptions)

def __init__(self):
self._observers = []
def _emit(self, *args) -> None:
i = 0
while i < len(self._subscriptions):
try:
self._subscriptions[i](*args)
i += 1
except DeadReferenceError:
del self._subscriptions[i]

def observe(self, observer: Observer) -> None:
self._observers.append(observer)
assert_parameter_max_count(observer, 0)

def unobserve(self, observer: Observer) -> None:
self._observers.remove(observer)
class Event(BaseEvent[Observer], Observable, Emitter):
def _get_parameter_count(self) -> int:
return 0

def __call__(self) -> None:
for observer in self._observers:
observer()


class ValueEvent(Generic[_S], ValueObservable[_S], ValueEmitter[_S]):
_observers: list[Observer | ValueObserver[_S]]
self._emit()

def __init__(self):
self._observers = []
super().__init__()

def observe(self, observer: Observer | ValueObserver[_S]) -> None:
self._observers.append(observer)
assert_parameter_max_count(observer, 1)

def unobserve(self, observer: Observer | ValueObserver[_S]) -> None:
self._observers.remove(observer)
class ValueEvent(Generic[_S], BaseEvent[Observer | ValueObserver[_S]], ValueObservable[_S], ValueEmitter[_S]):
def _get_parameter_count(self) -> int:
return 1

def __call__(self, value: _S) -> None:
for observer in self._observers:
trim_and_call(observer, value)


class BiEvent(Generic[_S, _T], BiObservable[_S, _T], BiEmitter[_S, _T]):
_observers: list[Observer | ValueObserver[_S] | BiObserver[_S, _T]]

def __init__(self):
self._observers = []
self._emit(value)

def observe(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T]) -> None:
self._observers.append(observer)
assert_parameter_max_count(observer, 2)

def unobserve(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T]) -> None:
self._observers.remove(observer)
class BiEvent(Generic[_S, _T], BaseEvent[Observer | ValueObserver[_S] | BiObserver[_S, _T]], BiObservable[_S, _T], BiEmitter[_S, _T]):
def _get_parameter_count(self) -> int:
return 2

def __call__(self, value_0: _S, value_1: _T) -> None:
for observer in self._observers:
trim_and_call(observer, value_0, value_1)


class TriEvent(Generic[_S, _T, _U], TriObservable[_S, _T, _U], TriEmitter[_S, _T, _U]):
_observers: list[Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]]

def __init__(self):
self._observers = []
self._emit(value_0, value_1)

def observe(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]) -> None:
self._observers.append(observer)
assert_parameter_max_count(observer, 3)

def unobserve(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]) -> None:
self._observers.remove(observer)
class TriEvent(Generic[_S, _T, _U], BaseEvent[Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]], TriObservable[_S, _T, _U], TriEmitter[_S, _T, _U]):
def _get_parameter_count(self) -> int:
return 3

def __call__(self, value_0: _S, value_1: _T, value_2: _U) -> None:
for observer in self._observers:
trim_and_call(observer, value_0, value_1, value_2)
self._emit(value_0, value_1, value_2)
7 changes: 0 additions & 7 deletions src/pybind/float_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(self, *values: float | Value[int] | Value[float]):
if isinstance(v, Value):
v.observe(self._create_on_n_changed(i))
self._value = self._calculate_value()
self._on_change: ValueEvent[_U] = ValueEvent()

def _create_on_n_changed(self, index: int) -> Callable[[float], None]:
def on_change(new_value: float) -> None:
Expand All @@ -106,12 +105,6 @@ def transform(self, values: Sequence[float]) -> _U:
def value(self) -> _U:
return self._value

def observe(self, observer: Observer | ValueObserver[_U]) -> None:
self._on_change.observe(observer)

def unobserve(self, observer: Observer | ValueObserver[_U]) -> None:
self._on_change.unobserve(observer)


class CombinedTwoFloatValues(CombinedFloatValues[_U], Generic[_U], ABC):
def __init__(self,
Expand Down
33 changes: 33 additions & 0 deletions src/pybind/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import inspect
from inspect import Parameter
from typing import Callable


def _is_positional_parameter(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)


def count_positional_parameters(function: Callable) -> int:
parameters = inspect.signature(function).parameters
return sum(1 for parameter in parameters.values() if _is_positional_parameter(parameter))


def _is_required_positional_parameter(param: Parameter) -> bool:
return param.default == param.empty and _is_positional_parameter(param)


def count_non_default_parameters(function: Callable) -> int:
parameters = inspect.signature(function).parameters
return sum(1 for param in parameters.values() if _is_required_positional_parameter(param))


def assert_parameter_max_count(callable_: Callable, max_count: int) -> None:
if count_non_default_parameters(callable_) > max_count:
if hasattr(callable_, '__name__'):
callable_name = callable_.__name__
elif hasattr(callable_, '__class__'):
callable_name = callable_.__class__.__name__
else:
callable_name = str(callable_)
raise ValueError(f"Callable {callable_name} has too many non-default parameters: "
f"{count_non_default_parameters(callable_)} > {max_count}")
73 changes: 73 additions & 0 deletions src/pybind/observables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from typing import TypeVar, Callable, Generic, Protocol
from weakref import WeakMethod, ref

from pybind.functions import count_positional_parameters

_SC = TypeVar("_SC", contravariant=True)
_TC = TypeVar("_TC", contravariant=True)
Expand All @@ -10,6 +12,8 @@
_T = TypeVar("_T")
_U = TypeVar("_U")

_O = TypeVar('_O', bound=Callable)


class Observer(Protocol):
def __call__(self) -> None: ...
Expand All @@ -27,11 +31,68 @@ class TriObserver(Protocol[_SC, _TC, _UC]):
def __call__(self, arg1: _SC, arg2: _TC, arg3: _UC, /) -> None: ...


class DeadReferenceError(Exception):
pass


class Subscription(Generic[_O], ABC):
def __init__(self, observer: _O):
self._positional_parameter_count = count_positional_parameters(observer)

def _call(self, observer: _O, *args) -> None:
trimmed_args = args[:self._positional_parameter_count]
observer(*trimmed_args)

@abstractmethod
def __call__(self, *args) -> None:
raise NotImplementedError

@abstractmethod
def matches_observer(self, observer: _O) -> bool:
raise NotImplementedError


class StrongSubscription(Subscription[_O], Generic[_O]):
def __init__(self, observer: _O):
super().__init__(observer)
self._observer = observer

def __call__(self, *args) -> None:
self._call(self._observer, *args)

def matches_observer(self, observer: _O) -> bool:
return self._observer == observer


class WeakSubscription(Subscription[_O], Generic[_O]):
_ref: ref[_O] | WeakMethod

def __init__(self, observer: _O):
super().__init__(observer)
if hasattr(observer, '__self__'):
self._ref = WeakMethod(observer)
else:
self._ref = ref(observer)

def __call__(self, *args) -> None:
observer = self._ref()
if observer is None:
raise DeadReferenceError()
self._call(observer, *args)

def matches_observer(self, observer: _O) -> bool:
return self._ref() == observer


class Observable(ABC):
@abstractmethod
def observe(self, observer: Observer) -> None:
raise NotImplementedError

@abstractmethod
def weak_observe(self, observer: Observer) -> None:
raise NotImplementedError

@abstractmethod
def unobserve(self, observer: Observer) -> None:
raise NotImplementedError
Expand All @@ -42,6 +103,10 @@ class ValueObservable(Generic[_S], ABC):
def observe(self, observer: Observer | ValueObserver[_S]) -> None:
raise NotImplementedError

@abstractmethod
def weak_observe(self, observer: Observer | ValueObserver[_S]) -> None:
raise NotImplementedError

@abstractmethod
def unobserve(self, observer: Observer | ValueObserver[_S]) -> None:
raise NotImplementedError
Expand All @@ -52,6 +117,10 @@ class BiObservable(Generic[_S, _T], ABC):
def observe(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T]) -> None:
raise NotImplementedError

@abstractmethod
def weak_observe(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T]) -> None:
raise NotImplementedError

@abstractmethod
def unobserve(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T]) -> None:
raise NotImplementedError
Expand All @@ -62,6 +131,10 @@ class TriObservable(Generic[_S, _T, _U], ABC):
def observe(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]) -> None:
raise NotImplementedError

@abstractmethod
def weak_observe(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]) -> None:
raise NotImplementedError

@abstractmethod
def unobserve(self, observer: Observer | ValueObserver[_S] | BiObserver[_S, _T] | TriObserver[_S, _T, _U]) -> None:
raise NotImplementedError
Loading