Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 86 additions & 10 deletions tractor/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,12 @@ def _get_iv(self, sky, skyvariance, Nsky, skyderivs, Nsourceparams,
# dimension of the covariance matrix
D = len(umodels[0])
models_cov = np.zeros(shape=(D, imlist[0].data.shape[0], imlist[0].data.shape[1]))

# dictionary for tracking unique, non-overlapping sources in the FIM
uniq_source_dict = {}
# dictionary for tracking overlapping sources
overlap_dict = {}

# source params next
for i, (tim, umods, scale) in enumerate(zip(imlist, umodels, scales)):
mm = np.zeros(tim.shape)
Expand Down Expand Up @@ -435,25 +441,95 @@ def _get_iv(self, sky, skyvariance, Nsky, skyderivs, Nsourceparams,

slc = slice(y0, y0 + uh), slice(x0, x0 + uw)
slc_model = slice(um_start_y, um_start_y+(y1-y0)), slice(um_start_x, um_start_x+(x1-x0))
models_cov[ui][slc] += um.getImage()[slc_model] # add psfs
# dimension key for this model
key = (x0, y0, uh, uw)

# if this source is of the same dimension as any previous ones,
if key in uniq_source_dict:
# the previous matched model index
prev_idx = uniq_source_dict[key][0]
# the difference between the two matched models
diff_map = abs(um.getImage() - umods[prev_idx].getImage())

# check if the PSF models are identical
if np.allclose(diff_map, 0, atol=1e-5):
# no need to do anything if PSF models are identical.
# update the overlap dict
if prev_idx in overlap_dict:
overlap_dict[prev_idx].append(ui)
else:
# register in the overlap dict
overlap_dict[prev_idx] = [ui]

# if the PSF models are differnet though having the same model dimention,
# then treat them as distinct sources
else:
models_cov[ui][slc] += um.getImage()[slc_model]
# register as a separate source
uniq_source_dict[key].append(ui)

else:
# new unique source
models_cov[ui][slc] += um.getImage()[slc_model]
# register as a separate source
uniq_source_dict[key] = [ui]

# remove duplicate sources from models_cov
if overlap_dict:
models_cov = np.delete(models_cov,
np.concatenate(list(overlap_dict.values())),
axis=0)
# new dimension after removing duplicates
new_D = D - len(np.concatenate(list(overlap_dict.values())))

# if no overlaps, keep the original dimensions
else:
new_D = D

# faster implementation of the fisher information matrix ,using Einstein summation
F = np.zeros(shape=(new_D, new_D))

# double check dimension of models_cov for FIM calculation
assert models_cov.shape == (new_D, imlist[0].data.shape[0], imlist[0].data.shape[1])

# construct FIM
F[:new_D, :new_D] = -np.einsum('ijk,ljk->il',
models_cov * ie,
models_cov * ie)

# faster implementation of the fisher information matrix
F = np.zeros(shape=(D,D))
F[:D, :D] = -np.einsum('ijk,ljk->il', models_cov * ie, models_cov * ie)
if np.any(np.isnan(F)) or np.any(np.isinf(F)):
raise ValueError("Fisher matrix contains NaNs or Infs.")

# Calculate covariance matrix by inverting Fisher information matrix
# calculate covariance matrix by inverting FIM

try:
C = np.linalg.inv(F)

except np.linalg.LinAlgError as e:
# Handle the case where F is not invertible
print(f'Error: {e}. F is not invertible!')
C = np.inf + np.zeros_like(F)

# handle the case where F is not invertible, fill in NaNs
C = np.nan + np.zeros_like(F)

# might be a truncated list due to overlapping sources
var = -np.diag(C)
IV[Nsky:] = 1/var # inverse variance.

# reconstruct a variance array with the input number of sources
var_all = np.zeros(shape=D) + np.nan
var_all[np.concatenate(list(uniq_source_dict.values()))] = var

# set any overlapping sources to be nan; if no overlapping, skip
if overlap_dict:

var_all[list(overlap_dict.keys())] = np.nan

# remove any repetitive sources in the uniq source dict that are in the overlap dict
uniq_source_dict = {key: [idx for idx in value if idx not in overlap_dict] for key, value in uniq_source_dict.items()}

IV[Nsky:] = 1/var_all # inverse variance.

## what to be returned here?
## to keep track of overlapping (junk) sources, would be useful to return
## the unique and overlapping source dictionary
## uniq_source_dict, overlap_dict...

return IV

Expand Down