Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/test-action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: test

on:
pull_request:
push:
branches:
- master

jobs:
test-linux:
runs-on: ubuntu-latest

strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
uses: actions/setup-python@v2
with:
python-version: 3.7
- name: Install dependencies
run: |
pip install -e ".[testing]"

- name: test
run: |
python -m unittest discover
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ python_requires = >=3.4
install_requires =
torch
itk==5.3rc4

[options.extras_require]
testing =
monai

[options.packages.find]
where = src
Expand Down
59 changes: 54 additions & 5 deletions src/itk_torch_transform_bridge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itk
import torch
import torch.nn.functional as F

def monai_warp_to_itk_transform(image_fixed: "itk.Image", image_moving:"itk.Image", network_shape: "[int]", network: "torch.nn.Module", **kwargs)->"itk.Transform":
tensor_fixed, tensor_moving, convert_back = itk_transform_bridge(image_fixed, image_moving, network_shape, phi_type="displacement_field", order="vector_first", **kwargs)
Expand All @@ -13,18 +14,66 @@ def grid_sample_to_itk_transform(image_fixed: "itk.Image", image_moving:"itk.Ima

return convert_back(phi)

def itk_transform_bridge(image_fixed: "itk.Image", image_moving:"itk.Image", network_shape: "[int]", **kwargs)->"(torch.Tensor, torch.Tensor, Callable[[torch.Tensor], itk.Transform])":
# Convert images to tensors
def itk_transform_bridge(image_fixed: "itk.Image", image_moving:"itk.Image", network_shape: "[int]", phi_type="displacement_field", range=(-1, 1))->"(torch.Tensor, torch.Tensor, Callable[[torch.Tensor], itk.Transform])":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may work for the types, but a cool best practice is to use for forward type declarations like this:

 import itk.support.types as itkt

[...]

def itk_transform_bridge(image_fixed: "itkt.Image", [...]

Examples:

https://github.com/InsightSoftwareConsortium/ITK/blob/c49d6379c575332e88de858cc2cf9529a40be625/Wrapping/Generators/Python/itk/support/extras.py#L1416-L1417

https://github.com/InsightSoftwareConsortium/ITK/blob/c49d6379c575332e88de858cc2cf9529a40be625/Wrapping/Generators/Python/itk/support/extras.py#L1183-L1188

to_network_space = resampling_transform(image_moving, network_shape)
from_network_space = resampling_transform(image_fixed, network_shape).GetInverse()

moving_npy = np.array(image_moving)
fixed_npy = np.array(image_fixed)

# turn images into torch Tensors: add feature and batch dimensions (each of length 1)
moving_trch = torch.Tensor(moving_npy)[None, None]
fixed_trch = torch.Tensor(fixed_npy)[None, None]

# ...

# Here we resize the input images to the shape expected by the neural network. This affects the
# pixel stride as well as the magnitude of the displacement vectors of the resulting displacement field, which
# convert_back will have to compensate for.

#TODO: it is crucial to blur before this step if we are downsampling!
moving_resized = F.interpolate(moving_trch, size=network_shape, mode="trilinear", align_corners=False)
fixed_resized = F.interpolate(fixed_trch, size=network_shape, mode="trilinear", align_corners=False)


# Create convert_back function

def convert_back(phi: "torch.Tensor") -> "itk.Transform":
phi = phi.cpu().detach()

if phi_type == "coordinate_field" and range == (-1, 1):
# itk.DeformationFieldTransform expects a displacement field, so we subtract off the identity map.
disp = (phi - )

dimension = len(network_shape_list)


# We convert the displacement field into an itk Vector Image.
scale = torch.Tensor(network_shape_list)

for _ in network_shape_list:
scale = scale[:, None]
disp *= scale

# disp is a shape [3, H, W, D] tensor with vector components in the order [vi, vj, vk]
disp_itk_format = disp.double().numpy()[list(reversed(range(dimension)))].transpose(list(range(1, dimension + 1)) + [0])
# disp_itk_format is a shape [H, W, D, 3] array with vector components in the order [vk, vj, vi]
# as expected by itk.

itk_disp_field = itk.image_from_array(disp_itk_format, is_vector=True)

deformable_transform = itk.DisplacementFieldTransform[(itk.D, dimension)].New()

deformable_transform.SetDisplacementField(itk_disp_field)

final_transform = itk.CompositeTransform[itk.D, dimension].New()

final_transform.PrependTransform(from_network_space)
final_transform.PrependTransform(deformable_transform)
final_transform.PrependTransform(to_network_space)

return itk.CompositeTransform(some_stuff)
return final_transform

return tensor_fixed, tensor_moving, convert_back
return fixed_resized, moving_resized, convert_back

def resampling_transform(image, shape) -> itk.Transform:

Expand Down
4 changes: 4 additions & 0 deletions test/test_grid_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import itk_torch_transform
import torch
import itk
import torch.nn.functional as F
6 changes: 6 additions & 0 deletions test/test_monai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
import monai
import unittest

class TestMonaiWarp(unittest.TestCase):