Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion comfy/ldm/cosmos/cosmos_tokenizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def is_odd(n: int) -> bool:


def nonlinearity(x):
return x * torch.sigmoid(x)
# x * sigmoid(x)
return torch.nn.functional.silu(x)


def Normalize(in_channels, num_groups=32):
Expand Down
2 changes: 1 addition & 1 deletion comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_timestep_embedding(timesteps, embedding_dim):

def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
return torch.nn.functional.silu(x)


def Normalize(in_channels, num_groups=32):
Expand Down
7 changes: 4 additions & 3 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import traceback
from enum import Enum
from typing import List, Literal, NamedTuple, Optional
from typing import List, Literal, NamedTuple, Optional, Union
import asyncio

import torch
Expand Down Expand Up @@ -891,7 +891,7 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__

async def validate_prompt(prompt_id, prompt):
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
Expand All @@ -915,7 +915,8 @@ async def validate_prompt(prompt_id, prompt):
return (False, error, [], {})

if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
if partial_execution_list is None or x in partial_execution_list:
outputs.add(x)

if len(outputs) == 0:
error = {
Expand Down
7 changes: 6 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,12 @@ async def post_prompt(request):
if "prompt" in json_data:
prompt = json_data["prompt"]
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
valid = await execution.validate_prompt(prompt_id, prompt)

partial_execution_targets = None
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]

valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
extra_data = json_data["extra_data"]
Expand Down
15 changes: 14 additions & 1 deletion tests/inference/test_async_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient
from tests.inference.test_execution import ComfyClient, run_warmup


@pytest.mark.execution
Expand All @@ -24,6 +24,7 @@ def _server(self, args_pytest, request):
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
if use_lru:
Expand Down Expand Up @@ -82,6 +83,9 @@ def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder)

def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel."""
# Warmup execution to ensure server is fully initialized
run_warmup(client)

g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)

Expand Down Expand Up @@ -148,6 +152,9 @@ def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder)

def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_lazy")

g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
Expand Down Expand Up @@ -305,6 +312,9 @@ def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphB

def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_cache")

g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
Expand All @@ -324,6 +334,9 @@ def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder

def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_dynamic")

g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
Expand Down
Loading
Loading