Skip to content

Commit e608125

Browse files
committed
Require trial context on all participant-visible phases
1 parent ba98ed5 commit e608125

5 files changed

Lines changed: 178 additions & 7 deletions

File tree

psyflow/contracts/v0.1.0/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ These contracts define practical standards for building auditable psyflow/TAPS t
1414
- Runtime entrypoint pattern (`main.py`)
1515
- Trial runtime pattern (`src/run_trial.py`)
1616
- phase/stage labels are task-defined (generic), not MID-specific
17+
- every participant-visible phase should emit `set_trial_context(...)` before `show(...)` or `capture_response(...)`
1718
- participant-facing text localization should be config-driven (avoid hardcoded runtime text)
1819
- Responder/sampler plugin standards (`config_scripted_sim.yaml`, `config_sampler_sim.yaml`, `responders/`)
1920
- README metadata rows

psyflow/contracts/v0.1.0/responder_context.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ fail_on_literal_text_attr_assign: true
3737
notes:
3838
- phase labels are task-specific; validator should not hard-fail on phase naming style.
3939
- run_trial should execute an auditable trial flow and return serializable trial_data.
40+
- every participant-visible phase should emit set_trial_context(...) before show(...) or capture_response(...); non-visible bookkeeping is the only exception.
4041
- participant-facing text should be config-defined for localization portability.

