Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/distributed_level_runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def my_orch(w, args):
payload.callable = chip_callable.buffer_ptr()
payload.args = task_args.__ptr__()
payload.block_dim = 24
r = w.submit(WorkerType.CHIP, payload, outputs=[64])
r = w.submit(WorkerType.NEXT_LEVEL, payload, outputs=[64])

# SubWorker task: runs Python callable, depends on chip output
sub_p = WorkerPayload()
Expand All @@ -254,7 +254,7 @@ def my_orch(w, args):
args_list.append(a.__ptr__())

# 1 DAG node, 4 chips execute in parallel
w.submit(WorkerType.CHIP, payload, args_list=args_list, outputs=[out_size])
w.submit(WorkerType.NEXT_LEVEL, payload, args_list=args_list, outputs=[out_size])
```

### Why It's Uniform
Expand Down
23 changes: 10 additions & 13 deletions python/bindings/dist_worker_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ namespace nb = nanobind;

inline void bind_dist_worker(nb::module_ &m) {
// --- WorkerType ---
nb::enum_<WorkerType>(m, "WorkerType")
.value("CHIP", WorkerType::CHIP)
.value("SUB", WorkerType::SUB)
.value("DIST", WorkerType::DIST);
nb::enum_<WorkerType>(m, "WorkerType").value("NEXT_LEVEL", WorkerType::NEXT_LEVEL).value("SUB", WorkerType::SUB);

// --- TaskState ---
nb::enum_<TaskState>(m, "TaskState")
Expand Down Expand Up @@ -167,27 +164,27 @@ inline void bind_dist_worker(nb::module_ &m) {
)

.def(
"add_chip_worker",
"add_next_level_worker",
[](DistWorker &self, DistWorker &w) {
self.add_worker(WorkerType::CHIP, &w);
self.add_worker(WorkerType::NEXT_LEVEL, &w);
},
nb::arg("worker"), "Add a lower-level DistWorker as a CHIP sub-worker (for L4+)."
nb::arg("worker"), "Add a lower-level DistWorker as a NEXT_LEVEL sub-worker."
)

.def(
"add_chip_worker_native",
"add_next_level_worker",
[](DistWorker &self, ChipWorker &w) {
self.add_worker(WorkerType::CHIP, &w);
self.add_worker(WorkerType::NEXT_LEVEL, &w);
},
nb::arg("worker"), "Add a ChipWorker (_ChipWorker) as a CHIP sub-worker (for L3)."
nb::arg("worker"), "Add a ChipWorker as a NEXT_LEVEL sub-worker."
)

.def(
"add_chip_process",
"add_next_level_worker",
[](DistWorker &self, DistChipProcess &w) {
self.add_worker(WorkerType::CHIP, &w);
self.add_worker(WorkerType::NEXT_LEVEL, &w);
},
nb::arg("worker"), "Add a forked ChipProcess as a CHIP sub-worker (process-isolated)."
nb::arg("worker"), "Add a forked process as a NEXT_LEVEL sub-worker."
)

.def(
Expand Down
34 changes: 20 additions & 14 deletions python/simpler/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
w.init()

def my_orch(w, args):
r = w.submit(WorkerType.CHIP, chip_payload, inputs=[...], outputs=[64])
r = w.submit(WorkerType.NEXT_LEVEL, chip_payload, inputs=[...], outputs=[64])
w.submit(WorkerType.SUB, sub_payload(cid), inputs=[r.outputs[0].ptr])

w.run(Task(orch=my_orch, args=my_args))
Expand Down Expand Up @@ -235,6 +235,8 @@ def __init__(self, level: int, **config) -> None:

def register(self, fn: Callable) -> int:
"""Register a callable for SubWorker use. Must be called before init()."""
if self.level < 3:
raise RuntimeError("Worker.register() is only available at level 3+")
if self._initialized:
raise RuntimeError("Worker.register() must be called before init()")
cid = len(self._callable_registry)
Expand Down Expand Up @@ -365,7 +367,7 @@ def _start_level3(self) -> None:
for shm in self._chip_shms:
cp = DistChipProcess(_mailbox_addr(shm), self._l3_args_size)
self._dist_chip_procs.append(cp)
dw.add_chip_process(cp)
dw.add_next_level_worker(cp)

for shm in self._shms:
sw = DistSubWorker(_mailbox_addr(shm))
Expand All @@ -391,19 +393,8 @@ def run(self, task_or_payload, args=None, **kwargs) -> None:
if self.level == 2:
assert self._chip_worker is not None
if isinstance(task_or_payload, WorkerPayload):
from .task_interface import ChipCallConfig # noqa: PLC0415

config = ChipCallConfig()
config.block_dim = task_or_payload.block_dim
config.aicpu_thread_num = task_or_payload.aicpu_thread_num
config.enable_profiling = task_or_payload.enable_profiling
self._chip_worker.run(
task_or_payload.callable, # type: ignore[arg-type]
task_or_payload.args,
config,
)
self._run_l2_from_payload(task_or_payload)
else:
# run(callable, args, **kwargs)
self._chip_worker.run(task_or_payload, args, **kwargs)
else:
self._start_level3()
Expand All @@ -412,6 +403,21 @@ def run(self, task_or_payload, args=None, **kwargs) -> None:
task.orch(self, task.args)
self._dist_worker.drain()

def _run_l2_from_payload(self, payload: WorkerPayload) -> None:
"""Unpack a WorkerPayload and forward to ChipWorker (L2 only)."""
from .task_interface import ChipCallConfig # noqa: PLC0415

assert self._chip_worker is not None
config = ChipCallConfig()
config.block_dim = payload.block_dim
config.aicpu_thread_num = payload.aicpu_thread_num
config.enable_profiling = payload.enable_profiling
self._chip_worker.run(
payload.callable, # type: ignore[arg-type]
payload.args,
config,
)
Comment on lines +408 to +419
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _run_l2_from_payload helper currently passes raw pointers (integers) from WorkerPayload to self._chip_worker.run(). However, ChipWorker.run() expects higher-level objects, not raw memory addresses. This will lead to a TypeError at runtime. Since WorkerPayload is designed to hold raw pointers for cross-process dispatch, you should use the run_raw method on the underlying _impl instance, which is designed to handle these pointers directly.

        assert self._chip_worker is not None
        self._chip_worker._impl.run_raw(
            payload.callable,
            payload.args,
            payload.block_dim,
            payload.aicpu_thread_num,
            payload.enable_profiling,
        )


# ------------------------------------------------------------------
# Orchestration API (called from inside orch functions at L3+)
# ------------------------------------------------------------------
Expand Down
12 changes: 6 additions & 6 deletions src/common/distributed/dist_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void DistScheduler::start(const Config &cfg) {
threads.push_back(std::move(wt));
}
};
make_threads(cfg_.chip_workers, chip_threads_);
make_threads(cfg_.next_level_workers, next_level_threads_);
make_threads(cfg_.sub_workers, sub_threads_);

stop_requested_.store(false, std::memory_order_relaxed);
Expand All @@ -95,11 +95,11 @@ void DistScheduler::stop() {

if (sched_thread_.joinable()) sched_thread_.join();

for (auto &wt : chip_threads_)
for (auto &wt : next_level_threads_)
wt->stop();
for (auto &wt : sub_threads_)
wt->stop();
chip_threads_.clear();
next_level_threads_.clear();
sub_threads_.clear();

running_.store(false, std::memory_order_release);
Expand Down Expand Up @@ -157,7 +157,7 @@ void DistScheduler::run() {
// Exit when stop requested and all workers idle
if (stop_requested_.load(std::memory_order_acquire)) {
bool any_busy = false;
for (auto &wt : chip_threads_)
for (auto &wt : next_level_threads_)
if (!wt->idle()) {
any_busy = true;
break;
Expand Down Expand Up @@ -268,15 +268,15 @@ void DistScheduler::dispatch_ready() {
}

WorkerThread *DistScheduler::pick_idle(WorkerType type) {
auto &threads = (type == WorkerType::CHIP) ? chip_threads_ : sub_threads_;
auto &threads = (type == WorkerType::NEXT_LEVEL) ? next_level_threads_ : sub_threads_;
for (auto &wt : threads) {
if (wt->idle()) return wt.get();
}
return nullptr;
}

std::vector<WorkerThread *> DistScheduler::pick_n_idle(WorkerType type, int n) {
auto &threads = (type == WorkerType::CHIP) ? chip_threads_ : sub_threads_;
auto &threads = (type == WorkerType::NEXT_LEVEL) ? next_level_threads_ : sub_threads_;
std::vector<WorkerThread *> result;
result.reserve(n);
for (auto &wt : threads) {
Expand Down
6 changes: 3 additions & 3 deletions src/common/distributed/dist_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class DistScheduler {
DistTaskSlotState *slots;
int32_t num_slots;
DistReadyQueue *ready_queue;
std::vector<IWorker *> chip_workers; // WorkerType::CHIP
std::vector<IWorker *> sub_workers; // WorkerType::SUB
std::vector<IWorker *> next_level_workers; // WorkerType::NEXT_LEVEL
std::vector<IWorker *> sub_workers; // WorkerType::SUB
// Called when a task reaches CONSUMED (TensorMap cleanup + ring release).
std::function<void(DistTaskSlot)> on_consumed_cb;
};
Expand All @@ -104,7 +104,7 @@ class DistScheduler {
Config cfg_;

// Per-worker threads
std::vector<std::unique_ptr<WorkerThread>> chip_threads_;
std::vector<std::unique_ptr<WorkerThread>> next_level_threads_;
std::vector<std::unique_ptr<WorkerThread>> sub_threads_;

// Shared completion queue (WorkerThread → Scheduler)
Expand Down
7 changes: 3 additions & 4 deletions src/common/distributed/dist_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ using DistTaskSlot = int32_t;
// =============================================================================

enum class WorkerType : int32_t {
CHIP = 0, // ChipWorker: L2 hardware device
SUB = 1, // SubWorker: fork/shm Python function
DIST = 2, // DistWorker: lower-level node (L4+)
NEXT_LEVEL = 0, // Next-level Worker (L3→ChipWorker, L4→DistWorker(L3), …)
SUB = 1, // SubWorker: fork/shm Python function
};

// =============================================================================
Expand All @@ -75,7 +74,7 @@ enum class TaskState : int32_t {

struct WorkerPayload {
DistTaskSlot task_slot = DIST_INVALID_SLOT;
WorkerType worker_type = WorkerType::CHIP;
WorkerType worker_type = WorkerType::NEXT_LEVEL;

// --- ChipWorker fields (set in PR 2-2) ---
const void *callable = nullptr; // ChipCallable buffer ptr
Expand Down
4 changes: 2 additions & 2 deletions src/common/distributed/dist_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ DistWorker::~DistWorker() {

void DistWorker::add_worker(WorkerType type, IWorker *worker) {
if (initialized_) throw std::runtime_error("DistWorker: add_worker after init");
if (type == WorkerType::CHIP || type == WorkerType::DIST) chip_workers_.push_back(worker);
if (type == WorkerType::NEXT_LEVEL) next_level_workers_.push_back(worker);
else sub_workers_.push_back(worker);
}

Expand All @@ -38,7 +38,7 @@ void DistWorker::init() {
cfg.slots = slots_.get();
cfg.num_slots = DIST_TASK_WINDOW_SIZE;
cfg.ready_queue = &ready_queue_;
cfg.chip_workers = chip_workers_;
cfg.next_level_workers = next_level_workers_;
cfg.sub_workers = sub_workers_;
cfg.on_consumed_cb = [this](DistTaskSlot slot) {
on_consumed(slot);
Expand Down
6 changes: 3 additions & 3 deletions src/common/distributed/dist_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
* Usage (L3 host worker, instantiated from Python via nanobind):
*
* DistWorker dw(level=3);
* dw.add_worker(WorkerType::CHIP, chip_worker_ptr);
* dw.add_worker(WorkerType::NEXT_LEVEL, chip_worker_ptr);
* dw.add_worker(WorkerType::SUB, sub_worker_ptr);
* dw.init();
*
Expand All @@ -32,7 +32,7 @@
* dw.execute(); // blocks until all submitted tasks complete
*
* // When used as an IWorker by a higher-level DistWorker (L4+):
* parent.add_worker(WorkerType::DIST, &dw);
* parent.add_worker(WorkerType::NEXT_LEVEL, &dw);
* // parent scheduler calls dw.dispatch() / dw.poll()
*/

Expand Down Expand Up @@ -107,7 +107,7 @@ class DistWorker : public IWorker {
DistOrchestrator orchestrator_;
DistScheduler scheduler_;

std::vector<IWorker *> chip_workers_;
std::vector<IWorker *> next_level_workers_;
std::vector<IWorker *> sub_workers_;

// --- Drain support ---
Expand Down
4 changes: 2 additions & 2 deletions tests/st/a2a3/tensormap_and_ringbuffer/test_l3_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def run_dag(w, callables, task_args, config):
callables.keep(chip_args) # prevent GC before drain

chip_p = WorkerPayload()
chip_p.worker_type = WorkerType.CHIP
chip_p.worker_type = WorkerType.NEXT_LEVEL
chip_p.callable = callables.vector_kernel.buffer_ptr()
chip_p.args = chip_args.__ptr__()
chip_p.block_dim = config.block_dim
chip_p.aicpu_thread_num = config.aicpu_thread_num
chip_result = w.submit(WorkerType.CHIP, chip_p, inputs=[], outputs=[task_args.f.numel() * 4])
chip_result = w.submit(WorkerType.NEXT_LEVEL, chip_p, inputs=[], outputs=[task_args.f.numel() * 4])

sub_p = WorkerPayload()
sub_p.worker_type = WorkerType.SUB
Expand Down
4 changes: 2 additions & 2 deletions tests/st/a2a3/tensormap_and_ringbuffer/test_l3_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def run_dag(w, callables, task_args, config):
callables.keep(args0, args1) # prevent GC before drain

chip_p = WorkerPayload()
chip_p.worker_type = WorkerType.CHIP
chip_p.worker_type = WorkerType.NEXT_LEVEL
chip_p.callable = callables.vector_kernel.buffer_ptr()
chip_p.block_dim = config.block_dim
chip_p.aicpu_thread_num = config.aicpu_thread_num
group_result = w.submit(WorkerType.CHIP, chip_p, args_list=[args0.__ptr__(), args1.__ptr__()], outputs=[4])
group_result = w.submit(WorkerType.NEXT_LEVEL, chip_p, args_list=[args0.__ptr__(), args1.__ptr__()], outputs=[4])

sub_p = WorkerPayload()
sub_p.worker_type = WorkerType.SUB
Expand Down
4 changes: 2 additions & 2 deletions tests/ut/cpp/test_dist_orchestrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ struct OrchestratorFixture : public ::testing::Test {
// Submit a CHIP task with the given input/output specs.
DistSubmitResult submit_chip(const std::vector<DistInputSpec> &inputs, const std::vector<DistOutputSpec> &outputs) {
WorkerPayload p;
p.worker_type = WorkerType::CHIP;
return orch.submit(WorkerType::CHIP, p, inputs, outputs);
p.worker_type = WorkerType::NEXT_LEVEL;
return orch.submit(WorkerType::NEXT_LEVEL, p, inputs, outputs);
}
};

Expand Down
24 changes: 12 additions & 12 deletions tests/ut/cpp/test_dist_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct SchedulerFixture : public ::testing::Test {
cfg.slots = slots.get();
cfg.num_slots = N;
cfg.ready_queue = &rq;
cfg.chip_workers = {&chip_worker};
cfg.next_level_workers = {&chip_worker};
cfg.on_consumed_cb = [this](DistTaskSlot s) {
orch.on_consumed(s);
std::lock_guard<std::mutex> lk(consumed_mu);
Expand All @@ -124,8 +124,8 @@ struct SchedulerFixture : public ::testing::Test {

DistSubmitResult submit_chip(const std::vector<DistInputSpec> &inputs, const std::vector<DistOutputSpec> &outputs) {
WorkerPayload p;
p.worker_type = WorkerType::CHIP;
return orch.submit(WorkerType::CHIP, p, inputs, outputs);
p.worker_type = WorkerType::NEXT_LEVEL;
return orch.submit(WorkerType::NEXT_LEVEL, p, inputs, outputs);
}

void wait_consumed(DistTaskSlot slot, int timeout_ms = 500) {
Expand Down Expand Up @@ -213,7 +213,7 @@ struct GroupSchedulerFixture : public ::testing::Test {
cfg.slots = slots.get();
cfg.num_slots = N;
cfg.ready_queue = &rq;
cfg.chip_workers = {&worker_a, &worker_b};
cfg.next_level_workers = {&worker_a, &worker_b};
cfg.on_consumed_cb = [this](DistTaskSlot s) {
orch.on_consumed(s);
std::lock_guard<std::mutex> lk(consumed_mu);
Expand Down Expand Up @@ -247,10 +247,10 @@ TEST_F(GroupSchedulerFixture, GroupDispatchesToNWorkers) {
int dummy_args_1 = 1;

WorkerPayload p;
p.worker_type = WorkerType::CHIP;
p.worker_type = WorkerType::NEXT_LEVEL;
std::vector<const void *> args_list = {&dummy_args_0, &dummy_args_1};

auto res = orch.submit_group(WorkerType::CHIP, p, args_list, {}, {{64}});
auto res = orch.submit_group(WorkerType::NEXT_LEVEL, p, args_list, {}, {{64}});
DistTaskSlot slot = res.task_slot;

// Both workers should receive dispatches
Expand All @@ -274,9 +274,9 @@ TEST_F(GroupSchedulerFixture, GroupDispatchesToNWorkers) {
TEST_F(GroupSchedulerFixture, GroupCompletesOnlyWhenAllDone) {
int d0 = 0, d1 = 1;
WorkerPayload p;
p.worker_type = WorkerType::CHIP;
p.worker_type = WorkerType::NEXT_LEVEL;

auto res = orch.submit_group(WorkerType::CHIP, p, {&d0, &d1}, {}, {});
auto res = orch.submit_group(WorkerType::NEXT_LEVEL, p, {&d0, &d1}, {}, {});
DistTaskSlot slot = res.task_slot;

worker_a.wait_running();
Expand All @@ -297,15 +297,15 @@ TEST_F(GroupSchedulerFixture, GroupDependencyChain) {
// Task B depends on A's output — B stays PENDING until group A finishes.
int d0 = 0, d1 = 1;
WorkerPayload pa;
pa.worker_type = WorkerType::CHIP;
pa.worker_type = WorkerType::NEXT_LEVEL;

auto a = orch.submit_group(WorkerType::CHIP, pa, {&d0, &d1}, {}, {{128}});
auto a = orch.submit_group(WorkerType::NEXT_LEVEL, pa, {&d0, &d1}, {}, {{128}});
uint64_t a_out = reinterpret_cast<uint64_t>(a.outputs[0].ptr);

// Submit B depending on A's output
WorkerPayload pb;
pb.worker_type = WorkerType::CHIP;
auto b = orch.submit(WorkerType::CHIP, pb, {{a_out}}, {});
pb.worker_type = WorkerType::NEXT_LEVEL;
auto b = orch.submit(WorkerType::NEXT_LEVEL, pb, {{a_out}}, {});
EXPECT_EQ(slots[b.task_slot].state.load(), TaskState::PENDING);

// Complete group A
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/py/test_dist_worker/test_host_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def orch(hw, _args):
# no workers — submit with empty workers list isn't useful here;
# instead verify that submit() allocates output buffers correctly
# by using a SUB worker that immediately signals done
p.worker_type = WorkerType.CHIP # no CHIP workers — task stays RUNNING
p.worker_type = WorkerType.NEXT_LEVEL # no NEXT_LEVEL workers — task stays RUNNING
# For output allocation test, just verify DistSubmitResult has outputs
# We re-init with sub workers for a real execution test
pass
Expand Down
Loading