Skip to content

Implementation of a Parametrised Reduced Functional#241

Open
divijghose wants to merge 48 commits into
dolfin-adjoint:masterfrom
divijghose:parametrised_reduced_functional
Open

Implementation of a Parametrised Reduced Functional#241
divijghose wants to merge 48 commits into
dolfin-adjoint:masterfrom
divijghose:parametrised_reduced_functional

Conversation

@divijghose
Copy link
Copy Markdown
Contributor

@divijghose divijghose commented Feb 4, 2026

Current method to use parameters

The derivative_components optional argument of ReducedFunctional is used after adding parameters to the list of controls, to specify which components are to be zeroed out (by omitting them from derivative_components). This allows the user to update parameters by calling the Reduced Functional while zeroing out the gradient with respect to the parameters.

Parametrised Reduced Functional

ParametrisedReducedFunctional is a subclass of ReducedFunctional with wrapping call and derivative methods, with the parameters as attributes. The parameter_update method is called to update parameters. The parameters are not included in the derivative calculation and the optional argument derivative_components is not required.

…educedFunctional` where parameters can be updated but are not included in the derivative calculations:

1. Adds a `parameter_update` method
2. Parameters are appended at the end of the list of optimization controls, so `derivative_components` is not a required argument.
3. The `derivative` method returns only derivative corresponding to optimization controls.
@JHopeCollins
Copy link
Copy Markdown
Contributor

In the current implementation ParameterisedReducedFunctional inherits the controls property from ReducedFunctional. I think this means that this will pass when really it shouldn't:

prf = ParameterisedReducedFunctional(functional, user_controls, parameters)
assert len(prf.controls) == len(user_controls) + len(parameters)

This will also be a problem later because the optimisers will expect:

len(prf.derivative()) == len(prf.controls)

when you will actually have (correctly):

len(prf.derivative()) == len(prf.user_controls)

If this is the case then you may need to override the controls property for ParameterisedReducedFunctional so that it returns only what the user thinks are the controls.
However, if I remember correctly how Python inheritance works, that will then mean that when the parent ReducedFunctional accesses self.controls (for example to calculate the derivative here), it will won't see the full list of user_controls + parameter, but will only see user_controls.

To get around this, you may have to instead inherit from the AbstractReducedFunctional base class and just internally create your own ReducedFunctional(functional=functional, controls=user_controls+parameters).

@divijghose
Copy link
Copy Markdown
Contributor Author

In this intermediate implementation, the ParametrisedReducedFunctional inherits from AbstractReducedFunctional. This means that self.controls returns user_controls, which is the required behaviour as discussed above.

However, a lot of the code is duplicated, and, as @colinjcotter suggested, a more efficient way to do this would be to call ReducedFunctional inside ParametrisedReducedFunctional, which will get the required behaviour from self.controls while reusing most of the methods from ReducedFunctional.

…dFunctional` internally:

