-
Notifications
You must be signed in to change notification settings - Fork 459
[Store] pub_tensor for multiple replica #1103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @stmatengss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Mooncake Store's Python integration by introducing a Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a pub_tensor method to allow publishing tensors with replication configurations, which is a great addition. The changes also include good code quality improvements like consistent use of namespaces. However, I've found a few issues. There's a critical issue in the test updates for allocation_strategy_test.cpp where the tests no longer correctly test the preferred segment allocation logic because the code under test wasn't updated. Additionally, in store_py.cpp, the tensor type check is fragile and there's significant code duplication in tensor processing logic that could be refactored for better maintainability. Please see my detailed comments.
| empty_allocators_by_name; | ||
| std::vector<std::shared_ptr<BufferAllocatorBase>> empty_allocators; | ||
| ReplicateConfig config{1, false, "local"}; | ||
| ReplicateConfig config{1, false, {"local"}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change updates the test to initialize the new preferred_segments field in ReplicateConfig. However, the RandomAllocationStrategy that is being tested here has not been updated to use this new field. It still relies on the deprecated preferred_segment field, which is now left empty by this test initialization.
As a result, this test and others in this file that check preferred segment logic are no longer effective. The allocation strategy will not use the preferred segment, and the test will fall back to random allocation, not testing the intended logic.
To fix this, RandomAllocationStrategy needs to be updated to use config.preferred_segments instead of config.preferred_segment. Since allocation_strategy.h is not part of this PR, this change makes the tests incorrect.
| if (!(tensor.attr("__class__") | ||
| .attr("__name__") | ||
| .cast<std::string>() | ||
| .find("Tensor") != std::string::npos)) { | ||
| LOG(ERROR) << "Input is not a PyTorch tensor"; | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current method of checking if the object is a PyTorch tensor by searching for the substring "Tensor" in its class name is fragile. For example, a user-defined class named MyTensor would incorrectly pass this check. A more robust approach is to use py::isinstance.
Consider using py::isinstance(tensor, torch_module().attr("Tensor")) for a more reliable type check.
if (!py::isinstance(tensor, torch_module().attr("Tensor"))) {
LOG(ERROR) << "Input is not a PyTorch tensor";
return -static_cast<int>(ErrorCode::INVALID_PARAMS);
}| int pub_tensor(const std::string &key, pybind11::object tensor, | ||
| const ReplicateConfig &config = ReplicateConfig{}) { | ||
| if (!store_ || !store_->client_) { | ||
| LOG(ERROR) << "Client is not initialized"; | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); | ||
| } | ||
|
|
||
| // Validate segment preferences | ||
| if (!config.preferred_segments.empty() && | ||
| config.preferred_segments.size() != config.replica_num) { | ||
| LOG(ERROR) << "Preferred segments size (" | ||
| << config.preferred_segments.size() | ||
| << ") must match replica_num (" << config.replica_num | ||
| << ")"; | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); | ||
| } | ||
|
|
||
| try { | ||
| if (!(tensor.attr("__class__") | ||
| .attr("__name__") | ||
| .cast<std::string>() | ||
| .find("Tensor") != std::string::npos)) { | ||
| LOG(ERROR) << "Input is not a PyTorch tensor"; | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); | ||
| } | ||
|
|
||
| uintptr_t data_ptr = tensor.attr("data_ptr")().cast<uintptr_t>(); | ||
| size_t numel = tensor.attr("numel")().cast<size_t>(); | ||
| size_t element_size = tensor.attr("element_size")().cast<size_t>(); | ||
| size_t tensor_size = numel * element_size; | ||
|
|
||
| pybind11::object shape_obj = tensor.attr("shape"); | ||
| pybind11::object dtype_obj = tensor.attr("dtype"); | ||
|
|
||
| TensorDtype dtype_enum = get_tensor_dtype(dtype_obj); | ||
| if (dtype_enum == TensorDtype::UNKNOWN) { | ||
| LOG(ERROR) << "Unsupported tensor dtype!"; | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); | ||
| } | ||
|
|
||
| pybind11::tuple shape_tuple = | ||
| pybind11::cast<pybind11::tuple>(shape_obj); | ||
| int32_t ndim = static_cast<int32_t>(shape_tuple.size()); | ||
| if (ndim > 4) { | ||
| LOG(ERROR) << "Tensor has more than 4 dimensions: " << ndim; | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); | ||
| } | ||
|
|
||
| TensorMetadata metadata; | ||
| metadata.dtype = static_cast<int32_t>(dtype_enum); | ||
| metadata.ndim = ndim; | ||
|
|
||
| for (int i = 0; i < 4; i++) { | ||
| if (i < ndim) { | ||
| metadata.shape[i] = shape_tuple[i].cast<int32_t>(); | ||
| } else { | ||
| metadata.shape[i] = -1; | ||
| } | ||
| } | ||
|
|
||
| // Section with GIL released | ||
| py::gil_scoped_release release_gil; | ||
| char *buffer = reinterpret_cast<char *>(data_ptr); | ||
| char *metadata_buffer = reinterpret_cast<char *>(&metadata); | ||
| std::vector<std::span<const char>> values; | ||
| values.emplace_back( | ||
| std::span<const char>(metadata_buffer, sizeof(TensorMetadata))); | ||
| values.emplace_back(std::span<const char>(buffer, tensor_size)); | ||
|
|
||
| // Use put_parts to put metadata and tensor together with custom | ||
| // config | ||
| auto put_result = store_->put_parts_internal(key, values, config); | ||
| if (!put_result) { | ||
| return -static_cast<int>(put_result.error()); | ||
| } | ||
|
|
||
| return 0; | ||
| } catch (const pybind11::error_already_set &e) { | ||
| LOG(ERROR) << "Failed to access tensor data: " << e.what(); | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication for tensor validation and metadata extraction between this new pub_tensor function, the existing put_tensor function (lines 279-347), and batch_put_tensor (lines 432-584). This makes the code harder to maintain, as any changes to this logic will need to be applied in multiple places.
To improve maintainability, consider extracting this common logic into a private helper function within the MooncakeStorePyWrapper class. This helper could take a py::object and return an std::optional or tl::expected containing the extracted TensorMetadata, data_ptr, and tensor_size, or an error. Both pub_tensor and put_tensor could then call this helper to simplify their implementations and eliminate redundant code.
| .cast<std::string>() | ||
| .find("Tensor") != std::string::npos)) { | ||
| LOG(ERROR) << "Input is not a PyTorch tensor"; | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ErrorCode here is already a negative value, so there’s no need to negate it again. There are a few other places in the code with the same issue — I’ll submit a PR to fix them. Ideally, we should handle this uniformly inside a function like to_py_ret.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This issue in existing functions is fixed in this pr: #1129
You could also use the to_py_ret
| } | ||
| } | ||
|
|
||
| int pub_tensor(const std::string &key, pybind11::object tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add some tests in scripts/test_tensor_api.py to make the pub/sub work as expected.
Description
Type of Change
How Has This Been Tested?
Checklist