Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions app/entrypoints/blast-data.json
Original file line number Diff line number Diff line change
Expand Up @@ -2010,30 +2010,35 @@
"size": 720000164
},
{
"TODO": "Remove me",
"path": "sbipp/SBI_model.pt",
"version_id": "bGgAnrSKwAqAodKHhuUncVEKLs7M0pt",
"etag": "cbf5690b49ca21cddacae4ddf24f146f-4",
"size": 65007493
},
{
"TODO": "Replace me with data/sbipp/SBI_model_blast_zfix_global.pt",
"path": "sbipp/SBI_model_global.pt",
"version_id": "TqYFph0y083LP9FqdP1bBDpbnQ5G9ye",
"etag": "255f5a086f7e5dbbb322faac659f6673-4",
"size": 65007493
},
{
"TODO": "Replace me with data/sbipp/SBI_model_blast_zfree_global.pt",
"path": "sbipp/SBI_model_local.pt",
"version_id": "FtKKE82y1UkrQ1dhXkFc3HebczNWISP",
"etag": "22b59e91fa47e87f0e688f46e4d31242-4",
"size": 65009239
},
{
"TODO": "Replace with new model",
"path": "sbipp_phot/sbi_phot_global.h5",
"version_id": "TKLSBpdYgHXIg1J8js.XoTKXoRPd8RP",
"etag": "cb5e6fd738694f4da1541e4b1fa717d5-2",
"size": 1024005856
},
{
"TODO": "Replace with new model",
"path": "sbipp_phot/sbi_phot_local.h5",
"version_id": "jZiVHZ6wMlO1PLbhQbr1OA-0H-j8mqI",
"etag": "37a90f8b01643a4699b01d027c34256c-2",
Expand Down
Empty file added app/host/SBI/__init__.py
Empty file.
128 changes: 81 additions & 47 deletions app/host/SBI/run_sbi_blast.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@
}

sbi_params = {
"anpe_fname_global": f"{settings.SBIPP_ROOT}/SBI_model_global.pt", # trained sbi model
"train_fname_global": f"{settings.SBIPP_PHOT_ROOT}/sbi_phot_global.h5", # training set
"anpe_fname_local": f"{settings.SBIPP_ROOT}/SBI_model_local.pt", # trained sbi model
"train_fname_local": f"{settings.SBIPP_PHOT_ROOT}/sbi_phot_local.h5", # training set
#"anpe_fname_global": f"{settings.SBIPP_ROOT}/SBI_model_global.pt", # trained sbi model
#"train_fname_global": f"{settings.SBIPP_PHOT_ROOT}/sbi_phot_global.h5", # training set
#"anpe_fname_local": f"{settings.SBIPP_ROOT}/SBI_model_local.pt", # trained sbi model
#"train_fname_local": f"{settings.SBIPP_PHOT_ROOT}/sbi_phot_local.h5", # training set
"anpe_fname_zspec": f"{settings.SBIPP_ROOT}/SBI_model_blast_zfix_global.pt", # trained sbi model
"train_fname_zspec": f"{settings.SBIPP_PHOT_ROOT}/sbi_phot_blast_zfix_global.h5", # training set
"anpe_fname_zphot": f"{settings.SBIPP_ROOT}/SBI_model_blast_zfree_global.pt", # trained sbi model
"train_fname_zphot": f"{settings.SBIPP_PHOT_ROOT}/sbi_phot_blast_zfree_global.h5", # training set
"nhidden": 500, # architecture of the trained density estimator
"nblocks": 15, # architecture of the trained density estimator
}
Expand Down Expand Up @@ -76,7 +80,7 @@

