-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapi_tool.py
More file actions
685 lines (579 loc) · 36.3 KB
/
api_tool.py
File metadata and controls
685 lines (579 loc) · 36.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
# api_tool.py
import os
import json
import zipfile
import urllib.request
import urllib.error
import http.client
import subprocess
import shutil
import asyncio
import tempfile
import time
import contextlib
from aiohttp import web
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
CUSTOM_NODES_DIR = os.path.dirname(THIS_DIR)
MAX_ZIP_SIZE = 2 * 1024 * 1024 * 1024 # 2GB ZIP 包上限
def _is_local_request(request):
"""检查请求是否来自本机,保护敏感安装接口"""
remote = request.remote or ""
if remote in ("127.0.0.1", "localhost", "::1"):
return True
return False
def _prepare_git_env():
"""准备 Git 环境变量,防止 git 弹窗要求输入密码导致后台卡死"""
env = os.environ.copy()
env["GIT_TERMINAL_PROMPT"] = "0"
env["GCM_INTERACTIVE"] = "never" # 禁止 Git Credential Manager 交互
env["GIT_CONFIG_NOSYSTEM"] = "1" # 禁用系统级 gitconfig(可能配置了 credential.helper)
env["GCM_PROVIDER"] = "" # 清空 GCM provider,阻止任何凭证提供者
env["GIT_ASKPASS"] = "" # 禁用 Git ASKPASS 外部程序
env["SSH_ASKPASS"] = "" # 禁用 SSH ASKPASS 外部程序
return env
async def install_tool_handler(request):
# NOTE: 与 install_tool_stream_handler 共享核心安装逻辑(URL校验、双链路容灾、Git克隆),如需修改请同步
if not _is_local_request(request):
return web.json_response({"error": "Forbidden: local access only"}, status=403)
data = await request.json()
item_url = data.get("url")
item_id = data.get("id")
account = data.get("account") # 用户身份凭证
if not item_url or not account:
return web.json_response({"error": "缺少下载凭证或链接"}, status=400)
# 校验 URL 是否为有效的 Git 仓库地址
valid_git_hosts = ["github.com", "gitlab.com", "gitee.com", "bitbucket.org", "kkgithub.com"]
if not any(host in item_url for host in valid_git_hosts):
return web.json_response({"error": "该资源链接不是有效的 Git 仓库地址,无法自动安装。请前往资源原始页面手动下载。"}, status=400)
target_dir_name = item_url.rstrip("/").split("/")[-1].replace(".git", "")
clone_target_path = os.path.join(CUSTOM_NODES_DIR, target_dir_name)
# 清理残留机制。如果文件夹已存在,说明可能是旧的无 .git 残缺安装,直接移除
if os.path.exists(clone_target_path):
try:
shutil.rmtree(clone_target_path)
except Exception as e:
return web.json_response({"error": f"目录 {target_dir_name} 已存在且被占用,无法自动清理,请先手动删除。错误: {str(e)}"}, status=400)
try:
# 🚀 核心升级:双链路容灾机制
# 链路 A:使用目前最稳定的域名级直接替换镜像
mirror_url = item_url.replace("https://kkgithub.com", "https://github.com")
env = _prepare_git_env()
try:
print(f"正在尝试通过加速镜像 Clone: {mirror_url}")
subprocess.run(
["git", "-c", "credential.helper=", "clone", "--depth", "1", "--single-branch", "--no-tags", mirror_url, clone_target_path],
capture_output=True,
text=True,
check=True,
env=env,
timeout=1200 # 20分钟超时
)
print("✅ 镜像 Git Clone 安装成功!保留了完整的版本控制 (.git)。")
return web.json_response({"status": "success"})
except subprocess.TimeoutExpired:
return web.json_response({"error": "安装超时:仓库过大或网络异常,请检查网络后重试"}, status=504)
except subprocess.CalledProcessError as e1:
print(f"⚠️ 镜像源不可用或发生冲突,系统正在自动无缝回退至直连: {item_url}")
# 清理刚才克隆到一半可能留下的残缺空文件夹
if os.path.exists(clone_target_path):
shutil.rmtree(clone_target_path)
# 链路 B:官方直连 (专门照顾开了科学上网/全局代理的用户)
subprocess.run(
["git", "-c", "credential.helper=", "clone", "--depth", "1", "--single-branch", "--no-tags", item_url, clone_target_path],
capture_output=True,
text=True,
check=True,
env=env,
timeout=1200 # 20分钟超时
)
print("✅ 直连 Git Clone 安装成功!")
return web.json_response({"status": "success"})
except subprocess.TimeoutExpired:
return web.json_response({"error": "安装超时:仓库过大或网络异常,请检查网络后重试"}, status=504)
except FileNotFoundError:
# 拦截用户电脑根本没装 Git 的情况
return web.json_response({"error": "系统中未检测到 Git,请先安装 Git 环境才能下载插件!"}, status=500)
except subprocess.CalledProcessError as e2:
# 两条链路都失败了的最终兜底
error_msg = e2.stderr or e2.stdout
return web.json_response({"error": f"Git Clone 失败,镜像与直连均不可用。请检查网络或开启代理: {error_msg}"}, status=500)
except Exception as e:
# 兜底异常拦截
return web.json_response({"error": f"安装过程发生未知失败: {str(e)}"}, status=500)
async def install_private_tool_handler(request):
"""本地 API:针对付费/私有库,通过云端鉴权代理下载 ZIP 包,流式写入临时文件并磁盘解压覆盖
NOTE: 与 install_private_tool_stream_handler 共享核心安装逻辑(ZIP下载、重试、解压),如需修改请同步
"""
if not _is_local_request(request):
return web.json_response({"error": "Forbidden: local access only"}, status=403)
data = await request.json()
item_url = data.get("url")
item_id = data.get("id")
account = data.get("account")
if not item_url or not account or not item_id:
return web.json_response({"error": "缺少核心鉴权参数"}, status=400)
target_dir_name = item_url.rstrip("/").split("/")[-1].replace(".git", "")
extract_target_path = os.path.join(CUSTOM_NODES_DIR, target_dir_name)
tmp_path = None
try:
proxy_api_url = "https://zhiwei666-comfyui-ranking-api.hf.space/api/proxy_github_zip"
payload = json.dumps({"url": item_url, "item_id": item_id, "account": account}).encode("utf-8")
headers = {'Content-Type': 'application/json'}
# ZIP 下载(最多重试3次)
max_retries = 3
for attempt in range(max_retries):
try:
req = urllib.request.Request(proxy_api_url, data=payload, headers=headers)
print(f"[ComfyUI-Ranking] 🔒 正在向云端发起私有资产鉴权与加密拉取: {item_id}" + (f"(第{attempt+1}次尝试)" if attempt > 0 else ""))
response = await asyncio.to_thread(lambda: urllib.request.urlopen(req, timeout=600))
try:
content_length = int(response.headers.get('Content-Length', 0))
# 磁盘空间检查(仅第一次)
if attempt == 0 and content_length > 0:
required_space = content_length * 4
free_space = shutil.disk_usage(CUSTOM_NODES_DIR).free
if free_space < required_space:
free_gb = free_space / (1024**3)
need_gb = required_space / (1024**3)
return web.json_response({"error": f"磁盘空间不足:需要约 {need_gb:.1f}GB,当前剩余 {free_gb:.1f}GB"}, status=500)
# 流式下载到临时文件,避免大文件 OOM
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file:
tmp_path = tmp_file.name
downloaded = 0
chunk_size = 1024 * 1024 # 1MB 分块
while True:
chunk = await asyncio.to_thread(response.read, chunk_size)
if not chunk:
break
# 检查是否为错误响应(仅检查第一个 chunk)
if downloaded == 0 and not chunk.startswith(b'PK\x03\x04'):
await asyncio.to_thread(os.unlink, tmp_path)
tmp_path = None
try:
text = chunk[:2000].decode('utf-8', errors='ignore')
error_data = json.loads(text)
err_msg = error_data.get("detail", error_data.get("error", "云端返回非ZIP内容"))
except (json.JSONDecodeError, ValueError):
preview = chunk[:200].decode('utf-8', errors='ignore').strip()
err_msg = f"云端返回非ZIP内容(可能是认证失败或仓库不可达): {preview[:100]}"
return web.json_response({"error": f"云端拒绝访问或拉取失败: {err_msg}"}, status=403)
tmp_file.write(chunk)
downloaded += len(chunk)
if downloaded > MAX_ZIP_SIZE:
await asyncio.to_thread(os.unlink, tmp_path)
tmp_path = None
return web.json_response({"error": f"ZIP 文件体积超过安全上限 ({MAX_ZIP_SIZE // (1024*1024*1024)}GB)"}, status=413)
finally:
response.close()
# 下载成功,跳出重试循环
break
except (http.client.IncompleteRead, urllib.error.URLError, ConnectionResetError, TimeoutError) as e:
# 清理失败的临时文件
if tmp_path and os.path.exists(tmp_path):
try:
await asyncio.to_thread(os.unlink, tmp_path)
except:
pass
tmp_path = None
if attempt < max_retries - 1:
wait_time = [2, 5][attempt]
print(f"[ComfyUI-Ranking] ⚠️ ZIP 下载失败(第{attempt+1}次),{wait_time}秒后重试: {e}")
await asyncio.sleep(wait_time)
else:
print(f"[ComfyUI-Ranking] ❌ ZIP 下载失败(已重试{max_retries}次): {e}")
return web.json_response({"status": "error", "message": f"下载失败(网络不稳定,已重试{max_retries}次): {str(e)}"}, status=500)
print("[ComfyUI-Ranking] ✅ 成功接收云端安全 ZIP 数据流,执行热覆盖解压...")
# 从临时文件磁盘解压(不占内存)
with zipfile.ZipFile(tmp_path) as zip_ref:
namelist = zip_ref.namelist()
if not namelist:
return web.json_response({"error": "下载的压缩包结构为空"}, status=500)
top_level_dir = namelist[0].split('/')[0] + '/'
# 🚀 核心修复 1:执行纯净更新!在确认 ZIP 完好无损后,先彻底抹除旧版本文件夹,防止残留的废弃 .py 文件引发报错
if os.path.exists(extract_target_path):
try:
shutil.rmtree(extract_target_path)
except Exception as e:
# 🚀 核心修复 2:拦截 Windows 下 Python 文件被 ComfyUI 进程死锁的情况
return web.json_response({"error": "旧版本文件正在被 ComfyUI 进程占用,无法覆盖更新。请彻底关闭控制台黑框,重新启动 ComfyUI 后再点击更新。"}, status=500)
os.makedirs(extract_target_path, exist_ok=True)
for member in namelist:
if member.startswith(top_level_dir):
target_path = member.replace(top_level_dir, "", 1)
if not target_path: continue
# 防止路径穿越攻击 - 第一层防御
if ".." in target_path or target_path.startswith("/") or target_path.startswith("\\"):
print(f"[ComfyUI-Ranking] ⚠️ 跳过不安全路径: {target_path}")
continue
# 防止路径穿越攻击 - 第二层防御:使用 normpath 规范化检查
abs_target = os.path.normpath(os.path.join(extract_target_path, target_path))
abs_base = os.path.normpath(extract_target_path)
if not abs_target.startswith(abs_base):
print(f"[ComfyUI-Ranking] ⚠️ 跳过不安全路径: {target_path}")
continue
source = zip_ref.open(member)
dest_path = os.path.join(extract_target_path, target_path)
if member.endswith('/'):
os.makedirs(dest_path, exist_ok=True)
else:
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
with open(dest_path, "wb") as target:
shutil.copyfileobj(source, target) # 强制写入纯净新文件
print(f"[ComfyUI-Ranking] 🎉 私有插件 {target_dir_name} 静默更新/安装完成!无 .git 目录残留。")
return web.json_response({"status": "success"})
except urllib.error.HTTPError as e:
err_msg = e.read().decode('utf-8', errors='ignore')
return web.json_response({"error": f"拉取中断: {err_msg}"}, status=500)
except zipfile.BadZipFile:
return web.json_response({"error": "下载的文件不是有效的ZIP格式,可能原因:云端代理认证失败、私有仓库权限不足或已删除、GitHub 返回了错误页面。请检查仓库地址和密钥是否正确。"}, status=500)
except Exception as e:
return web.json_response({"error": f"本地解压覆盖异常: {str(e)}"}, status=500)
finally:
# 清理临时文件,捕获异常防止覆盖响应(Windows下需重试避免WinError 32)
if tmp_path and os.path.exists(tmp_path):
for _retry in range(5):
try:
await asyncio.to_thread(os.unlink, tmp_path)
break
except Exception:
await asyncio.sleep(0.5)
else:
print(f"[ComfyUI-Ranking] ⚠️ 临时文件清理失败(已重试5次,不影响安装结果): {tmp_path}")
async def install_tool_stream_handler(request):
"""SSE 流式接口:通过 Git Clone 下载插件,实时推送进度
NOTE: 与 install_tool_handler 共享核心安装逻辑(URL校验、双链路容灾、Git克隆),如需修改请同步
"""
if not _is_local_request(request):
return web.json_response({"error": "Forbidden: local access only"}, status=403)
data = await request.json()
item_url = data.get("url")
item_id = data.get("id")
account = data.get("account")
resp = web.StreamResponse(
status=200,
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
}
)
await resp.prepare(request)
async def send_progress(stage, progress, message, status=None):
event = {"stage": stage, "progress": progress, "message": message}
if status:
event["status"] = status
await resp.write(f"data: {json.dumps(event, ensure_ascii=False)}\n\n".encode('utf-8'))
try:
if not item_url or not account:
await send_progress("error", -1, "缺少下载凭证或链接", "error")
await resp.write_eof()
return resp
# 校验 URL 是否为有效的 Git 仓库地址
valid_git_hosts = ["github.com", "gitlab.com", "gitee.com", "bitbucket.org", "kkgithub.com"]
if not any(host in item_url for host in valid_git_hosts):
await send_progress("error", -1, "该资源链接不是有效的 Git 仓库地址,无法自动安装。请前往资源原始页面手动下载。", "error")
await resp.write_eof()
return resp
await send_progress("validate", 5, "校验安装参数...")
target_dir_name = item_url.rstrip("/").split("/")[-1].replace(".git", "")
clone_target_path = os.path.join(CUSTOM_NODES_DIR, target_dir_name)
await send_progress("cleanup", 15, "清理残留目录...")
if os.path.exists(clone_target_path):
try:
shutil.rmtree(clone_target_path)
except Exception as e:
await send_progress("error", -1, f"目录 {target_dir_name} 已存在且被占用,无法自动清理,请先手动删除。错误: {str(e)}", "error")
await resp.write_eof()
return resp
env = _prepare_git_env()
mirror_url = item_url.replace("https://kkgithub.com", "https://github.com")
await send_progress("git_mirror", 25, "尝试镜像源克隆...")
try:
await send_progress("git_cloning", 50, "正在克隆仓库(浅克隆模式)...")
proc = await asyncio.create_subprocess_exec(
"git", "-c", "credential.helper=", "clone", "--depth", "1", "--single-branch", "--no-tags", mirror_url, clone_target_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env
)
# 使用心跳机制等待克隆完成,防止连接被代理/浏览器认为空闲而断开
clone_task = asyncio.create_task(proc.communicate())
start_time = time.time()
while not clone_task.done():
await asyncio.sleep(15) # 每15秒检查一次
if clone_task.done():
break
elapsed = time.time() - start_time
if elapsed > 1200: # 总超时20分钟
proc.kill()
await proc.wait()
clone_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await clone_task
await send_progress("error", -1, "安装超时:仓库过大或网络异常(已等待20分钟)", "error")
await resp.write_eof()
return resp
progress_pct = min(30 + int(elapsed / 1200 * 50), 79)
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
await send_progress("git_cloning", progress_pct, f"正在克隆仓库(浅克隆模式)... 已等待 {minutes}分{seconds}秒")
stdout, stderr = await clone_task
if proc.returncode != 0:
raise subprocess.CalledProcessError(
proc.returncode, ["git", "-c", "credential.helper=", "clone", "--depth", "1", "--single-branch", "--no-tags", mirror_url, clone_target_path],
output=stdout, stderr=stderr
)
await send_progress("complete", 100, "✅ 安装成功!", "success")
except subprocess.CalledProcessError as e1:
await send_progress("git_fallback", 55, "镜像失败,切换直连源...")
if os.path.exists(clone_target_path):
shutil.rmtree(clone_target_path)
await send_progress("git_direct", 70, "正在直连克隆(浅克隆模式)...")
proc = await asyncio.create_subprocess_exec(
"git", "-c", "credential.helper=", "clone", "--depth", "1", "--single-branch", "--no-tags", item_url, clone_target_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env
)
# 使用心跳机制等待克隆完成,防止连接被代理/浏览器认为空闲而断开
clone_task = asyncio.create_task(proc.communicate())
start_time = time.time()
while not clone_task.done():
await asyncio.sleep(15)
if clone_task.done():
break
elapsed = time.time() - start_time
if elapsed > 1200:
proc.kill()
await proc.wait()
clone_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await clone_task
await send_progress("error", -1, "安装超时:仓库过大或网络异常(已等待20分钟)", "error")
await resp.write_eof()
return resp
progress_pct = min(30 + int(elapsed / 1200 * 50), 79)
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
await send_progress("git_cloning", progress_pct, f"正在克隆仓库(浅克隆模式)... 已等待 {minutes}分{seconds}秒")
stdout, stderr = await clone_task
if proc.returncode != 0:
raise subprocess.CalledProcessError(
proc.returncode, ["git", "-c", "credential.helper=", "clone", "--depth", "1", "--single-branch", "--no-tags", item_url, clone_target_path],
output=stdout, stderr=stderr
)
await send_progress("complete", 100, "✅ 安装成功!", "success")
except FileNotFoundError:
await send_progress("error", -1, "系统中未检测到 Git,请先安装 Git 环境才能下载插件!", "error")
except subprocess.CalledProcessError as e2:
error_msg = e2.stderr.decode('utf-8', errors='ignore') if e2.stderr else (e2.stdout.decode('utf-8', errors='ignore') if e2.stdout else "")
await send_progress("error", -1, f"Git Clone 失败,镜像与直连均不可用。请检查网络或开启代理: {error_msg}", "error")
except Exception as e:
await send_progress("error", -1, f"安装过程发生未知失败: {str(e)}", "error")
await resp.write_eof()
return resp
async def install_private_tool_stream_handler(request):
"""SSE 流式接口:针对付费/私有库,通过云端鉴权代理下载 ZIP 包,流式写入临时文件并磁盘解压覆盖
NOTE: 与 install_private_tool_handler 共享核心安装逻辑(ZIP下载、重试、解压),如需修改请同步
"""
if not _is_local_request(request):
return web.json_response({"error": "Forbidden: local access only"}, status=403)
data = await request.json()
item_url = data.get("url")
item_id = data.get("id")
account = data.get("account")
resp = web.StreamResponse(
status=200,
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
}
)
await resp.prepare(request)
async def send_progress(stage, progress, message, status=None):
event = {"stage": stage, "progress": progress, "message": message}
if status:
event["status"] = status
await resp.write(f"data: {json.dumps(event, ensure_ascii=False)}\n\n".encode('utf-8'))
tmp_path = None
try:
if not item_url or not account or not item_id:
await send_progress("error", -1, "缺少核心鉴权参数", "error")
await resp.write_eof()
return resp
await send_progress("validate", 5, "校验安装参数...")
await send_progress("auth", 15, "验证购买权限...")
target_dir_name = item_url.rstrip("/").split("/")[-1].replace(".git", "")
extract_target_path = os.path.join(CUSTOM_NODES_DIR, target_dir_name)
proxy_api_url = "https://zhiwei666-comfyui-ranking-api.hf.space/api/proxy_github_zip"
payload = json.dumps({"url": item_url, "item_id": item_id, "account": account}).encode("utf-8")
headers = {'Content-Type': 'application/json'}
# ZIP 下载(最多重试3次)
max_retries = 3
for attempt in range(max_retries):
try:
req = urllib.request.Request(proxy_api_url, data=payload, headers=headers)
await send_progress("downloading", 30, "从云端下载资源包..." + (f"(第{attempt+1}次尝试)" if attempt > 0 else ""))
with urllib.request.urlopen(req, timeout=600) as response:
content_length = int(response.headers.get('Content-Length', 0))
# 磁盘空间检查(仅第一次)
if attempt == 0 and content_length > 0:
required_space = content_length * 4
free_space = shutil.disk_usage(CUSTOM_NODES_DIR).free
if free_space < required_space:
free_gb = free_space / (1024**3)
need_gb = required_space / (1024**3)
await send_progress("error", -1, f"磁盘空间不足:需要约 {need_gb:.1f}GB,当前剩余 {free_gb:.1f}GB", "error")
await resp.write_eof()
return resp
# 流式下载到临时文件,避免大文件 OOM
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as tmp_file:
tmp_path = tmp_file.name
downloaded = 0
chunk_size = 1024 * 1024 # 1MB 分块
last_progress_time = time.time()
while True:
# 使用异步读取避免阻塞事件循环,配合心跳机制防止慢网络下连接空闲断开
read_task = asyncio.create_task(asyncio.to_thread(response.read, chunk_size))
# 等待读取完成,每15秒发送一次心跳保活
while not read_task.done():
done, _ = await asyncio.wait([read_task], timeout=15)
if done:
break
# 读取超过15秒未完成,发送心跳保活
current_time = time.time()
if current_time - last_progress_time >= 15:
mb_done = downloaded / (1024 * 1024)
if content_length > 0:
pct = 30 + int(downloaded / content_length * 30)
mb_total = content_length / (1024 * 1024)
await send_progress("downloading", pct, f"下载中... {mb_done:.1f}MB / {mb_total:.1f}MB")
else:
pct = min(30 + int(downloaded / (200 * 1024 * 1024) * 30), 59)
await send_progress("downloading", pct, f"下载中... 已接收 {mb_done:.1f}MB")
last_progress_time = current_time
chunk = read_task.result()
if not chunk:
break
# 检查是否为错误响应(仅检查第一个 chunk)
if downloaded == 0 and not chunk.startswith(b'PK\x03\x04'):
await asyncio.to_thread(os.unlink, tmp_path)
tmp_path = None
try:
text = chunk[:2000].decode('utf-8', errors='ignore')
error_data = json.loads(text)
err_msg = error_data.get("detail", error_data.get("error", "云端返回非ZIP内容"))
except (json.JSONDecodeError, ValueError):
preview = chunk[:200].decode('utf-8', errors='ignore').strip()
err_msg = f"云端返回非ZIP内容(可能是认证失败或仓库不可达): {preview[:100]}"
await send_progress("error", -1, f"云端拒绝访问或拉取失败: {err_msg}", "error")
await resp.write_eof()
return resp
tmp_file.write(chunk)
downloaded += len(chunk)
if downloaded > MAX_ZIP_SIZE:
await asyncio.to_thread(os.unlink, tmp_path)
tmp_path = None
await send_progress("error", -1, f"ZIP 文件体积超过安全上限 ({MAX_ZIP_SIZE // (1024*1024*1024)}GB)", "error")
await resp.write_eof()
return resp
# 每5MB或每10秒更新一次下载进度
current_time = time.time()
if downloaded % (5 * 1024 * 1024) < chunk_size or current_time - last_progress_time >= 10:
mb_done = downloaded / (1024 * 1024)
if content_length > 0:
pct = 30 + int(downloaded / content_length * 30) # 30%~60%
mb_total = content_length / (1024 * 1024)
await send_progress("downloading", pct, f"下载中... {mb_done:.1f}MB / {mb_total:.1f}MB")
else:
# 无法获取总大小时,显示已下载量,进度缓慢递增(按200MB估算)
pct = min(30 + int(downloaded / (200 * 1024 * 1024) * 30), 59)
await send_progress("downloading", pct, f"下载中... 已接收 {mb_done:.1f}MB")
last_progress_time = current_time
# 下载成功,跳出重试循环
break
except (http.client.IncompleteRead, urllib.error.URLError, ConnectionResetError, TimeoutError) as e:
if tmp_path and os.path.exists(tmp_path):
try:
await asyncio.to_thread(os.unlink, tmp_path)
except:
pass
tmp_path = None
if attempt < max_retries - 1:
wait_time = [2, 5][attempt]
await send_progress("downloading", 10, f"⚠️ 下载中断,{wait_time}秒后重试(第{attempt+2}次)...")
await asyncio.sleep(wait_time)
else:
await send_progress("error", -1, f"下载失败(网络不稳定,已重试{max_retries}次): {str(e)}", "error")
await resp.write_eof()
return resp
await send_progress("download_done", 60, "下载完成,准备解压...")
# 从临时文件磁盘解压(不占内存)
with zipfile.ZipFile(tmp_path) as zip_ref:
namelist = zip_ref.namelist()
if not namelist:
await send_progress("error", -1, "下载的压缩包结构为空", "error")
await resp.write_eof()
return resp
top_level_dir = namelist[0].split('/')[0] + '/'
await send_progress("extracting", 75, "解压安装文件...")
if os.path.exists(extract_target_path):
try:
shutil.rmtree(extract_target_path)
except Exception as e:
await send_progress("error", -1, "旧版本文件正在被 ComfyUI 进程占用,无法覆盖更新。请彻底关闭控制台黑框,重新启动 ComfyUI 后再点击更新。", "error")
await resp.write_eof()
return resp
os.makedirs(extract_target_path, exist_ok=True)
total_files = len(namelist)
processed = 0
for member in namelist:
if member.startswith(top_level_dir):
target_path = member.replace(top_level_dir, "", 1)
if not target_path:
continue
# 防止路径穿越攻击 - 第一层防御
if ".." in target_path or target_path.startswith("/") or target_path.startswith("\\"):
continue
# 防止路径穿越攻击 - 第二层防御:使用 normpath 规范化检查
abs_target = os.path.normpath(os.path.join(extract_target_path, target_path))
abs_base = os.path.normpath(extract_target_path)
if not abs_target.startswith(abs_base):
continue
source = zip_ref.open(member)
dest_path = os.path.join(extract_target_path, target_path)
if member.endswith('/'):
os.makedirs(dest_path, exist_ok=True)
else:
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
with open(dest_path, "wb") as target:
shutil.copyfileobj(source, target)
processed += 1
# 每解压 50 个文件推送一次进度
if processed % 50 == 0:
progress_pct = 75 + int(processed / total_files * 15) # 75%~90%
await send_progress("installing", progress_pct, f"写入目标目录... {processed}/{total_files}")
await send_progress("complete", 100, "✅ 安装成功!", "success")
except urllib.error.HTTPError as e:
err_msg = e.read().decode('utf-8', errors='ignore')
await send_progress("error", -1, f"拉取中断: {err_msg}", "error")
except zipfile.BadZipFile:
await send_progress("error", -1, "下载的文件不是有效的ZIP格式,可能原因:云端代理认证失败、私有仓库权限不足或已删除、GitHub 返回了错误页面。请检查仓库地址和密钥是否正确。", "error")
except Exception as e:
await send_progress("error", -1, f"本地解压覆盖异常: {str(e)}", "error")
finally:
# 清理临时文件,捕获异常防止覆盖响应(Windows下需重试避免WinError 32)
if tmp_path and os.path.exists(tmp_path):
for _retry in range(5):
try:
await asyncio.to_thread(os.unlink, tmp_path)
break
except Exception:
await asyncio.sleep(0.5)
else:
print(f"[ComfyUI-Ranking] ⚠️ 临时文件清理失败(已重试5次,不影响安装结果): {tmp_path}")
await resp.write_eof()
return resp