Skip to content

Commit 2f9e7f9

Browse files
Added nice metrics by default
1 parent 205ed91 commit 2f9e7f9

File tree

2 files changed

+93
-15
lines changed

2 files changed

+93
-15
lines changed

example.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from python_mg import Lexicon
2+
from python_mg.metrics import grammar_f1, grammar_f1_from_strings
23
import numpy as np
34
import numpy.typing as npt
45

@@ -37,26 +38,28 @@
3738
for i in range(len(batch)):
3839
z[i, : len(batch[i])] = batch[i]
3940

40-
cont = lexicon.token_continuations(z, "C")[:, :-1, :]
41+
cont = lexicon.token_continuations(z, "C")
4142

42-
out = np.eye(len(tokens))[z]
43+
out = np.eye(len(tokens))[z[:, 1:]]
44+
out = np.log(out / out.sum(axis=-1, keepdims=True))
4345

46+
print(grammar_f1_from_strings(lexicon, z, out, "C"))
4447

4548
for i in range(len(z[0]) - 1):
4649
print(lexicon.detokenize(batch[0]))
4750
print([rev_tokens[s] for s in cont[0, i, :].nonzero()[0]])
4851

4952

50-
# for p in lexicon.generate_grammar("C", max_strings=50):
51-
# print(p)
52-
# tokens = p.tokens()
53-
# print(tokens)
54-
# print(lexicon.detokenize(tokens))
55-
# print(lexicon.detokenize(tokens.tolist()))
56-
# print(lexicon.parse_tokens(tokens, "C"))
57-
# print(p.latex())
58-
# print(p.log_prob())
59-
# print(p.prob())
60-
# tree = p.to_tree()
61-
# print(tree.normal_string())
62-
# print(tree.base_string())
53+
for p in lexicon.generate_grammar("C", max_strings=50):
54+
print(p)
55+
tokens = p.tokens()
56+
print(tokens)
57+
print(lexicon.detokenize(tokens))
58+
print(lexicon.detokenize(tokens.tolist()))
59+
print(lexicon.parse_tokens(tokens, "C"))
60+
print(p.latex())
61+
print(p.log_prob())
62+
print(p.prob())
63+
tree = p.to_tree()
64+
print(tree.normal_string())
65+
print(tree.base_string())

python/python_mg/metrics.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Literal
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
from python_mg._lib_name import Lexicon
7+
8+
9+
def grammar_f1(
10+
preds: npt.NDArray[np.float64],
11+
correct: npt.NDArray[np.bool],
12+
) -> dict[str, npt.NDArray[np.float64]]:
13+
if preds.shape != correct.shape:
14+
raise ValueError("correct and preds must have matching shapes")
15+
16+
precision: npt.NDArray[np.float64] = np.exp( # pyright: ignore[reportAny]
17+
np.logaddexp.reduce(
18+
np.where(correct, preds, -np.inf), axis=-1
19+
) # pyright: ignore[reportAny]
20+
)
21+
22+
total_bad: npt.NDArray[np.float64] = ( # pyright: ignore[reportAny]
23+
np.logaddexp.reduce(np.where(~correct, preds, -np.inf), axis=-1, keepdims=True)
24+
)
25+
better_than_bad = np.where(np.where(correct, preds, -np.inf) > total_bad, 1.0, 0.0)
26+
27+
recall = np.where(correct, better_than_bad, 0.0).sum( # pyright: ignore[reportAny]
28+
axis=-1
29+
) / correct.sum(axis=-1)
30+
31+
return {
32+
"f1": (2 * precision * recall) / (precision + recall),
33+
"precision": precision,
34+
"recall": recall,
35+
}
36+
37+
38+
def grammar_f1_from_strings(
39+
lexicon: Lexicon,
40+
tokens: npt.NDArray[np.int_],
41+
preds: npt.NDArray[np.float64],
42+
category: str,
43+
min_log_prob: float | None = -128.0,
44+
move_prob: float = 0.5,
45+
max_steps: int | None = 64,
46+
n_beams: int | None = 256,
47+
reduction: Literal["none", "sentence_mean"] = "sentence_mean",
48+
) -> dict[str, npt.NDArray[np.float64]]:
49+
conts = lexicon.token_continuations(
50+
tokens,
51+
category,
52+
min_log_prob=min_log_prob,
53+
move_prob=move_prob,
54+
max_steps=max_steps,
55+
n_beams=n_beams,
56+
)[..., :-1, :]
57+
58+
d = grammar_f1(preds, conts)
59+
60+
if reduction == "sentence_mean":
61+
mask = (tokens[..., :-1] != 2) & ( # pyright: ignore[reportAny]
62+
tokens[..., :-1] != 1
63+
)
64+
65+
d = {
66+
k: np.where(mask, v, 0.0).sum(axis=-1) # pyright: ignore[reportAny]
67+
/ mask.sum(axis=-1) # pyright: ignore[reportAny]
68+
for k, v in d.items()
69+
}
70+
elif reduction != "none":
71+
raise ValueError(
72+
f'"{reduction}" is not a valid reduction'
73+
) # pyright: ignore[reportUnreachable]
74+
75+
return d

0 commit comments

Comments
 (0)