# training set
def run_training_set():
for _fit_type in ["global", "local"]:
for _fit_type in ["zspec", "zphot"]:
data = h5py.File(sbi_params[f"train_fname_{_fit_type}"], "r")
x_train = np.array(data["theta"]) # physical parameters
y_train = np.array(data["phot"]) # fluxes & uncertainties
Expand Down Expand Up @@ -109,76 +113,104 @@ def run_training_set():
)
)
anpe._x_shape = Ut.x_shape_from_simulation(y_tensor)
if _fit_type == "global":
hatp_x_y_global = anpe.build_posterior(
if _fit_type == "zspec":
hatp_x_y_zspec = anpe.build_posterior(
p_x_y_estimator, sample_with="rejection"
)
y_train_global = y_train[:]
x_train_global = x_train[:]
elif _fit_type == "local":
hatp_x_y_local = anpe.build_posterior(
y_train_zspec = y_train[:]
x_train_zspec = x_train[:]
elif _fit_type == "zphot":
hatp_x_y_zphot = anpe.build_posterior(
p_x_y_estimator, sample_with="rejection"
)
y_train_local = y_train[:]
x_train_local = x_train[:]
y_train_zphot = y_train[:]
x_train_zphot = x_train[:]

print("""Storing training sets as data files...""")
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_global.pkl"), "wb"
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_zspec.pkl"), "wb"
) as handle:
pickle.dump(hatp_x_y_global, handle)
pickle.dump(hatp_x_y_zspec, handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_global.pkl"), "wb"
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_zspec.pkl"), "wb"
) as handle:
pickle.dump(y_train_global, handle)
pickle.dump(y_train_zspec, handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_global.pkl"), "wb"
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_zspec.pkl"), "wb"
) as handle:
pickle.dump(x_train_global, handle)
pickle.dump(x_train_zspec, handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_local.pkl"), "wb"
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_zphot.pkl"), "wb"
) as handle:
pickle.dump(hatp_x_y_local, handle)
pickle.dump(hatp_x_y_zphot, handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_local.pkl"), "wb"
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_zphot.pkl"), "wb"
) as handle:
pickle.dump(y_train_local, handle)
pickle.dump(y_train_zphot, handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_local.pkl"), "wb"
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_zphot.pkl"), "wb"
) as handle:
pickle.dump(x_train_local, handle)

pickle.dump(x_train_zphot, handle)


try:
print("""Loading training sets from data files...""")
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_global.pkl"), "rb"
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_zspec.pkl"), "rb"
) as handle:
hatp_x_y_global = pickle.load(handle)
hatp_x_y_zspec = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_global.pkl"), "rb"
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_zspec.pkl"), "rb"
) as handle:
y_train_global = pickle.load(handle)
y_train_zspec = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_global.pkl"), "rb"
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_zspec.pkl"), "rb"
) as handle:
x_train_global = pickle.load(handle)
x_train_zspec = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_local.pkl"), "rb"
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_zphot.pkl"), "rb"
) as handle:
hatp_x_y_local = pickle.load(handle)
hatp_x_y_zphot = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_local.pkl"), "rb"
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_zphot.pkl"), "rb"
) as handle:
y_train_local = pickle.load(handle)
y_train_zphot = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_local.pkl"), "rb"
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_zphot.pkl"), "rb"
) as handle:
x_train_local = pickle.load(handle)
x_train_zphot = pickle.load(handle)
print("""Training sets loaded.""")
except Exception as err:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly is the failure mode you anticipate here? This seems too broad. Also, it looks like this block is executed upon import (which happens here), because it is not under a "if name=main" condition. Can this be done better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversation follow-up: This redundant pickle file loading code should be refactored into a function.

print(f"""Error loading training sets: {err}. Regenerating...""")
run_training_set()
print("""Loading training sets from data files...""")
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_zspec.pkl"), "rb"
) as handle:
hatp_x_y_zspec = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_zspec.pkl"), "rb"
) as handle:
y_train_zspec = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_zspec.pkl"), "rb"
) as handle:
x_train_zspec = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "hatp_x_y_zphot.pkl"), "rb"
) as handle:
hatp_x_y_zphot = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "y_train_zphot.pkl"), "rb"
) as handle:
y_train_zphot = pickle.load(handle)
with open(
os.path.join(settings.SBI_TRAINING_ROOT, "x_train_zphot.pkl"), "rb"
) as handle:
x_train_zphot = pickle.load(handle)
print("""Training sets loaded.""")


print("""Training sets generated.""")


Expand All @@ -189,7 +221,7 @@ def maggies_to_asinh(x):
return -a * math.asinh((x / 2.0) * np.exp(mu / a)) + mu


def fit_sbi_pp(observations, n_filt_cuts=True, fit_type="global"):
def fit_sbi_pp(observations, n_filt_cuts=True):
np.random.seed(100) # make results reproducible

# toy noise model
Expand Down Expand Up @@ -260,6 +292,8 @@ def fit_sbi_pp(observations, n_filt_cuts=True, fit_type="global"):
"mags_unc"
] = mags_unc ##2.5/np.log(10)*observations['maggies_unc']/observations['maggies']
obs["redshift"] = observations["redshift"]
if observations["redshift"] is not None: fit_type="zspec"
else: fit_type="zphot"
obs["wavelengths"] = wavelengths
obs["filternames"] = filternames

