Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c748cd4
Added `ctis.inverters.MartInverter`, an implementation of the Multipl…
roytsmart Apr 14, 2026
ce54305
added tutorial
roytsmart Apr 14, 2026
56f2003
added merit to results
roytsmart Apr 14, 2026
f975000
Updates to the tutorial
roytsmart Apr 14, 2026
f44513d
references
roytsmart Apr 14, 2026
995a310
adding description to the tutorial
roytsmart Apr 14, 2026
4e1ca50
More updates to tutorial
roytsmart Apr 14, 2026
a05f43c
More updates to tutorial
roytsmart Apr 14, 2026
be8cb5c
black
roytsmart Apr 15, 2026
ff852a5
ruff
roytsmart Apr 15, 2026
cfd4ede
default guess
roytsmart Apr 15, 2026
0995514
added testing infrastructure
roytsmart Apr 15, 2026
36caf1f
tutorial
roytsmart Apr 15, 2026
8c2061e
black
roytsmart Apr 15, 2026
3084026
testing
roytsmart Apr 15, 2026
c625921
black
roytsmart Apr 15, 2026
147efef
Added beginnings of mart discussion
roytsmart Apr 24, 2026
d56c0a7
update discussion.
roytsmart Apr 24, 2026
e91d1da
more tweaks to the discussion.
roytsmart Apr 24, 2026
ca10157
refs
roytsmart Apr 24, 2026
453859a
even more changes to discussion
roytsmart Apr 24, 2026
de50aef
compute correlation of residual with predicted images as a function o…
roytsmart Apr 25, 2026
0820fc5
black
roytsmart Apr 25, 2026
0138068
ruff
roytsmart Apr 25, 2026
c6b9425
tests
roytsmart Apr 25, 2026
63934f8
bump named-arrays version
roytsmart Apr 25, 2026
04982c7
coverage
roytsmart Apr 25, 2026
030472d
sphinx link
roytsmart Apr 25, 2026
6bb5ac4
use Pearson's r for now
roytsmart Apr 25, 2026
44741c7
coverage
roytsmart Apr 25, 2026
ead610e
black
roytsmart Apr 25, 2026
658a931
doc tweaks
roytsmart Apr 25, 2026
6a5e048
docs
roytsmart Apr 25, 2026
a84170f
docs
roytsmart Apr 25, 2026
0f570fa
added an iris tutorial
roytsmart Apr 25, 2026
e71aa25
add iris to deps
roytsmart Apr 25, 2026
b59ae2f
try with smaller obs
roytsmart Apr 25, 2026
7c9a7ad
try to use less memory
roytsmart Apr 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ctis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

from . import scenes
from . import instruments
from . import inverters

__all__ = [
"scenes",
"instruments",
"inverters",
]
14 changes: 13 additions & 1 deletion ctis/instruments/_instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,19 @@ def uncertainty(self) -> Callable[[na.ScalarArray], na.ScalarArray]:
for a given number of photons.
"""

@property
@abc.abstractmethod
def channel(self):
"""
Human-readable name of each independent CTIS channel.
"""

@property
@abc.abstractmethod
def axis_channel(self) -> str | tuple[str, ...]:
"""
The logical axis or axes of this instrument corresponding to
the different dispersion magnitudes and angles.
the different CTIS channels.
"""

@property
Expand Down Expand Up @@ -391,6 +398,11 @@ class IdealInstrument(
A grid of wavelength and position coordinates on the sensor plane.
"""

channel: str | na.AbstractScalar = dataclasses.MISSING
"""
Human-readable name of each independent CTIS channel.
"""

axis_channel: str | tuple[str, ...] = dataclasses.MISSING
"""
The logical axis or axes of this instrument corresponding to
Expand Down
5 changes: 4 additions & 1 deletion ctis/instruments/_instruments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@
dispersion = 200 * u.km / u.s
dispersion = (dispersion.to(**AA) - wavelength_rest) / u.pix

angle = na.linspace(0, 360, axis="channel", num=3, endpoint=False)

instrument_ideal = ctis.instruments.IdealInstrument(
area_effective=1 * u.cm**2,
timedelta_exposure=10 * u.s,
plate_scale=2 * u.arcsec / u.pix,
dispersion=dispersion,
angle=na.linspace(0, 360, axis="channel", num=3, endpoint=False),
angle=angle,
wavelength_ref=wavelength_rest,
position_ref=32 * u.pix,
coordinates_scene=coordinates_scene,
coordinates_sensor=coordinates_sensor,
channel=angle,
axis_channel="channel",
axis_wavelength="wavelength",
axis_scene_xy=("scene_x", "scene_y"),
Expand Down
19 changes: 19 additions & 0 deletions ctis/inverters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Inversion algorithms which can reconstruct scenes from observed images."""

from . import merit
from ._results import InversionResult
from ._inverters import AbstractInverter
from ._iterative import (
AbstractIterativeInverter,
MartInverter,
IterativeInversionResult,
)

__all__ = [
"merit",
"AbstractInverter",
"AbstractIterativeInverter",
"MartInverter",
"InversionResult",
"IterativeInversionResult",
]
48 changes: 48 additions & 0 deletions ctis/inverters/_inverters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import abc
import dataclasses
import named_arrays as na
import ctis
from ._results import InversionResult

__all__ = [
"AbstractInverter",
]


@dataclasses.dataclass
class AbstractInverter(
abc.ABC,
):
"""
An interface describing an algorithm which can invert CTIS observations
to yield a reconstruction of the observed scene.
"""

