Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
697 changes: 697 additions & 0 deletions STL_main/DT1_Test/Test_StlOperator.ipynb

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions STL_main/DT1_Test/Test_WaveletTransform.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions STL_main/DataType1.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,21 +423,22 @@ def DT1_cov_func(array1, Fourier1, array2, Fourier2,

if mask is None and Fourier1 and Fourier2:
# Compute covariance (complex values)
cov = torch.mean(array1 * array2.conj(), dim=(-2, -1)).real
cov = torch.mean(array1 * array2.conj(), dim=(-2, -1))

else:
# We pass everything to real space
if Fourier1:
_array1 = torch.fft.ifft2(array1, norm="ortho").real
_array1 = torch.fft.ifft2(array1, norm="ortho")
else:
_array1 = array1
if Fourier2:
_array2 = torch.fft.ifft2(array2, norm="ortho").real
_array2 = torch.fft.ifft2(array2, norm="ortho")
else:
_array2 = array2
# Define mask
mask = 1 if mask is None else mask
# Compute covariance (complex values)
cov = torch.mean(_array1 * _array2 * mask, dim=(-2, -1))
cov = torch.mean(_array1 * _array2.conj() * mask, dim=(-2, -1))


return cov

Expand Down Expand Up @@ -715,7 +716,6 @@ def DT1_wavelet_conv(data, wavelet_j, Fourier, mask_MR):
- Fourier: bool
Fourier status of the convolution (True in this DT)
"""

# Pass data in Fourier if in real space
_data = data if Fourier else torch.fft.fft2(data)

Expand All @@ -733,4 +733,4 @@ def DT1_wavelet_conv(data, wavelet_j, Fourier, mask_MR):
###############################################################################

def DT1_subsampling_func_fromMR(param):
pass
pass
198 changes: 97 additions & 101 deletions STL_main/ST_Operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import numpy as np
from copy import deepcopy, copy
from StlData import StlData
from WaveletTransform import Wavelet_Operator
from ST_Statistics import ST_Statistics
Expand Down Expand Up @@ -97,18 +98,12 @@ class ST_Operator:
########################################
def __init__(self,
DT, N0, J=None, L=None, WType=None,
SC="scat_cov", jmin=None, jmax=None, dj=None,
SC="ScatCov", jmin=None, jmax=None, dj=None,
pbc=True, mask=None,
norm="S2", S2_ref=None, iso=False,
angular_ft=False, scale_ft=False,
flatten=False,
wavelet_array = None,
wavelet_array_MR = None,
dg_max = None,
j_to_dg = None,
Single_Kernel = None,
mask_st=None,
mask_opt = False,
mask_MR = None
):
'''
Expand Down Expand Up @@ -333,11 +328,12 @@ def apply(self, data,
Nb, Nc = bk.shape(data.array, 0), bk.shape(data.array, 1)

# Create a ST_statistics instance
data_st = ST_Statistics(self.DT, N0, J, L, WType,
SC, jmin, jmax, dj,
pbc, mask_MR if pass_mask else None,
Nb, Nc, self.wavelet_op)
print(vars(data_st))
data_st = ST_Statistics(
self.DT, N0, J, L, WType,
SC, jmin, jmax, dj,
pbc, mask_MR if pass_mask else None,
Nb, Nc
)
# Define the mask for conv computation if necessary
if not pbc:
mask_bc = self.construct_mask_bc(mask_MR)
Expand All @@ -349,8 +345,14 @@ def apply(self, data,
if self.SC == "ScatCov":
data_st.S1 = bk.zeros((Nb,Nc,J,L))
data_st.S2 = bk.zeros((Nb,Nc,J,L))
data_st.S3 = bk.zeros((Nb,Nc,J,J,L,L)) + bk.nan
data_st.S4 = bk.zeros((Nb,Nc,J,J,J,L,L)) + bk.nan
data_st.S3_1 = bk.zeros((Nb,Nc,J,J,L,L)) + bk.nan
data_st.S4_1 = bk.zeros((Nb,Nc,J,J,J,L,L,L)) + bk.nan
data_st.S3_2 = bk.zeros((Nb,Nc,J,J,L,L)) + bk.nan
data_st.S4_2 = bk.zeros((Nb,Nc,J,J,J,L,L,L)) + bk.nan
data_st.S3_3 = bk.zeros((Nb,Nc,J,J,L,L)) + bk.nan
data_st.S4_3 = bk.zeros((Nb,Nc,J,J,J,L,L,L)) + bk.nan
data_st.S3_4 = bk.zeros((Nb,Nc,J,J,L,L)) + bk.nan
data_st.S4_4 = bk.zeros((Nb,Nc,J,J,J,L,L,L)) + bk.nan

########################################
# ST coefficients computation
Expand All @@ -363,135 +365,129 @@ def apply(self, data,

### !! WARNING: simple code break here !! ###
# Should probably be segmented in subfunction.

########################################
### Version vanilla ###

# Vanilla version uses the following form for S3 and S4
# S3 = Cov(|I*psi1|*psi2, I*psi2)
# S4 = Cov(|I*psi1|*psi3, |I*psi2|*psi3)

# Compute first convolution and modulus

# --- Compute first convolution and modulus ---
data_l1 = self.wavelet_op.apply(data) #(Nb,Nc,J,L,N)
data_l1m = data_l1.modulus_func(copy=True) #(Nb,Nc,J,L,N)

# Compute S1 and S2
# --- Compute S1 and S2 ---
data_st.S1 = data_l1m.mean_func() #(Nb,Nc,J,L)
data_st.S2 = data_l1m.mean_func(square=True) #(Nb,Nc,J,L)

################################
### Higher order computation ###
# for j3 in range(J):
# # Scale smaller-eq to j3 whose [I*psi| will be convolved at j3
# data_l1.array = data_l1.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
# data_l1m.array = data_l1m.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
# # Downsample at Nj3
# data_l1.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
# data_l1m.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)

# # Compute |I*psi2|*psi3 #(Nb,Nc,j3+1,L2,L3,N3)
# data_l1m_l2 = self.wavelet_op.apply(data_l1m, j3)
################################

# --- Version vanilla ---
# Vanilla version uses the following form for S3 and S4
# S3 = Cov(|I*psi1|*psi2, I*psi2)
# S4 = Cov(|I*psi1|*psi3, |I*psi2|*psi3)

for j3 in range(J):
# Scale smaller-eq to j3 whose [I*psi| will be convolved at j3
# data_l1.array = data_l1.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
# data_l1m.array = data_l1m.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
data_l1_tmp = copy(data_l1) #(Nb,Nc,j3+1,L,N)
data_l1m_tmp = copy(data_l1m)
data_l1_tmp.array = data_l1_tmp.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
data_l1m_tmp.array = data_l1m_tmp.array[:,:,:j3+1]

# Downsample at Nj3
data_l1_tmp.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
data_l1m_tmp.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)

# for j2 in range(j3+1):
# # S3(j2,j3) = Cov(|I*psi2|*psi3, I*psi3)
# data_st.S3[:,:,j2,j3,:,:] = StlData.cov_func(
# data_l1m_l2[:,:,j2], #(Nb,Nc,L2,L3,N3)
# data_l1[:,:,j3,None] #(Nb,Nc, 1,L3,N3)
# )

# for j1 in range(j2+1):
# # S4(j1,j2,j3) = Cov(|I*psi1|*psi3, |I*psi2|*psi3)
# data_st.S4[:,:,j1,j2,j3,:,:,:] = StlData.cov_func(
# data_l1m_l2[:,:,j1,:,None], #(Nb,Nc,L1, 1,L3,N3)
# data_l1m_l2[:,:,j2,None,:] #(Nb,Nc, 1,L2,L3,N3)
# )
# Compute |I*psi2|*psi3 #(Nb,Nc,j3+1,L2,L3,N3)
data_l1m_l2 = self.wavelet_op.apply(data_l1m_tmp, j=j3)

for j2 in range(j3+1):
# S3(j2,j3) = Cov(|I*psi2|*psi3, I*psi3)
data_st.S3_1[:,:,j2,j3,:,:] = data_l1m_l2[:,:,j2].cov_func(data_l1_tmp[:,:,j3, None]) #(Nb,Nc,L2,L3,N3) x (Nb,Nc,1,L3,N3)

for j1 in range(j2+1):
# S4(j1,j2,j3) = Cov(|I*psi1|*psi3, |I*psi2|*psi3)
data_st.S4_1[:,:,j1,j2,j3,:,:,:] = data_l1m_l2[:,:,j1,:,None].cov_func(data_l1m_l2[:,:,j2,None,:]) #(Nb,Nc,L1, 1,L3,N3) x (Nb,Nc, 1,L2,L3,N3)


### Alternative Higher order computation, version batchée ###
# --- Version batchée ---

for j3 in range(J):
# Scale smaller-eq to j3 whose [I*psi| will be convolved at j3
data_l1.array = data_l1.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
data_l1m.array = data_l1m.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
# Downsample at Nj3
data_l1.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
data_l1m.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
# data_l1.array = data_l1.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
# data_l1m.array = data_l1m.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
data_l1_tmp = copy(data_l1) #(Nb,Nc,j3+1,L,N)
data_l1m_tmp = copy(data_l1m)
data_l1_tmp.array = data_l1_tmp.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
data_l1m_tmp.array = data_l1m_tmp.array[:,:,:j3+1]

# Downsample at Nj3
# data_l1.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
# data_l1m.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
data_l1_tmp.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
data_l1m_tmp.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)

# Compute |I*psi2|*psi3 #(Nb,Nc,j3+1,L2,L3,N3)
data_l1m_l2 = self.wavelet_op.apply(data_l1m, j3)
data_l1m_l2 = self.wavelet_op.apply(data_l1m_tmp, j=j3)

# S3(j2,j3) = Cov(|I*psi2|*psi3, I*psi3)
data_st.S3[:,:,:j3+1,j3,:,:] = StlData.cov_func(
data_l1m_l2[:,:,:j3+1], #(Nb,Nc,j3+1,L2,L3,N3)
data_l1[:,:,j3,None,None] #(Nb,Nc, 1, 1,L3,N3)
)

for j2 in range(j2+1):
data_st.S3_2[:,:,:j3+1,j3,:,:] = data_l1m_l2[:,:,:j3+1].cov_func(data_l1_tmp[:,:,j3,None,None]) # (Nb,Nc,j3+1,L2,L3,N3) x (Nb,Nc, 1, 1,L3,N3)

for j2 in range(j3+1):
# S4(j1,j2,j3) = Cov(|I*psi1|*psi3, |I*psi2|*psi3)
data_st.S4[:,:,:j2+1,j2,j3,:,:,:] = StlData.cov_func(
data_l1m_l2[:,:,:j2+1,:,None], #(Nb,Nc,j2+1,L1, 1,L3,N3)
data_l1m_l2[:,:,j2,None,None,:] #(Nb,Nc, 1, 1,L2,L3,N3)
)
data_st.S4_2[:,:,:j2+1,j2,j3,:,:,:] = data_l1m_l2[:,:,:j2+1,:,None].cov_func(data_l1m_l2[:,:,j2,None,None,:]) # (Nb,Nc,j2+1,L1, 1,L3,N3) x (Nb,Nc, 1, 1,L2,L3,N3)

########################################
### Version Sihao (adaptée) ###

# Vanilla version uses the following form for S3 and S4
# --- Version Sihao (adaptée) ---
# This version uses the following form for S3 and S4
# S3 = Cov(|I*psi1|, I*psi2)
# S4 = Cov(|I*psi1|, |I*psi2|*psi3)

# Compute first convolution and modulus
data_l1m = self.wavelet_op.apply(data).modulus_func() #(Nb,Nc,J,L,N)

# Compute S1 and S2
data_st.S1 = data_l1m.mean_func() #(Nb,Nc,J,L)
data_st.S2 = data_l1m.mean_func(square=True) #(Nb,Nc,J,L)

### Higher order computation ###

for j3 in range(J):
# Scale smaller-eq to j3 whose [I*psi| will be convolved at j3
data_l1m.data = data_l1m.data[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
# data_l1m.array = data_l1m.array[:,:,:j3+1] #(Nb,Nc,j3+1,L,N)
# # Downsample at Nj3
# data.downsample(j_to_dg[j3]) #(Nb,Nc,N3)
# data_l1m.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
data_tmp = copy(data)
data_l1m_tmp = copy(data_l1m)
data_l1m_tmp.array = data_l1m_tmp.array[:,:,:j3+1]

# Downsample at Nj3
data.downsample(j_to_dg[j3]) #(Nb,Nc,N3)
data_l1m.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)
data_tmp.downsample(j_to_dg[j3])
data_l1m_tmp.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)

for j2 in range(j3+1):
# Compute |I*psi2|*psi3 #(Nb,Nc,L2,L3,N3)
data_l1m_l2 = self.wavelet_op.apply(data_l1m[:,:,j2], j3)
data_l1m_l2 = self.wavelet_op.apply(data_l1m_tmp[:,:,j2], j3)

# S3(j2,j3) = Cov(I, |I*psi2|*psi3)
data_st.S3[:,:,j2,j3,:,:] = StlData.cov_func(
data[:,:,None,None], #(Nb,Nc, 1, 1,N3)
data_l1m_l2 #(Nb,Nc,L2,L3,N3)
)
data_st.S3_3[:,:,j2,j3,:,:] = data_tmp[:,:,None,None].cov_func(data_l1m_l2) #(Nb,Nc, 1, 1,N3) x (Nb,Nc,L2,L3,N3)

for j1 in range(j2+1):
# S4(j1,j2,j3) = Cov(|I*psi1|, |I*psi2|*psi3)
data_st.S4[:,:,j1,j2,j3,:,:,:] = StlData.cov_func(
data_l1m[:,:,j1,:,None,None], #(Nb,Nc,L1, 1, 1,N3)
data_l1m_l2[:,:,None,:,:] #(Nb,Nc, 1,L2,L3,N3)
)
data_st.S4_3[:,:,j1,j2,j3,:,:,:] = data_l1m_tmp[:,:,j1,:,None,None].cov_func(data_l1m_l2[:,:,None,:,:]) #(Nb,Nc,L1, 1, 1,N3) x (Nb,Nc, 1,L2,L3,N3)


### Alternative higher order computation, version batchée ###
# --- Version Sihao (adaptée) + batchée ---

for j3 in range(J):
# Scale smaller-eq to j3 whose [I*psi| will be convolved at j3
data_l1m.data = data_l1m.data[:,:,:j3+1] #(Nb,Nc,:j3+1,L,N)
data_tmp = copy(data)
data_l1m_tmp = copy(data_l1m)
data_l1m_tmp.array = data_l1m_tmp.array[:,:,:j3+1]

# Downsample at Nj3
data.downsample(j_to_dg[j3]) #(Nb,Nc,N3)
data_l1m.downsample(j_to_dg[j3]) #(Nb,Nc,:j3+1,L,N3)
data_tmp.downsample(j_to_dg[j3])
data_l1m_tmp.downsample(j_to_dg[j3]) #(Nb,Nc,j3+1,L,N3)

# Compute |I*psi2|*psi3 #(Nb,Nc,:j3+1,L2,L3,N3)
data_l1m_l2 = self.wavelet_op.apply(data_l1m, j3)
data_l1m_l2 = self.wavelet_op.apply(data_l1m_tmp, j3)

# S3(j2,j3) = Cov(|I*psi2|*psi3, I*psi3)
data_st.S3[:,:,:j3+1,j3,:,:] = StlData.cov_func(
data[:,:,None,None,None], #(Nb,Nc, 1, 1, 1,N3)
data_l1m_l2[:,:,:j3+1] #(Nb,Nc,j3+1,L2,L3,N3)
)
data_st.S3_4[:,:,:j3+1,j3,:,:] = data_tmp[:,:,None,None,None].cov_func(data_l1m_l2[:,:,:j3+1]) # (Nb,Nc, 1, 1, 1,N3) x (Nb,Nc,j3+1,L2,L3,N3)

for j2 in range(j3+1):
# S4(j1,j2,j3) = Cov(|I*psi1|, |I*psi2|*psi3)
data_st.S4[:,:,:j2+1,j2,j3,:,:,:] = StlData.cov_func(
data_l1m[:,:,:j2+1,:,None,None], #(Nb,Nc,j2+1,L1 1, 1,N3)
data_l1m_l2[:,:,j2,None,None], #(Nb,Nc, 1, 1,L2,L3,N3)
)
data_st.S4_4[:,:,:j2+1,j2,j3,:,:,:] = data_l1m_tmp[:,:,:j2+1,:,None,None].cov_func(data_l1m_l2[:,:,j2,None,None]) # (Nb,Nc,j2+1,L1 1, 1,N3) x (Nb,Nc, 1, 1,L2,L3,N3)

########################################
# Additional transform/compression
Expand Down
11 changes: 4 additions & 7 deletions STL_main/ST_Statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,19 @@ def __init__(self,
DT, N0, J, L, WType,
SC, jmin, jmax, dj,
pbc, mask_MR,
Nb, Nc, wavelet_op):
Nb, Nc):
'''
Constructor, see details above.
'''

# Main parameters
self.DT = DT
self.N0 = N0

# Wavelet operator
self.wavelet_op = wavelet_op

# Wavelet transform related parameters
self.J = self.wavelet_op.J
self.L = self.wavelet_op.L
self.WType = self.wavelet_op.WType
self.J = J
self.L = L
self.WType = WType

# Scattering transform related parameters
self.SC = SC
Expand Down
2 changes: 1 addition & 1 deletion STL_main/WaveletTransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def plot(self, Fourier=None):
# To be done

###########################################################################
def apply(self, data, MR=None, j=None, mask_MR=None, O_Fourier=None):
def apply(self, data, j=None, MR=None, mask_MR=None, O_Fourier=None):

'''
Compute the Wavelet Transform (WT) of data.
Expand Down
Binary file added STL_main/__pycache__/DataType1.cpython-313.pyc
Binary file not shown.
Binary file added STL_main/__pycache__/DataType2.cpython-313.pyc
Binary file not shown.
Binary file added STL_main/__pycache__/ST_Operator.cpython-313.pyc
Binary file not shown.
Binary file not shown.
Binary file added STL_main/__pycache__/StlData.cpython-313.pyc
Binary file not shown.
Binary file not shown.
Binary file added STL_main/__pycache__/backend.cpython-313.pyc
Binary file not shown.