Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 136 additions & 10 deletions simSPI/linear_simulator/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,144 @@ def __init__(self, config):
super(Projector, self).__init__()

self.config = config
self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.float32)
lin_coords = torch.linspace(-1.0, 1.0, self.config.side_len)
[x, y, z] = torch.meshgrid(
[
lin_coords,
]
* 3
)
coords = torch.stack([y, x, z], dim=-1)
self.register_buffer("vol_coords", coords.reshape(-1, 3))
self.space = config.space

if self.space == "real":
self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.float32)
lin_coords = torch.linspace(-1.0, 1.0, self.config.side_len)
[x, y, z] = torch.meshgrid(
[
lin_coords,
]
* 3
)
coords = torch.stack([y, x, z], dim=-1)

self.register_buffer("vol_coords", coords.reshape(-1, 3))
elif self.space == "fourier":
# Assume DC coefficient is at self.vol[n//2+1,n//2+1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment could be moved in a docstring, so that it is more visible by users --- when we will publish the documentation website with Sphinx.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relates a bit to my comment in issue #102, we should try to standardize somewhere the conventions for Fourier space representations. Then it doesn't really need to be anywhere... I'm also not sure that this is currently accurate.

# this means that self.vol = fftshift(fft3(fftshift(real_vol)))
self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.complex64)
freq_coords = torch.fft.fftfreq(self.config.side_len, dtype=torch.float32)
[x, y] = torch.meshgrid(
[
freq_coords,
]
* 2
)
coords = torch.stack([y, x], dim=-1)
# Rescale coordinates to [-1,1] to be compatible with
# torch.nn.functional.grid_sample
coords = 2 * coords
self.register_buffer("vol_coords", coords.reshape(-1, 2))
else:
raise NotImplementedError(
f"Space type '{self.space}' " f"has not been implemented!"
)

def forward(self, rot_params, proj_axis=-1):
"""Forward method for projection.

Parameters
----------
rot_params : tensor of rotation matrices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description of rot_params does not strictly follow the docstring convention:

  • if it is a tensor, what is its shape
  • it should be on two lines, where the first line explains the type of datastructure and the second explains what the parameter represents.

This will become important to generate a clean documentation website.

Additionally, this description says that rot_params is a tensor; yet _forward_fourier docstring mentions a dict: which one is true?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this docstring could point to _forward_fourier docstring about the details? I find it well explained there that rot_params is a dictionary that contains a tensor of specific shape.

"""
if self.space == "real":
return self._forward_real(rot_params, proj_axis)

if self.space == "fourier":
if proj_axis != -1:
raise NotImplementedError(
"proj_axis must currently be -1 for Fourier space projection"
)
return self._forward_fourier(rot_params)
raise NotImplementedError(
f"Space type '{self.space}' " f"has not been implemented!"
)

def _forward_fourier(self, rot_params):
"""Output the tomographic projection of the volume in Fourier space.

Take a slide through the Fourier space volume whose normal is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: slice

oriented according to rot_params. The volume is assumed to be cube
represented in the fourier space. The output image follows
(batch x channel x height x width) convention of pytorch. Therefore,
a dummy channel dimension is added at the end to projection.

Parameters
----------
rot_params: dict of type str to {tensor}
Dictionary containing parameters for rotation, with keys
rotmat: str map to tensor
rotation matrix (batch_size x 3 x 3) to rotate the volume

Returns
-------
projection: tensor
Tensor containing tomographic projection in the Fourier domain
(batch_size x 1 x sidelen x sidelen)

Comments
--------
Note that the Fourier volumes are arbitrary
channel x height x width complex valued tensors,
they are not assumed to be Fourier transforms of a real valued 3D functions.

Note that the tomographic projection is interpolated on a rotated 2D grid.
The rotated 2D grid extends outside the boundaries of the 3D grid.
The values outside the boundaries are not defined in a useful way.
Therefore, in most applications, it make sense to apply a radial filter
to the sample.

"""
rotmat = rot_params["rotmat"]
batch_sz = rotmat.shape[0]

rotmat = torch.transpose(rotmat, -1, -2)
rot_vol_coords = self.vol_coords.repeat((batch_sz, 1, 1)).bmm(rotmat[:, :2, :])