- Instead of inheriting the `AbstractReducedFunctional` or `ReducedFunctional` classes, `ParametrisedReducedFunctional` simply calls a `ReducedFunctional` object internally and passes the controls and parameters together as `all_controls`
Comment thread pyadjoint/reduced_functional.py Outdated
`ParametrisedReducedFunctional` must be a subclass of `AbstractReducedFunctional`

Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
Comment thread pyadjoint/reduced_functional.py Outdated
…h component of the parameter list must first be wrapped in `Control`.
1. Basic test to check `call`, `derivative` and `parameter_update` methods
2. Combination tests with single/multiple controls and single/multiple parameters
3. Tests to check behaviour of `controls` and `parameters` property
4. Evaluation on a more complex example
5. Tests to check behaviour in case of multiple parameter updates before call.
…cedFunctional` with `derivative_components`.
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread tests/pyadjoint/test_parametrised_rf.py Outdated
Comment thread tests/pyadjoint/test_parametrised_rf.py
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread tests/pyadjoint/test_parametrised_rf.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
Comment thread pyadjoint/reduced_functional.py Outdated
functional. Input is a list of Controls.
eval_cb_pos (function): Callback function after evaluating the
functional. Inputs are the functional value and a list of Controls.
derivative_cb_pre_for_controls (function): Callback function before evaluating
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very large number of arguments for init! Why not keep derivative_cb_pre etc and make the calling signature derivative_cb_pre(controls, parameters) instead of jut derivative_cb_pre(controls) like in ReducedFunctional?

Comment thread pyadjoint/reduced_functional.py Outdated
Comment on lines +462 to +465
if parameters is None:
raise ValueError("Parameters must be provided. If no parameters are needed, use ReducedFunctional instead.")
if len(Enlist(parameters)) == 0:
raise ValueError("Parameters list cannot be empty. If no parameters are needed, use ReducedFunctional instead.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would be fine with allowing zero parameters, but this is an API choice and could be something to discuss in the meeting.

Comment thread tests/pyadjoint/test_parametrised_rf.py Outdated
Comment thread tests/pyadjoint/test_parametrised_rf.py Outdated
Comment on lines +37 to +39
assert min(taylor_results["R0"]["Rate"]) >= 0.95, f"Error in R0 rate: {taylor_results['R0']['Rate']}"
assert min(taylor_results["R1"]["Rate"]) >= 1.95, f"Error in R1 rate: {taylor_results['R1']['Rate']}"
assert min(taylor_results["R2"]["Rate"]) >= 2.95, f"Error in R2 rate: {taylor_results['R2']['Rate']}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If something failed I want to see more information.

Suggested change
assert min(taylor_results["R0"]["Rate"]) >= 0.95, f"Error in R0 rate: {taylor_results['R0']['Rate']}"
assert min(taylor_results["R1"]["Rate"]) >= 1.95, f"Error in R1 rate: {taylor_results['R1']['Rate']}"
assert min(taylor_results["R2"]["Rate"]) >= 2.95, f"Error in R2 rate: {taylor_results['R2']['Rate']}"
assert min(taylor_results["R0"]["Rate"]) >= 0.95, f"Error in R0 rate: {taylor_results['R0']}"
assert min(taylor_results["R1"]["Rate"]) >= 1.95, f"Error in R1 rate: {taylor_results['R1']}"
assert min(taylor_results["R2"]["Rate"]) >= 2.95, f"Error in R2 rate: {taylor_results['R2']}"

Comment thread tests/pyadjoint/test_parametrised_rf.py Outdated
Comment thread tests/pyadjoint/test_parametrised_rf.py
divijghose and others added 9 commits April 28, 2026 11:15
Change `complex_expression` to `complicated_expression` in test

Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
Correction in the description of PRF in the doctoring

Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
Correction for typo in callback name

Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
Co-authored-by: Josh Hope-Collins <jhc.jss@gmail.com>
@divijghose
Copy link
Copy Markdown
Contributor Author

divijghose commented Apr 30, 2026

Considering the discussion at the Firedrake meeting, the following commits rewrite ReducedFunctional to accept a parameters argument, with the eventual aim of deprecating the use of derivative_components. For now, a user can pass either parameters or derivative_components (but not both together, this would throw up an error).

1. `ParametrisedReducedFunctional` has been removed in favour of a `ReducedFunctional` that accepts `parameters` as an argument, along with a check to make sure either `derivative_components` or `parameters` is passed, but not both simultaneously.
2. If `parameters` is passed, a new `ReducedFunctional` object is created recursively. Methods will check if the `parameters` attribute is present to branch out their implementation.
3. Derivative callback include a `parameters` argument in their signature.
…est to validate initialization of `ReducedFunctional` with either `derivative_components` or `parameters`
@divijghose
Copy link
Copy Markdown
Contributor Author

@JHopeCollins, with respect to the recent CI failure:

    @no_annotations
    def derivative(self, adj_input=1.0, apply_riesz=False):
        values = [c.tape_value() for c in self.controls]
>       controls = self.derivative_cb_pre(self.parameters if hasattr(self, "_parameters") else None, self.controls)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       TypeError: EnsembleReducedFunctional.<lambda>() takes 1 positional argument but 2 were given

This happens because the derivative callback now has the signature derivative_cb_pre(parameters, controls), which does not reflect downstream in EnsembleReducedFunctional.

@JHopeCollins
Copy link
Copy Markdown
Contributor

This happens because the derivative callback now has the signature derivative_cb_pre(parameters, controls), which does not reflect downstream in EnsembleReducedFunctional.

If the user has not passed parameters then we shouldn't change the callback interface otherwise we will break current code before the deprecation cycle is done. For now the callback signature will depend on whether parameters are passed or not.
I also think it makes more sense to put the controls first in the signature: derivative_cb_pre(controls, parameters)

divijghose and others added 2 commits May 4, 2026 12:34
Co-authored-by: Copilot <copilot@github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants