|
1 | 1 | from dataclasses import dataclass, field |
2 | 2 | from typing import List, Optional, Any, Dict |
3 | 3 | from math import ceil |
| 4 | +import math |
4 | 5 | import random |
5 | 6 | import hashlib |
6 | 7 | from datetime import datetime |
@@ -40,6 +41,7 @@ class TaskSettings: |
40 | 41 |
|
41 | 42 | # --- Trial logic --- |
42 | 43 | conditions: List[str] = field(default_factory=list) |
| 44 | + condition_weights: Any = None |
43 | 45 | block_seed: Optional[List[int]] = None |
44 | 46 |
|
45 | 47 | # --- Seeding strategy --- |
@@ -81,6 +83,65 @@ def set_block_seed(self, seed_base: Optional[int]): |
81 | 83 | rng = random.Random(seed_base) |
82 | 84 | self.block_seed = [rng.randint(0, 99999) for _ in range(self.total_blocks)] |
83 | 85 |
|
| 86 | + def resolve_condition_weights(self) -> list[float] | None: |
| 87 | + """Resolve and validate optional condition weights. |
| 88 | +
|
| 89 | + Returns |
| 90 | + ------- |
| 91 | + list[float] | None |
| 92 | + A weight vector aligned to ``self.conditions`` when |
| 93 | + ``self.condition_weights`` is configured; otherwise ``None`` to |
| 94 | + indicate even/default generation. |
| 95 | + """ |
| 96 | + raw = getattr(self, "condition_weights", None) |
| 97 | + if raw is None: |
| 98 | + return None |
| 99 | + |
| 100 | + if not isinstance(self.conditions, list): |
| 101 | + raise TypeError("conditions must be a list when condition_weights is provided.") |
| 102 | + |
| 103 | + labels = [str(c) for c in self.conditions] |
| 104 | + if not labels: |
| 105 | + raise ValueError("conditions must be non-empty when condition_weights is provided.") |
| 106 | + |
| 107 | + values: list[Any] |
| 108 | + if isinstance(raw, dict): |
| 109 | + keyed = {str(k): v for k, v in raw.items()} |
| 110 | + missing = [label for label in labels if label not in keyed] |
| 111 | + extra = [key for key in keyed if key not in labels] |
| 112 | + if missing: |
| 113 | + raise ValueError(f"condition_weights missing entries for condition(s): {missing}") |
| 114 | + if extra: |
| 115 | + raise ValueError(f"condition_weights contains unknown condition key(s): {extra}") |
| 116 | + values = [keyed[label] for label in labels] |
| 117 | + elif isinstance(raw, (list, tuple)): |
| 118 | + if len(raw) != len(labels): |
| 119 | + raise ValueError( |
| 120 | + "condition_weights length mismatch: expected " |
| 121 | + f"{len(labels)} for conditions {labels}, got {len(raw)}" |
| 122 | + ) |
| 123 | + values = list(raw) |
| 124 | + else: |
| 125 | + raise TypeError("condition_weights must be null, list/tuple, or mapping keyed by condition label.") |
| 126 | + |
| 127 | + weights: list[float] = [] |
| 128 | + for i, value in enumerate(values): |
| 129 | + try: |
| 130 | + w = float(value) |
| 131 | + except Exception as exc: |
| 132 | + raise TypeError( |
| 133 | + f"condition_weights[{i}] could not be parsed as number: {value!r}" |
| 134 | + ) from exc |
| 135 | + if not math.isfinite(w): |
| 136 | + raise ValueError(f"condition_weights[{i}] must be finite, got {w!r}") |
| 137 | + if w <= 0: |
| 138 | + raise ValueError(f"condition_weights[{i}] must be > 0, got {w!r}") |
| 139 | + weights.append(w) |
| 140 | + |
| 141 | + if sum(weights) <= 0: |
| 142 | + raise ValueError(f"condition_weights sum must be > 0, got {weights}") |
| 143 | + return weights |
| 144 | + |
84 | 145 | def add_subinfo(self, subinfo: Dict[str, Any]): |
85 | 146 | """ |
86 | 147 | Add subject-specific information and set seed/file names accordingly. |
|
0 commit comments