# rescale the coordinates to be compatible with the edge alignment of
# torch.nn.functional.grid_sample
if self.config.side_len % 2 == 0: # even case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments even case and odd case might not be necessary as it is evident from the code.

rot_vol_coords = (
(rot_vol_coords + 1)
* (self.config.side_len)
/ (self.config.side_len - 1)
) - 1
else: # odd case
rot_vol_coords = (
(rot_vol_coords) * (self.config.side_len) / (self.config.side_len - 1)
)

projection = torch.empty(
(batch_sz, self.config.side_len, self.config.side_len),
dtype=torch.complex64,
)
# interpolation is decomposed to real and imaginary parts due to torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be put in a docstring?

# grid_sample type rules. Requires data and coordinates of same type.
# padding_mode="reflection" is required due to possible pathologies
# right on the border.
# however, padding_mode="zeros" is what users might expect in most
# cases other than these axis aligned cases.
padding_mode = "zeros"
projection.real = torch.nn.functional.grid_sample(
self.vol.real.repeat((batch_sz, 1, 1, 1, 1)),
rot_vol_coords[:, None, None, :, :],
align_corners=True,
padding_mode=padding_mode,
).reshape(batch_sz, self.config.side_len, self.config.side_len)

projection.imag = torch.nn.functional.grid_sample(
self.vol.imag.repeat((batch_sz, 1, 1, 1, 1)),
rot_vol_coords[:, None, None, :, :],
align_corners=True,
padding_mode=padding_mode,
).reshape(batch_sz, self.config.side_len, self.config.side_len)

projection = projection[:, None, :, :]
return projection

def _forward_real(self, rot_params, proj_axis=-1):
"""Output the tomographic projection of the volume.

First rotate the volume and then sum it along an axis.
Expand Down
Binary file modified tests/data/linear_simulator_data.npy
Binary file not shown.
Binary file modified tests/data/linear_simulator_data_cube.npy
Binary file not shown.
Binary file modified tests/data/projector_data.npy
Binary file not shown.
34 changes: 33 additions & 1 deletion tests/test_projector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test function for projector module."""

import numpy as np
import torch

from simSPI.linear_simulator.projector import Projector

Expand Down Expand Up @@ -63,15 +64,46 @@ def normalized_mse(a, b):
return (a - b).pow(2).sum().sqrt() / a.pow(2).sum().sqrt()


def test_projector():
def test_projector_real():
"""Test accuracy of projector function."""
path = "tests/data/projector_data.npy"

saved_data, config = init_data(path)
config["space"] = "real"
rot_params = saved_data["rot_params"]
projector = Projector(config)
projector.vol = saved_data["volume"]

out = projector(rot_params)
error = normalized_mse(saved_data["projector_output"], out).item()
assert (error < 0.01) == 1


def test_projector_fourier():
"""Test accuracy of projector function.

Note: corrent test only checks that the scaling is compatible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: current

"""
path = "tests/data/projector_data.npy"

saved_data, config = init_data(path)
config["space"] = "fourier"
rot_params = saved_data["rot_params"]
projector = Projector(config)
projector.vol = torch.fft.fftshift(
torch.fft.fftn(torch.fft.fftshift(saved_data["volume"], dim=[-3, -2, -1])),
dim=[-3, -2, -1],
)

sz = projector.vol.shape[0]

out = projector(rot_params)
fft_proj_out = torch.fft.fft2(
torch.fft.fftshift(saved_data["projector_output"], dim=(2, 3))
)

print(out.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid prints in the tests because our logs might become overcrowded when we add more tests: could these become asserts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, maybe these could be replaced with logging, so they could be displayed if really needed?

print("ratio", sz, (fft_proj_out.real / out.real).median())
print("ratio", sz, 1 / (fft_proj_out.real[0, 0, 0, 0] / out.real[0, 0, 0, 0]))
print("ratio", sz, 1 / (fft_proj_out.real[:, 0, 0, 0] / out.real[:, 0, 0, 0]))
assert 0.01 > (fft_proj_out.real[0, 0, 0, 0] / out.real[0, 0, 0, 0] - 1).abs()