Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/openlifu/bf/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class Sequence(DictMixin):
pulse_train_count: Annotated[int, OpenLIFUFieldData("Pulse train count", "Number of pulse trains in the sequence")] = 1
"""Number of pulse trains in the sequence"""

focus_order: Annotated[list[int] | None, OpenLIFUFieldData("Focus order", "Optional focus index order for each pulse")] = None
"""Optional focus index order for each pulse"""

def __post_init__(self):
if self.pulse_interval <= 0:
raise ValueError("Pulse interval must be positive")
Expand All @@ -38,6 +41,15 @@ def __post_init__(self):
raise ValueError("Pulse train interval must be greater than or equal to the total pulse interval")
if self.pulse_train_count <= 0:
raise ValueError("Pulse train count must be positive")
if self.focus_order is not None:
if len(self.focus_order) == 0:
raise ValueError("Focus order must not be empty")
if len(self.focus_order) != self.pulse_count:
raise ValueError("Focus order length must match pulse count")
if any(not isinstance(focus_index, int) for focus_index in self.focus_order):
raise TypeError("Focus order entries must be integers")
if any(focus_index < 1 for focus_index in self.focus_order):
raise ValueError("Focus order entries must be positive")

def to_table(self) -> pd.DataFrame:
"""
Expand All @@ -49,7 +61,8 @@ def to_table(self) -> pd.DataFrame:
{"Name": "Pulse Interval", "Value": self.pulse_interval, "Unit": "s"},
{"Name": "Pulse Count", "Value": self.pulse_count, "Unit": ""},
{"Name": "Pulse Train Interval", "Value": self.pulse_train_interval, "Unit": "s"},
{"Name": "Pulse Train Count", "Value": self.pulse_train_count, "Unit": ""}
{"Name": "Pulse Train Count", "Value": self.pulse_train_count, "Unit": ""},
{"Name": "Focus Order", "Value": self.focus_order, "Unit": ""}
]
return pd.DataFrame.from_records(records)

Expand Down
23 changes: 19 additions & 4 deletions src/openlifu/plan/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class Protocol:
virtual_fit_options: Annotated[VirtualFitOptions, OpenLIFUFieldData("Virtual fit options", "Configuration of the virtual fit algorithm")] = field(default_factory=VirtualFitOptions)
"""Configuration of the virtual fit algorithm"""

scaling_options: Annotated[dict, OpenLIFUFieldData("Scaling options", "Options to adjust solution scaling. By default, no additional scaling options are applied")] = field(default_factory=dict)
"""Options to adjust solution scaling. By default, no additional scaling options are applied"""

def __post_init__(self):
self.logger = logging.getLogger(__name__)

Expand All @@ -97,6 +100,7 @@ def from_dict(d : Dict[str,Any]) -> Protocol:
if "virtual_fit_options" in d:
d['virtual_fit_options'] = VirtualFitOptions.from_dict(d['virtual_fit_options'])
d["analysis_options"] = SolutionAnalysisOptions.from_dict(d.get("analysis_options", {}))
d["scaling_options"] = d.get("scaling_options", {})
return Protocol(**d)

def to_dict(self):
Expand All @@ -116,6 +120,7 @@ def to_dict(self):
"target_constraints": [tc.to_dict() for tc in self.target_constraints],
"virtual_fit_options": self.virtual_fit_options.to_dict(),
"analysis_options": self.analysis_options.to_dict(),
"scaling_options": self.scaling_options,
}

@staticmethod
Expand Down Expand Up @@ -316,8 +321,11 @@ def calc_solution(
simulation_result_aggregated: xa.Dataset = xa.Dataset()
foci: List[Point] = self.focal_pattern.get_targets(target)

if self.sequence.focus_order is not None and max(self.sequence.focus_order) > len(foci):
raise ValueError(f"Focus order index {max(self.sequence.focus_order)} exceeds number of foci ({len(foci)})")

# updating solution sequence if pulse mismatch
if (self.sequence.pulse_count % len(foci)) != 0:
if self.sequence.focus_order is None and (self.sequence.pulse_count % len(foci)) != 0:
self.fix_pulse_mismatch(on_pulse_mismatch, foci)
Comment on lines +324 to 329
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calc_solution() only validates focus_order via max(self.sequence.focus_order) > len(foci). Since Sequence.focus_order can be mutated after initialization (bypassing Sequence.__post_init__), this misses other invalid states (empty list, wrong length vs pulse_count, non-positive indices) and max([]) would crash with an unhelpful exception. Consider performing a full validation here (length, type/positivity, and bounds vs len(foci)) or calling a shared Sequence.validate_focus_order(num_foci=...) helper.

Copilot uses AI. Check for mistakes.
# run simulation and aggregate the results
for focus in foci:
Expand Down Expand Up @@ -364,14 +372,21 @@ def calc_solution(
raise ValueError(f"Cannot scale solution {solution.id} if simulation is not enabled!")
self.logger.info(f"Scaling solution {solution.id}...")
#TODO can analysis be an attribute of solution ?
solution.scale(self.focal_pattern, analysis_options=analysis_options)
solution.scale(self.focal_pattern, analysis_options=analysis_options, **self.scaling_options)

if simulate:
# Finally the resulting pressure is max-aggregated and intensity is mean-aggregated, over all focus points .
pnp_aggregated = solution.simulation_result['p_min'].max(dim="focal_point_index", keep_attrs=True)
ppp_aggregated = solution.simulation_result['p_max'].max(dim="focal_point_index", keep_attrs=True)
# TODO: Ensure this mean is weighted by the number of times each point is focused on, once openlifu supports hitting points different numbers of times
intensity_aggregated = solution.simulation_result['intensity'].mean(dim="focal_point_index", keep_attrs=True)
focus_counts = solution.get_focus_counts()
focus_weights = xa.DataArray(
focus_counts / np.sum(focus_counts),
dims=("focal_point_index",),
coords={"focal_point_index": solution.simulation_result.coords["focal_point_index"]},
)
intensity = solution.simulation_result['intensity']
intensity_aggregated = (intensity * focus_weights).sum(dim="focal_point_index", keep_attrs=True)
intensity_aggregated.attrs.update(intensity.attrs)
Comment on lines +381 to +389
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new weighted intensity aggregation in calc_solution() is core to focus_order support, but there doesn't appear to be a unit test asserting that aggregation is weighted by focus_counts (and differs from the previous unweighted mean). Adding a focused test would help prevent regressions in the weighting logic.

Copilot uses AI. Check for mistakes.
simulation_result_aggregated = deepcopy(solution.simulation_result)
simulation_result_aggregated = simulation_result_aggregated.drop_dims("focal_point_index")
simulation_result_aggregated['p_min'] = pnp_aggregated
Expand Down
126 changes: 118 additions & 8 deletions src/openlifu/plan/solution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import base64
import heapq
import json
import logging
import tempfile
Expand Down Expand Up @@ -123,6 +124,8 @@ def __post_init__(self):
raise ValueError("Pulse train interval must be greater than or equal to the total pulse interval")
if self.sequence.pulse_train_count <= 0:
raise ValueError("Pulse train count must be positive")
if (self.sequence.focus_order is not None and len(self.foci) > 0 and max(self.sequence.focus_order) > len(self.foci)):
raise ValueError(f"Focus order index {max(self.sequence.focus_order)} exceeds number of foci ({len(self.foci)})")
Comment on lines +127 to +128
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solution.__post_init__ validates focus_order using max(self.sequence.focus_order), which will raise a built-in ValueError if focus_order is later mutated to [] (dataclass fields are mutable) and it only checks the upper bound. If the intent is to validate focus_order at solution creation time, consider guarding against empty lists and validating lower bound / length as well (or centralizing validation in a helper so Sequence construction and later mutations are checked consistently).

Copilot uses AI. Check for mistakes.
if len(self.foci)>0 and self.delays is not None and self.delays.shape[0] != len(self.foci):
raise ValueError(f"Delays number of foci ({self.delays.shape[0]}) does not match number of foci ({len(self.foci)})")
if len(self.foci)>0 and self.apodizations is not None and self.apodizations.shape[0] != len(self.foci):
Expand All @@ -138,6 +141,83 @@ def num_foci(self) -> int:
"""Get the number of foci"""
return len(self.foci)

def get_focus_order(self) -> np.ndarray:
"""Get the focus index order for each pulse."""
if self.sequence.focus_order is not None:
return np.array(self.sequence.focus_order)
return (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_focus_order() builds the default round-robin sequence as (np.arange(pulse_count) - 1) % num_foci + 1, which starts the sequence at num_foci (e.g., for 3 foci it yields [3,1,2,...]) rather than [1,2,3,...]. This will skew focus counts/weighting whenever focus_order is not explicitly provided. Consider using np.arange(pulse_count) % num_foci + 1 (and handle num_foci()==0 with a clear error or empty result).

Suggested change
return (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1
num_foci = self.num_foci()
if num_foci == 0:
raise ValueError("Cannot compute default focus order when there are no foci")
return np.arange(self.sequence.pulse_count) % num_foci + 1

Copilot uses AI. Check for mistakes.

def get_focus_counts(self) -> np.ndarray:
"""Get the number of pulses assigned to each focus."""
focus_order = self.get_focus_order()
return np.array([
np.sum(focus_order == (focus_index + 1))
for focus_index in range(self.num_foci())
])

def compute_balanced_focus_counts(self, balance_metric_values: np.ndarray, pulse_count: int) -> np.ndarray:
"""Compute per-focus pulse counts that balance a positive per-focus metric."""
balance_metric_values = np.asarray(balance_metric_values, dtype=float)
if balance_metric_values.shape != (self.num_foci(),):
raise ValueError(f"Balance metric must have one value per focus ({self.num_foci()})")
if pulse_count < self.num_foci():
raise ValueError(f"Pulse count ({pulse_count}) must be greater than or equal to number of foci ({self.num_foci()})")
if np.any(~np.isfinite(balance_metric_values)) or np.any(balance_metric_values <= 0):
raise ValueError("Balance metric values must be finite and positive")

remaining_pulses = pulse_count - self.num_foci()
counts = np.ones(self.num_foci(), dtype=int)
if remaining_pulses == 0:
return counts

weights = 1 / balance_metric_values
ideal_extra_counts = weights / np.sum(weights) * remaining_pulses
extra_counts = np.floor(ideal_extra_counts).astype(int)
counts += extra_counts

leftover_pulses = remaining_pulses - int(np.sum(extra_counts))
remainders = ideal_extra_counts - extra_counts
for focus_index in np.argsort(remainders)[::-1][:leftover_pulses]:
counts[focus_index] += 1
return counts

def build_focus_order(self, focus_counts: np.ndarray, ordering: str = "minimize_repeats") -> list[int]:
"""Build a focus order from per-focus pulse counts."""
if ordering != "minimize_repeats":
raise ValueError(f"Unsupported focus ordering '{ordering}'")
focus_counts = np.asarray(focus_counts, dtype=int)
if focus_counts.shape != (self.num_foci(),):
raise ValueError(f"Focus counts must have one value per focus ({self.num_foci()})")
if np.any(focus_counts < 0):
raise ValueError("Focus counts must be non-negative")

heap = []
for focus_index, focus_count in enumerate(focus_counts):
if focus_count > 0:
heap.append((-focus_count, focus_index + 1))

heapq.heapify(heap)
focus_order = []
previous_count = 0
previous_focus_index = None

while heap or previous_count < 0:
if not heap:
focus_order.append(previous_focus_index)
previous_count += 1
continue

focus_count, focus_index = heapq.heappop(heap)
focus_order.append(focus_index)
focus_count += 1

if previous_count < 0:
heapq.heappush(heap, (previous_count, previous_focus_index))

previous_count = focus_count
previous_focus_index = focus_index
return focus_order

def simulate(self,
params: xa.Dataset,
sim_options: SimSetup | None = None,
Expand Down Expand Up @@ -195,14 +275,16 @@ def simulate(self,
def analyze(self,
simulation_result: xa.Dataset | None = None,
options: SolutionAnalysisOptions = SolutionAnalysisOptions(),
param_constraints: Dict[str,ParameterConstraint] | None = None) -> SolutionAnalysis:
param_constraints: Dict[str,ParameterConstraint] | None = None,
focus_counts: np.ndarray | None = None) -> SolutionAnalysis:
"""Analyzes the treatment solution.

