Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/pymegdec/stimulus_nested_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,17 @@ def _expected_outer_participants(outer_path: Path) -> set[int]:
return {int(token.removeprefix("p")) for token in match.group("outer_label").split("-")}


def _row_outer_participants(path: Path, rows: list[dict[str, str]]) -> set[int]:
def _row_participants(path: Path, rows: list[dict[str, str]], *, participant_column: str = "test_participant") -> set[int]:
participants: set[int] = set()
for row_number, row in enumerate(rows, start=2):
raw_participant = row.get("test_participant")
raw_participant = row.get(participant_column)
if raw_participant is None or str(raw_participant).strip() == "":
raise NestedMatrixShardError(f"{path.name} row {row_number} is missing test_participant.")
raise NestedMatrixShardError(f"{path.name} row {row_number} is missing {participant_column}.")
try:
participants.add(int(raw_participant))
except ValueError as exc:
raise NestedMatrixShardError(
f"{path.name} row {row_number} has non-integer test_participant={raw_participant!r}."
f"{path.name} row {row_number} has non-integer {participant_column}={raw_participant!r}."
) from exc
return participants

Expand Down Expand Up @@ -115,13 +115,24 @@ def validate_nested_matrix_shards(
for outer_path in outer_paths:
participant_row_sets: list[tuple[str, Path, set[int]]] = []
if require_complete_outer_participants:
participant_row_sets.append(("outer", outer_path, _row_outer_participants(outer_path, _read_rows(outer_path))))
participant_row_sets.append(("outer", outer_path, _row_participants(outer_path, _read_rows(outer_path))))
for kind in required_kinds:
sidecar_path = _shard_path(outer_path, kind)
if not sidecar_path.exists():
missing.append(f"bundle={bundle} outer={outer_path.name} missing={sidecar_path.name}")
elif require_complete_outer_participants:
participant_row_sets.append((kind, sidecar_path, _row_outer_participants(sidecar_path, _read_rows(sidecar_path))))
participant_column = "outer_test_participant" if kind == "inner_validation" else "test_participant"
participant_row_sets.append(
(
kind,
sidecar_path,
_row_participants(
sidecar_path,
_read_rows(sidecar_path),
participant_column=participant_column,
),
)
)
if require_complete_outer_participants:
expected_participants = _expected_outer_participants(outer_path)
for kind, path, actual_participants in participant_row_sets:
Expand Down
66 changes: 58 additions & 8 deletions tests/test_stimulus_nested_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ def _selected_row(participant: int) -> dict:
}


def _inner_validation_row(outer_participant: int, validation_participant: int, *, candidate_index: int = 1) -> dict:
return {
"outer_test_participant": outer_participant,
"test_participant": validation_participant,
"inner_validation_participant": validation_participant,
"candidate_index": candidate_index,
"accuracy": 0.2,
"balanced_accuracy": 0.2,
"chance_accuracy": 0.0625,
"window_center_s": 0.175,
"window_size_s": 0.1,
"window_start_s": 0.125,
"window_stop_s": 0.225,
"feature_mode": "sensor_flat",
"normalization": "subject_baseline_whiten",
"alignment": "none",
"classifier": "multinomial-logistic",
"classifier_param": 1,
"components_pca": 64,
"max_trials_per_class_per_participant": 10,
"label_shuffle_control": False,
"label_shuffle_seed": 0,
}


def _prediction_rows(participant: int) -> list[dict]:
return [
{
Expand Down Expand Up @@ -133,10 +158,7 @@ def test_strict_validation_rejects_incomplete_participant_chunks(self) -> None:
stem = tmp_path / "nested-matrix-logreg-p1-p2" / "matrix_logreg_p1-p2"
_write_csv(stem.with_name(f"{stem.name}_outer.csv"), [_outer_row(1, balanced=0.1)])
_write_csv(stem.with_name(f"{stem.name}_selected.csv"), [_selected_row(1)])
_write_csv(
stem.with_name(f"{stem.name}_inner_validation.csv"),
[{"test_participant": 1, "candidate_index": 1, "balanced_accuracy": 0.1}],
)
_write_csv(stem.with_name(f"{stem.name}_inner_validation.csv"), [_inner_validation_row(1, 3)])
_write_csv(stem.with_name(f"{stem.name}_predictions.csv"), _prediction_rows(1))

shards = discover_nested_matrix_shards(tmp_path)
Expand All @@ -158,17 +180,45 @@ def test_strict_validation_rejects_incomplete_participant_chunks(self) -> None:
expected_shard_count=1,
)

def test_strict_validation_uses_outer_participant_for_inner_validation_sidecar(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
stem = tmp_path / "nested-matrix-logreg-p1-p2" / "matrix_logreg_p1-p2"
_write_csv(
stem.with_name(f"{stem.name}_outer.csv"),
[_outer_row(1, balanced=0.10), _outer_row(2, balanced=0.20)],
)
_write_csv(stem.with_name(f"{stem.name}_selected.csv"), [_selected_row(1), _selected_row(2)])
_write_csv(
stem.with_name(f"{stem.name}_inner_validation.csv"),
[
_inner_validation_row(1, 3),
_inner_validation_row(2, 3),
],
)
_write_csv(stem.with_name(f"{stem.name}_predictions.csv"), [*_prediction_rows(1), *_prediction_rows(2)])

artifacts = aggregate_nested_matrix_outputs(
tmp_path,
tmp_path / "out",
output_stem="nested_matrix",
signflip_permutations=0,
strict_shards=True,
expected_shard_count=1,
)

self.assertEqual(len(artifacts["outer"]), 2)
self.assertEqual(len(artifacts["inner_validation"]), 2)
self.assertTrue((tmp_path / "out" / "nested_matrix_inner_validation.csv").exists())

def test_aggregates_nested_matrix_shards_and_recomputes_bundle_summary(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
for participant, balanced in [(1, 0.10), (2, 0.20)]:
stem = tmp_path / f"nested-matrix-logreg-p{participant}" / f"matrix_logreg_p{participant}"
_write_csv(stem.with_name(f"{stem.name}_outer.csv"), [_outer_row(participant, balanced=balanced)])
_write_csv(stem.with_name(f"{stem.name}_selected.csv"), [_selected_row(participant)])
_write_csv(
stem.with_name(f"{stem.name}_inner_validation.csv"),
[{"test_participant": participant, "candidate_index": participant, "balanced_accuracy": balanced}],
)
_write_csv(stem.with_name(f"{stem.name}_inner_validation.csv"), [_inner_validation_row(participant, 3)])
_write_csv(stem.with_name(f"{stem.name}_predictions.csv"), _prediction_rows(participant))

artifacts = aggregate_nested_matrix_outputs(
Expand Down