Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion py4DSTEM/io/datastructure/py4dstem/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
Q_pixel_size: Optional[Union[float,list]] = 1,
Q_pixel_units: Optional[Union[str,list]] = 'pixels',
slicelabels: Optional[Union[bool,list]] = None,
calibration: Optional = None,
calibration: Optional[Calibration] = None,
stack_pointer = None,
):
"""
Accepts:
Expand Down Expand Up @@ -100,6 +101,12 @@ def __init__(
self.tree['calibration'].set_Q_pixel_size( Q_pixel_size )
self.tree['calibration'].set_Q_pixel_units( Q_pixel_units )

# Add attribute of stack pointer for Dask related stuff
# Tacking this here for now
# this can also be used as a quick check for
self.stack_pointer = stack_pointer





Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/io/datastructure/py4dstem/datacube_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,11 +777,15 @@ def find_Bragg_disks(
CUDA = False,
CUDA_batched = True,
distributed = None,
dask = True,
dask_params = None,

_qt_progress_bar = None,

name = 'braggvectors',
returncalc = True,

**kwargs
):
"""
Finds the Bragg disks by cross correlation with `template`.
Expand Down Expand Up @@ -879,6 +883,7 @@ def find_Bragg_disks(
processing
if distributed is None, which is the default, processing will be in
serial
dask (dict): if not None ... TODO
_qt_progress_bar (QProgressBar instance): used only by the GUI for serial
execution
name (str): name for the output BraggVectors
Expand Down Expand Up @@ -924,8 +929,11 @@ def find_Bragg_disks(
CUDA = CUDA,
CUDA_batched = CUDA_batched,
distributed = distributed,
dask = dask,
dask_params = dask_params,

_qt_progress_bar = _qt_progress_bar,
**kwargs
)


Expand Down
13 changes: 12 additions & 1 deletion py4DSTEM/io/native/legacy/read_v0_12.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Reader for py4DSTEM v0.12 files

from inspect import stack
import h5py
import numpy as np
from os.path import splitext, exists

import dask.array as da

from py4DSTEM.io.native.legacy.read_utils import is_py4DSTEM_file, get_py4DSTEM_topgroups, get_py4DSTEM_version, version_is_geq
from py4DSTEM.io.native.legacy.read_utils_v0_12 import get_py4DSTEM_dataobject_info
from py4DSTEM.io.datastructure import DataCube
Expand All @@ -12,6 +16,7 @@
from py4DSTEM.io.datastructure import PointListArray
from py4DSTEM import tqdmnd


def read_v0_12(fp, **kwargs):
"""
File reader for files written by py4DSTEM v0.12. Precise behavior is detemined by which
Expand Down Expand Up @@ -287,8 +292,14 @@ def get_datacube_from_grp(g,mem='RAM',binfactor=1,bindtype=None):
elif (mem, binfactor) == ("MEMMAP", 1):
data = g['data']
stack_pointer = None
elif (mem, binfactor) == ("DASK", 1):
stack_pointer = g['data']
shape = g['data'].shape

data = da.from_array(stack_pointer, chunks=(1,1,shape[2], shape[3]))

name = g.name.split('/')[-1]
return DataCube(data=data,name=name)
return DataCube(data=data,name=name, stack_pointer=stack_pointer)


def get_diffractionslice_from_grp(g):
Expand Down
3 changes: 1 addition & 2 deletions py4DSTEM/process/diskdetection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
from py4DSTEM.process.diskdetection.braggvectormap import *

#from .diskdetection_aiml import *
#from .diskdetection_parallel_new import *

from py4DSTEM.process.diskdetection.diskdetection_parallel_new import *
41 changes: 31 additions & 10 deletions py4DSTEM/process/diskdetection/diskdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from scipy.ndimage import gaussian_filter


from py4DSTEM.io.datastructure.py4dstem import DataCube, QPoints, BraggVectors
from py4DSTEM.process.utils.get_maxima_2D import get_maxima_2D
from py4DSTEM.process.utils.cross_correlate import get_cross_correlation_FT
Expand All @@ -13,6 +14,7 @@




def find_Bragg_disks(
data,
template,
Expand All @@ -34,9 +36,11 @@ def find_Bragg_disks(
CUDA = False,
CUDA_batched = True,
distributed = None,
dask : bool = False,
dask_params : dict = None,

_qt_progress_bar = None,
):
**kws):
"""
Finds the Bragg disks in the diffraction patterns represented by `data` by
cross/phase correlatin with `template`.
Expand All @@ -53,10 +57,10 @@ def find_Bragg_disks(
and returns a instance or length N list of instances of QPoints

For disk detection on a full DataCube, the calculation can be performed
on the CPU, GPU or a cluster. By default the CPU is used. If `CUDA` is set
to True, tries to use the GPU. If `CUDA_batched` is also set to True,
batches the FFT/IFFT computations on the GPU. For distribution to a cluster,
distributed must be set to a dictionary, with contents describing how
on the CPU, GPU, or using dask or ipyparallel. By default the CPU is used.
If `CUDA` is set to True, tries to use the GPU. If `CUDA_batched` is also set
to True, batches the FFT/IFFT computations on the GPU. For distribution to a
cluster, distributed must be set to a dictionary, with contents describing how
distributed processing should be performed - see below for details.


Expand Down Expand Up @@ -141,6 +145,9 @@ def find_Bragg_disks(
processing
if distributed is None, which is the default, processing will be in
serial
dask (dict): if not None, indictates dask should be used. Must then be a
dictionary with arguments to pass to the dask detection function.
Valid arguments are (...). See docstring for (...) for details.
_qt_progress_bar (QProgressBar instance): used only by the GUI for serial
execution

Expand All @@ -153,6 +160,8 @@ def find_Bragg_disks(
- a (DataCube,rx,ry) 3-tuple, returns a list of QPoints
instances
"""
# TODO add checks about ensuring Dask and Cuda aren't both passed i.e. ensure user knows
# behaviour

# parse args

Expand Down Expand Up @@ -196,11 +205,13 @@ def find_Bragg_disks(
mode = 'dc_GPU'
else:
mode = 'dc_GPU_batched'
elif dask:
mode = 'dc_dask'
else:
x = _parse_distributed(distributed)
connect, data_file, cluster_path, distributed_mode = x
if distributed_mode == 'dask':
mode = 'dc_dask'
mode = 'dc_dask_old'
elif distributed_mode == 'ipyparallel':
mode = 'dc_ipyparallel'
else:
Expand All @@ -222,6 +233,9 @@ def find_Bragg_disks(
kws['connect'] = connect
kws['data_file'] = data_file
kws['cluster_path'] = cluster_path
# dask kwargs
if dask_params is not None:
kws.update(dask_params)

# run and return
ans = fn(
Expand All @@ -243,7 +257,8 @@ def find_Bragg_disks(
return ans



# TODO add extra skeleton func which imports betaparallel and returns it if added dask_cuda
# TODO add MLAI at some point
def _get_function_dictionary():

d = {
Expand All @@ -252,14 +267,19 @@ def _get_function_dictionary():
"dc_CPU" : _find_Bragg_disks_CPU,
"dc_GPU" : _find_Bragg_disks_CUDA_unbatched,
"dc_GPU_batched" : _find_Bragg_disks_CUDA_batched,
"dc_dask" : _find_Bragg_disks_dask,
"dc_dask_old" : _find_Bragg_disks_dask,
# "dc_dask" : beta_parallel_disk_detection,
"dc_dask" : place_holder,

"dc_ipyparallel" : _find_Bragg_disks_ipp,
}

return d



# TODO change the name to something better
def place_holder():
from .diskdetection_parallel_new import beta_parallel_disk_detection
return beta_parallel_disk_detection


# Single diffraction pattern
Expand Down Expand Up @@ -721,6 +741,7 @@ def _parse_distributed(distributed):

elif "dask" in distributed:
mode = 'dask'
print(type(distributed))
if "client" in distributed["dask"]:
connect = distributed["dask"]["client"]
else:
Expand Down
57 changes: 40 additions & 17 deletions py4DSTEM/process/diskdetection/diskdetection_parallel_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
from dask import delayed
import dask
#import dask.bag as db
from py4DSTEM.io.datastructure import PointListArray, PointList
from py4DSTEM.process.diskdetection.diskdetection import _find_Bragg_disks_single_DP_FK


from py4DSTEM.io.datastructure.py4dstem import DataCube, QPoints, BraggVectors, PointListArray, PointList

from py4DSTEM.process.diskdetection.diskdetection import _find_Bragg_disks_single

from py4DSTEM.io import PointListArray, PointList, datastructure
import time
from dask.diagnostics import ProgressBar
Expand Down Expand Up @@ -50,7 +54,6 @@ def register_dill_serializer():
register_serialization_family('dill', dill_dumps, dill_loads)
return None


#### END OF SERAILISERS ####


Expand All @@ -62,7 +65,7 @@ def register_dill_serializer():
# TODO add ML-AI version
def _find_Bragg_disks_single_DP_FK_dask_wrapper(arr, *args,**kwargs):
# THis is needed as _find_Bragg_disks_single_DP_FK takes 2D array these arrays have the wrong shape
return _find_Bragg_disks_single_DP_FK(arr[0,0], *args, **kwargs)
return _find_Bragg_disks_single(arr[0,0], *args, **kwargs)


#### END OF DASK WRAPPER FUNCTIONS ####
Expand All @@ -75,7 +78,7 @@ def _find_Bragg_disks_single_DP_FK_dask_wrapper(arr, *args,**kwargs):

def beta_parallel_disk_detection(dataset,
probe,
#rxmin=None, # these would allow selecting a sub section
#rxmin=None, # these would allow selecting a sub section # probably not a useful case
#rxmax=None,
#rymin=None,
#rymax=None,
Expand Down Expand Up @@ -125,13 +128,21 @@ def beta_parallel_disk_detection(dataset,
# ... dask stuff.
#TODO add assert statements and other checks. Think about reordering opperations

## adding assert statement to make sure peaks not passed as a keyword argument
assert 'peaks' not in kwargs, "peaks must not be passed as a keyword arguement"

# Check to see if a dask client has been passed.
# if no client passed
if dask_client == None:
# check if parameters are passed create a cluster, and pass them to dask client.
if dask_client_params !=None:

dask.config.set({'distributed.worker.memory.spill': False,
'distributed.worker.memory.target': False})
cluster = LocalCluster(**dask_client_params)
dask_client = Client(cluster, **dask_client_params)

# if no parameters are passed create them with some default values
else:
# AUTO MAGICALLY SET?
# LET DASK SET?
Expand All @@ -154,8 +165,10 @@ def beta_parallel_disk_detection(dataset,
pass


# Probe stuff
#### Probe stuff
# check that the probe shape is correct.
assert (probe.shape == dataset.data.shape[2:]), "Probe and Diffraction Pattern Shapes are Mismatched"

if probe_type != "FT":
#TODO clean up and pull out redudant parts
#if probe.dtype != (np.complex128 or np.complex64 or np.complex256):
Expand Down Expand Up @@ -192,7 +205,7 @@ def beta_parallel_disk_detection(dataset,
# loop over the dataset_delayed and create a delayed function of
for x in np.ndindex(dataset_delayed.shape):
temp = delayed(_find_Bragg_disks_single_DP_FK_dask_wrapper)(dataset_delayed[x],
probe_kernel_FT=dask_probe_delayed[0,0],
template=dask_probe_delayed[0,0],
#probe_kernel_FT=delayed_probe_kernel_FT,
*args, **kwargs) #passing through args from earlier or should I use
#corrPower=corrPower,
Expand All @@ -207,28 +220,38 @@ def beta_parallel_disk_detection(dataset,

output = dask_client.gather(_temp_peaks) # gather the future objects

coords = [('qx',float),('qy',float),('intensity',float)]
peaks = PointListArray(coordinates=coords, shape=dataset.data.shape[:-2])

#temp_peaks[0][0]
dtype = [('qx',float),('qy',float),('intensity',float)]
peaks = PointListArray(dtype=dtype, shape=dataset.data.shape[:-2])


# operating over a list so we need the size (0->count) and re-create the probe positions (0->rx,0->ry),
# count is the size of the list
for (count,(rx, ry)) in zip([i for i in range(dataset.data[...,0,0].size)],np.ndindex(dataset.data.shape[:-2])):
#peaks.get_pointlist(rx, ry).add_pointlist(temp_peaks[0][count])
#peaks.get_pointlist(rx, ry).add_pointlist(output[count][0])
peaks.get_pointlist(rx, ry).add_pointlist(output[count])
peaks.get_pointlist(rx, ry).add(output[count])


# create a BraggVectors obj
braggvectors = BraggVectors(dataset.Rshape, dataset.Qshape)
# populate the uncalibrated object with the
braggvectors._v_uncal = peaks


# TODO Remove ability to return the clinet
# TODO RE-VISIT IF NEEDED TO RETURN

# Clean up
# Clean up dask related stuff
dask_client.cancel(_temp_peaks) # removes from the dask workers
del _temp_peaks # deletes the object
if close_dask_client:
dask_client.close()
return peaks
return braggvectors
elif close_dask_client == False and return_dask_client == True:
return peaks, dask_client
return braggvectors, dask_client
elif close_dask_client and return_dask_client == False:
return peaks
return braggvectors
else:
print('Dask Client in unknown state, this may result in unpredicitable behaviour later')
return peaks
return braggvectors

15 changes: 15 additions & 0 deletions py4DSTEM/test/dask/diskdetection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Test dask disk detection functionality

# Devices use cases:
# - local machine
# - cluster

# Storage use case:
# - as dask array
# - as mem map
# - in RAM

# Future cases:
# - GPU + dask


11 changes: 11 additions & 0 deletions py4DSTEM/test/dask/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Test dask i/o functionality

# Cases:
# - load a datacube "normally", i.e. into memory, and then convert it to a dask array
# - load a datacube directly from .h5 to a mapped dask array
# - load a datacube into a numpy memmap, and then work on that as a dask array





Loading