Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: 2026 Alec Delaney
# SPDX-License-Identifier: MIT

name: Run code tests

on:
push:
branches:
- 'main'
pull_request:
paths:
- '.github/workflows/codecov.yml'
- 'pyproject.toml'
- 'circuitpython_functools'
- 'tests/**'

permissions: read-all

jobs:
codecov:
runs-on: ubuntu-latest

steps:
- name: Setup Python 3.x
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.py-version }}
- name: Checkout the repository
uses: actions/checkout@v6
- name: Install base requirements
run: |
pip install ".[optional]"
- name: Install test requirements
run: |
pip install pytest pytest-cov
- name: Run tests
run: |
pytest --cov --cov-branch --cov-report=xml
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
68 changes: 67 additions & 1 deletion circuitpython_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,70 @@ def _partial(*more_args, **more_kwargs):
# their simplified implementation of the wraps function!
def wraps(wrapped, assigned=None, updated=None):
"""Define a wrapper function when writing function decorators."""
return wrapped

def decorator(wrapper):
return wrapper

return decorator


def total_ordering(cls): # noqa: PLR0912
"""Automatically create the comparison functions."""
has_lt = "__lt__" in cls.__dict__
has_gt = "__gt__" in cls.__dict__
has_le = "__le__" in cls.__dict__
has_ge = "__ge__" in cls.__dict__

if not (has_lt or has_gt or has_le or has_ge):
raise ValueError("must define at least one ordering operation: < > <= >=")

def instance_guard(x, cls):
if not isinstance(x, cls):
raise TypeError("unsupport comparison")
return True

if not has_lt:
if has_le:
lt_func = lambda self, other: self <= other and self != other
elif has_gt:
lt_func = lambda self, other: not (self > other) and self != other
else: # has_ge
lt_func = lambda self, other: not (self >= other)
cls.__lt__ = lambda self, other: instance_guard(other, cls) and lt_func(
self, other
)

if not has_le:
if has_lt:
le_func = lambda self, other: self < other or self == other
elif has_gt:
le_func = lambda self, other: not (self > other)
else: # has_ge
le_func = lambda self, other: self == other or not (self >= other)
cls.__le__ = lambda self, other: instance_guard(other, cls) and le_func(
self, other
)

if not has_gt:
if has_lt:
gt_func = lambda self, other: self != other and not (self < other)
elif has_ge:
gt_func = lambda self, other: self >= other and self != other
else: # has_le
gt_func = lambda self, other: not (self <= other)
cls.__gt__ = lambda self, other: instance_guard(other, cls) and gt_func(
self, other
)

if not has_ge:
if has_lt:
ge_func = lambda self, other: not (self < other)
elif has_gt:
ge_func = lambda self, other: self > other or self == other
else: # has_le
ge_func = lambda self, other: self == other or not (self <= other)
cls.__ge__ = lambda self, other: instance_guard(other, cls) and ge_func(
self, other
)

return cls
34 changes: 34 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: 2026 Alec Delaney
# SPDX-FileCopyrightText: Python Software Foundation

# SPDX-License-Identifier: MIT
# SPDX-License-Identifier: PSF-2.0

"""Tests for cache."""

from circuitpython_functools import cache

# Example adapted from CPython documentation

TOTAL_CALLS = 0


def test_cache():
"""Tests that cache decorator works as expected."""

@cache
def factorial(n):
global TOTAL_CALLS # noqa: PLW0603
TOTAL_CALLS += 1
return n * factorial(n - 1) if n else 1

assert TOTAL_CALLS == 0 # noqa: PLR2004

_ = factorial(10)
assert TOTAL_CALLS == 11 # noqa: PLR2004

_ = factorial(5)
assert TOTAL_CALLS == 11 # noqa: PLR2004

_ = factorial(12)
assert TOTAL_CALLS == 13 # noqa: PLR2004
165 changes: 165 additions & 0 deletions tests/test_lru_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# SPDX-FileCopyrightText: 2026 Alec Delaney
# SPDX-FileCopyrightText: Python Software Foundation
#
# SPDX-License-Identifier: MIT
# SPDX-License-Identifier: PSF-2.0

"""Tests for lru_cache."""

import pytest

from circuitpython_functools import _lru_cache_records, lru_cache

# Factorial example adapted from CPython documentation

TOTAL_CALLS = 0


def test_lru_cache_default():
"""Tests the lru_cache works with default settings (no arguments provided)."""
global TOTAL_CALLS # noqa: PLW0603

@lru_cache
def factorial(n):
global TOTAL_CALLS # noqa: PLW0603
TOTAL_CALLS += 1
return n * factorial(n - 1) if n else 1

assert TOTAL_CALLS == 0 # noqa: PLR2004

_ = factorial(10)
assert TOTAL_CALLS == 11 # noqa: PLR2004

