Skip to content

Commit 17ece51

Browse files
committed
fix: align balance description and preserve duplicates in sampling
Address Copilot review by matching schema wording with actual deterministic ordering and preventing unconditional signature-based de-duplication during final sampling, so enable_dedup=false semantics remain effective. Made-with: Cursor
1 parent 253b05d commit 17ece51

2 files changed

Lines changed: 31 additions & 11 deletions

File tree

src/autocode_mcp/tools/problem.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def input_schema(self) -> dict:
230230
},
231231
"enable_balance": {
232232
"type": "boolean",
233-
"description": "启用平衡分布:在已满足「至少一半为 extreme/tle」后,将剩余名额在各非极限类型间尽量均衡分配;关闭时剩余名额按确定性签名顺序填充",
233+
"description": "启用平衡分布:在已满足「至少一半为 extreme/tle」后,将剩余名额在各非极限类型间尽量均衡分配;关闭时剩余名额按确定性的 (type_param, signature) 顺序填充",
234234
"default": True,
235235
},
236236
"oversample_ratio": {
@@ -604,17 +604,18 @@ def _balance_and_sample(
604604
)
605605

606606
result: list[CandidateTest] = []
607-
used_sig: set[str] = set()
607+
selected_ids: set[int] = set()
608608

609609
for c in extreme_pool:
610610
if len(result) >= need_limit:
611611
break
612-
if c.signature in used_sig:
612+
cid = id(c)
613+
if cid in selected_ids:
613614
continue
614615
result.append(c)
615-
used_sig.add(c.signature)
616+
selected_ids.add(cid)
616617

617-
remaining = [c for c in candidates if c.signature not in used_sig]
618+
remaining = [c for c in candidates if id(c) not in selected_ids]
618619
need_more = target_count - len(result)
619620
if need_more <= 0:
620621
return result[:target_count]
@@ -637,10 +638,11 @@ def _balance_and_sample(
637638
for i, type_param in enumerate(type_order):
638639
count = base_count + (1 if i < rem else 0)
639640
for c in by_type[type_param][:count]:
640-
if c.signature in used_sig:
641+
cid = id(c)
642+
if cid in selected_ids:
641643
continue
642644
result.append(c)
643-
used_sig.add(c.signature)
645+
selected_ids.add(cid)
644646
if len(result) >= target_count:
645647
break
646648
if len(result) >= target_count:
@@ -650,18 +652,20 @@ def _balance_and_sample(
650652
for c in sorted(remaining, key=lambda c: (c.type_param, c.signature)):
651653
if len(result) >= target_count:
652654
break
653-
if c.signature in used_sig:
655+
cid = id(c)
656+
if cid in selected_ids:
654657
continue
655658
result.append(c)
656-
used_sig.add(c.signature)
659+
selected_ids.add(cid)
657660
else:
658661
for c in sorted(remaining, key=lambda c: (c.type_param, c.signature)):
659662
if len(result) >= target_count:
660663
break
661-
if c.signature in used_sig:
664+
cid = id(c)
665+
if cid in selected_ids:
662666
continue
663667
result.append(c)
664-
used_sig.add(c.signature)
668+
selected_ids.add(cid)
665669

666670
return result[:target_count]
667671

tests/test_tools/test_problem.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,22 @@ def mk(type_param: str, sig: str) -> CandidateTest:
673673
assert sum(1 for x in out11 if x.type_param in ("3", "4")) >= 6
674674

675675

676+
def test_balance_and_sample_keeps_duplicates_when_dedup_disabled():
677+
"""采样函数不应按 signature 强制去重(由 enable_dedup 控制前置候选)。"""
678+
tool = ProblemGenerateTestsTool()
679+
680+
dup1 = CandidateTest("in-a", "out", "3", "same")
681+
dup2 = CandidateTest("in-b", "out", "3", "same")
682+
dup3 = CandidateTest("in-c", "out", "2", "same")
683+
dup4 = CandidateTest("in-d", "out", "1", "same")
684+
candidates = [dup1, dup2, dup3, dup4]
685+
686+
out = tool._balance_and_sample(candidates, 4, balance_remainder=False)
687+
assert len(out) == 4
688+
assert out.count(dup1) == 1
689+
assert out.count(dup2) == 1
690+
691+
676692
@pytest.mark.asyncio
677693
async def test_problem_generate_tests_balance():
678694
"""测试平衡分布功能。"""

0 commit comments

Comments
 (0)