Args:
simulation_result: The simulation result dataset to analyze. If None, uses self.simulation_result.
options: A struct for solution analysis options.
param_constraints: A dictionary of parameter constraints to apply to the analysis.
The keys are the parameter names and the values are the ParameterConstraint objects.
focus_counts: Optional per-focus pulse counts to use for ITA calculations.

Returns: A struct containing the results of the analysis.
"""
Expand Down Expand Up @@ -241,7 +323,7 @@ def analyze(self,
solution_analysis.sequence_duration_s = float(self.sequence.pulse_interval * self.sequence.pulse_count * self.sequence.pulse_train_count)
else:
solution_analysis.sequence_duration_s = float(self.sequence.pulse_train_interval * self.sequence.pulse_train_count)
ita_mWcm2 = rescale_coords(self.get_ita(intensity=simulation_result['intensity'], units="mW/cm^2"), options.distance_units)
ita_mWcm2 = rescale_coords(self.get_ita(intensity=simulation_result['intensity'], units="mW/cm^2", focus_counts=focus_counts), options.distance_units)

power_W = np.zeros(self.num_foci())
TIC = np.zeros(self.num_foci())
Expand Down Expand Up @@ -422,14 +504,20 @@ def compute_scaling_factors(
def scale(
self,
focal_pattern: FocalPattern,
analysis_options: SolutionAnalysisOptions = SolutionAnalysisOptions()
analysis_options: SolutionAnalysisOptions = SolutionAnalysisOptions(),
balance_method: str | None = None,
balance_metric: str = "mainlobe_ispta_mWcm2",
ordering: str = "minimize_repeats",
) -> None:
"""
Scale the solution in-place to match the target pressure.

