Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions brian2/core/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,18 @@ class Variable(CacheKey):
_cache_irrelevant_attributes = {"owner"}

def __init__(
self,
name,
dimensions=DIMENSIONLESS,
owner=None,
dtype=None,
scalar=False,
constant=False,
read_only=False,
dynamic=False,
array=False,
self,
name,
dimensions=DIMENSIONLESS,
owner=None,
dtype=None,
scalar=False,
constant=False,
read_only=False,
dynamic=False,
array=False,
description=None,
inline_comments=None,
):
assert isinstance(dimensions, Dimension)

Expand Down Expand Up @@ -185,6 +187,16 @@ def __init__(
#: Whether the variable is an array
self.array = array

#: A variable associated description
self.description = description

#: inline comments associated with this variable
self.inline_comments = (
[dict(comment) for comment in inline_comments]
if inline_comments is not None
else []
)

def __getstate__(self):
state = self.__dict__.copy()
state["owner"] = state["owner"].__repr__.__self__ # replace proxy
Expand Down Expand Up @@ -474,6 +486,8 @@ def __init__(
read_only=False,
dynamic=False,
unique=False,
description=None,
inline_comments=None,
):
super().__init__(
dimensions=dimensions,
Expand All @@ -485,6 +499,8 @@ def __init__(
read_only=read_only,
dynamic=dynamic,
array=True,
description=description,
inline_comments=inline_comments,
)

#: Wether all values in this arrays are necessarily unique (only
Expand Down Expand Up @@ -712,6 +728,8 @@ def __init__(
dimensions=DIMENSIONLESS,
dtype=None,
scalar=False,
description=None,
inline_comments=None,
):
super().__init__(
dimensions=dimensions,
Expand All @@ -721,6 +739,8 @@ def __init__(
scalar=scalar,
constant=False,
read_only=True,
description=description,
inline_comments=inline_comments,
)

#: The `Device` responsible for memory access
Expand Down Expand Up @@ -1707,6 +1727,8 @@ def add_array(
scalar=False,
unique=False,
index=None,
description=None,
inline_comments=None,
):
"""
Add an array (initialized with zeros).
Expand Down Expand Up @@ -1756,6 +1778,8 @@ def add_array(
scalar=scalar,
read_only=read_only,
unique=unique,
description=description,
inline_comments=inline_comments,
)
self._add_variable(name, var, index)
# This could be avoided, but we currently need it so that standalone
Expand Down Expand Up @@ -1978,7 +2002,7 @@ def add_constant(self, name, value, dimensions=DIMENSIONLESS):
self._add_variable(name, var)

def add_subexpression(
self, name, expr, dimensions=DIMENSIONLESS, dtype=None, scalar=False, index=None
self, name, expr, dimensions=DIMENSIONLESS, dtype=None, scalar=False, index=None, description=None, inline_comments=None
):
"""
Add a named subexpression.
Expand Down Expand Up @@ -2009,6 +2033,8 @@ def add_subexpression(
dtype=dtype,
device=self.device,
scalar=scalar,
description=description,
inline_comments=inline_comments,
)
self._add_variable(name, var, index=index)

Expand Down
71 changes: 51 additions & 20 deletions brian2/equations/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import sympy
from pyparsing import (
CharsNotIn,
Combine,
Group,
LineEnd,
OneOrMore,
Optional,
ParseException,
ParseResults,
Suppress,
Word,
ZeroOrMore,
Expand All @@ -35,7 +35,7 @@
get_unit,
get_unit_for_display,
)
from brian2.utils.caching import CacheKey, cached
from brian2.utils.caching import CacheKey, _hashable, cached
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers
from brian2.utils.topsort import topsort
Expand Down Expand Up @@ -78,12 +78,13 @@

# very broad definition here, expression will be analysed by sympy anyway
# allows for multi-line expressions, where each line can have comments
EXPRESSION = Combine(
OneOrMore(
(CharsNotIn(":#\n") + Suppress(Optional(LineEnd()))).ignore("#" + restOfLine)
),
joinString=" ",
).set_results_name("expression")
FRAGMENT = Group(
CharsNotIn(":#\n").set_results_name("text")
+ Optional(Suppress("#") + restOfLine).set_results_name("comment")
)
EXPRESSION = OneOrMore(
FRAGMENT + Suppress(Optional(LineEnd()))
).set_results_name("fragments")

# a unit
# very broad definition here, again. Whether this corresponds to a valid unit
Expand All @@ -97,6 +98,7 @@
FLAGS = (
Suppress("(") + FLAG + ZeroOrMore(Suppress(",") + FLAG) + Suppress(")")
).set_results_name("flags")
DESCRIPTION = (Suppress("#") + restOfLine).set_results_name("description")

###############################################################################
# Equations
Expand All @@ -105,25 +107,24 @@
# Parameter:
# x : volt (flags)
PARAMETER_EQ = Group(
IDENTIFIER + Suppress(":") + UNIT + Optional(FLAGS)
IDENTIFIER + Suppress(":") + UNIT + Optional(FLAGS) + Optional(DESCRIPTION)
).set_results_name(PARAMETER)

# Static equation:
# x = 2 * y : volt (flags)
STATIC_EQ = Group(
IDENTIFIER + Suppress("=") + EXPRESSION + Suppress(":") + UNIT + Optional(FLAGS)
IDENTIFIER + Suppress("=") + EXPRESSION + Suppress(":") + UNIT + Optional(FLAGS) + Optional(DESCRIPTION)
).set_results_name(SUBEXPRESSION)

# Differential equation
# dx/dt = -x / tau : volt
DIFF_OP = Suppress("d") + IDENTIFIER + Suppress("/") + Suppress("dt")
DIFF_EQ = Group(
DIFF_OP + Suppress("=") + EXPRESSION + Suppress(":") + UNIT + Optional(FLAGS)
DIFF_OP + Suppress("=") + EXPRESSION + Suppress(":") + UNIT + Optional(FLAGS) + Optional(DESCRIPTION)
).set_results_name(DIFFERENTIAL_EQUATION)

# ignore comments
EQUATION = (PARAMETER_EQ | STATIC_EQ | DIFF_EQ).ignore("#" + restOfLine)
EQUATIONS = ZeroOrMore(EQUATION)
EQUATION = PARAMETER_EQ | STATIC_EQ | DIFF_EQ
EQUATIONS = ZeroOrMore(EQUATION | Suppress("#" + restOfLine) | Suppress(LineEnd()))


class EquationError(Exception):
Expand Down Expand Up @@ -383,6 +384,15 @@ def parse_string_equations(eqns):
"""
equations = {}

def _as_text(value):
if isinstance(value, ParseResults):
if len(value) == 0:
return ""
return _as_text(value[0])
if value is None:
return ""
return str(value)

try:
parsed = EQUATIONS.parse_string(eqns, parse_all=True)
except ParseException as p_exc:
Expand All @@ -408,16 +418,32 @@ def parse_string_equations(eqns):
f"Error parsing the unit specification for variable '{identifier}'."
) from ex

expression = eq_content.get("expression")
if expression is not None:
expression = None
inline_comments = []
fragments = eq_content.get("fragments")
if fragments is not None:
expression_chunks = []
comments = []
for fragment in fragments:
text = _as_text(fragment["text"]) if "text" in fragment else ""
expression_chunks.append(text)
comment = _as_text(fragment["comment"] if "comment" in fragment else "").strip()
if comment:
text_for_comment = text.strip().lstrip('(').rstrip(')')
comments.append({"text": text_for_comment, "comment": comment})

# Replace multiple whitespaces (arising from joining multiline
# strings) with single space
p = re.compile(r"\s{2,}")
expression = Expression(p.sub(" ", expression))
clean_expression = p.sub(" ", " ".join(expression_chunks)).strip()
expression = Expression(clean_expression)
inline_comments = comments

description = _as_text(eq_content.get("description")).strip() or None
flags = list(eq_content.get("flags", []))

equation = SingleEquation(
eq_type, identifier, dims, var_type=var_type, expr=expression, flags=flags
eq_type, identifier, dims, var_type=var_type, expr=expression, flags=flags, description=description, inline_comments=inline_comments
)

if identifier in equations:
Expand Down Expand Up @@ -457,7 +483,7 @@ class SingleEquation(Hashable, CacheKey):
_cache_irrelevant_attributes = {"update_order"}

def __init__(
self, type, varname, dimensions, var_type=FLOAT, expr=None, flags=None
self, type, varname, dimensions, var_type=FLOAT, expr=None, flags=None, description=None, inline_comments=None
):
self.type = type
self.varname = varname
Expand All @@ -479,6 +505,11 @@ def __init__(
self.flags = []
else:
self.flags = list(flags)
self.description = description
if inline_comments is None:
self.inline_comments = []
else:
self.inline_comments = list(inline_comments)

# will be set later in the sort_subexpressions method of Equations
self.update_order = -1
Expand Down Expand Up @@ -508,7 +539,7 @@ def __ne__(self, other):
return not self == other

def __hash__(self):
return hash(self._state_tuple)
return hash(_hashable(self._state_tuple))

def _latex(self, *args):
if self.type == DIFFERENTIAL_EQUATION:
Expand Down
4 changes: 4 additions & 0 deletions brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,8 @@ def _create_variables(self, user_dtype, events):
dtype=dtype,
constant=constant,
scalar=shared,
description=eq.description,
inline_comments=eq.inline_comments,
)
elif eq.type == SUBEXPRESSION:
self.variables.add_subexpression(
Expand All @@ -834,6 +836,8 @@ def _create_variables(self, user_dtype, events):
expr=str(eq.expr),
dtype=dtype,
scalar="shared" in eq.flags,
description=eq.description,
inline_comments=eq.inline_comments,
)
else:
raise AssertionError(f"Unknown type of equation: {eq.eq_type}")
Expand Down
37 changes: 37 additions & 0 deletions brian2/tests/test_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,43 @@ def test_parse_equations():
parse_string_equations(error_eqs)


@pytest.mark.codegen_independent
def test_parse_equations_comment_metadata():
eqs = parse_string_equations(
"""
dv/dt = (g_L*(E_L - v) + # leak current
g_e*(E_e - v) + # excitatory input
g_i*(E_i - v)) / tau : 1 # inhibitory input
I_syn = g_e*(E_e - v) + # excitatory component
g_i*(E_i - v) : amp # inhibitory component
v_t : 1 # threshold
"""
)

assert eqs["v"].expr.code == "(g_L*(E_L - v) + g_e*(E_e - v) + g_i*(E_i - v)) / tau"
assert eqs["v"].description == "inhibitory input"
assert eqs["v"].inline_comments == [
{"text": "g_L*(E_L - v) +", "comment": "leak current"},
{"text": "g_e*(E_e - v) +", "comment": "excitatory input"},
]

assert eqs["I_syn"].expr.code == "g_e*(E_e - v) + g_i*(E_i - v)"
assert eqs["I_syn"].description == "inhibitory component"
assert eqs["I_syn"].inline_comments == [
{"text": "g_e*(E_e - v) +", "comment": "excitatory component"},
]

assert eqs["v_t"].description == "threshold"
assert eqs["v_t"].inline_comments == []


@pytest.mark.codegen_independent
def test_parse_equations_parameter_description():
eqs = parse_string_equations("v_t : 1 # threshold")
assert eqs["v_t"].description == "threshold"
assert eqs["v_t"].inline_comments == []


@pytest.mark.codegen_independent
def test_correct_replacements():
"""Test replacing variables via keyword arguments"""
Expand Down
34 changes: 34 additions & 0 deletions brian2/tests/test_neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,40 @@ def test_variables():
assert "not_refractory" in G.variables and "lastspike" in G.variables


@pytest.mark.codegen_independent
def test_equation_comment_metadata_on_variables():
G = NeuronGroup(
1,
"""
dv/dt = (g_L*(E_L-v) + # leak current
g_e*(E_e-v)) / tau : 1 # membrane potential
I_syn = g_L*(E_L-v) + # leak component
g_e*(E_e-v) : amp # excitatory component
v_t : 1 # threshold
""",
)

assert G.equations["v"].description == "membrane potential"
assert G.equations["v"].inline_comments == [
{"text": "g_L*(E_L-v) +", "comment": "leak current"}
]
assert G.variables["v"].description == "membrane potential"
assert G.variables["v"].inline_comments == [
{"text": "g_L*(E_L-v) +", "comment": "leak current"}
]

assert G.equations["I_syn"].description == "excitatory component"
assert G.equations["I_syn"].inline_comments == [
{"text": "g_L*(E_L-v) +", "comment": "leak component"}
]
assert G.variables["I_syn"].description == "excitatory component"
assert G.variables["I_syn"].inline_comments == [
{"text": "g_L*(E_L-v) +", "comment": "leak component"}
]

assert G.variables["v_t"].description == "threshold"


@pytest.mark.codegen_independent
def test_variableview_calculations():
# Check that you can directly calculate with "variable views"
Expand Down