Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand All @@ -13,6 +13,11 @@
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#if !defined(_WIN32)
#include <sys/mman.h>
#include <unistd.h>
#include <cerrno>
#endif
#include <cctype>
#include <cstdio>

Expand Down Expand Up @@ -275,6 +280,45 @@
method_name.empty() ? "so_blob" : method_name + "_so_blob";

const NamedDataMap* named_data_map = context.get_named_data_map();
ET_CHECK_OR_RETURN_ERROR(
named_data_map != nullptr,
Internal,
"CudaBackend requires a NamedDataMap for weight loading");

// Prefetch the weights blob — trigger async readahead so pages are
// resident by the time update_constants_from_blob memcpy's them.
// This overlaps disk I/O with the .so write + dlopen.
std::string weights_blob_key =
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
#if !defined(_WIN32)
{
auto prefetch_buf = named_data_map->get_data(weights_blob_key.c_str());
if (prefetch_buf.ok() && prefetch_buf->data() != nullptr) {
uintptr_t addr = reinterpret_cast<uintptr_t>(prefetch_buf->data());
size_t page_size = getpagesize();
uintptr_t aligned_addr = addr & ~(page_size - 1);
size_t aligned_size = prefetch_buf->size() + (addr - aligned_addr);
int ret = madvise(
reinterpret_cast<void*>(aligned_addr),
aligned_size,
MADV_WILLNEED);
if (ret != 0) {
ET_LOG(
Info,
"CudaBackend::init - madvise(MADV_WILLNEED) failed for %s: %s",
weights_blob_key.c_str(),
strerror(errno));
} else {
ET_LOG(
Info,
"CudaBackend::init - Prefetching %s (%.1f MB)",
weights_blob_key.c_str(),
prefetch_buf->size() / (1024.0 * 1024.0));
}
}
}
#endif

auto aoti_dso_buffer = named_data_map->get_data(so_blob_key.c_str());
ET_CHECK_OR_RETURN_ERROR(
aoti_dso_buffer.ok(),
Expand Down Expand Up @@ -338,9 +382,8 @@

handle->container_handle = container_handle;

// Look into named data map for constant data
std::string weights_blob_key =
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
// Look into named data map for constant data (key computed above for
// prefetch)
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
Expand Down
Loading