|
| 1 | +# SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries |
| 2 | +# SPDX-FileCopyrightText: Copyright (c) 2022 Alec Delaney |
| 3 | +# SPDX-FileCopyrightText: Python Software Foundation |
| 4 | +# SPDX-FileCopyrightText: MicroPython Developers |
| 5 | +# |
| 6 | +# SPDX-License-Identifier: MIT |
| 7 | +# SPDX-License-Identifier: PSF-2.0 |
| 8 | + |
| 9 | +"""CircuitPython implementation of CPython's functools library. |
| 10 | +
|
| 11 | +* Author(s): Alec Delaney |
| 12 | +
|
| 13 | +Implementation Notes |
| 14 | +-------------------- |
| 15 | +
|
| 16 | +**Software and Dependencies:** |
| 17 | +
|
| 18 | +* Adafruit CircuitPython firmware for the supported boards: |
| 19 | + https://circuitpython.org/downloads |
| 20 | +
|
| 21 | +""" |
| 22 | + |
| 23 | +import gc |
| 24 | +from collections import OrderedDict |
| 25 | + |
| 26 | +__version__ = "0.0.0+auto.0" |
| 27 | +__repo__ = "https://github.com/tekktrik/CircuitPython_functools.git" |
| 28 | + |
| 29 | +_cache_records = {} |
| 30 | +_lru_cache_records = {} |
| 31 | + |
| 32 | + |
| 33 | +class _ObjectMark: |
| 34 | + pass |
| 35 | + |
| 36 | + |
| 37 | +# Cache-related code ported from CPython |
| 38 | + |
| 39 | +# As a general note, some weird things are happening here due to internal differences between |
| 40 | +# CPython and CircuitPython, such as the fact that closures can't have things added to them, |
| 41 | +# hence the need to create objects so that cache_clear can be called from the returned "wrapped" |
| 42 | +# functions. |
| 43 | + |
| 44 | + |
| 45 | +def _make_key(args, kwargs, kwd_mark=(_ObjectMark(),)): |
| 46 | + """Make a key for the cache records.""" |
| 47 | + key = tuple(args) |
| 48 | + if kwargs: |
| 49 | + key += kwd_mark |
| 50 | + for item in kwargs.items(): |
| 51 | + key += tuple(item) |
| 52 | + return hash(key) |
| 53 | + |
| 54 | + |
| 55 | +class _CachedFunc: |
| 56 | + """Wrapped unbounded cache function.""" |
| 57 | + |
| 58 | + def __init__(self, maxsize, user_func): |
| 59 | + """Initialize the wrapped cache function.""" |
| 60 | + self._maxsize = maxsize |
| 61 | + checked_records = _cache_records if maxsize < 0 else _lru_cache_records |
| 62 | + |
| 63 | + def cache_wrapper(*args, **kwargs): |
| 64 | + sentinel = object() |
| 65 | + |
| 66 | + # Make the key for the inner dictionary |
| 67 | + key = _make_key(args, kwargs) |
| 68 | + |
| 69 | + # if there is no inner dictionary yet, make one |
| 70 | + if checked_records.get(user_func) is None: |
| 71 | + checked_records[user_func] = OrderedDict() |
| 72 | + |
| 73 | + # Attempt to get an existing entry, updating its location in the queue |
| 74 | + # and returning it if so |
| 75 | + result = checked_records[user_func].get(key, sentinel) |
| 76 | + if result is not sentinel: |
| 77 | + checked_records[user_func].move_to_end(key) |
| 78 | + return result |
| 79 | + |
| 80 | + # Calculate the actual value |
| 81 | + result = user_func(*args, **kwargs) |
| 82 | + |
| 83 | + # If the cache is bounded and too full to store the new result, eject the |
| 84 | + # least-recently-use entry |
| 85 | + if maxsize >= 0 and len(checked_records[user_func]) >= maxsize: |
| 86 | + first_key = next(iter(checked_records[user_func])) |
| 87 | + del checked_records[user_func][first_key] |
| 88 | + |
| 89 | + # Store the result |
| 90 | + checked_records[user_func][key] = result |
| 91 | + |
| 92 | + # Return the new result |
| 93 | + return result |
| 94 | + |
| 95 | + self._user_func = user_func |
| 96 | + self._wrapped_func = cache_wrapper |
| 97 | + |
| 98 | + def __call__(self, *args, **kwargs): |
| 99 | + """Call the wrapped function.""" |
| 100 | + return self._wrapped_func(*args, **kwargs) |
| 101 | + |
| 102 | + def cache_clear(self): |
| 103 | + """Clear the cache.""" |
| 104 | + checked_records = _cache_records if self._maxsize < 0 else _lru_cache_records |
| 105 | + if self._user_func in checked_records: |
| 106 | + checked_records[self._user_func].clear() |
| 107 | + gc.collect() |
| 108 | + |
| 109 | + |
| 110 | +def cache(user_function): |
| 111 | + """Create an unbounded cache.""" |
| 112 | + return _CachedFunc(-1, user_function) |
| 113 | + |
| 114 | + |
| 115 | +def lru_cache(*args, **kwargs): |
| 116 | + """Create a bounded cache which ejects the least recently used entry.""" |
| 117 | + cpython_max_args = 2 |
| 118 | + if len(args) == cpython_max_args or "typed" in kwargs: |
| 119 | + raise NotImplementedError("Using typed is not supported") |
| 120 | + |
| 121 | + if len(args) == 1 and isinstance(args[0], int): |
| 122 | + maxsize = args[0] |
| 123 | + elif len(args) == 1 and str(type(args[0]) == "<class 'function'>"): |
| 124 | + return _CachedFunc(128, args[0]) |
| 125 | + elif "maxsize" in kwargs: |
| 126 | + maxsize = kwargs["maxsize"] |
| 127 | + else: |
| 128 | + raise SyntaxError("lru_cache syntax incorrect") |
| 129 | + |
| 130 | + return partial(_CachedFunc, maxsize) |
| 131 | + |
| 132 | + |
| 133 | +# Partial ported from the MicroPython library |
| 134 | +def partial(func, *args, **kwargs): |
| 135 | + """Create a partial of the function.""" |
| 136 | + |
| 137 | + def _partial(*more_args, **more_kwargs): |
| 138 | + local_kwargs = kwargs.copy() |
| 139 | + local_kwargs.update(more_kwargs) |
| 140 | + return func(*(args + more_args), **local_kwargs) |
| 141 | + |
| 142 | + return _partial |
| 143 | + |
| 144 | + |
| 145 | +# Thank you to the MicroPython Development team for |
| 146 | +# their simplified implementation of the wraps function! |
| 147 | +def wraps(wrapped, assigned=None, updated=None): |
| 148 | + """Define a wrapper function when writing function decorators.""" |
| 149 | + |
| 150 | + def decorator(wrapper): |
| 151 | + return wrapper |
| 152 | + |
| 153 | + return decorator |
| 154 | + |
| 155 | + |
| 156 | +def total_ordering(cls): # noqa: PLR0912 |
| 157 | + """Automatically create the comparison functions.""" |
| 158 | + has_lt = "__lt__" in cls.__dict__ |
| 159 | + has_gt = "__gt__" in cls.__dict__ |
| 160 | + has_le = "__le__" in cls.__dict__ |
| 161 | + has_ge = "__ge__" in cls.__dict__ |
| 162 | + |
| 163 | + if not (has_lt or has_gt or has_le or has_ge): |
| 164 | + raise ValueError("must define at least one ordering operation: < > <= >=") |
| 165 | + |
| 166 | + def instance_guard(x, cls): |
| 167 | + if not isinstance(x, cls): |
| 168 | + raise TypeError("unsupport comparison") |
| 169 | + return True |
| 170 | + |
| 171 | + if not has_lt: |
| 172 | + if has_le: |
| 173 | + lt_func = lambda self, other: self <= other and self != other |
| 174 | + elif has_gt: |
| 175 | + lt_func = lambda self, other: not (self > other) and self != other |
| 176 | + else: # has_ge |
| 177 | + lt_func = lambda self, other: not (self >= other) |
| 178 | + cls.__lt__ = lambda self, other: instance_guard(other, cls) and lt_func( |
| 179 | + self, other |
| 180 | + ) |
| 181 | + |
| 182 | + if not has_le: |
| 183 | + if has_lt: |
| 184 | + le_func = lambda self, other: self < other or self == other |
| 185 | + elif has_gt: |
| 186 | + le_func = lambda self, other: not (self > other) |
| 187 | + else: # has_ge |
| 188 | + le_func = lambda self, other: self == other or not (self >= other) |
| 189 | + cls.__le__ = lambda self, other: instance_guard(other, cls) and le_func( |
| 190 | + self, other |
| 191 | + ) |
| 192 | + |
| 193 | + if not has_gt: |
| 194 | + if has_lt: |
| 195 | + gt_func = lambda self, other: self != other and not (self < other) |
| 196 | + elif has_ge: |
| 197 | + gt_func = lambda self, other: self >= other and self != other |
| 198 | + else: # has_le |
| 199 | + gt_func = lambda self, other: not (self <= other) |
| 200 | + cls.__gt__ = lambda self, other: instance_guard(other, cls) and gt_func( |
| 201 | + self, other |
| 202 | + ) |
| 203 | + |
| 204 | + if not has_ge: |
| 205 | + if has_lt: |
| 206 | + ge_func = lambda self, other: not (self < other) |
| 207 | + elif has_gt: |
| 208 | + ge_func = lambda self, other: self > other or self == other |
| 209 | + else: # has_le |
| 210 | + ge_func = lambda self, other: self == other or not (self <= other) |
| 211 | + cls.__ge__ = lambda self, other: instance_guard(other, cls) and ge_func( |
| 212 | + self, other |
| 213 | + ) |
| 214 | + |
| 215 | + return cls |
0 commit comments