CurrentModule = HybridVariationalInference
First load necessary packages.
using HybridVariationalInference
using HybridVariationalInference: HybridVariationalInference as HVI
using ComponentArrays: ComponentArrays as CA
using Bijectors
using StableRNGs
using SimpleChains
using StatsFuns
using MLUtils
using DistributionFits
using UnPackNext, specify many moving parts of the Hybrid variational inference (HVI)
The example process based model (PBM) predicts a double-monod constrained rate
for different substrate concentrations, S1, and S2.
function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
# extract parameters not depending on order, i.e whether they are in θP or θM
@unpack r0, r1, K1, K2 = θc
r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
endIts formulation is independent of which parameters are global, site-specific,
or fixed during the model inversion.
However, it cannot assume an ordering in the parameters, but needs to
access the components by its symbolic names in the provided ComponentArray.
HVI requires the evaluation of the likelihood of the predictions. It corresponds to the cost of predictions given some observations.
The user specifies a function of the negative log-Likelihood
neg_logden(obs, pred, uncertainty_parameters),
where all of the parameters are arrays with columns for sites.
Here, we use the neg_logden_indep_normal function
that assumed observations to be distributed independently
normal around a true value.
The provided y_unc uncertainty parameters, here, corresponds to
logσ2, denoting the log of the variance parameter of the normal distribution.
py = neg_logden_indep_normalIn this example, we will assign a fixed value to r0 parameter, treat the K2 parameter as unknown but the same across sites, and predict r1 and K1 for each site separately, based on covariates known at the sites.
Here we provide initial values for them by using ComponentVector.
FT = Float32
θM0 = θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) # separately for each individual
θP0 = θP = CA.ComponentVector{FT}(K2=2.0) # population: same across individuals,
θFix = CA.ComponentVector{FT}(r0=0.3) # r0, i.e. not estimatedHVI allows for transformations of parameters in an unconstrained space, where the probability density is not strictly zero anywhere to the original constrained space.
Here, our model parameters are strictly positive, and we use the exponential function to transform unconstrained estimates to the original constrained domain.
transP = Stacked(HVI.Exp())
transM = Stacked(HVI.Exp(), HVI.Exp())Parameter transformations are specified using the Bijectors package.
Because, Bijectors.elementwise(exp), has problems with automatic differentiation (AD)
on GPU, we use the public but non-exported Exp wrapper inside Bijectors.Stacked.
HVI is an approximate bayesian analysis and combines prior information on the parameters with the model and observed data.
Here, we provide a wide prior by fitting a Lognormal distributions to
- the mode corresponding to the initial value provided above
- the 0.95-quantile 3 times the mode
using the
DistributionFits.jlpackage.
θall = vcat(θP, θM)
priors_dict = Dict{Symbol, Distribution}(
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode)))The model parameters are inverted using information on the
- observed data,
y_o - its uncertainty,
y_unc - known covariates across sites,
xM - model drivers,
xPHere, we use synthetic data generated by the package.
rng = StableRNG(111)
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(
rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,)))Lets look at them.
size(xM), size(xP), size(y_o), size(y_unc)((5, 800), (16, 800), (8, 800), (8, 800))
All of them have 800 columns, corresponding to 800 sites. There are 5 site-covaritas, 16 values of model drivers, and 8 observations per site.
xP[:,1]ComponentVector{Float32}(S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.2, 0.1], S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0])
In each column of the model drivers there is a ComponentVector with
components S1 and S2 corresponding to the concentrations, for which outputs
were observed.
This allows notation x.S1 in the PBM above.
The y_unc becomes its meaning by the Likelihood-function to be specified with
the problem below.
HVI uses MLUtils.DataLoader to provide batches of the data during each
iteration of the solver. In addition to the data, it provides an
index to the sites inside a tuple.
n_site = size(y_o,2)
n_batch = 20
train_dataloader = MLUtils.DataLoader(
(CA.getdata(xM), CA.getdata(xP), y_o, y_unc, 1:n_site),
batchsize=n_batch, partial=false)The machine-learning (ML) part predicts parameters of the posterior of site-specific
PBM parameters, given the covariates.
Here, we specify a 3-layer feed-forward neural network using the SimpleChains
framework which works efficiently on CPU.
n_out = length(θM) # number of individuals to predict
n_input = n_covar = size(xM,1)
g_chain = SimpleChain(
static(n_input), # input dimension (optional)
TurboDense{true}(tanh, n_input * 4),
TurboDense{true}(tanh, n_input * 4),
# dense layer without bias that maps to n outputs to (0..1)
TurboDense{false}(logistic, n_out)
)
# get a template of the parameter vector, ϕg0
g_chain_app, ϕg0 = construct_ChainsApplicator(rng, g_chain)The g_chain_app ChainsApplicator predicts the parameters of the posterior,
approximation given a vector of ML weights,ϕg.
During construction, an initial template of this vector is created.
This abstraction layer allows to use different ML frameworks and replace the
SimpleChains model by Flux or Lux.
In order to balance gradients, the g_chain_app ModelApplicator defined above
predicts on a scale (0..1).
Now the priors are used to translate this to the parameter range by using the
cumulative density distribution.
Priors were specified at constrained scale, but the ML model predicts
parameters on unconstrained scale.
This transformation of the distribution can be mathematically worked out for
specific prior distribution forms.
However, for simplicity, a NormalScalingModelApplicator
is fitted to the transformed 5% and 95% quantiles of the original prior.
priorsM = Tuple(priors_dict[k] for k in keys(θM))
lowers, uppers = get_quantile_transformed(priorsM, transM)
g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT)The g_chain_scaled ModelApplicator now predicts in unconstrained scale,
transforms logistic predctions around 0.5 to the range of
high prior probability of the parameters,
and transforms ML predictions near 0 or 1 towards the outer lower probability ranges.
All the specifications above are stored in a HybridProblem structure.
Before, a PBMSiteApplicator is constructed that translates an invocation
given a vector of global parameters, and a matrix of site parameters to
invocation of the process based model (PBM), defined at the beginning.
f_batch = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1])
ϕq0 = init_hybrid_ϕq(MeanHVIApproximation(), θP, θM, transP)
prob = HybridProblem(θM, ϕq0, g_chain_scaled, ϕg0,
f_batch, priors_dict, py,
transM, transP, train_dataloader, n_covar, n_site, n_batch)Eventually, having assembled all the moving parts of the HVI, we can perform the inversion.
# silence warning of no GPU backend found (because we did not import CUDA here)
ENV["MLDATADEVICES_SILENCE_WARN_NO_GPU"] = 1using OptimizationOptimisers
import Zygote
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
(; probo, interpreters) = solve(prob, solver; rng,
callback = callback_loss(100), # output during fitting
epochs = 2,
);The solver object is constructed given the specific stochastic optimization algorithm and the number of Monte-Carlo samples that are drawn in each iteration from the predicted parameter posterior.
Then the solver is applied to the problem using solve
for a given number of iterations or epochs.
For this tutorial, we additionally specify that the function to transfer structures to
the GPU is the identity function, so that all stays on the CPU, and this tutorial
hence does not require ad GPU or GPU libraries.
Among the return values are
probo: A copy of the HybridProblem, with updated optimized parametersinterpreters: ANamedTuplewith severalComponentArrayInterpreters that will help analyzing the results.
So far, the process-based model ran for each single site. For this simple model, some performance grains result from matrix-computations when running the model for all sites within one batch simultaneously.
In the following, the PBM specification accepts matrices as arguments for parameters and drivers and returns a matrix of precitions. For the parameters, one row corresponds to one site. For the drivers and predictions, one column corresponds to one site.
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
# extract several covariates from xP
S1 = view(xPc, Val(:S1), :)
S2 = view(xPc, Val(:S2), :)
#
# extract the parameters as row-repeated vectors
# θc[:,:r0] is parameter r0 for each site in batch
# dot-multiplication of full matrix times row-vector repeats for each observation row
# also introduces zero for missing observations, leading to zero gradient there
is_valid = isfinite.(S1) .&& isfinite.(S2)
r0 = is_valid .* CA.getdata(θc[:, Val(:r0)])'
r1 = is_valid .* CA.getdata(θc[:, Val(:r1)])'
K1 = is_valid .* CA.getdata(θc[:, Val(:K1)])'
K2 = is_valid .* CA.getdata(θc[:, Val(:K2)])'
# each variable is a matrix (n_obs x n_site)
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
endAgain, the function should not rely on the order of parameters but use symbolic indexing to extract the parameter vectors.
A corresponding PBMPopulationApplicator transforms calls with
partitioned global and site parameters to calls of this matrix version of the PBM.
The HVI Problem needs to be updated with this new applicatior.
f_batch = PBMPopulationApplicator(f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1])
probo_sites = HybridProblem(probo; f_batch)For numerical efficiency, the number of sites within one batch is part of the
PBMPopulationApplicator. The problem stores an applicator for n_batch sites,
however, an applicator for n_site_pred sites can be obtained by
create_nsite_applicator(f_batch, n_site_pred).
(; probo) = solve(probo_sites, solver; rng,
callback = callback_loss(100), # output during fitting
epochs = 20,
#is_inferred = Val(true), # activate type-checks
);Extracting useful information from the optimized HybridProblem is covered
in the following Inspect results of fitted problem tutorial.
In order to use the results from this tutorial in other tutorials,
the updated probo HybridProblem and the interpreters are saved to a JLD2 file.
Before the problem is updated, so that it uses the redefinition DoubleMM.f_doubleMM_sites
of the PBM in module DoubleMM rather than
module Main to allow for easier reloading with JLD2.
f_batch = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1])
probo2 = HybridProblem(probo; f_batch)using JLD2
fname = "intermediate/basic_cpu_results.jld2"
mkpath("intermediate")
if probo2 isa AbstractHybridProblem # do not save on failure above
jldsave(fname, false, IOStream; probo=probo2, interpreters)
end