Skip to content

Commit 8888c03

Browse files
committed
update default params computation in interconnect()
1 parent 8f3615b commit 8888c03

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

control/nlsys.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,11 @@ def __init__(self, syslist, connections=None, inplist=None, outlist=None,
710710
if outputs is None and outlist is not None:
711711
outputs = len(outlist)
712712

713+
if params is None:
714+
params = {}
715+
for sys in self.syslist:
716+
params = params | sys.params
717+
713718
# Create updfcn and outfcn
714719
def updfcn(t, x, u, params):
715720
self._update_params(params)
@@ -2268,7 +2273,8 @@ def interconnect(
22682273
22692274
params : dict, optional
22702275
Parameter values for the systems. Passed to the evaluation functions
2271-
for the system as default values, overriding internal defaults.
2276+
for the system as default values, overriding internal defaults. If
2277+
not specified, defaults to parameters from subsystems.
22722278
22732279
dt : timebase, optional
22742280
The timebase for the system, used to specify whether the system is

control/tests/interconnect_test.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,15 +666,29 @@ def test_interconnect_params():
666666
# Create a nominally unstable system
667667
sys1 = ct.nlsys(
668668
lambda t, x, u, params: params['a'] * x[0] + u[0],
669-
states=1, inputs='u', outputs='y', params={'a': 1})
669+
states=1, inputs='u', outputs='y', params={'a': 2, 'c':2})
670670

671671
# Simple system for serial interconnection
672672
sys2 = ct.nlsys(
673673
None, lambda t, x, u, params: u[0],
674-
inputs='r', outputs='u')
674+
inputs='r', outputs='u', params={'a': 4, 'b': 3})
675675

676-
# Create a series interconnection
676+
# Make sure default parameters get set as expected
677677
sys = ct.interconnect([sys1, sys2], inputs='r', outputs='y')
678+
assert sys.params == {'a': 4, 'c': 2, 'b': 3}
679+
assert sys.dynamics(0, [1], [0]).item() == 4
680+
681+
# Make sure we can override the parameters
682+
sys = ct.interconnect(
683+
[sys1, sys2], inputs='r', outputs='y', params={'b': 1})
684+
assert sys.params == {'b': 1}
685+
assert sys.dynamics(0, [1], [0]).item() == 2
686+
assert sys.dynamics(0, [1], [0], params={'a': 5}).item() == 5
687+
688+
# Create final series interconnection, with proper parameter values
689+
sys = ct.interconnect(
690+
[sys1, sys2], inputs='r', outputs='y', params={'a': 1})
691+
assert sys.params == {'a': 1}
678692

679693
# Make sure we can call the update function
680694
sys.updfcn(0, [0], [0], {})

0 commit comments

Comments
 (0)