-
Notifications
You must be signed in to change notification settings - Fork 11
Add focus ordering and ispta balancing #449 #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,6 +76,9 @@ class Protocol: | |
| virtual_fit_options: Annotated[VirtualFitOptions, OpenLIFUFieldData("Virtual fit options", "Configuration of the virtual fit algorithm")] = field(default_factory=VirtualFitOptions) | ||
| """Configuration of the virtual fit algorithm""" | ||
|
|
||
| scaling_options: Annotated[dict, OpenLIFUFieldData("Scaling options", "Options to adjust solution scaling. By default, no additional scaling options are applied")] = field(default_factory=dict) | ||
| """Options to adjust solution scaling. By default, no additional scaling options are applied""" | ||
|
|
||
| def __post_init__(self): | ||
| self.logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -97,6 +100,7 @@ def from_dict(d : Dict[str,Any]) -> Protocol: | |
| if "virtual_fit_options" in d: | ||
| d['virtual_fit_options'] = VirtualFitOptions.from_dict(d['virtual_fit_options']) | ||
| d["analysis_options"] = SolutionAnalysisOptions.from_dict(d.get("analysis_options", {})) | ||
| d["scaling_options"] = d.get("scaling_options", {}) | ||
| return Protocol(**d) | ||
|
|
||
| def to_dict(self): | ||
|
|
@@ -116,6 +120,7 @@ def to_dict(self): | |
| "target_constraints": [tc.to_dict() for tc in self.target_constraints], | ||
| "virtual_fit_options": self.virtual_fit_options.to_dict(), | ||
| "analysis_options": self.analysis_options.to_dict(), | ||
| "scaling_options": self.scaling_options, | ||
| } | ||
|
|
||
| @staticmethod | ||
|
|
@@ -316,8 +321,11 @@ def calc_solution( | |
| simulation_result_aggregated: xa.Dataset = xa.Dataset() | ||
| foci: List[Point] = self.focal_pattern.get_targets(target) | ||
|
|
||
| if self.sequence.focus_order is not None and max(self.sequence.focus_order) > len(foci): | ||
| raise ValueError(f"Focus order index {max(self.sequence.focus_order)} exceeds number of foci ({len(foci)})") | ||
|
|
||
| # updating solution sequence if pulse mismatch | ||
| if (self.sequence.pulse_count % len(foci)) != 0: | ||
| if self.sequence.focus_order is None and (self.sequence.pulse_count % len(foci)) != 0: | ||
| self.fix_pulse_mismatch(on_pulse_mismatch, foci) | ||
| # run simulation and aggregate the results | ||
| for focus in foci: | ||
|
|
@@ -364,14 +372,21 @@ def calc_solution( | |
| raise ValueError(f"Cannot scale solution {solution.id} if simulation is not enabled!") | ||
| self.logger.info(f"Scaling solution {solution.id}...") | ||
| #TODO can analysis be an attribute of solution ? | ||
| solution.scale(self.focal_pattern, analysis_options=analysis_options) | ||
| solution.scale(self.focal_pattern, analysis_options=analysis_options, **self.scaling_options) | ||
|
|
||
| if simulate: | ||
| # Finally the resulting pressure is max-aggregated and intensity is mean-aggregated, over all focus points . | ||
| pnp_aggregated = solution.simulation_result['p_min'].max(dim="focal_point_index", keep_attrs=True) | ||
| ppp_aggregated = solution.simulation_result['p_max'].max(dim="focal_point_index", keep_attrs=True) | ||
| # TODO: Ensure this mean is weighted by the number of times each point is focused on, once openlifu supports hitting points different numbers of times | ||
| intensity_aggregated = solution.simulation_result['intensity'].mean(dim="focal_point_index", keep_attrs=True) | ||
| focus_counts = solution.get_focus_counts() | ||
| focus_weights = xa.DataArray( | ||
| focus_counts / np.sum(focus_counts), | ||
| dims=("focal_point_index",), | ||
| coords={"focal_point_index": solution.simulation_result.coords["focal_point_index"]}, | ||
| ) | ||
| intensity = solution.simulation_result['intensity'] | ||
| intensity_aggregated = (intensity * focus_weights).sum(dim="focal_point_index", keep_attrs=True) | ||
| intensity_aggregated.attrs.update(intensity.attrs) | ||
|
Comment on lines
+381
to
+389
|
||
| simulation_result_aggregated = deepcopy(solution.simulation_result) | ||
| simulation_result_aggregated = simulation_result_aggregated.drop_dims("focal_point_index") | ||
| simulation_result_aggregated['p_min'] = pnp_aggregated | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||||||||
| from __future__ import annotations | ||||||||||||
|
|
||||||||||||
| import base64 | ||||||||||||
| import heapq | ||||||||||||
| import json | ||||||||||||
| import logging | ||||||||||||
| import tempfile | ||||||||||||
|
|
@@ -123,6 +124,8 @@ def __post_init__(self): | |||||||||||
| raise ValueError("Pulse train interval must be greater than or equal to the total pulse interval") | ||||||||||||
| if self.sequence.pulse_train_count <= 0: | ||||||||||||
| raise ValueError("Pulse train count must be positive") | ||||||||||||
| if (self.sequence.focus_order is not None and len(self.foci) > 0 and max(self.sequence.focus_order) > len(self.foci)): | ||||||||||||
| raise ValueError(f"Focus order index {max(self.sequence.focus_order)} exceeds number of foci ({len(self.foci)})") | ||||||||||||
|
Comment on lines
+127
to
+128
|
||||||||||||
| if len(self.foci)>0 and self.delays is not None and self.delays.shape[0] != len(self.foci): | ||||||||||||
| raise ValueError(f"Delays number of foci ({self.delays.shape[0]}) does not match number of foci ({len(self.foci)})") | ||||||||||||
| if len(self.foci)>0 and self.apodizations is not None and self.apodizations.shape[0] != len(self.foci): | ||||||||||||
|
|
@@ -138,6 +141,83 @@ def num_foci(self) -> int: | |||||||||||
| """Get the number of foci""" | ||||||||||||
| return len(self.foci) | ||||||||||||
|
|
||||||||||||
| def get_focus_order(self) -> np.ndarray: | ||||||||||||
| """Get the focus index order for each pulse.""" | ||||||||||||
| if self.sequence.focus_order is not None: | ||||||||||||
| return np.array(self.sequence.focus_order) | ||||||||||||
| return (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1 | ||||||||||||
|
||||||||||||
| return (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1 | |
| num_foci = self.num_foci() | |
| if num_foci == 0: | |
| raise ValueError("Cannot compute default focus order when there are no foci") | |
| return np.arange(self.sequence.pulse_count) % num_foci + 1 |
Copilot
AI
May 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_ita() currently computes isppa_avg via np.sum(np.expand_dims(intensity.data, axis=-1) * counts, axis=-1) / np.sum(counts). With the current array shapes/dims (intensity has a focal_point_index dimension), this does not actually apply the per-focus weights; it effectively cancels out and leaves the focal_point_index dimension intact. That means focus_counts has no effect on ITA, undermining ISPTA balancing. Consider computing a weighted mean over the focal_point_index dimension (e.g., (intensity_scaled * focus_weights).sum(dim='focal_point_index')) and returning an ITA DataArray without focal_point_index so downstream analysis uses the treatment-averaged intensity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calc_solution()only validatesfocus_orderviamax(self.sequence.focus_order) > len(foci). SinceSequence.focus_ordercan be mutated after initialization (bypassingSequence.__post_init__), this misses other invalid states (empty list, wrong length vspulse_count, non-positive indices) andmax([])would crash with an unhelpful exception. Consider performing a full validation here (length, type/positivity, and bounds vslen(foci)) or calling a sharedSequence.validate_focus_order(num_foci=...)helper.