@@ -43,7 +43,6 @@ def query_sharded(n, context):
4343 data = txn .get (pickle .dumps (tuple (context )))
4444 if not data :
4545 print (f"Context { context } not found in shard { shard } ." )
46- print (index )
4746 return pickle .loads (data ) if data else None
4847
4948def generate (start = "" , max_len = 20 , n = None , dev = False ):
@@ -53,26 +52,31 @@ def generate(start="", max_len=20, n=None, dev=False):
5352 need = n - 1
5453 ctx = tuple ((([START ]* need ) + start_tokens )[- need :]) if need else ()
5554 out = start_tokens [:]
56- for _ in range (max_len ):
57- res = query_sharded (n , ctx )
58- next_word = max (res , key = res .get ) if res else None
59-
60- if next_word in (None , END ):
61- out .append ('#' )
62- break
63- out .append (next_word )
64- if need :
65- ctx = tuple ((list (ctx )+ [next_word ])[- need :])
66- return " " .join (out )
55+ if max_len == 0 :
56+ next_word = query_sharded (n , ctx )
57+ output_str = "\n " .join (f"{ v } { k } " for k , v in sorted (next_word .items (), key = lambda x : x [1 ], reverse = True ))
58+ return output_str
59+ else :
60+ for _ in range (max_len ):
61+ res = query_sharded (n , ctx )
62+ if res is None or res == END or not res :
63+ out .append ('#' )
64+ break
65+ words = list (res .keys ())
66+ probs = list (res .values ())
67+ next_word = random .choices (words , weights = probs , k = 1 )[0 ]
68+ out .append (next_word )
69+ if need :
70+ ctx = tuple ((list (ctx )+ [next_word ])[- need :])
71+ return " " .join (out )
6772
6873def run (response , answer , params :Params ) -> Result :
6974 output = []
70- word_count = params .get ("word_count" , 10 )
71- if word_count == "random" :
72- word_count = random .randint (3 ,15 )
73- response_used = isinstance (response , str )
74- context = response if response_used else "the general" # Default context
7575 context_window = params .get ("context_window" , 3 ) or 3
76+ context = response if isinstance (response , str ) else "the general" # Default context
77+ word_count = params .get ("word_count" , 10 )
78+ word_count = random .randint (3 ,15 ) if word_count == "random" else word_count
79+
7680 try :
7781 output .append (generate (context ,word_count ,context_window ,dev = params .get ("dev" , False )))
7882 except Exception as e :
@@ -86,8 +90,7 @@ def run(response, answer, params:Params) -> Result:
8690 "traceback" : tb ,
8791 }
8892 preface = 'Context window: ' + str (context_window )+ ', Word count: ' + str (word_count )+ '. Output: <br>'
89- feedback_items = [("general" , preface + ' ' .join (output ))]
90- #feedback_items.append("| Answer not an integer; used default context window") if not response_used else None
93+ feedback_items = [("general" , preface + ' ' .join (output ).replace ("</s>" , "" ).replace ("<s>" , "" ).strip ())]
9194 is_correct = True
9295 print (feedback_items )
9396 return Result (is_correct = is_correct ,feedback_items = feedback_items )
0 commit comments