Args:
focal_pattern: FocalPattern
analysis_options: plan.solution.SolutionAnalysisOptions
balance_method: Optional method for balancing scaled delivery. Supported: "ispta_repeats".
balance_metric: The per-focus analysis metric used for balancing.
ordering: How to order balanced repeats. Supported: "minimize_repeats".

Returns:
analysis_scaled: the resulting plan.solution.SolutionAnalysis from scaled solution
Expand All @@ -446,6 +534,18 @@ def scale(
self.apodizations[i] = self.apodizations[i]*apod_factors[i]
self.voltage = v1

if balance_method is None:
return
if balance_method != "ispta_repeats":
raise ValueError(f"Unsupported balance method '{balance_method}'")
baseline_focus_counts = np.ones(self.num_foci(), dtype=int)
scaled_analysis = self.analyze(options=analysis_options, focus_counts=baseline_focus_counts)
if not hasattr(scaled_analysis, balance_metric):
raise ValueError(f"Unknown balance metric '{balance_metric}'")
balance_metric_values = np.array(getattr(scaled_analysis, balance_metric))
focus_counts = self.compute_balanced_focus_counts(balance_metric_values, self.sequence.pulse_count)
self.sequence.focus_order = self.build_focus_order(focus_counts, ordering=ordering)

def get_pulsetrain_dutycycle(self) -> float:
"""
Compute the pulse train dutycycle given a sequence.
Expand All @@ -471,7 +571,12 @@ def get_sequence_dutycycle(self) -> float:
sequence_duty_cycle = self.get_pulsetrain_dutycycle() * between_pulsetrain_duty_cycle
return sequence_duty_cycle

def get_ita(self, intensity: xa.DataArray | None = None, units: str = "mW/cm^2") -> xa.DataArray:
def get_ita(
self,
intensity: xa.DataArray | None = None,
units: str = "mW/cm^2",
focus_counts: np.ndarray | None = None
) -> xa.DataArray:
"""
Calculate the intensity-time-area product for a treatment solution.

Expand All @@ -480,6 +585,7 @@ def get_ita(self, intensity: xa.DataArray | None = None, units: str = "mW/cm^2")
If provided, use this intensity data array instead of the one from the simulation result.
units: str
Target units. Default "mW/cm^2".
focus_counts: Optional per-focus pulse counts. If not provided, use the sequence focus order.

Returns:
xa.DataArray
Expand All @@ -491,10 +597,14 @@ def get_ita(self, intensity: xa.DataArray | None = None, units: str = "mW/cm^2")
intensity_scaled = rescale_data_arr(self.simulation_result['intensity'], units)
pulsetrain_dutycycle = self.get_pulsetrain_dutycycle()
treatment_dutycycle = self.get_sequence_dutycycle()
pulse_seq = (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1
counts = np.zeros((1, 1, 1, self.num_foci()))
for i in range(self.num_foci()):
counts[0, 0, 0, i] = np.sum(pulse_seq == (i+1))
if focus_counts is None:
focus_counts = self.get_focus_counts()
focus_counts = np.asarray(focus_counts)
if focus_counts.shape != (self.num_foci(),):
raise ValueError(f"Focus counts must have one value per focus ({self.num_foci()})")
if np.any(focus_counts < 0):
raise ValueError("Focus counts must be non-negative")
counts = focus_counts.reshape((1, 1, 1, self.num_foci()))
intensity = intensity_scaled.copy(deep=True)
isppa_avg = np.sum(np.expand_dims(intensity.data, axis=-1) * counts, axis=-1) / np.sum(counts)
intensity.data = isppa_avg * pulsetrain_dutycycle * treatment_dutycycle
Comment on lines +600 to 610
Copy link

Copilot AI May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_ita() currently computes isppa_avg via np.sum(np.expand_dims(intensity.data, axis=-1) * counts, axis=-1) / np.sum(counts). With the current array shapes/dims (intensity has a focal_point_index dimension), this does not actually apply the per-focus weights; it effectively cancels out and leaves the focal_point_index dimension intact. That means focus_counts has no effect on ITA, undermining ISPTA balancing. Consider computing a weighted mean over the focal_point_index dimension (e.g., (intensity_scaled * focus_weights).sum(dim='focal_point_index')) and returning an ITA DataArray without focal_point_index so downstream analysis uses the treatment-averaged intensity.

Copilot uses AI. Check for mistakes.
Expand Down
41 changes: 41 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import logging
from pathlib import Path

import numpy as np
import pytest
import xarray as xa

from openlifu import Protocol, Transducer
from openlifu.bf.focal_patterns import Wheel
Expand All @@ -29,6 +31,11 @@ def example_wheel_pattern() -> Wheel:
return Wheel(num_spokes=6)

def test_to_dict_from_dict(example_protocol: Protocol):
example_protocol.scaling_options = {
"balance_method": "ispta_repeats",
"balance_metric": "mainlobe_ispta_mWcm2",
"ordering": "minimize_repeats",
}
proto_dict = example_protocol.to_dict()
new_protocol = Protocol.from_dict(proto_dict)
assert new_protocol == example_protocol
Expand Down Expand Up @@ -106,3 +113,37 @@ def test_fix_pulse_mismatch(
assert example_protocol.sequence.pulse_count == 2*num_foci
elif on_pulse_mismatch is OnPulseMismatchAction.ROUNDDOWN:
assert example_protocol.sequence.pulse_count == num_foci


def test_calc_solution_skips_pulse_mismatch_when_focus_order_present(
example_protocol: Protocol,
example_transducer: Transducer,
example_session: Session,
mocker
):
"""Test explicit focus_order allows pulse counts that are not divisible by number of foci."""
example_protocol.focal_pattern = Wheel(num_spokes=3)
num_foci = example_protocol.focal_pattern.num_foci()
example_protocol.sequence.pulse_count = 5
example_protocol.sequence.focus_order = [1, 2, 3, 1, 2]
beamform_mock = mocker.patch.object(
example_protocol,
"beamform",
return_value=(np.zeros(len(example_transducer.elements)), np.ones(len(example_transducer.elements))),
)
fix_pulse_mismatch_mock = mocker.patch.object(example_protocol, "fix_pulse_mismatch")

solution, simulation_result_aggregated, solution_analysis = example_protocol.calc_solution(
target=example_session.targets[0],
transducer=example_transducer,
params=xa.Dataset(),
simulate=False,
scale=False,
)

assert solution.sequence.focus_order == [1, 2, 3, 1, 2]
assert solution.sequence.pulse_count == 5
assert beamform_mock.call_count == num_foci
fix_pulse_mismatch_mock.assert_not_called()
assert simulation_result_aggregated is None
assert solution_analysis is None
Loading
Loading