Skip to content

Commit 6401e6e

Browse files
committed
Updates to allow correlated errors from var.pkl file
1 parent 17bf5c1 commit 6401e6e

4 files changed

Lines changed: 98 additions & 78 deletions

File tree

pipt/loop/ensemble.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,8 @@ def _org_data_var(self):
448448
self.datavar[i][datatype[j]].append(var_value[c])
449449
else:
450450
self.datavar[i][datatype[j]].append(var_value)
451+
elif datavar[i][datatype[j]][0].lower() == 'emp':
452+
self.datavar[i][datatype[j]].append(datavar[i][datatype[j]][1])
451453
else:
452454
print('\n\033[1;31mERROR: Cannot read data variance from pkl file! The first entry in the pkl file must be either "rel" or "abs"!\033[1;m')
453455
sys.exit()
@@ -541,12 +543,17 @@ def set_observations(self):
541543
# enObs: samples from N(0,Cd)
542544
enObs = cholesky(self.cov_data).T @ np.random.randn(self.cov_data.shape[0], self.ne)
543545
else:
544-
enObs = at.extract_tot_empirical_cov(
545-
self.datavar,
546-
self.assim_index,
547-
self.list_datatypes,
548-
self.ne
549-
)
546+
# Extract assim indices
547+
if isinstance(self.assim_index[1], list):
548+
l_prim = [int(x) for x in self.assim_index[1]]
549+
else:
550+
l_prim = [int(self.assim_index[1])]
551+
552+
# Concatenate datavar in the same manner as aug_obs_pred_data
553+
enObs = np.concatenate(tuple(
554+
self.datavar[el][dat] for el in l_prim for dat in self.list_datatypes
555+
if self.datavar[el][dat] is not None
556+
))
550557

551558
# Screen data if required
552559
if ('screendata' in self.keys_da) and (self.keys_da['screendata'] == 'yes'):
@@ -558,7 +565,7 @@ def set_observations(self):
558565
)
559566

560567
# Center the ensemble of perturbed observed data
561-
enObs = vecObs[:, np.newaxis] - enObs
568+
# enObs = vecObs[:, np.newaxis] - enObs
562569
self.cov_data = np.var(enObs, ddof=1, axis=1)
563570
self.scale_data = np.sqrt(self.cov_data)
564571

pipt/update_schemes/enkf.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ def __init__(self, keys_da, keys_en, sim):
4040

4141
# At the moment, the iterative loop is threated as an iterative smoother an thus we check if assim. indices
4242
# are given as in the Simultaneous loop.
43-
self.check_assimindex_sequential()
43+
self.check_assimindex_simultaneous()
44+
45+
self.assim_index = [self.keys_da['obsname'], self.keys_da['assimindex'][0]]
46+
self.list_datatypes, self.list_act_datatypes = at.get_list_data_types(self.obs_data, self.assim_index)
47+
4448

4549
# Extract no. assimilation steps from MDA keyword in DATAASSIM part of init. file and set this equal to
4650
# the number of iterations pluss one. Need one additional because the iter=0 is the prior run.
@@ -56,11 +60,11 @@ def __init__(self, keys_da, keys_en, sim):
5660
else:
5761
self.trunc_energy = 0.98
5862

59-
self.state_scaling = at.calc_scaling(
60-
self.prior_enX,
61-
self.list_states,
62-
self.prior_info
63-
)
63+
# Get the perturbed observations and observation scaling
64+
self.vecObs, self.enObs = self.set_observations()
65+
self.enObs_conv = deepcopy(self.enObs)
66+
67+
self._ext_scaling()
6468

