-
Notifications
You must be signed in to change notification settings - Fork 10
Feature/ft projector (draft) #124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
a25197e
ba72105
ada366e
e854004
062d0dd
fde4f09
de7c746
22b004f
92644f4
6096255
d8de646
f511945
e8f57ea
6e1e537
eb6df3e
ffea6ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,18 +21,144 @@ def __init__(self, config): | |
| super(Projector, self).__init__() | ||
|
|
||
| self.config = config | ||
| self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.float32) | ||
| lin_coords = torch.linspace(-1.0, 1.0, self.config.side_len) | ||
| [x, y, z] = torch.meshgrid( | ||
| [ | ||
| lin_coords, | ||
| ] | ||
| * 3 | ||
| ) | ||
| coords = torch.stack([y, x, z], dim=-1) | ||
| self.register_buffer("vol_coords", coords.reshape(-1, 3)) | ||
| self.space = config.space | ||
|
|
||
| if self.space == "real": | ||
| self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.float32) | ||
| lin_coords = torch.linspace(-1.0, 1.0, self.config.side_len) | ||
| [x, y, z] = torch.meshgrid( | ||
| [ | ||
| lin_coords, | ||
| ] | ||
| * 3 | ||
| ) | ||
| coords = torch.stack([y, x, z], dim=-1) | ||
|
|
||
| self.register_buffer("vol_coords", coords.reshape(-1, 3)) | ||
| elif self.space == "fourier": | ||
| # Assume DC coefficient is at self.vol[n//2+1,n//2+1] | ||
| # this means that self.vol = fftshift(fft3(fftshift(real_vol))) | ||
| self.vol = torch.rand([self.config.side_len] * 3, dtype=torch.complex64) | ||
| freq_coords = torch.fft.fftfreq(self.config.side_len, dtype=torch.float32) | ||
| [x, y] = torch.meshgrid( | ||
| [ | ||
| freq_coords, | ||
| ] | ||
| * 2 | ||
| ) | ||
| coords = torch.stack([y, x], dim=-1) | ||
| # Rescale coordinates to [-1,1] to be compatible with | ||
| # torch.nn.functional.grid_sample | ||
| coords = 2 * coords | ||
| self.register_buffer("vol_coords", coords.reshape(-1, 2)) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Space type '{self.space}' " f"has not been implemented!" | ||
| ) | ||
|
|
||
| def forward(self, rot_params, proj_axis=-1): | ||
| """Forward method for projection. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| rot_params : tensor of rotation matrices | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The description of
This will become important to generate a clean documentation website. Additionally, this description says that rot_params is a tensor; yet
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this docstring could point to |
||
| """ | ||
| if self.space == "real": | ||
| return self._forward_real(rot_params, proj_axis) | ||
|
|
||
| if self.space == "fourier": | ||
| if proj_axis != -1: | ||
| raise NotImplementedError( | ||
| "proj_axis must currently be -1 for Fourier space projection" | ||
| ) | ||
| return self._forward_fourier(rot_params) | ||
| raise NotImplementedError( | ||
| f"Space type '{self.space}' " f"has not been implemented!" | ||
| ) | ||
|
|
||
| def _forward_fourier(self, rot_params): | ||
| """Output the tomographic projection of the volume in Fourier space. | ||
|
|
||
| Take a slide through the Fourier space volume whose normal is | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo: |
||
| oriented according to rot_params. The volume is assumed to be cube | ||
| represented in the fourier space. The output image follows | ||
| (batch x channel x height x width) convention of pytorch. Therefore, | ||
| a dummy channel dimension is added at the end to projection. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| rot_params: dict of type str to {tensor} | ||
| Dictionary containing parameters for rotation, with keys | ||
| rotmat: str map to tensor | ||
| rotation matrix (batch_size x 3 x 3) to rotate the volume | ||
|
|
||
| Returns | ||
| ------- | ||
| projection: tensor | ||
| Tensor containing tomographic projection in the Fourier domain | ||
| (batch_size x 1 x sidelen x sidelen) | ||
|
|
||
| Comments | ||
| -------- | ||
| Note that the Fourier volumes are arbitrary | ||
| channel x height x width complex valued tensors, | ||
| they are not assumed to be Fourier transforms of a real valued 3D functions. | ||
|
|
||
| Note that the tomographic projection is interpolated on a rotated 2D grid. | ||
| The rotated 2D grid extends outside the boundaries of the 3D grid. | ||
| The values outside the boundaries are not defined in a useful way. | ||
| Therefore, in most applications, it make sense to apply a radial filter | ||
| to the sample. | ||
|
|
||
| """ | ||
| rotmat = rot_params["rotmat"] | ||
| batch_sz = rotmat.shape[0] | ||
|
|
||
| rotmat = torch.transpose(rotmat, -1, -2) | ||
| rot_vol_coords = self.vol_coords.repeat((batch_sz, 1, 1)).bmm(rotmat[:, :2, :]) | ||
|
|
||
| # rescale the coordinates to be compatible with the edge alignment of | ||
| # torch.nn.functional.grid_sample | ||
| if self.config.side_len % 2 == 0: # even case | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comments |
||
| rot_vol_coords = ( | ||
| (rot_vol_coords + 1) | ||
| * (self.config.side_len) | ||
| / (self.config.side_len - 1) | ||
| ) - 1 | ||
| else: # odd case | ||
| rot_vol_coords = ( | ||
| (rot_vol_coords) * (self.config.side_len) / (self.config.side_len - 1) | ||
| ) | ||
|
|
||
| projection = torch.empty( | ||
| (batch_sz, self.config.side_len, self.config.side_len), | ||
| dtype=torch.complex64, | ||
| ) | ||
| # interpolation is decomposed to real and imaginary parts due to torch | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be put in a docstring? |
||
| # grid_sample type rules. Requires data and coordinates of same type. | ||
| # padding_mode="reflection" is required due to possible pathologies | ||
| # right on the border. | ||
| # however, padding_mode="zeros" is what users might expect in most | ||
| # cases other than these axis aligned cases. | ||
| padding_mode = "zeros" | ||
| projection.real = torch.nn.functional.grid_sample( | ||
| self.vol.real.repeat((batch_sz, 1, 1, 1, 1)), | ||
| rot_vol_coords[:, None, None, :, :], | ||
| align_corners=True, | ||
| padding_mode=padding_mode, | ||
| ).reshape(batch_sz, self.config.side_len, self.config.side_len) | ||
|
|
||
| projection.imag = torch.nn.functional.grid_sample( | ||
| self.vol.imag.repeat((batch_sz, 1, 1, 1, 1)), | ||
| rot_vol_coords[:, None, None, :, :], | ||
| align_corners=True, | ||
| padding_mode=padding_mode, | ||
| ).reshape(batch_sz, self.config.side_len, self.config.side_len) | ||
|
|
||
| projection = projection[:, None, :, :] | ||
| return projection | ||
|
|
||
| def _forward_real(self, rot_params, proj_axis=-1): | ||
| """Output the tomographic projection of the volume. | ||
|
|
||
| First rotate the volume and then sum it along an axis. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| """Test function for projector module.""" | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from simSPI.linear_simulator.projector import Projector | ||
|
|
||
|
|
@@ -63,15 +64,46 @@ def normalized_mse(a, b): | |
| return (a - b).pow(2).sum().sqrt() / a.pow(2).sum().sqrt() | ||
|
|
||
|
|
||
| def test_projector(): | ||
| def test_projector_real(): | ||
| """Test accuracy of projector function.""" | ||
| path = "tests/data/projector_data.npy" | ||
|
|
||
| saved_data, config = init_data(path) | ||
| config["space"] = "real" | ||
| rot_params = saved_data["rot_params"] | ||
| projector = Projector(config) | ||
| projector.vol = saved_data["volume"] | ||
|
|
||
| out = projector(rot_params) | ||
| error = normalized_mse(saved_data["projector_output"], out).item() | ||
| assert (error < 0.01) == 1 | ||
|
|
||
|
|
||
| def test_projector_fourier(): | ||
| """Test accuracy of projector function. | ||
|
|
||
| Note: corrent test only checks that the scaling is compatible. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo: current |
||
| """ | ||
| path = "tests/data/projector_data.npy" | ||
|
|
||
| saved_data, config = init_data(path) | ||
| config["space"] = "fourier" | ||
| rot_params = saved_data["rot_params"] | ||
| projector = Projector(config) | ||
| projector.vol = torch.fft.fftshift( | ||
| torch.fft.fftn(torch.fft.fftshift(saved_data["volume"], dim=[-3, -2, -1])), | ||
| dim=[-3, -2, -1], | ||
| ) | ||
|
|
||
| sz = projector.vol.shape[0] | ||
|
|
||
| out = projector(rot_params) | ||
| fft_proj_out = torch.fft.fft2( | ||
| torch.fft.fftshift(saved_data["projector_output"], dim=(2, 3)) | ||
| ) | ||
|
|
||
| print(out.dtype) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should avoid prints in the tests because our logs might become overcrowded when we add more tests: could these become asserts?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, maybe these could be replaced with |
||
| print("ratio", sz, (fft_proj_out.real / out.real).median()) | ||
| print("ratio", sz, 1 / (fft_proj_out.real[0, 0, 0, 0] / out.real[0, 0, 0, 0])) | ||
| print("ratio", sz, 1 / (fft_proj_out.real[:, 0, 0, 0] / out.real[:, 0, 0, 0])) | ||
| assert 0.01 > (fft_proj_out.real[0, 0, 0, 0] / out.real[0, 0, 0, 0] - 1).abs() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment could be moved in a docstring, so that it is more visible by users --- when we will publish the documentation website with Sphinx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This relates a bit to my comment in issue #102, we should try to standardize somewhere the conventions for Fourier space representations. Then it doesn't really need to be anywhere... I'm also not sure that this is currently accurate.