Expand All @@ -268,14 +302,14 @@ def fit_sbi_pp(observations, n_filt_cuts=True, fit_type="global"):
return {}, 1

# prepare to pass the reconstructed model to sbi_pp
if fit_type == "global":
sbi_params["y_train"] = y_train_global
sbi_params["theta_train"] = x_train_global
sbi_params["hatp_x_y"] = hatp_x_y_global
elif fit_type == "local":
sbi_params["y_train"] = y_train_local
sbi_params["hatp_x_y"] = hatp_x_y_local
sbi_params["theta_train"] = x_train_local
if fit_type == "zspec":
sbi_params["y_train"] = y_train_zspec
sbi_params["theta_train"] = x_train_zspec
sbi_params["hatp_x_y"] = hatp_x_y_zspec
elif fit_type == "zphot":
sbi_params["y_train"] = y_train_zphot
sbi_params["hatp_x_y"] = hatp_x_y_zphot
sbi_params["theta_train"] = x_train_zphot

# Run SBI++
chain, obs, flags = sbi_pp.sbi_pp(
Expand Down
77 changes: 44 additions & 33 deletions app/host/SBI/snrfiles/DES_g_magvsnr.txt
Original file line number Diff line number Diff line change
@@ -1,33 +1,44 @@
14.095 1498.834
14.345 1351.754
14.595 1539.793
14.845 1354.735
15.095 1127.177
15.345 825.908
15.595 665.150
15.845 842.000
16.095 728.931
16.345 676.029
16.595 562.362
16.845 488.038
17.095 383.174
17.345 394.547
17.595 349.924
17.845 320.220
18.095 242.899
18.345 192.111
18.595 193.220
18.845 169.726
19.095 142.447
19.345 155.718
19.595 130.075
19.845 111.698
20.095 98.763
20.345 92.329
20.595 75.412
20.845 72.376
21.095 69.339
21.345 56.024
21.595 45.311
21.845 44.467
22.095 32.316
12.756 100.743
13.006 100.797
13.256 100.850
13.506 100.769
13.756 100.600
14.006 100.633
14.256 100.610
14.506 100.451
14.756 100.579
15.006 100.473
15.256 100.087
15.506 98.510
15.756 99.593
16.006 99.552
16.256 99.251
16.506 98.512
16.756 97.598
17.006 97.181
17.256 96.188
17.506 94.882
17.756 93.243
18.006 91.857
18.256 90.140
18.506 88.165
18.756 85.328
19.006 81.921
19.256 78.153
19.506 71.680
19.756 68.049
20.006 65.352
20.256 59.085
20.506 52.060
20.756 49.272
21.006 41.462
21.256 38.490
21.506 33.899
21.756 28.464
22.006 25.879
22.256 20.193
22.506 19.895
22.756 16.265
23.006 14.125
23.256 12.318
23.506 8.555
77 changes: 42 additions & 35 deletions app/host/SBI/snrfiles/DES_r_magvsnr.txt
Original file line number Diff line number Diff line change
@@ -1,35 +1,42 @@
13.220 1380.660
13.470 1515.118
13.720 1440.689
13.970 1075.392
14.220 1171.863
14.470 1122.357
14.720 743.579
14.970 847.772
15.220 701.530
15.470 607.197
15.720 638.522
15.970 581.440
16.220 401.855
16.470 350.811
16.720 509.118
16.970 344.474
17.220 283.598
17.470 204.505
17.720 224.727
17.970 158.058
18.220 173.284
18.470 165.233
18.720 150.001
18.970 109.567
19.220 96.969
19.470 110.299
19.720 74.490
19.970 85.434
20.220 79.614
20.470 73.795
20.720 46.062
20.970 45.486
21.220 41.995
21.470 38.503
21.720 25.180
12.424 100.617
12.674 100.664
12.924 100.712
13.174 100.388
13.424 100.662
13.674 100.623
13.924 100.486
14.174 100.477
14.424 100.136
14.674 99.826
14.924 99.986
15.174 99.668
15.424 99.002
15.674 98.421
15.924 98.065
16.174 95.832
16.424 95.031
16.674 94.200
16.924 94.663
17.174 92.153
17.424 90.064
17.674 86.884
17.924 84.744
18.174 81.741
18.424 78.850
18.674 73.872
18.924 72.352
19.174 65.074
19.424 58.806
19.674 54.362
19.924 51.035
20.174 40.640
20.424 38.705
20.674 32.043
20.924 33.475
21.174 28.748
21.424 23.027
21.674 19.599
21.924 16.171
22.174 14.494
22.424 12.818
22.674 11.141
Loading