Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions python/sdist/amici/exporters/sundials/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,20 @@ class AmiciCxxCodePrinter(CXX11CodePrinter):

optimizations: Iterable[Optimization] = ()

def __init__(self):
def __init__(self, extract_cse: bool | None = None):
"""Create code printer"""
super().__init__()

# extract common subexpressions in matrix functions?
self.extract_cse = os.getenv("AMICI_EXTRACT_CSE", "0").lower() in (
"1",
"on",
"true",
self.extract_cse = (
os.getenv("AMICI_EXTRACT_CSE", "0").lower()
in (
"1",
"on",
"true",
)
if extract_cse is None
else extract_cse
)

# Floating-point optimizations
Expand Down Expand Up @@ -115,27 +120,54 @@ def _print_ComplexInfinity(self, expr):
return "std::numeric_limits<double>::infinity()"

def _get_sym_lines_array(
self, equations: sp.Matrix, variable: str, indent_level: int
self,
equations: sp.Matrix,
variable: str,
indent_level: int,
indices: Sequence[int] | None = None,
) -> list[str]:
"""
Generate C++ code for assigning symbolic terms in symbols to C++ array
`variable`.

:param equations:
vectors of symbolic expressions

:param variable:
name of the C++ array to assign to

:param indent_level:
indentation level (number of leading blanks)
:param indices:
List of custom indices corresponding to entries in `equations`.
If `None`, the indices will be 0..(N-1).

:return:
C++ code as list of lines
"""
if indices is None:
indices = range(len(equations))

if self.extract_cse:
res = self._get_sym_lines_symbols(
symbols=sp.Matrix(
[sp.Symbol(f"{variable}[{index}]") for index in indices]
),
equations=equations,
variable=variable,
indent_level=indent_level,
indices=indices,
)
# make compound statement so that extracted subexpressions are
# scoped locally and can be used in switch-cases
indent = " " * indent_level
return [
f"{indent}{{",
*(f"{indent}{l}" for l in res),
f"{indent}}}",
]

return [
" " * indent_level + f"{variable}[{index}] = {self.doprint(math)};"
for index, math in enumerate(equations)
for index, math in zip(indices, equations, strict=True)
if math not in [0, 0.0]
]

Expand Down
11 changes: 8 additions & 3 deletions python/sdist/amici/exporters/sundials/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,9 +680,14 @@ def _get_function_body(
f"if(std::find("
"reinitialization_state_idxs.cbegin(), "
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())",
f" {function}[{index}] = "
f"{self._code_printer.doprint(formula)};",
"reinitialization_state_idxs.cend()) {",
*self._code_printer._get_sym_lines_array(
equations=sp.Matrix([formula]),
indices=[index],
variable=function,
indent_level=4,
),
"}",
]
)
cases[ipar] = expressions
Expand Down
51 changes: 51 additions & 0 deletions python/tests/test_cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,54 @@ def test_float_arithmetic():
cp = AmiciCxxCodePrinter()
assert cp.doprint(sp.Rational(1, 2)) == "1.0/2.0"
assert cp.doprint(sp.Integer(1) / sp.Integer(2)) == "1.0/2.0"


@skip_on_valgrind
def test_extract_cse():
"""Test extraction of common subexpressions."""
cp = AmiciCxxCodePrinter()
cp_cse = AmiciCxxCodePrinter(extract_cse=True)

a, b, c = sp.symbols("a b c")
x1, x2, x3 = sp.symbols("x1 x2 x3")

syms = sp.Matrix([x1, x2, x3])
eqs = sp.Matrix([a * b * c, a * b, a * b * c + a])

expected = [
" x1 = a*b*c; // x[0]",
" x2 = a*b; // x[1]",
" x3 = a*b*c + a; // x[2]",
]

expected_cse = [
" const realtype __amici_cse_0 = a*b;",
" const realtype __amici_cse_1 = __amici_cse_0*c;",
" x2 = __amici_cse_0; // x[1]",
" x1 = __amici_cse_1; // x[0]",
" x3 = __amici_cse_1 + a; // x[2]",
]

assert expected == cp._get_sym_lines_symbols(
symbols=syms, equations=eqs, variable="x", indent_level=2
)
assert expected_cse == cp_cse._get_sym_lines_symbols(
symbols=syms, equations=eqs, variable="x", indent_level=2
)

expected = [" x[0] = a*b*c;", " x[1] = a*b;", " x[2] = a*b*c + a;"]
expected_cse = [
" {",
" const realtype __amici_cse_0 = a*b;",
" const realtype __amici_cse_1 = __amici_cse_0*c;",
" x[1] = __amici_cse_0; // x[1]",
" x[0] = __amici_cse_1; // x[0]",
" x[2] = __amici_cse_1 + a; // x[2]",
" }",
]
assert expected == cp._get_sym_lines_array(
equations=eqs, variable="x", indent_level=2
)
assert expected_cse == cp_cse._get_sym_lines_array(
equations=eqs, variable="x", indent_level=2
)
Loading