@property
@abc.abstractmethod
def instrument(self) -> ctis.instruments.AbstractInstrument:
"""
A model of a CTIS instrument which transforms the radiance of an observed
scene to photons measured by the sensors.
"""

@abc.abstractmethod
def __call__(
self,
images: na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray],
**kwargs,
) -> InversionResult:
"""
Reconstruct a scene using the observed images.

Parameters
----------
images
The observed images used to calculate the reconstruction.
Must be evaluated on the same coordinates as
:attr:`~ctis.instruments.AbstractInstrument.coordinates_sensor`
attribute of :attr:`instrument`.
kwargs
Additional keyword arguments which can be used by subclass
implementations.
"""
31 changes: 31 additions & 0 deletions ctis/inverters/_inverters_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import abc
import numpy as np
import named_arrays as na
import ctis


class AbstractTestAbstractInverter(
abc.ABC,
):

def test_instrument(self, a: ctis.inverters.AbstractInverter):
result = a.instrument
assert isinstance(result, ctis.instruments.AbstractInstrument)

def test__call__(
self,
a: ctis.inverters.AbstractInverter,
images: na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray],
**kwargs,
) -> ctis.inverters.InversionResult:
result = a(images, **kwargs)

assert isinstance(result, ctis.inverters.InversionResult)

assert result.solution.sum() > 0
assert isinstance(result.success, bool)
assert isinstance(result.message, str)
assert np.all(result.images == images)
assert result.inverter == a

return result
8 changes: 8 additions & 0 deletions ctis/inverters/_iterative/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ._iterative import AbstractIterativeInverter, IterativeInversionResult
from ._mart import MartInverter

__all__ = [
"AbstractIterativeInverter",
"IterativeInversionResult",
"MartInverter",
]
118 changes: 118 additions & 0 deletions ctis/inverters/_iterative/_iterative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import ClassVar
import abc
import dataclasses
import numpy as np
import astropy.units as u
import named_arrays as na
import ctis
from .. import AbstractInverter, InversionResult

__all__ = [
"AbstractIterativeInverter",
"IterativeInversionResult",
]


@dataclasses.dataclass
class AbstractIterativeInverter(
AbstractInverter,
):
"""
An abstract inversion algorithm which reconstructs an observed scene
using iterative methods.

These methods will apply some operation repeatedly until a specified
convergence criteria is met.
"""

axis_iteration: ClassVar[str] = "iteration"
"""The logical axis associated with changing iteration index."""

@property
@abc.abstractmethod
def num_iteration(self) -> int:
"""
The maximum number of iterations to perform.

If convergence is not reached before this number is exceeded,
a warning is raised and an unsuccessful result is returned.
"""

def mean_chi_squared(
self,
images_observed: na.ScalarArray,
images_predicted: na.ScalarArray,
) -> na.ScalarArray:
r"""
Evaluate :math:`\langle \chi^2 \rangle` for each observed/predicted
image pair.

Parameters
----------
images_observed
The actual measured images.
images_predicted
The images predicted by the inversion.
"""

uncertainty = self.instrument.uncertainty(images_predicted)

uncertainty = np.maximum(uncertainty, 1 * u.photon)

return ctis.inverters.merit.mean_chi_squared(
observed=images_observed,
expected=images_predicted,
uncertainty=uncertainty,
axis=self.instrument.axis_sensor_xy,
)

def correlation_residual(
self,
images_observed: na.ScalarArray,
images_predicted: na.ScalarArray,
) -> na.ScalarArray:
"""
Evaluate the correlation between the predicted images and the residual.

Parameters
----------
images_observed
The actual measured images.
images_predicted
The images predicted by the inversion.
"""
return ctis.inverters.merit.correlation_residual(
observed=images_observed,
expected=images_predicted,
axis=self.instrument.axis_sensor_xy,
)


@dataclasses.dataclass
class IterativeInversionResult(
InversionResult,
):
"""The results of an iterative inversion attempt."""

inverter: AbstractIterativeInverter

num_iteration: int
"""The number of iterations performed by the inverter."""

mean_chi_squared: na.ScalarArray
"""The mean chi squared statistic for each iteration."""

correlation_residual: na.ScalarArray
"""
The correlation between the predicted images and the residuals
for each iteration.
"""

@property
def iteration(self) -> na.ScalarArray:
"""The iteration value for each iteration."""
return na.arange(
start=0,
stop=self.num_iteration,
axis=self.inverter.axis_iteration,
)
34 changes: 34 additions & 0 deletions ctis/inverters/_iterative/_iterative_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import ctis
import named_arrays as na
from .._inverters_test import AbstractTestAbstractInverter


class AbstractTestAbstractIterativeInverter(
AbstractTestAbstractInverter,
):

def test__call__(
self,
a: ctis.inverters.AbstractIterativeInverter,
images: na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray],
**kwargs,
) -> ctis.inverters.IterativeInversionResult:

result = super().test__call__(
a=a,
images=images,
**kwargs,
)

axis_iteration = result.inverter.axis_iteration

assert result.iteration.size == result.num_iteration
assert result.mean_chi_squared.shape[axis_iteration] == result.num_iteration
assert result.correlation_residual.shape[axis_iteration] == result.num_iteration

return result

def test_num_iteration(self, a: ctis.inverters.AbstractIterativeInverter):
result = a.num_iteration
assert isinstance(result, int)
assert result > 0
5 changes: 5 additions & 0 deletions ctis/inverters/_iterative/_mart/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._mart import MartInverter

__all__ = [
"MartInverter",
]
Loading
Loading