Skip to content

Commit fbf1d69

Browse files
authored
Merge pull request #4 from OpenMLRL/dev
Dev
2 parents 62f4429 + 2b69dc2 commit fbf1d69

3 files changed

Lines changed: 20 additions & 23 deletions

File tree

configs/config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ data:
1212

1313
collab:
1414
mode: TAKE_JOB # ONE | TAKE_JOB
15-
num_agents: 3 # used when mode=TAKE_JOB
15+
num_agents: 2 # used when mode=TAKE_JOB
1616

1717
external:
1818
mode: code_feedback # plain | plain_simple | code_feedback
@@ -24,15 +24,15 @@ trainer:
2424
num_train_epochs: 3
2525
per_device_train_batch_size: 1
2626
# Learning rate for optimizer (alias: lr)
27-
learning_rate: 1.7e-5
27+
learning_rate: 1e-5
2828
logging_steps: 50
2929
save_steps: 200
3030
num_generations: 3
3131
# Per-agent generation cap; increase if outputs truncate.
32-
max_new_tokens: 660
32+
max_new_tokens: 600
3333
temperature: 0.25
3434
top_p: 0.90
35-
num_turns: 2
35+
num_turns: 1
3636
# PPO-related (CoMLRL MAGRPO) options
3737
# Whether to normalize advantages when updating policy
3838
# normalize_advantage: true

rewards/CE_reward.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -394,20 +394,7 @@ def collect_calls(fn: "ast.FunctionDef") -> Set[str]:
394394
_count_pass_syntax = 0
395395

396396
def get_reward_function(strategy, num_agents: int) -> Callable[..., List[float]]:
397-
"""Return a reward function implementing the redesigned lv1+lv2+lv3 scoring.
398-
399-
- V = total number of class methods requiring implementation
400-
- lv1 = 2 * (|union of chosen methods across agents| / V)
401-
Special case: if coverage < 1/2 then reward = 0 for this sample
402-
- lv2 = total-picks control:
403-
Let S = Σ_i |A_i| be the total number of functions generated by all agents.
404-
* If S ≥ 2V+2: terminate this sample early with total reward = -INF (=-1)
405-
* If 0 <= S <= V: lv2(S) = 2 - 3 * ((S - V)^2) / V^2 (assuming V>0)
406-
* If V < S <= 2V+2: lv2(S) = 2 - 3 * ((S - V)^2) / (V + 1)^2
407-
- lv3 = balance based on variance of |A_i| around t = V/N, with
408-
MSD = (1/N) * Σ (s_i - t)^2 and MSD_max = (1/N) * V^2 * (1 - 1/N),
409-
R_bal = max(0, 1 - MSD/(MSD_max + eps))
410-
Total reward = lv1 + lv2 + lv3
397+
"""Return a reward function
411398
"""
412399

413400
def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
@@ -447,9 +434,12 @@ def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
447434
INF = 1
448435
_count_total += 1
449436

437+
V_set: Set[str] = set(method_names)
438+
V = len(V_set)
439+
450440
# Early penalty: penalize by number of agents with zero functions (k * -INF) and skip
451441
try:
452-
zeros = sum(1 for s in A_sets if (len(s) if s is not None else 0) == 0)
442+
zeros = sum(1 for s in A_sets if (len(s) if s is not None else 0) in (0, V))
453443
if zeros > 0:
454444
rewards.append(-INF * 0.5 * zeros)
455445
continue
@@ -460,8 +450,7 @@ def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
460450
_count_pass_lv0 += 1
461451

462452
# New reward rules (lv1 + lv2)
463-
V_set: Set[str] = set(method_names)
464-
V = len(V_set)
453+
465454
if V <= 0:
466455
rewards.append(-INF)
467456
continue
@@ -477,11 +466,13 @@ def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
477466
continue
478467

479468
lv1 = 2.0 * coverage_ratio
469+
if coverage_ratio == 1.0:
470+
lv1 += 0.5 # bonus for full coverage
480471

481472
# lv2: constrain total picks S = sum_i |A_i|
482473
S_total = sum(len(s) for s in A_sets)
483474
# Early termination if total picks exceed 2V
484-
if S_total > 2 * V + 2:
475+
if S_total >= 2 * V:
485476
rewards.append(-INF)
486477
continue
487478

@@ -522,7 +513,7 @@ def reward_wrapper(*agent_completions, batch_items=None, prompts=None):
522513
sum_J += J
523514
N_pairs += 1
524515
mean_J = (sum_J / N_pairs) if N_pairs > 0 else 0.0
525-
jaccard_term = 1.0 * (1.0 - 1.5 * mean_J)
516+
jaccard_term = 1.0 * (1.0 - 2.0 * mean_J)
526517
# if jaccard_term > 2.0:
527518
# jaccard_term = 2.0
528519
# elif jaccard_term < -2.0:

utils/prompting.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def build_take_job_prompt(
8181
SKELETON START
8282
{skeleton.strip()}
8383
SKELETON END
84+
85+
As a final reminder, please select a **non-empty, proper subset** of {v_braced} to implement.
86+
87+
We recommend choosing a consecutive block of methods that either starts at the beginning of {v_braced} or ends at its last method (DO NOT limit yourself to only the beginning).
88+
89+
Take particular care not to select all methods for implementation!
8490
"""
8591
).strip()
8692

0 commit comments

Comments
 (0)