|
| 1 | +"""Taylor/Maclaurin series expansion utilities.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import math |
| 6 | +from collections.abc import Callable |
| 7 | +from typing import Literal |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +HAS_SYMPY = True |
| 12 | +try: |
| 13 | + import sympy as sp |
| 14 | +except Exception: # pragma: no cover - optional dependency guard |
| 15 | + HAS_SYMPY = False |
| 16 | + |
| 17 | + |
| 18 | +SupportedFunction = Literal["exp", "sin", "cos", "log1p"] |
| 19 | + |
| 20 | + |
| 21 | +def _builtin_derivative_value(name: SupportedFunction, n: int, at: float) -> float: |
| 22 | + if name == "exp": |
| 23 | + return float(math.exp(at)) |
| 24 | + if name == "sin": |
| 25 | + return float(math.sin(at + n * math.pi / 2.0)) |
| 26 | + if name == "cos": |
| 27 | + return float(math.cos(at + n * math.pi / 2.0)) |
| 28 | + if name == "log1p": |
| 29 | + if n == 0: |
| 30 | + return float(math.log1p(at)) |
| 31 | + return float(((-1) ** (n - 1)) * math.factorial(n - 1) / (1.0 + at) ** n) |
| 32 | + msg = f"Unsupported function: {name}" |
| 33 | + raise ValueError(msg) |
| 34 | + |
| 35 | + |
| 36 | +def _numerical_nth_derivative(func: Callable[[float], float], n: int, at: float, h: float = 1e-4) -> float: |
| 37 | + if n == 0: |
| 38 | + return float(func(at)) |
| 39 | + if n == 1: |
| 40 | + return float((func(at + h) - func(at - h)) / (2.0 * h)) |
| 41 | + return float( |
| 42 | + (_numerical_nth_derivative(func, n - 1, at + h, h) - _numerical_nth_derivative(func, n - 1, at - h, h)) / (2.0 * h) |
| 43 | + ) |
| 44 | + |
| 45 | + |
| 46 | +def taylor_series( |
| 47 | + x: float | np.ndarray, |
| 48 | + order: int, |
| 49 | + center: float = 0.0, |
| 50 | + *, |
| 51 | + function_name: SupportedFunction | None = None, |
| 52 | + function: Callable[[float], float] | None = None, |
| 53 | +) -> np.ndarray: |
| 54 | + """Evaluate Taylor polynomial of given order around center. |
| 55 | +
|
| 56 | + f(x) ≈ Σ[n=0..order] f^(n)(a) / n! * (x-a)^n |
| 57 | + """ |
| 58 | + if function_name is None and function is None: |
| 59 | + msg = "Provide either function_name or function." |
| 60 | + raise ValueError(msg) |
| 61 | + |
| 62 | + x_arr = np.asarray(x, dtype=float) |
| 63 | + approximation = np.zeros_like(x_arr, dtype=float) |
| 64 | + |
| 65 | + for n in range(order + 1): |
| 66 | + if function_name is not None: |
| 67 | + derivative_at_center = _builtin_derivative_value(function_name, n, center) |
| 68 | + else: |
| 69 | + assert function is not None |
| 70 | + derivative_at_center = _numerical_nth_derivative(function, n, center) |
| 71 | + |
| 72 | + approximation = approximation + derivative_at_center / math.factorial(n) * (x_arr - center) ** n |
| 73 | + |
| 74 | + return approximation |
| 75 | + |
| 76 | + |
| 77 | +def maclaurin_series( |
| 78 | + x: float | np.ndarray, |
| 79 | + order: int, |
| 80 | + *, |
| 81 | + function_name: SupportedFunction | None = None, |
| 82 | + function: Callable[[float], float] | None = None, |
| 83 | +) -> np.ndarray: |
| 84 | + """Evaluate Maclaurin polynomial (Taylor around 0).""" |
| 85 | + return taylor_series(x, order=order, center=0.0, function_name=function_name, function=function) |
| 86 | + |
| 87 | + |
| 88 | +def estimate_lagrange_remainder(max_derivative: float, x: float | np.ndarray, order: int, center: float = 0.0) -> np.ndarray: |
| 89 | + """Estimate Lagrange remainder bound: M|x-a|^(n+1)/(n+1)!""" |
| 90 | + x_arr = np.asarray(x, dtype=float) |
| 91 | + return np.abs(max_derivative) * np.abs(x_arr - center) ** (order + 1) / math.factorial(order + 1) |
| 92 | + |
| 93 | + |
| 94 | +def symbolic_taylor_expression(expr_str: str, symbol: str = "x", center: float = 0.0, order: int = 6) -> str: |
| 95 | + """Return symbolic Taylor expression as a string when SymPy is available.""" |
| 96 | + if not HAS_SYMPY: |
| 97 | + msg = "SymPy is not available." |
| 98 | + raise RuntimeError(msg) |
| 99 | + x = sp.symbols(symbol) |
| 100 | + expr = sp.sympify(expr_str) |
| 101 | + series = sp.series(expr, x, center, order + 1).removeO() |
| 102 | + return str(sp.expand(series)) |
0 commit comments