Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/main/java/mascot/app/beauti/NeDynamicsListInputEditor.java
Original file line number Diff line number Diff line change
Expand Up @@ -445,17 +445,22 @@ private void addParameter(RealParameter parameter, String pId, String dynamics,
}

private void removeParameters(NeDynamics neDynamics, MCMC mcmc) {
// The spec interface inputs (RealScalar / RealVector) don't extend
// StateNode, but at runtime the XML always binds a RealScalarParam /
// RealVectorParam (which do). Cast to StateNode at the call site so
// the BEAUti disconnect logic can use getID().
CompoundDistribution posterior = (CompoundDistribution) mcmc.posteriorInput.get();
if (neDynamics instanceof ConstantNe) {
removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((ConstantNe) neDynamics).NeInput.get());
removeParameter(posterior, (StateNode) ((ConstantNe) neDynamics).NeInput.get());
}else if (neDynamics instanceof ExponentialNe) {
removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((ExponentialNe) neDynamics).growthRateInput.get());
removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((ExponentialNe) neDynamics).logNeNullInput.get());
removeParameter(posterior, (StateNode) ((ExponentialNe) neDynamics).growthRateInput.get());
removeParameter(posterior, (StateNode) ((ExponentialNe) neDynamics).logNeNullInput.get());
}else if (neDynamics instanceof Skygrowth) {
removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((Skygrowth) neDynamics).NeInput.get());
removeParameter(posterior, (StateNode) ((Skygrowth) neDynamics).NeInput.get());
}else if (neDynamics instanceof LogisticNe) {
removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((LogisticNe) neDynamics).capacityInput.get());
removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((LogisticNe) neDynamics).carryingProportionInput.get());
removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((LogisticNe) neDynamics).growthRateInput.get());
removeParameter(posterior, (StateNode) ((LogisticNe) neDynamics).capacityInput.get());
removeParameter(posterior, (StateNode) ((LogisticNe) neDynamics).carryingProportionInput.get());
removeParameter(posterior, (StateNode) ((LogisticNe) neDynamics).growthRateInput.get());
}
}