_ = factorial(5)
assert TOTAL_CALLS == 11 # noqa: PLR2004

_ = factorial(12)
assert TOTAL_CALLS == 13 # noqa: PLR2004

TOTAL_CALLS = 0


def test_lru_cache_maxsize_arg():
"""Tests the lru_cache maxsize function when given as arg."""
global TOTAL_CALLS # noqa: PLW0603

@lru_cache(5)
def factorial(n):
global TOTAL_CALLS # noqa: PLW0603
TOTAL_CALLS += 1
return n * factorial(n - 1) if n else 1

assert TOTAL_CALLS == 0 # noqa: PLR2004

_ = factorial(10)
assert TOTAL_CALLS == 11 # noqa: PLR2004
print("---")

_ = factorial(5)
assert TOTAL_CALLS == 17 # noqa: PLR2004
print("---")

_ = factorial(12)
assert TOTAL_CALLS == 24 # noqa: PLR2004

TOTAL_CALLS = 0


def test_lru_cache_maxsize_kwarg():
"""Tests the lru_cache maxsize function when given as kwarg."""
global TOTAL_CALLS # noqa: PLW0603

@lru_cache(maxsize=5)
def factorial(n):
global TOTAL_CALLS # noqa: PLW0603
TOTAL_CALLS += 1
return n * factorial(n - 1) if n else 1

assert TOTAL_CALLS == 0 # noqa: PLR2004

_ = factorial(10)
assert TOTAL_CALLS == 11 # noqa: PLR2004
print("---")

_ = factorial(5)
assert TOTAL_CALLS == 17 # noqa: PLR2004
print("---")

_ = factorial(12)
assert TOTAL_CALLS == 24 # noqa: PLR2004

TOTAL_CALLS = 0


def test_lru_cache_func_kwarg():
"""Tests the lru_cache when function has kwargs."""
global TOTAL_CALLS # noqa: PLW0603

@lru_cache
def factorial(*, n):
global TOTAL_CALLS # noqa: PLW0603
TOTAL_CALLS += 1
return n * factorial(n=n - 1) if n else 1

assert TOTAL_CALLS == 0 # noqa: PLR2004

_ = factorial(n=10)
assert TOTAL_CALLS == 11 # noqa: PLR2004

_ = factorial(n=5)
assert TOTAL_CALLS == 11 # noqa: PLR2004

_ = factorial(n=12)
assert TOTAL_CALLS == 13 # noqa: PLR2004

TOTAL_CALLS = 0


def test_lru_cache_cache_clear():
"""Tests the automatically attached cache_clear method works."""
global TOTAL_CALLS # noqa: PLW0603

@lru_cache
def factorial(n):
global TOTAL_CALLS # noqa: PLW0603
TOTAL_CALLS += 1
return n * factorial(n=n - 1) if n else 1

assert TOTAL_CALLS == 0 # noqa: PLR2004

_ = factorial(n=10)
assert TOTAL_CALLS == 11 # noqa: PLR2004

factorial.cache_clear()

_ = factorial(n=10)
assert TOTAL_CALLS == 22 # noqa: PLR2004

TOTAL_CALLS = 0


def test_lru_cache_typed_error_args():
"""Tests that lru_cache raises an error if "typed" given as arg."""
with pytest.raises(NotImplementedError):

@lru_cache(100, True)
def factorial(n):
return n * factorial(n=n - 1) if n else 1


def test_lru_cache_typed_error_kwargs():
"""Tests that lru_cache raises an error if "typed" given as kwarg."""
with pytest.raises(NotImplementedError):

@lru_cache(typed=True)
def factorial(n):
return n * factorial(n=n - 1) if n else 1


def test_lru_cache_syntax_error():
"""Tests that lru_cache raises an error if arguments are incorrect."""
with pytest.raises(SyntaxError):

@lru_cache(100, True, "a")
def factorial(n):
return n * factorial(n=n - 1) if n else 1
26 changes: 26 additions & 0 deletions tests/test_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: 2026 Alec Delaney
#
# SPDX-License-Identifier: MIT

"""Tests for partial."""

from circuitpython_functools import partial


def test_partial():
"""Tests functionality of partial."""

def towerize(x, y, z):
return (x * 100) + (y * 10) + z

towerize100 = partial(towerize, 1)
assert towerize100(5, 3) == 153 # noqa: PLR2004
assert towerize100(z=2, y=7) == 172 # noqa: PLR2004

towerize150 = partial(towerize, 1, 5)
assert towerize150(6) == 156 # noqa: PLR2004
assert towerize150(z=8) == 158 # noqa: PLR2004

towerize20 = partial(towerize, y=2)
assert towerize20(9, z=0) == 920 # noqa: PLR2004
assert towerize20(z=1, x=5) == 521 # noqa: PLR2004
Loading
Loading