6569
def calc_analysis(self):
6670
"""
@@ -74,32 +78,32 @@ def calc_analysis(self):
7478
np.concatenate(self.keys_da['assimindex']))]
7579
list_datatypes, list_active_dataypes = at.get_list_data_types(
7680
self.obs_data, assim_index)
77-
if not hasattr(self, 'cov_data'):
78-
self.full_cov_data = at.gen_covdata(
79-
self.datavar, assim_index, list_datatypes)
80-
else:
81-
self.full_cov_data = self.cov_data
82-
83-
#obs_data_vector, pred_data = at.aug_obs_pred_data(
84-
# self.obs_data, self.pred_data, assim_index, list_datatypes)
81+
# if not hasattr(self, 'cov_data'):
82+
# self.full_cov_data = at.gen_covdata(
83+
# self.datavar, assim_index, list_datatypes)
84+
# else:
85+
# self.full_cov_data = self.cov_data
86+
87+
# #obs_data_vector, pred_data = at.aug_obs_pred_data(
88+
# # self.obs_data, self.pred_data, assim_index, list_datatypes)
8589

86-
vecObs, enPred = at.aug_obs_pred_data(
90+
_, enPred = at.aug_obs_pred_data(
8791
self.obs_data,
8892
self.pred_data,
8993
assim_index,
9094
list_datatypes
9195
)
9296

93-
# Generate realizations of the observed data
94-
generator = Cholesky() # Initialize GeoStat class for generating realizations
95-
self.enObs = generator.gen_real(
96-
vecObs,
97-
self.full_cov_data,
98-
self.ne
99-
)
97+
# # Generate realizations of the observed data
98+
# generator = Cholesky() # Initialize GeoStat class for generating realizations
99+
# self.enObs = generator.gen_real(
100+
# vecObs,
101+
# self.full_cov_data,
102+
# self.ne
103+
# )
100104

101105
# Calc. misfit for the initial iteration
102-
data_misfit = at.calc_objectivefun(self.enObs, enPred, self.full_cov_data)
106+
data_misfit = at.calc_objectivefun(self.enObs, enPred, self.scale_data)
103107

104108
# Store the (mean) data misfit (also for conv. check)
105109
self.data_misfit = np.mean(data_misfit)
@@ -119,27 +123,36 @@ def calc_analysis(self):
119123
self.obs_data, self.assim_index)
120124

121125
# Augment observed and predicted data
122-
self.vecObs, self.enPred = at.aug_obs_pred_data(
123-
self.obs_data,
124-
self.pred_data,
125-
self.assim_index,
126-
self.list_datatypes
127-
)
128-
129-
self.cov_data = at.gen_covdata(
130-
self.datavar,
131-
self.assim_index,
132-
self.list_datatypes
133-
)
134-
135-
generator = Cholesky() # Initialize GeoStat class for generating realizations
136-
self.data_random_state = deepcopy(np.random.get_state())
137-
self.enObs, self.scale_data = generator.gen_real(
138-
self.vecObs,
139-
self.cov_data,
140-
self.ne,
141-
return_chol=True
142-
)
126+
if ('emp_cov' in self.keys_da) and (self.keys_da['emp_cov'] == 'yes'):
127+
_, self.enPred = at.aug_obs_pred_data(
128+
self.obs_data,
129+
self.pred_data,
130+
self.assim_index,
131+
self.list_datatypes
132+
)
133+
else:
134+
self.vecObs, self.enPred = at.aug_obs_pred_data(
135+
self.obs_data,
136+
self.pred_data,
137+
self.assim_index,
138+
self.list_datatypes
139+
)
140+
141+
self.cov_data = at.gen_covdata(
142+
self.datavar,
143+
self.assim_index,
144+
self.list_datatypes
145+
)
146+
147+
generator = Cholesky() # Initialize GeoStat class for generating realizations
148+
self.data_random_state = deepcopy(np.random.get_state())
149+
self.enObs, self.scale_data = generator.gen_real(
150+
self.vecObs,
151+
self.cov_data,
152+
self.ne,
153+
return_chol=True
154+
)
155+
143156
self.E = np.dot(self.enObs, self.proj)
144157

145158
if 'localanalysis' in self.keys_da:

pipt/update_schemes/es.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def check_convergence(self):
5252
list_datatypes)
5353

5454
data_misfit = at.calc_objectivefun(
55-
self.full_real_obs_data, pred_data, self.full_cov_data)
55+
self.enObs, pred_data, self.scale_data)
5656
self.data_misfit = np.mean(data_misfit)
5757
self.data_misfit_std = np.std(data_misfit)
5858

pipt/update_schemes/update_methods_ns/approx_update.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -189,34 +189,34 @@ def update(self, enX, enY, enE, **kwargs):
189189

190190
else:
191191

192-
if ('emp_cov' in self.keys_da) and (self.keys_da['emp_cov'] == 'yes'):
192+
# if ('emp_cov' in self.keys_da) and (self.keys_da['emp_cov'] == 'yes'):
193193

194-
# Scale and center the ensemble matrecies: enX and enE
195-
enXcentered = self.scale(enX - np.mean(enX, 1)[:,None], self.state_scaling)
196-
enEcentered = self.scale(enE - np.mean(enE, 1)[:,None], self.scale_data)
197-
198-
Sinv = np.diag(1/Sd)
199-
X0 = Sinv @ Ud.T @ enEcentered
200-
eigval, eigvec = np.linalg.eig(X0 @ X0.T)
201-
202-
# Calculate and scale difference between observations and predictions (residuals)
203-
enRes = self.scale(enE - enY, self.scale_data)
204-
205-
# Compute the update step
206-
X1 = (Ud @ Sinv @ eigvec).T @ enRes
207-
X2 = solve((self.lam + 1) * np.diag(eigval) + np.eye(len(eigval)), X1)
208-
X3 = np.dot(VTd.T, eigvec) @ X2
209-
self.step = np.dot(self.state_scaling[:, None]*enXcentered, X3)
194+
# # Scale and center the ensemble matrecies: enX and enE
195+
# enXcentered = self.scale(enX - np.mean(enX, 1)[:,None], self.state_scaling)
196+
# enEcentered = self.scale(enE - np.mean(enE, 1)[:,None], self.scale_data)
197+
198+
# Sinv = np.diag(1/Sd)
199+
# X0 = Sinv @ Ud.T @ enEcentered
200+
# eigval, eigvec = np.linalg.eig(X0 @ X0.T)
201+
202+
# # Calculate and scale difference between observations and predictions (residuals)
203+
# enRes = self.scale(enE - enY, self.scale_data)
204+
205+
# # Compute the update step
206+
# X1 = (Ud @ Sinv @ eigvec).T @ enRes
207+
# X2 = solve((self.lam + 1) * np.diag(eigval) + np.eye(len(eigval)), X1)
208+
# X3 = np.dot(VTd.T, eigvec) @ X2
209+
# self.step = np.dot(self.state_scaling[:, None]*enXcentered, X3)
210210

211-
else:
212-
enXcentered = self.scale(np.dot(enX, self.proj), self.state_scaling)
213-
enRes = self.scale(enE - enY, self.scale_data)
214-
215-
# Compute the update step
216-
X1 = Ud.T @ enRes
217-
X2 = solve((self.lam + 1)*np.eye(Sd.size) + np.diag(Sd**2), X1)
218-
X3 = VTd.T @ np.diag(Sd) @ X2
219-
self.step = np.dot(self.state_scaling[:, None] * enXcentered, X3)
211+
# else:
212+
enXcentered = self.scale(np.dot(enX, self.proj), self.state_scaling)
213+
enRes = self.scale(enE - enY, self.scale_data)
214+
215+
# Compute the update step
216+
X1 = Ud.T @ enRes
217+
X2 = solve((self.lam + 1)*np.eye(Sd.size) + np.diag(Sd**2), X1)
218+
X3 = VTd.T @ np.diag(Sd) @ X2
219+
self.step = np.dot(self.state_scaling[:, None] * enXcentered, X3)
220220

221221

222222
def scale(self, data, scaling):

0 commit comments

Comments
 (0)