-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathshannon_words_build.py
More file actions
62 lines (49 loc) · 2.19 KB
/
shannon_words_build.py
File metadata and controls
62 lines (49 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import lmdb, pickle, nltk, json, os
from nltk.corpus import brown
from pathlib import Path
from collections import defaultdict
from evaluation_function.models.utils import shard_for
from lf_toolkit.evaluation import Result, Params
os.environ["PYTHONHASHSEED"] = "0"
START, END = "<s>", "</s>"
BASE_DIR = Path("evaluation_function/models")
MODEL_DIR = Path(os.environ.get("MODEL_DIR", BASE_DIR / "storage"/"lmdb_sharded"))
MODEL_DIR.mkdir(parents=True, exist_ok=True)
def corpus_sents(limit=None):
for s in brown.sents()[:limit]:
yield [w.lower() for w in s]
def build_sharded_lmdb(n=4, n_shards=64, map_size=2**28):
n_dir = MODEL_DIR / f"ngrams_{n}"
n_dir.mkdir(parents=True, exist_ok=True)
# open environments (small, parallel)
envs = [lmdb.open(str(n_dir / f"shard_{i:02d}.lmdb"), map_size=map_size) for i in range(n_shards)]
index = {i: str(n_dir / f"shard_{i:02d}.lmdb") for i in range(n_shards)}
print(f"Building {n}-grams into {n_shards} shards...")
counts = [defaultdict(lambda: defaultdict(int)) for _ in range(n_shards)]
for sent in corpus_sents(limit=None):
s = ([START]*(n-1)) + sent + ([END] if n>1 else [])
for i in range(len(s)-n+1):
ctx = tuple(s[i:i+n-1])
nxt = s[i+n-1]
shard = shard_for(ctx, n_shards)
counts[shard][ctx][nxt] += 1
# write per-shard
for shard, env in enumerate(envs):
with env.begin(write=True) as txn:
for ctx, nexts in counts[shard].items():
txn.put(pickle.dumps(ctx), pickle.dumps(dict(nexts)))
env.close()
print(f"✅ shard {shard:02d} written with {len(counts[shard])} contexts")
with open(n_dir / "index.json", "w") as f:
json.dump(index, f, indent=2)
print(f"✅ index written to {n_dir/'index.json'}")
def run(response, answer, params:Params) -> Result:
nltk.download("brown", quiet=True)
n_max = params.get("n_max",7)
for n in range(2,n_max+1):
print('Building for n=', n)
build_sharded_lmdb(n=n, n_shards=64)
print('Complete for n=', n)
return Result(is_correct=True, feedback_items = [("general", "Complete.")])
if __name__ == "__main__":
run(None, None, {})