Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions RMK_support/common_variables.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional
from abc import ABC, abstractmethod
from typing import Optional, cast

import numpy as np

from .variable_container import Variable, node
from .derivations import Species
from . import derivations
from abc import ABC, abstractmethod
from .derivations import Species
from .variable_container import Variable, node


def timeDerivative(
Expand Down Expand Up @@ -340,10 +341,11 @@ def density(self, initVals: Optional[np.ndarray] = None) -> Variable:
data=initVals,
defaultLatex="n_{" + self.species.latex() + "}",
).withDual()
n_dual = cast(Variable, n.dual)
if self.__associateOnCreation__:
if n.name not in self.species.associatedVarNames:
self.species.associateVar(n, n.dual)
self.species[n.subtype] = n
self.species.associateVar(n, n_dual)
self.species[n.subtype] = n
if self.__addOnCreation__:
if n.name not in self.__context__.variables.varNames:
self.__context__.variables.add(n)
Expand All @@ -369,10 +371,11 @@ def flux(self, initVals: Optional[np.ndarray] = None) -> Variable:
).withDual(
"G" + self.species.name, "\\vec{\\Gamma}_{" + self.species.latex() + "}"
)
G_dual = cast(Variable, G.dual)
if self.__associateOnCreation__:
if G.name not in self.species.associatedVarNames:
self.species.associateVar(G, G.dual)
self.species[G.subtype] = G.dual
self.species.associateVar(G, G_dual)
self.species[G.subtype] = G_dual
if self.__addOnCreation__:
if G.name not in self.__context__.variables.varNames:
self.__context__.variables.add(G)
Expand All @@ -393,10 +396,11 @@ def energyDensity(self, initVals: Optional[np.ndarray] = None) -> Variable:
data=initVals,
defaultLatex="W_{" + self.species.latex() + "}",
).withDual()
W_dual = cast(Variable, W.dual)
if self.__associateOnCreation__:
if W.name not in self.species.associatedVarNames:
self.species.associateVar(W, W.dual)
self.species[W.subtype] = W
self.species.associateVar(W, W_dual)
self.species[W.subtype] = W
if self.__addOnCreation__:
if W.name not in self.__context__.variables.varNames:
self.__context__.variables.add(W)
Expand All @@ -418,10 +422,11 @@ def temperature(self, initVals: Optional[np.ndarray] = None) -> Variable:
isStationary=True,
defaultLatex="T_{" + self.species.latex() + "}",
).withDual()
T_dual = cast(Variable, T.dual)
if self.__associateOnCreation__:
if T.name not in self.species.associatedVarNames:
self.species.associateVar(T, T.dual)
self.species[T.subtype] = T
self.species.associateVar(T, T_dual)
self.species[T.subtype] = T
if self.__addOnCreation__:
if T.name not in self.__context__.variables.varNames:
self.__context__.variables.add(T)
Expand Down Expand Up @@ -450,10 +455,11 @@ def flowSpeed(self, initVals: Optional[np.ndarray] = None) -> Variable:
+ self.species.latex()
+ "}\\right)_{dual}",
).withDual("u" + self.species.name, "\\vec{u}_{" + self.species.latex() + "}")
u_dual = cast(Variable, u.dual)
if self.__associateOnCreation__:
if u.name not in self.species.associatedVarNames:
self.species.associateVar(u, u.dual)
self.species[u.subtype] = u.dual
self.species.associateVar(u, u_dual)
self.species[u.subtype] = u_dual
if self.__addOnCreation__:
if u.name not in self.__context__.variables.varNames:
self.__context__.variables.add(u)
Expand All @@ -479,10 +485,11 @@ def heatflux(self, initVals: Optional[np.ndarray] = None) -> Variable:
+ self.species.latex()
+ "}\\right)_{dual}",
).withDual("q" + self.species.name, "\\vec{q}_{" + self.species.latex() + "}")
q_dual = cast(Variable, q.dual)
if self.__associateOnCreation__:
if q.name not in self.species.associatedVarNames:
self.species.associateVar(q, q.dual)
self.species[q.subtype] = q.dual
self.species.associateVar(q, q_dual)
self.species[q.subtype] = q_dual
if self.__addOnCreation__:
if q.name not in self.__context__.variables.varNames:
self.__context__.variables.add(q)
Expand All @@ -508,10 +515,11 @@ def pressure(self, initVals: Optional[np.ndarray] = None) -> Variable:
subtype="pressure",
defaultLatex="p_{" + self.species.latex() + "}",
).withDual()
p_dual = cast(Variable, p.dual)
if self.__associateOnCreation__:
if p.name not in self.species.associatedVarNames:
self.species.associateVar(p, p.dual)
self.species[p.subtype] = p
self.species.associateVar(p, p_dual)
self.species[p.subtype] = p
if self.__addOnCreation__:
if p.name not in self.__context__.variables.varNames:
self.__context__.variables.add(p)
Expand All @@ -534,10 +542,11 @@ def viscosity(self, initVals: Optional[np.ndarray] = None) -> Variable:
subtype="viscosity",
defaultLatex="\\Pi_{" + self.species.latex() + "}",
).withDual()
pi_dual = cast(Variable, pi.dual)
if self.__associateOnCreation__:
if pi.name not in self.species.associatedVarNames:
self.species.associateVar(pi, pi.dual)
self.species[pi.subtype] = pi
self.species.associateVar(pi, pi_dual)
self.species[pi.subtype] = pi
if self.__addOnCreation__:
if pi.name not in self.__context__.variables.varNames:
self.__context__.variables.add(pi)
Expand Down
80 changes: 66 additions & 14 deletions RMK_support/tests/test_common_variables.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest

from RMK_support import Grid, RMKContext, Species
import RMK_support.common_variables as cv
from RMK_support import Grid, RMKContext, Species


@pytest.fixture
Expand All @@ -22,64 +22,116 @@ def context():
return rk


def test_standard_variable_factory(context):
@pytest.mark.parametrize(
"associateOnCreation,addOnCreation",
[
pytest.param(
True,
True,
id="default",
),
pytest.param(
True,
False,
id="associateOnly",
),
pytest.param(
False,
True,
id="addOnly",
),
],
)
def test_standard_variable_factory(
context: RMKContext, associateOnCreation: bool, addOnCreation: bool
):

rk = context

factory = cv.StandardFluidVariables(rk, rk.species["e"])
associateOnCreation = False
addOnCreation = False

factory = cv.StandardFluidVariables(
rk, rk.species["e"], associateOnCreation, addOnCreation
)

n = factory.density()

assert n.name == "ne"
assert "ne" in rk.variables.varNames
assert "ne" in rk.species["e"].associatedVarNames
assert n.units == "norm. density"
assert n.normConst == rk.normDensity
assert n.unitsSI == "$m^{-3}$"
assert (n.name in rk.variables.varNames) == addOnCreation
assert (n.name in rk.species["e"].associatedVarNames) == associateOnCreation

G = factory.flux()

assert G.name == "Ge_dual"
assert "Ge" in rk.variables.varNames
assert G.normConst == rk.normDensity * rk.norms["speed"]

factory.species = rk.species["n"]
assert (G.name in rk.variables.varNames) == addOnCreation
assert (G.name in rk.species["e"].associatedVarNames) == associateOnCreation

T = factory.temperature()

assert T.isStationary
assert T.normConst == rk.normTemperature
assert (T.name in rk.variables.varNames) == addOnCreation
assert (T.name in rk.species["e"].associatedVarNames) == associateOnCreation

# Use factory to create a new species 'n' and associate variables to it
factory.species = rk.species["n"]

p = factory.pressure()

assert p.name == "pn"
assert p.isDerived
assert p.normConst == rk.normTemperature * rk.normDensity
assert (p.name in rk.variables.varNames) == addOnCreation
assert (p.name in rk.species["n"].associatedVarNames) == associateOnCreation
# Pressure requires n and T as dependents - check if these were automatically added
assert ("nn" in rk.variables.varNames) == addOnCreation
assert ("Tn" in rk.variables.varNames) == addOnCreation
if associateOnCreation:
assert factory.species["density"].name == "nn"
assert factory.species["temperature"].name == "Tn"

u = factory.flowSpeed()

assert u.name == "un_dual"
assert u.isDerived
assert u.normConst == rk.norms["speed"]
assert (u.name in rk.variables.varNames) == addOnCreation
assert (u.name in rk.species["n"].associatedVarNames) == associateOnCreation

W = factory.energyDensity()
assert W.name in rk.species["n"].associatedVarNames

assert W.name == "Wn"
assert W.normConst == rk.normTemperature * rk.normDensity
assert (W.name in rk.variables.varNames) == addOnCreation
assert (W.name in rk.species["n"].associatedVarNames) == associateOnCreation

q = factory.heatflux()

assert q.name == "qn_dual"
assert q.isStationary
assert q.isOnDualGrid
assert q.normConst == rk.norms["heatFlux"]
assert (q.name in rk.variables.varNames) == addOnCreation
assert (q.name in rk.species["n"].associatedVarNames) == associateOnCreation

pi = factory.viscosity()
assert pi.isStationary

assert pi.name == "pin"
assert pi.isStationary
assert (pi.name in rk.variables.varNames) == addOnCreation
assert (pi.name in rk.species["n"].associatedVarNames) == associateOnCreation

E = cv.electricField("E", rk)
assert E.normConst == rk.norms["EField"]

assert factory.species["density"].name == "nn"
assert factory.species["temperature"].name == "Tn"
assert factory.species["heatflux"].name == "qn"
assert E.normConst == rk.norms["EField"]

dndt = cv.timeDerivative("dndt", rk.norms["time"], n)

assert dndt.units == "norm. density / time norm."
assert dndt.normConst == rk.normDensity / rk.norms["time"]
assert dndt.unitsSI == "$m^{-3}/s$"