Skip to content

Commit a7ef7fc

Browse files
committed
fix: 代码审查问题修复
- 修复: stress_test 中 returncode 拼写错误应为 return_code - 修复: 写入文件时添加 encoding=utf-8 - 修复: file_ops 中添加路径遍历防护 - 修复: testlib.h 模板不存在时返回错误而非静默创建占位符 - 重构: 提取公共 _run_process 逻辑以消除 compiler 中的代码重复
1 parent b1c07c3 commit a7ef7fc

File tree

4 files changed

+90
-108
lines changed

4 files changed

+90
-108
lines changed

src/autocode_mcp/tools/file_ops.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
文件操作工具。
33
"""
4+
45
import os
56

67
from .base import Tool, ToolResult
@@ -47,6 +48,15 @@ async def execute(self, path: str, problem_dir: str | None = None) -> ToolResult
4748
else:
4849
full_path = path
4950

51+
# 规范化路径并防止路径遍历攻击
52+
full_path = os.path.normpath(os.path.abspath(full_path))
53+
54+
# 如果指定了 problem_dir,确保文件在该目录内
55+
if problem_dir:
56+
problem_dir = os.path.normpath(os.path.abspath(problem_dir))
57+
if not full_path.startswith(problem_dir + os.sep) and full_path != problem_dir:
58+
return ToolResult.fail("Access denied: path outside problem directory")
59+
5060
if not os.path.exists(full_path):
5161
return ToolResult.fail(f"File not found: {path}")
5262

@@ -116,8 +126,19 @@ async def execute(
116126
else:
117127
full_path = path
118128

119-
# 确保目录存在
129+
# 规范化路径并防止路径遍历攻击
120130
dir_path = os.path.dirname(full_path)
131+
if dir_path:
132+
dir_path = os.path.normpath(os.path.abspath(dir_path))
133+
134+
# 如果指定了 problem_dir,确保文件在该目录内
135+
if problem_dir:
136+
problem_dir = os.path.normpath(os.path.abspath(problem_dir))
137+
full_path = os.path.normpath(os.path.abspath(full_path))
138+
if not full_path.startswith(problem_dir + os.sep) and full_path != problem_dir:
139+
return ToolResult.fail("Access denied: path outside problem directory")
140+
141+
# 确保目录存在
121142
if dir_path:
122143
os.makedirs(dir_path, exist_ok=True)
123144

src/autocode_mcp/tools/problem.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Problem 工具组 - 题目管理。
33
"""
4+
45
import os
56
import shutil
67

