Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 86 additions & 13 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,9 +2056,16 @@ def _constraints(

obj_6D = copy_to_device(self.object_6D, device="cpu")

obj_6D = gaussian_filter(
obj_6D, real_space_gaussian_filter, axes=(0, 1, 2)
) # axes only supported in cpu
# obj_6D = gaussian_filter(
# obj_6D, real_space_gaussian_filter, axes=(0, 1, 2)
# ) # axes only supported in cpu
for i in range(obj_6D.shape[3]):
for j in range(obj_6D.shape[4]):
for k in range(obj_6D.shape[5]):
obj_6D[:, :, :, i, j, k] = gaussian_filter(
obj_6D[:, :, :, i, j, k],
sigma=real_space_gaussian_filter
)

self._object = copy_to_device(
obj_6D.reshape((s[0], s[1] * s[2], s[3] * s[4] * s[5])), device=storage
Expand Down Expand Up @@ -2255,7 +2262,7 @@ def widget(
**kwargs,
):
""" """
from ipywidgets import HBox, VBox, widgets, interact, Dropdown, Label, Layout
from ipywidgets import HBox, VBox, widgets, interact, Dropdown, Label, Layout, widgets
from skimage.feature import peak_local_max
from scipy.ndimage import gaussian_filter
from py4DSTEM.visualize import return_scaled_histogram_ordering
Expand Down Expand Up @@ -2293,6 +2300,30 @@ def widget(
vmax=vmax,
)

# Buttons to set view angles
button_xy = widgets.Button(description="XY view", layout=Layout(width="100px"))
button_xz = widgets.Button(description="XZ view", layout=Layout(width="100px"))
button_yz = widgets.Button(description="YZ view", layout=Layout(width="100px"))

def on_click_xy(b):
ax2.view_init(elev=90, azim=-90) # Top-down
fig.canvas.draw_idle()

def on_click_xz(b):
ax2.view_init(elev=0, azim=-90) # Front-on
fig.canvas.draw_idle()

def on_click_yz(b):
ax2.view_init(elev=0, azim=0) # Side-on
fig.canvas.draw_idle()

button_xy.on_click(on_click_xy)
button_xz.on_click(on_click_xz)
button_yz.on_click(on_click_yz)

view_buttons = widgets.HBox([button_xy, button_xz, button_yz])


# %matplotlib ipympl

with plt.ioff():
Expand All @@ -2301,6 +2332,10 @@ def widget(
ax1 = fig.add_subplot(1, 3, 2)
ax2 = fig.add_subplot(1, 3, 3, projection="3d")

from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import axes3d
ax2.set_proj_type('ortho')

x = obj_6D.shape[0] // 2
y = obj_6D.shape[1] // 2
z = obj_6D.shape[2] // 2
Expand Down Expand Up @@ -2370,6 +2405,7 @@ def widget(
ax2.set_xlim([0, obj_6D.shape[3]])
ax2.set_ylim([0, obj_6D.shape[4]])
ax2.set_zlim([0, obj_6D.shape[5]])
set_axes_equal(ax2)

plt.tight_layout()

Expand Down Expand Up @@ -2474,6 +2510,8 @@ def update_images(
ax1.set_title("xz")
ax2.set_title("Diffraction")

set_axes_equal(ax2)

plt.tight_layout()

fig.canvas.draw_idle()
Expand Down Expand Up @@ -2580,16 +2618,51 @@ def update_images(
fig.canvas.layout.height = "400px"
fig.canvas.toolbar_position = "bottom"

widget = widgets.VBox(
[
fig.canvas,
HBox([x, y]),
HBox([z, gaussian_filter_diffraction]),
HBox([minimum_threshold, scale_intensities]),
HBox([intensities_power, block_center]),
],
)
widget = widgets.VBox([
fig.canvas,
view_buttons,
HBox([x, y]),
HBox([z, gaussian_filter_diffraction]),
HBox([minimum_threshold, scale_intensities]),
HBox([intensities_power, block_center]),
])

# widget = widgets.VBox(
# [
# fig.canvas,
# HBox([x, y]),
# HBox([z, gaussian_filter_diffraction]),
# HBox([minimum_threshold, scale_intensities]),
# HBox([intensities_power, block_center]),
# ],
# )

display(widget)

return self

def set_axes_equal(ax):
"""Set 3D plot axes to equal scale (for matplotlib >= 3.3)."""
xlim = ax.get_xlim3d()
ylim = ax.get_ylim3d()
zlim = ax.get_zlim3d()

# Calculate ranges and midpoints
x_range, x_middle = np.ptp(xlim), np.mean(xlim)
y_range, y_middle = np.ptp(ylim), np.mean(ylim)
z_range, z_middle = np.ptp(zlim), np.mean(zlim)

# Set max range
plot_radius = 0.5 * max([x_range, y_range, z_range])

ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])

# Force equal aspect
try:
ax.set_box_aspect([1, 1, 1]) # Requires matplotlib >= 3.3
except AttributeError:
pass