Expand Down
39 changes: 22 additions & 17 deletions src/main/java/mascot/glmmodel/ErrorSmoothing.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import beast.base.inference.Distribution;
import beast.base.inference.State;
import beast.base.inference.StateNode;
import beast.base.inference.distribution.ParametricDistribution;
import beast.base.spec.domain.Real;
import beast.base.spec.inference.distribution.ScalarDistribution;
import beast.base.spec.inference.parameter.IntVectorParam;
import beast.base.spec.inference.parameter.RealVectorParam;
import beast.base.spec.type.RealVector;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -19,14 +21,14 @@
"If x is multidimensional, the components of x are assumed to be independent, " +
"so the sum of log probabilities of all elements of x is returned as the prior.")
public class ErrorSmoothing extends Distribution {
final public Input<Function> m_x = new Input<>("x", "point at which the density is calculated", Validate.REQUIRED);
final public Input<RealVector<? extends Real>> m_x = new Input<>("x", "point at which the density is calculated", Validate.REQUIRED);

final public Input<ParametricDistribution> distInput = new Input<>("distr", "distribution used to calculate prior, e.g. normal, beta, gamma.", Validate.REQUIRED);
final public Input<ScalarDistribution<?, Double>> distInput = new Input<>("distr", "distribution used to calculate prior, e.g. normal, beta, gamma.", Validate.REQUIRED);

/**
* shadows distInput *
*/
protected ParametricDistribution dist;
protected ScalarDistribution<?, Double> dist;

@Override
public void initAndValidate() {
Expand All @@ -36,9 +38,14 @@ public void initAndValidate() {

@Override
public double calculateLogP() {
Function x = m_x.get();
// spec types enforce bounds via domain, so no explicit check needed
logP = dist.calcLogP(x);
RealVector<? extends Real> x = m_x.get();
// Components of x are treated as iid draws from `dist` β€” sum the
// scalar log-densities. The spec ScalarDistribution doesn't have a
// single-shot logP-over-a-vector method; loop manually.
logP = 0;
for (int i = 0; i < x.size(); i++) {
logP += dist.logDensity(x.get(i));
}
if (logP == Double.POSITIVE_INFINITY) {
logP = Double.NEGATIVE_INFINITY;
}
Expand Down Expand Up @@ -67,19 +74,17 @@ public void sample(State state, Random random) {
sampleConditions(state, random);

// sample distribution parameters
Function x = m_x.get();
RealVector<? extends Real> x = m_x.get();

Double[] newx;
try {
newx = dist.sample(1)[0];

// Spec ScalarDistribution.sample() returns one value per call;
// for a vector x we draw size(x) iid samples. m_x is now typed
// RealVector<? extends Real>, so only the real-vector branch
// applies (the legacy code had an IntVectorParam branch that's
// unreachable under the new Input type).
if (x instanceof RealVectorParam<?> rvp) {
for (int i = 0; i < newx.length; i++) {
rvp.set(i, newx[i]);
}
} else if (x instanceof IntVectorParam<?> ivp) {
for (int i = 0; i < newx.length; i++) {
ivp.set(i, (int)Math.round(newx[i]));
for (int i = 0; i < rvp.size(); i++) {
rvp.set(i, dist.sample().get(0));
}
}

Expand Down
21 changes: 20 additions & 1 deletion src/main/java/mascot/glmmodel/GlmModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import beast.base.core.Input.Validate;
import beast.base.core.Loggable;
import beast.base.inference.CalculationNode;
import beast.base.inference.StateNode;
import beast.base.inference.StateNodeInitialiser;
import beast.base.spec.inference.parameter.BoolVectorParam;
import beast.base.spec.domain.Real;
import beast.base.spec.inference.parameter.RealScalarParam;
import beast.base.spec.inference.parameter.RealVectorParam;

public abstract class GlmModel extends CalculationNode implements Loggable {
import java.util.List;

public abstract class GlmModel extends CalculationNode implements Loggable, StateNodeInitialiser {

public Input<CovariateList> covariateListInput = new Input<>("covariateList", "input of covariates", Validate.REQUIRED);
public Input<RealVectorParam<? extends Real>> scalerInput = new Input<>("scaler", "input of covariates scaler", Validate.REQUIRED);
Expand Down Expand Up @@ -84,6 +88,21 @@ public void setNrDummy(){
verticalEntries = 0;
}

// Subclasses (LogLinear, etc.) fix up the dimensions of scaler /
// indicator / error parameters during their own initAndValidate. The
// default initStateNodes is a no-op; getInitialisedStateNodes reports
// the parameters this class takes responsibility for so the framework
// can dedupe.
@Override
public void initStateNodes() {
}

@Override
public void getInitialisedStateNodes(List<StateNode> stateNodes) {
stateNodes.add(scalerInput.get());
stateNodes.add(indicatorInput.get());
if (errorInput.get() != null) stateNodes.add(errorInput.get());
if (constantErrorInput.get() != null) stateNodes.add(constantErrorInput.get());
}

}
31 changes: 14 additions & 17 deletions src/main/java/mascot/glmmodel/MaxRate.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import beast.base.core.Input.Validate;
import beast.base.inference.Distribution;
import beast.base.inference.State;
import beast.base.inference.distribution.ParametricDistribution;
import beast.base.spec.domain.Real;
import beast.base.spec.inference.parameter.RealVectorParam;
import beast.base.spec.inference.distribution.ScalarDistribution;
import mascot.dynamics.GLM;

import java.util.ArrayList;
Expand All @@ -21,14 +19,14 @@
"so the sum of log probabilities of all elements of x is returned as the prior.")
public class MaxRate extends Distribution {
final public Input<GLM> GLMStepwiseModelInput = new Input<>("GLMmodel", "glm model input");
final public Input<ParametricDistribution> distInput = new Input<>("distr", "distribution used to calculate prior, e.g. normal, beta, gamma.", Validate.REQUIRED);
final public Input<ScalarDistribution<?, Double>> distInput = new Input<>("distr", "distribution used to calculate prior, e.g. normal, beta, gamma.", Validate.REQUIRED);
final public Input<Boolean> migrationOnlyInput = new Input<>("migrationOnly", "put prior only on migration rates", false);
final public Input<Boolean> NeOnlyInput = new Input<>("NeOnly", "put prior only on migration rates", false);

/**
* shadows distInput *
*/
protected ParametricDistribution dist;
protected ScalarDistribution<?, Double> dist;


@Override
Expand All @@ -42,29 +40,28 @@ public double calculateLogP() {
Double[] mig = GLMStepwiseModelInput.get().getAllCoalescentRate();
Double[] coal = GLMStepwiseModelInput.get().getAllBackwardsMigration();

RealVectorParam<Real> dCoal = new RealVectorParam<>(unbox(coal), Real.INSTANCE);
RealVectorParam<Real> dMig = new RealVectorParam<>(unbox(mig), Real.INSTANCE);

logP = 0.0;

if (migrationOnlyInput.get()){
logP += dist.calcLogP(dMig);
logP += sumLogDensity(mig);
}else{
logP += dist.calcLogP(dCoal);
logP += dist.calcLogP(dMig);
logP += sumLogDensity(coal);
logP += sumLogDensity(mig);
}
if (logP == Double.POSITIVE_INFINITY) {
logP = Double.NEGATIVE_INFINITY;
}
return logP;
}

private static double[] unbox(Double[] values) {
double[] result = new double[values.length];
for (int i = 0; i < values.length; i++) {
result[i] = values[i];
}
return result;
private double sumLogDensity(Double[] values) {
// Each entry is treated as an iid draw from `dist` (matches what
// legacy ParametricDistribution.calcLogP(Function) did over a
// multidimensional sample). Spec ScalarDistribution exposes
// logDensity per scalar; sum manually.
double sum = 0;
for (Double v : values) sum += dist.logDensity(v);
return sum;
}

/**
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/mascot/logger/StructuredTreeLogger.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import beast.base.evolution.tree.Tree;
import beast.base.evolution.tree.TreeInterface;
import beast.base.inference.StateNode;
import beast.base.spec.inference.parameter.BoolVectorParam;
import beast.base.spec.type.BoolVector;
import mascot.distribution.Mascot;
import mascot.ode.Euler2ndOrderTransitions;
import mascot.ode.MascotODEUpDown;
Expand Down Expand Up @@ -48,7 +48,7 @@ public class StructuredTreeLogger extends Tree implements Loggable {
public Input<BranchRateModel.Base> clockModelInput = new Input<BranchRateModel.Base>("branchratemodel", "rate to be logged with branches of the tree");
public Input<List<Function>> parameterInput = new Input<List<Function>>("metadata", "meta data to be logged with the tree nodes", new ArrayList<>());
public Input<Boolean> maxStateInput = new Input<Boolean>("maxState", "report branch lengths as substitutions (branch length times clock rate for the branch)", false);
public Input<BoolVectorParam> conditionalStateProbsInput = new Input<>("conditionalStateProbs", "report branch lengths as substitutions (branch length times clock rate for the branch)");
public Input<BoolVector> conditionalStateProbsInput = new Input<>("conditionalStateProbs", "report branch lengths as substitutions (branch length times clock rate for the branch)");
public Input<Boolean> substitutionsInput = new Input<Boolean>("substitutions", "report branch lengths as substitutions (branch length times clock rate for the branch)", false);
public Input<Integer> decimalPlacesInput = new Input<Integer>("dp", "the number of decimal places to use writing branch lengths and rates, use -1 for full precision (default = full precision)", -1);

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/mascot/logger/mappedProbLogger.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import beast.base.evolution.branchratemodel.BranchRateModel;
import beast.base.evolution.tree.Node;
import beast.base.inference.CalculationNode;
import beast.base.spec.inference.parameter.BoolVectorParam;
import beast.base.spec.type.BoolVector;
import mascot.distribution.MappedMascot;

import java.io.PrintStream;
Expand All @@ -34,7 +34,7 @@ public class mappedProbLogger extends CalculationNode implements Loggable {
"meta data to be logged with the tree nodes", new ArrayList<>());
public Input<Boolean> maxStateInput = new Input<Boolean>("maxState",
"report branch lengths as substitutions (branch length times clock rate for the branch)", false);
public Input<BoolVectorParam> conditionalStateProbsInput = new Input<>("conditionalStateProbs",
public Input<BoolVector> conditionalStateProbsInput = new Input<>("conditionalStateProbs",
"report branch lengths as substitutions (branch length times clock rate for the branch)");
public Input<Boolean> substitutionsInput = new Input<Boolean>("substitutions",
"report branch lengths as substitutions (branch length times clock rate for the branch)", false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import beast.base.core.Loggable;
import beast.base.evolution.alignment.Alignment;
import beast.base.evolution.datatype.DataType;
import beast.base.evolution.likelihood.TreeLikelihood;
import beast.base.spec.evolution.likelihood.TreeLikelihood;
import beast.base.evolution.tree.Node;
import beast.base.evolution.tree.Tree;
import beast.base.evolution.tree.TreeInterface;
Expand Down
11 changes: 4 additions & 7 deletions src/main/java/mascot/parameterdynamics/ConstantNe.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import beast.base.core.Input;
import beast.base.core.Input.Validate;
import beast.base.spec.domain.Real;
import beast.base.spec.inference.parameter.RealScalarParam;
import beast.base.spec.type.RealScalar;

public class ConstantNe extends NeDynamics {

public Input<RealScalarParam<Real>> NeInput = new Input<>(
public Input<RealScalar<Real>> NeInput = new Input<>(
"logNe", "input of the Ne at the time of the most recent sampled ancestor", Validate.REQUIRED);

RealScalarParam<Real> Ne;
RealScalar<Real> Ne;

@Override
public void initAndValidate() {
Expand All @@ -31,9 +31,6 @@ public double getNeTime(double t) {

@Override
public boolean isDirty() {
if (Ne.somethingIsDirty())
return true;

return false;
return isDirtyInput(Ne);
}
}
18 changes: 6 additions & 12 deletions src/main/java/mascot/parameterdynamics/ExponentialNe.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import beast.base.core.Input;
import beast.base.core.Input.Validate;
import beast.base.spec.domain.Real;
import beast.base.spec.inference.parameter.RealScalarParam;
import beast.base.spec.type.RealScalar;

public class ExponentialNe extends NeDynamics {

public Input<RealScalarParam<Real>> logNeNullInput = new Input<>(
public Input<RealScalar<Real>> logNeNullInput = new Input<>(
"NeNull", "input of the Ne at the time of the most recent sampled ancestor", Validate.REQUIRED);
public Input<RealScalarParam<Real>> growthRateInput = new Input<>(
public Input<RealScalar<Real>> growthRateInput = new Input<>(
"growthRate", "input of the growth rate", Validate.REQUIRED);

public Input<Double> minNeInput = new Input<>(
"minNe", "input of the minimal Ne", 0.0);

RealScalarParam<Real> logNeNull;
RealScalarParam<Real> growthRate;
RealScalar<Real> logNeNull;
RealScalar<Real> growthRate;

@Override
public void initAndValidate() {
Expand All @@ -39,13 +39,7 @@ public double getNeTime(double t) {

@Override
public boolean isDirty() {
if (logNeNull.somethingIsDirty())
return true;

if (growthRate.somethingIsDirty())
return true;

return false;
return isDirtyInput(logNeNull) || isDirtyInput(growthRate);
}


Expand Down
16 changes: 15 additions & 1 deletion src/main/java/mascot/parameterdynamics/LogLinearGLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import beast.base.core.Input;
import beast.base.core.Input.Validate;
import beast.base.inference.StateNode;
import beast.base.inference.StateNodeInitialiser;
import beast.base.spec.inference.parameter.BoolVectorParam;
import beast.base.spec.domain.Real;
import beast.base.spec.inference.parameter.RealScalarParam;
import beast.base.spec.inference.parameter.RealVectorParam;
import mascot.glmmodel.CovariateList;

public class LogLinearGLM extends NeDynamics {
import java.util.List;

public class LogLinearGLM extends NeDynamics implements StateNodeInitialiser {
public Input<CovariateList> covariateListInput = new Input<>("covariateList", "input of covariates", Validate.REQUIRED);
public Input<RealVectorParam<? extends Real>> scalerInput = new Input<>("scaler", "input of covariates scaler", Validate.REQUIRED);
public Input<BoolVectorParam> indicatorInput = new Input<>("indicator", "input of covariates scaler", Validate.REQUIRED);
Expand Down Expand Up @@ -131,5 +135,15 @@ public void restore() {
valuesKnown = false;
}

@Override
public void initStateNodes() {
}

@Override
public void getInitialisedStateNodes(List<StateNode> stateNodes) {
stateNodes.add(scalerInput.get());
stateNodes.add(indicatorInput.get());
if (errorInput.get() != null) stateNodes.add(errorInput.get());
}

}
Loading
Loading