@@ -76,10 +77,11 @@ async def execute(
7677
dest_testlib = os.path.join(problem_dir, "files", "testlib.h")
7778
shutil.copy2(template_testlib, dest_testlib)
7879
else:
79-
# 如果模板不存在,创建一个占位符
80-
dest_testlib = os.path.join(problem_dir, "files", "testlib.h")
81-
with open(dest_testlib, "w") as f:
82-
f.write("// testlib.h - Please download from https://github.com/MikeMirzayanov/testlib\n")
80+
return ToolResult.fail(
81+
f"testlib.h template not found at {template_testlib}. "
82+
"Please download from https://github.com/MikeMirzayanov/testlib "
83+
"and place it in the templates/ directory."
84+
)
8385

8486
# 创建基础 README.md
8587
readme_path = os.path.join(problem_dir, "statements", "README.md")
@@ -178,23 +180,29 @@ async def execute(
178180
test_configs.extend([("1", "1", "1", "10", "1", "3")] * 3)
179181

180182
# 随机数据
181-
test_configs.extend([
182-
("2", "2", "10", "100", "1", "3"),
183-
("2", "2", "100", "1000", "1", "3"),
184-
("2", "2", "1000", "5000", "1", "3"),
185-
("2", "2", "5000", "10000", "1", "3"),
186-
])
183+
test_configs.extend(
184+
[
185+
("2", "2", "10", "100", "1", "3"),
186+
("2", "2", "100", "1000", "1", "3"),
187+
("2", "2", "1000", "5000", "1", "3"),
188+
("2", "2", "5000", "10000", "1", "3"),
189+
]
190+
)
187191

188192
# 大数据
189-
test_configs.extend([
190-
("3", "3", "100000", "200000", "1", "1"),
191-
("3", "3", "150000", "200000", "1", "1"),
192-
])
193+
test_configs.extend(
194+
[
195+
("3", "3", "100000", "200000", "1", "1"),
196+
("3", "3", "150000", "200000", "1", "1"),
197+
]
198+
)
193199

194200
# 边界数据
195-
test_configs.extend([
196-
("4", "4", "10", "50", "1", "3"),
197-
])
201+
test_configs.extend(
202+
[
203+
("4", "4", "10", "50", "1", "3"),
204+
]
205+
)
198206

199207
for i, _config in enumerate(test_configs[:test_count], 1):
200208
test_file = os.path.join(tests_dir, f"{i:02d}.in")

src/autocode_mcp/tools/stress_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
基于论文框架,比较 sol.cpp 和 brute.cpp 的输出。
55
"""
6+
67
import os
78
import tempfile
89

@@ -105,7 +106,7 @@ async def execute(
105106
with open(input_path) as f:
106107
input_data = f.read()
107108
val_result = await run_binary(val_exe, input_data, timeout=timeout)
108-
if val_result.returncode != 0:
109+
if val_result.return_code != 0:
109110
validator_failed = True
110111
last_input = input_data
111112
failed_round = i
@@ -185,7 +186,7 @@ async def _generate_input(
185186
"stderr": gen_result.stderr,
186187
}
187188

188-
with open(input_path, "w") as f:
189+
with open(input_path, "w", encoding="utf-8") as f:
189190
f.write(gen_result.stdout)
190191

191192
return {"success": True}

src/autocode_mcp/utils/compiler.py

Lines changed: 40 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- 资源限制
77
- 临时目录隔离
88
"""
9+
910
import asyncio
1011
import os
1112
import shutil
@@ -20,6 +21,7 @@
2021
@dataclass
2122
class CompileResult:
2223
"""编译结果。"""
24+
2325
success: bool
2426
binary_path: str | None = None
2527
error: str | None = None
@@ -30,6 +32,7 @@ class CompileResult:
3032
@dataclass
3133
class RunResult:
3234
"""执行结果。"""
35+
3336
success: bool
3437
return_code: int = -1
3538
stdout: str = ""
@@ -127,10 +130,7 @@ async def compile_cpp(
127130
)
128131

129132
try:
130-
stdout, stderr = await asyncio.wait_for(
131-
process.communicate(),
132-
timeout=timeout
133-
)
133+
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
134134
except TimeoutError:
135135
process.kill()
136136
await process.wait()
@@ -166,58 +166,40 @@ async def compile_cpp(
166166
)
167167

168168

169-
async def run_binary(
170-
binary_path: str,
169+
async def _run_process(
170+
cmd: list[str],
171171
stdin: str = "",
172172
timeout: int = 5,
173173
memory_mb: int = 256,
174174
) -> RunResult:
175-
"""
176-
运行二进制文件,带超时和内存限制。
177-
178-
Args:
179-
binary_path: 二进制文件路径
180-
stdin: 标准输入
181-
timeout: 超时时间(秒)
182-
memory_mb: 内存限制(MB),仅 Linux 有效
183-
184-
Returns:
185-
RunResult: 执行结果
186-
"""
187-
if not os.path.exists(binary_path):
188-
return RunResult(
189-
success=False,
190-
error=f"Binary not found: {binary_path}",
191-
)
192-
175+
"""运行进程的公共逻辑。"""
193176
import time
177+
194178
start_time = time.time()
195179

196180
try:
197-
# Windows 不支持 ulimit,仅使用 timeout
198181
if sys.platform == "win32":
199182
process = await asyncio.create_subprocess_exec(
200-
binary_path,
183+
*cmd,
201184
stdin=asyncio.subprocess.PIPE,
202185
stdout=asyncio.subprocess.PIPE,
203186
stderr=asyncio.subprocess.PIPE,
204187
)
205188
else:
206-
# Linux: 使用 prlimit 设置内存限制
207-
# 注意:需要 prlimit 命令可用
208189
memory_bytes = memory_mb * 1024 * 1024
209190
process = await asyncio.create_subprocess_exec(
210-
"prlimit", f"--as={memory_bytes}", f"--data={memory_bytes}",
211-
binary_path,
191+
"prlimit",
192+
f"--as={memory_bytes}",
193+
f"--data={memory_bytes}",
194+
*cmd,
212195
stdin=asyncio.subprocess.PIPE,
213196
stdout=asyncio.subprocess.PIPE,
214197
stderr=asyncio.subprocess.PIPE,
215198
)
216199

217200
try:
218201
stdout, stderr = await asyncio.wait_for(
219-
process.communicate(input=stdin.encode("utf-8")),
220-
timeout=timeout
202+
process.communicate(input=stdin.encode("utf-8") if stdin else None), timeout=timeout
221203
)
222204
except TimeoutError:
223205
process.kill()
@@ -242,7 +224,7 @@ async def run_binary(
242224
except FileNotFoundError:
243225
return RunResult(
244226
success=False,
245-
error=f"Binary not found or prlimit unavailable: {binary_path}",
227+
error=f"Binary not found or prlimit unavailable: {cmd[0]}",
246228
)
247229
except Exception as e:
248230
return RunResult(
@@ -251,19 +233,17 @@ async def run_binary(
251233
)
252234

253235

254-
async def run_binary_with_args(
236+
async def run_binary(
255237
binary_path: str,
256-
args: list[str],
257238
stdin: str = "",
258239
timeout: int = 5,
259240
memory_mb: int = 256,
260241
) -> RunResult:
261242
"""
262-
运行二进制文件并传递命令行参数
243+
运行二进制文件,带超时和内存限制
263244
264245
Args:
265246
binary_path: 二进制文件路径
266-
args: 命令行参数列表
267247
stdin: 标准输入
268248
timeout: 超时时间(秒)
269249
memory_mb: 内存限制(MB),仅 Linux 有效
@@ -277,65 +257,37 @@ async def run_binary_with_args(
277257
error=f"Binary not found: {binary_path}",
278258
)
279259

280-
import time
281-
start_time = time.time()
282-
283-
try:
284-
if sys.platform == "win32":
285-
process = await asyncio.create_subprocess_exec(
286-
binary_path,
287-
*args,
288-
stdin=asyncio.subprocess.PIPE,
289-
stdout=asyncio.subprocess.PIPE,
290-
stderr=asyncio.subprocess.PIPE,
291-
)
292-
else:
293-
memory_bytes = memory_mb * 1024 * 1024
294-
process = await asyncio.create_subprocess_exec(
295-
"prlimit", f"--as={memory_bytes}", f"--data={memory_bytes}",
296-
binary_path,
297-
*args,
298-
stdin=asyncio.subprocess.PIPE,
299-
stdout=asyncio.subprocess.PIPE,
300-
stderr=asyncio.subprocess.PIPE,
301-
)
260+
return await _run_process([binary_path], stdin, timeout, memory_mb)
302261

303-
try:
304-
stdout, stderr = await asyncio.wait_for(
305-
process.communicate(input=stdin.encode("utf-8") if stdin else None),
306-
timeout=timeout
307-
)
308-
except TimeoutError:
309-
process.kill()
310-
await process.wait()
311-
return RunResult(
312-
success=False,
313-
timed_out=True,
314-
error=f"Execution timeout after {timeout}s",
315-
time_ms=int((time.time() - start_time) * 1000),
316-
)
317262

318-
elapsed_ms = int((time.time() - start_time) * 1000)
263+
async def run_binary_with_args(
264+
binary_path: str,
265+
args: list[str],
266+
stdin: str = "",
267+
timeout: int = 5,
268+
memory_mb: int = 256,
269+
) -> RunResult:
270+
"""
271+
运行二进制文件并传递命令行参数。
319272
320-
return RunResult(
321-
success=process.returncode == 0,
322-
return_code=process.returncode,
323-
stdout=stdout.decode("utf-8", errors="replace"),
324-
stderr=stderr.decode("utf-8", errors="replace"),
325-
time_ms=elapsed_ms,
326-
)
273+
Args:
274+
binary_path: 二进制文件路径
275+
args: 命令行参数列表
276+
stdin: 标准输入
277+
timeout: 超时时间(秒)
278+
memory_mb: 内存限制(MB),仅 Linux 有效
327279
328-
except FileNotFoundError:
329-
return RunResult(
330-
success=False,
331-
error=f"Binary not found or prlimit unavailable: {binary_path}",
332-
)
333-
except Exception as e:
280+
Returns:
281+
RunResult: 执行结果
282+
"""
283+
if not os.path.exists(binary_path):
334284
return RunResult(
335285
success=False,
336-
error=f"Execution error: {str(e)}",
286+
error=f"Binary not found: {binary_path}",
337287
)
338288

289+
return await _run_process([binary_path, *args], stdin, timeout, memory_mb)
290+
339291

340292
async def compile_all(
341293
problem_dir: str,

0 commit comments

Comments
 (0)