Skip to content

Commit b09234f

Browse files
Peter JohnsonPeter Johnson
authored andcommitted
shannon words random choices
1 parent 1515e28 commit b09234f

1 file changed

Lines changed: 22 additions & 19 deletions

File tree

evaluation_function/models/shannon_words_ngram.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def query_sharded(n, context):
4343
data = txn.get(pickle.dumps(tuple(context)))
4444
if not data:
4545
print(f"Context {context} not found in shard {shard}.")
46-
print(index)
4746
return pickle.loads(data) if data else None
4847

4948
def generate(start="", max_len=20, n=None, dev=False):
@@ -53,26 +52,31 @@ def generate(start="", max_len=20, n=None, dev=False):
5352
need = n-1
5453
ctx = tuple((([START]*need) + start_tokens)[-need:]) if need else ()
5554
out = start_tokens[:]
56-
for _ in range(max_len):
57-
res = query_sharded(n, ctx)
58-
next_word = max(res, key=res.get) if res else None
59-
60-
if next_word in (None, END):
61-
out.append('#')
62-
break
63-
out.append(next_word)
64-
if need:
65-
ctx = tuple((list(ctx)+[next_word])[-need:])
66-
return " ".join(out)
55+
if max_len == 0:
56+
next_word = query_sharded(n, ctx)
57+
output_str = "\n".join(f"{v} {k}" for k, v in sorted(next_word.items(), key=lambda x: x[1], reverse=True))
58+
return output_str
59+
else:
60+
for _ in range(max_len):
61+
res = query_sharded(n, ctx)
62+
if res is None or res == END or not res:
63+
out.append('#')
64+
break
65+
words = list(res.keys())
66+
probs = list(res.values())
67+
next_word = random.choices(words, weights=probs, k=1)[0]
68+
out.append(next_word)
69+
if need:
70+
ctx = tuple((list(ctx)+[next_word])[-need:])
71+
return " ".join(out)
6772

6873
def run(response, answer, params:Params) -> Result:
6974
output=[]
70-
word_count = params.get("word_count", 10)
71-
if word_count == "random":
72-
word_count = random.randint(3,15)
73-
response_used = isinstance(response, str)
74-
context = response if response_used else "the general" # Default context
7575
context_window = params.get("context_window", 3) or 3
76+
context = response if isinstance(response, str) else "the general" # Default context
77+
word_count = params.get("word_count", 10)
78+
word_count = random.randint(3,15) if word_count == "random" else word_count
79+
7680
try:
7781
output.append(generate(context,word_count,context_window,dev=params.get("dev", False)))
7882
except Exception as e:
@@ -86,8 +90,7 @@ def run(response, answer, params:Params) -> Result:
8690
"traceback": tb,
8791
}
8892
preface = 'Context window: '+str(context_window)+', Word count: '+str(word_count)+'. Output: <br>'
89-
feedback_items = [("general", preface + ' '.join(output))]
90-
#feedback_items.append("| Answer not an integer; used default context window") if not response_used else None
93+
feedback_items = [("general", preface + ' '.join(output).replace("</s>", "").replace("<s>", "").strip())]
9194
is_correct = True
9295
print(feedback_items)
9396
return Result(is_correct=is_correct,feedback_items=feedback_items)

0 commit comments

Comments
 (0)