psyflow/contracts/v0.1.0/run_trial_pattern.md

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,19 @@ def run_trial(win, kb, settings, condition, stim_bank, controller, trigger_runti
1010
trial_data = {}
1111

1212
# preparatory phase (task-specific naming)
13-
StimUnit("cue", win, kb, runtime=trigger_runtime) \
14-
.add_stim(stim_bank.get("fixation")) \
15-
.show(duration=settings.cue_duration)
13+
cue = StimUnit("cue", win, kb, runtime=trigger_runtime).add_stim(stim_bank.get("fixation"))
14+
set_trial_context(
15+
cue,
16+
trial_id=1,
17+
phase="cue",
18+
deadline_s=settings.cue_duration,
19+
valid_keys=[],
20+
block_id=block_id,
21+
condition_id=str(condition),
22+
task_factors={"condition": str(condition), "block_idx": block_idx},
23+
stim_id="fixation",
24+
)
25+
cue.show(duration=settings.cue_duration)
1626

1727
# response window phase (name can be choice/decision/offer/probe/...)
1828
choice = StimUnit("choice", win, kb, runtime=trigger_runtime).add_stim(stim_bank.get("choice_screen"))
@@ -29,9 +39,19 @@ def run_trial(win, kb, settings, condition, stim_bank, controller, trigger_runti
2939
choice.capture_response(keys=settings.key_list, duration=1.2)
3040

3141
# outcome/feedback phase
32-
StimUnit("feedback", win, kb, runtime=trigger_runtime) \
33-
.add_stim(stim_bank.get("feedback")) \
34-
.show(duration=1.0)
42+
feedback = StimUnit("feedback", win, kb, runtime=trigger_runtime).add_stim(stim_bank.get("feedback"))
43+
set_trial_context(
44+
feedback,
45+
trial_id=1,
46+
phase="feedback",
47+
deadline_s=1.0,
48+
valid_keys=[],
49+
block_id=block_id,
50+
condition_id=str(condition),
51+
task_factors={"condition": str(condition), "block_idx": block_idx},
52+
stim_id="feedback",
53+
)
54+
feedback.show(duration=1.0)
3555

3656
return trial_data
3757
```
@@ -40,6 +60,7 @@ Key requirements:
4060
- `run_trial(...)` function present
4161
- task flow is auditable with task-specific stage names (no MID-only naming requirement)
4262
- at least one response window uses `set_trial_context(...)` + `capture_response(...)`
63+
- every participant-visible phase/screen emits `set_trial_context(...)` before `show(...)` or `capture_response(...)`
4364
- context includes `trial_id`, `phase`, `deadline_s`, `valid_keys`
4465
- returns serializable trial-level data
4566
- participant-facing labels/text/options are sourced from config stimuli (via `StimBank`) rather than hardcoded in `run_trial.py`

psyflow/templates/cookiecutter-psyflow/{{cookiecutter.project_name}}/references/task_logic_audit.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
## 7. Architecture Decisions (Auditability)
3434

3535
- `main.py` provides one auditable run flow across human/qa/sim.
36-
- `run_trial.py` sets trial context before response capture for simulation auditability.
36+
- `run_trial.py` sets trial context before every participant-visible phase for simulation and plotting auditability.
3737
- Replace template notes with task-specific design decisions and rationale.
3838

3939
## 8. Inference Log

psyflow/validate.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,146 @@ def _check_run_trial_text_localization(text: str, cfg: dict[str, Any]) -> list[s
369369
return issues
370370

371371

372+
def _find_visible_show_without_context(text: str) -> list[str]:
373+
try:
374+
tree = ast.parse(text)
375+
except SyntaxError:
376+
return []
377+
378+
run_trial_fn = None
379+
for node in tree.body:
380+
if isinstance(node, ast.FunctionDef) and node.name == "run_trial":
381+
run_trial_fn = node
382+
break
383+
if run_trial_fn is None:
384+
return []
385+
386+
context_units: set[str] = set()
387+
unit_stims: dict[str, list[str]] = {}
388+
unit_labels: dict[str, str] = {}
389+
warnings: list[str] = []
390+
391+
def _stmt_call(stmt: ast.stmt) -> ast.Call | None:
392+
if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call):
393+
return stmt.value
394+
if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.Call):
395+
return stmt.value
396+
return None
397+
398+
def _kw_expr(call: ast.Call, key: str) -> str:
399+
for kw in call.keywords:
400+
if kw.arg == key:
401+
try:
402+
return ast.unparse(kw.value)
403+
except Exception: # noqa: BLE001
404+
return ""
405+
return ""
406+
407+
def _node_name(node: ast.AST) -> str:
408+
return node.id if isinstance(node, ast.Name) else ""
409+
410+
def _add_stim_exprs(node: ast.AST) -> list[str]:
411+
out: list[str] = []
412+
current = node
413+
while isinstance(current, ast.Call):
414+
if isinstance(current.func, ast.Attribute) and current.func.attr == "add_stim" and current.args:
415+
try:
416+
out.append(ast.unparse(current.args[0]))
417+
except Exception: # noqa: BLE001
418+
pass
419+
current = current.func.value
420+
elif isinstance(current.func, ast.Attribute):
421+
current = current.func.value
422+
else:
423+
break
424+
out.reverse()
425+
return out
426+
427+
def _unit_label_expr(node: ast.AST) -> str:
428+
current = node
429+
while isinstance(current, ast.Call):
430+
if isinstance(current.func, ast.Name) and current.func.id in {"StimUnit", "make_unit"}:
431+
label = _kw_expr(current, "unit_label")
432+
if label:
433+
return label
434+
if current.args:
435+
try:
436+
return ast.unparse(current.args[0])
437+
except Exception: # noqa: BLE001
438+
return ""
439+
return ""
440+
if isinstance(current.func, ast.Attribute):
441+
current = current.func.value
442+
else:
443+
break
444+
return ""
445+
446+
def _walk(stmts: list[ast.stmt]) -> None:
447+
for stmt in stmts:
448+
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
449+
var = stmt.targets[0].id
450+
label = _unit_label_expr(stmt.value)
451+
if label:
452+
unit_labels[var] = label
453+
stim_exprs = _add_stim_exprs(stmt.value)
454+
if stim_exprs:
455+
unit_stims.setdefault(var, []).extend(stim_exprs)
456+
457+
if isinstance(stmt, ast.If):
458+
_walk(stmt.body)
459+
_walk(stmt.orelse)
460+
continue
461+
if isinstance(stmt, (ast.For, ast.AsyncFor, ast.While, ast.With, ast.AsyncWith)):
462+
_walk(stmt.body)
463+
_walk(getattr(stmt, "orelse", []))
464+
continue
465+
if isinstance(stmt, ast.Try):
466+
_walk(stmt.body)
467+
for handler in stmt.handlers:
468+
_walk(handler.body)
469+
_walk(stmt.orelse)
470+
_walk(stmt.finalbody)
471+
continue
472+
473+
call = _stmt_call(stmt)
474+
if call is None:
475+
continue
476+
if isinstance(call.func, ast.Name) and call.func.id == "set_trial_context" and call.args:
477+
unit_var = _node_name(call.args[0])
478+
if unit_var:
479+
context_units.add(unit_var)
480+
continue
481+
if not (isinstance(call.func, ast.Attribute) and call.func.attr == "show"):
482+
continue
483+
484+
base = call.func.value
485+
unit_var = _node_name(base)
486+
stim_exprs = list(unit_stims.get(unit_var, [])) if unit_var else []
487+
for expr in _add_stim_exprs(base):
488+
if expr not in stim_exprs:
489+
stim_exprs.append(expr)
490+
if not stim_exprs:
491+
continue
492+
if unit_var and unit_var in context_units:
493+
continue
494+
495+
label = _unit_label_expr(base)
496+
if not label and unit_var:
497+
label = unit_labels.get(unit_var, "")
498+
token = str(label or unit_var or "show() phase").strip().strip("'\"")
499+
warnings.append(token)
500+
501+
_walk(run_trial_fn.body)
502+
deduped: list[str] = []
503+
seen: set[str] = set()
504+
for item in warnings:
505+
if not item or item in seen:
506+
continue
507+
seen.add(item)
508+
deduped.append(item)
509+
return deduped
510+
511+
372512
def _looks_like_mid_identity(value: Any) -> bool:
373513
low = str(value or "").strip().lower()
374514
if not low:
@@ -1128,6 +1268,12 @@ def _check_responder_context(task_dir: Path, cfg: dict[str, Any]) -> ContractRes
11281268
fails.append(f"Missing required run_trial token: {s}")
11291269

11301270
fails.extend(_check_run_trial_text_localization(text, cfg))
1271+
visible_without_context = _find_visible_show_without_context(text)
1272+
for token in visible_without_context:
1273+
warns.append(
1274+
"Participant-visible phase appears to call show() without preceding set_trial_context(...): "
1275+
f"{token}"
1276+
)
11311277

11321278
req_any = list(cfg.get("required_strings_any") or [])
11331279
if req_any and not any(str(s) in text for s in req_any):
@@ -1256,6 +1402,8 @@ def _check_responder_context(task_dir: Path, cfg: dict[str, Any]) -> ContractRes
12561402

12571403
if fails:
12581404
suggestions.append("Call set_trial_context(...) with required fields before response windows.")
1405+
if visible_without_context:
1406+
suggestions.append("Emit set_trial_context(...) for every participant-visible phase, not only response windows.")
12591407
if warns:
12601408
suggestions.append("Include condition/block/task_factors context for richer simulation and audits.")
12611409
return _result(name, fails, warns, suggestions)

0 commit comments

Comments
 (0)