Skip to content

Commit 033d680

Browse files
authored
Merge pull request #32 from OpenPIV/copilot/fix-s2n-validation-logic
Fix GPU S2N validation: shape mismatch crash and inverted threshold logic
2 parents 3b00193 + b5405a1 commit 033d680

3 files changed

Lines changed: 6 additions & 4 deletions

File tree

openpiv/gpu/process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,7 @@ def _validate_fields(self, u, v, dp_u, dp_v):
11461146
if self.num_validation_iters == 0:
11471147
return u, v, val_locations
11481148
if "s2n" in self.validation_method:
1149-
s2n_ratio = self._corr_gpu.s2n_ratio
1149+
s2n_ratio = self._corr_gpu.s2n_ratio.reshape(u.shape)
11501150

11511151
# Create the validation object.
11521152
self._validation_gpu = Validation(

openpiv/gpu/validation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,10 @@ def _s2n_validation(self, s2n_ratio):
448448
return
449449
s2n_tol = log10(self.s2n_tol)
450450

451-
sig2noise_tol = s2n_ratio / DTYPE_f(s2n_tol)
452-
self.val_locations = _local_validation(sig2noise_tol, 1, self.val_locations)
451+
# Mark invalid where s2n_ratio < s2n_tol, i.e. (s2n_tol - s2n_ratio) > 0.
452+
self.val_locations = _local_validation(
453+
DTYPE_f(s2n_tol) - s2n_ratio, 0, self.val_locations
454+
)
453455

454456
def _mask_val_locations(self):
455457
"""Removes masked locations from the validation locations."""

openpiv/test/gpu/test_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test_validation_median_num_validation_locations(validation_gpu, peaks_reshap
205205
def test_validation_s2n_validation(validation_gpu, s2n_ratio):
206206
tol = log10(validation.S2N_TOL)
207207

208-
val_locations = validation._local_validation(s2n_ratio / tol, 1).get()
208+
val_locations = validation._local_validation(DTYPE_f(tol) - s2n_ratio, 0).get()
209209
validation_gpu._s2n_validation(s2n_ratio)
210210
val_locations_gpu = validation_gpu.val_locations.get()
211211

0 commit comments

Comments
 (0)