|
| 1 | +from typing import Literal |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import numpy.typing as npt |
| 5 | + |
| 6 | +from python_mg._lib_name import Lexicon |
| 7 | + |
| 8 | + |
| 9 | +def grammar_f1( |
| 10 | + preds: npt.NDArray[np.float64], |
| 11 | + correct: npt.NDArray[np.bool], |
| 12 | +) -> dict[str, npt.NDArray[np.float64]]: |
| 13 | + if preds.shape != correct.shape: |
| 14 | + raise ValueError("correct and preds must have matching shapes") |
| 15 | + |
| 16 | + precision: npt.NDArray[np.float64] = np.exp( # pyright: ignore[reportAny] |
| 17 | + np.logaddexp.reduce( |
| 18 | + np.where(correct, preds, -np.inf), axis=-1 |
| 19 | + ) # pyright: ignore[reportAny] |
| 20 | + ) |
| 21 | + |
| 22 | + total_bad: npt.NDArray[np.float64] = ( # pyright: ignore[reportAny] |
| 23 | + np.logaddexp.reduce(np.where(~correct, preds, -np.inf), axis=-1, keepdims=True) |
| 24 | + ) |
| 25 | + better_than_bad = np.where(np.where(correct, preds, -np.inf) > total_bad, 1.0, 0.0) |
| 26 | + |
| 27 | + recall = np.where(correct, better_than_bad, 0.0).sum( # pyright: ignore[reportAny] |
| 28 | + axis=-1 |
| 29 | + ) / correct.sum(axis=-1) |
| 30 | + |
| 31 | + return { |
| 32 | + "f1": (2 * precision * recall) / (precision + recall), |
| 33 | + "precision": precision, |
| 34 | + "recall": recall, |
| 35 | + } |
| 36 | + |
| 37 | + |
| 38 | +def grammar_f1_from_strings( |
| 39 | + lexicon: Lexicon, |
| 40 | + tokens: npt.NDArray[np.int_], |
| 41 | + preds: npt.NDArray[np.float64], |
| 42 | + category: str, |
| 43 | + min_log_prob: float | None = -128.0, |
| 44 | + move_prob: float = 0.5, |
| 45 | + max_steps: int | None = 64, |
| 46 | + n_beams: int | None = 256, |
| 47 | + reduction: Literal["none", "sentence_mean"] = "sentence_mean", |
| 48 | +) -> dict[str, npt.NDArray[np.float64]]: |
| 49 | + conts = lexicon.token_continuations( |
| 50 | + tokens, |
| 51 | + category, |
| 52 | + min_log_prob=min_log_prob, |
| 53 | + move_prob=move_prob, |
| 54 | + max_steps=max_steps, |
| 55 | + n_beams=n_beams, |
| 56 | + )[..., :-1, :] |
| 57 | + |
| 58 | + d = grammar_f1(preds, conts) |
| 59 | + |
| 60 | + if reduction == "sentence_mean": |
| 61 | + mask = (tokens[..., :-1] != 2) & ( # pyright: ignore[reportAny] |
| 62 | + tokens[..., :-1] != 1 |
| 63 | + ) |
| 64 | + |
| 65 | + d = { |
| 66 | + k: np.where(mask, v, 0.0).sum(axis=-1) # pyright: ignore[reportAny] |
| 67 | + / mask.sum(axis=-1) # pyright: ignore[reportAny] |
| 68 | + for k, v in d.items() |
| 69 | + } |
| 70 | + elif reduction != "none": |
| 71 | + raise ValueError( |
| 72 | + f'"{reduction}" is not a valid reduction' |
| 73 | + ) # pyright: ignore[reportUnreachable] |
| 74 | + |
| 75 | + return d |
0 commit comments