-
Notifications
You must be signed in to change notification settings - Fork 222
[tx] Refactor TinkerEngine to use backend architecture #787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces a clean separation between engine (orchestration) and backend (computation): **New files in `backends/`:** - `backend.py`: AbstractBackend interface defining the contract - `native.py`: NativeBackend implementation (extracted from engine.py) - `utils.py`: Shared utilities (log_timing, pad, pad_batch) - `__init__.py`: Module exports **Engine responsibilities (engine.py):** - Database operations (futures, checkpoints) - Request validation (`_filter_valid_requests`) - Data extraction from requests (`_prepare_model_pass_batch`, `_prepare_sample_batch`) - File I/O (checkpoint download/upload) - Orchestration of batch processing **Backend responsibilities (native.py):** - Model initialization and state management - JAX/Flax computation (forward, backward, gradient accumulation) - Optimizer creation and updates - Checkpoint data extraction/insertion **New types in types.py:** - `PreparedModelPassBatch`: Batch data for forward/backward ops - `PreparedSampleBatch`: Batch data for sampling ops This is a purely structural refactor - no functional changes to computation logic. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring of the TinkerEngine to use a backend architecture. The clean separation between the engine for orchestration and the backend for computation greatly improves the code's structure, clarity, and maintainability. The new AbstractBackend interface is well-defined, and the NativeBackend correctly encapsulates the JAX/Flax computation logic. The changes in TinkerEngine make it much cleaner and easier to follow. I've identified one potential high-severity issue regarding a division by zero and a couple of medium-severity suggestions for improving code clarity.
- Added `metrics` property in `TinkerEngine` for backward compatibility with backend metrics. - Introduced `configure_adapter` method in `AbstractBackend` to streamline LoRA adapter configuration. - Updated `NativeBackend` to implement the new `configure_adapter` method, replacing the previous `update_adapter_config` call. These changes improve the modularity and maintainability of the codebase while ensuring compatibility with existing metrics functionality.
- Replace create_optimizer + configure_adapter with register_model - Move optimizer storage from engine to backend (self.optimizers dict) - Remove optimizer params from checkpoint methods - Add min_seq_len param to round_up_seq_len calls 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Refactor backend to manage optimizers internally
added use_eager_sharding and nan-prevention
|
@pcmoritz can you take a look? you said you wanted to add pytorch compat at some point, I got another backend in a another PR with maxtext working with a lot of advanced sharding options (got 100k sequence length working on a TPU v5p-8 with 30B-A3B). |
# Conflicts: # skyrl-tx/tx/tinker/engine.py
|
@pcmoritz TODO: Port FSDP changes back into the native backend |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant and well-executed refactoring that introduces a backend architecture to TinkerEngine. The separation of concerns between the engine (orchestration) and the backend (computation) is clear and greatly improves the codebase's structure and maintainability. The new AbstractBackend interface is well-defined, and the NativeBackend correctly encapsulates the JAX/Flax computation logic.
I've identified a critical issue where request validation was inadvertently dropped for batch processing methods, which could lead to crashes. I've also found a medium-severity issue regarding a potential division-by-zero in gradient calculation.
Overall, this is an excellent structural improvement. Once the identified issues are addressed, this will be a great addition to the codebase.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant and well-executed refactoring of TinkerEngine. It introduces a clean separation between the engine (orchestration) and a new backend layer (computation) by creating an AbstractBackend and a NativeBackend implementation. This change greatly improves the modularity and maintainability of the codebase, making it easier to add new backends in the future as mentioned in the PR description. The logic has been correctly moved from the engine to the native backend, and the engine's role is now clearly focused on orchestration tasks like DB operations and request handling. The introduction of PreparedModelPassBatch and PreparedSampleBatch types for communication between the engine and backend is a good design choice. While described as a purely structural refactor, I noticed it also includes a subtle but important bug fix in the loss calculation to prevent division by zero, which is a great improvement. I've added a couple of suggestions for further minor performance and robustness enhancements. Overall, this is excellent work that significantly improves the architecture.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed structural refactoring by separating the TinkerEngine's orchestration logic from the computation logic, which is now handled by a new backend architecture. The introduction of AbstractBackend and NativeBackend creates a clean separation of concerns, making the system more modular and extensible. The related test and benchmark files have been updated correctly to reflect these architectural changes. My review comments focus on improving the robustness of the new NativeBackend by making a utility function safe against division-by-zero and replacing assert statements with proper ValueError exceptions for input validation. Overall, this is a high-quality refactoring that greatly improves the codebase's structure.
|
@OhadRubin Thanks a lot for your PR! I did some refactoring, mostly making the AbstractBackend more canonical by removing the jax dependency, and streamlining the methods to line up very closely with the tinker specification (in particular, I removed the adapter_index handling from the engine, since different backends might handle it differently). This way will give more flexibility to the backend implementation, e.g. we would like to add a backend that is implemented using skyrl-train as part of the ongoing SkyRL tinkerification effort. If backends want to share this code, we can do so via a utilities file? Do you want to have a look at the changes and let me know if that works for you (e.g. if you can implement #788 with this API) or if you have any suggestions for improvement? Thanks again a lot for the contribution, this will also help making the jax backend multi-node, which I'm planning to work on soon :) |
|
@pcmoritz Looks good! I'll align my draft PR to it. A few questions:
|
|
Thanks for the feedback!
Let me know if you have thoughts about these things, in the meantime I'll go ahead and rename the |
|
I pushed a first version of the refactor, it will need a little more refinement though, I'll get to that in a few hours :) |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a well-executed refactoring that introduces a clean separation between the TinkerEngine and a new backend architecture. The introduction of AbstractBackend and the JaxBackend implementation successfully decouples orchestration from computation, which will make the system more modular and extensible. The configuration has also been nicely refactored to move backend-specific settings into a backend_config dictionary.
I have a couple of suggestions for further improvement regarding adapter index allocation and a leaky abstraction in the engine, but overall this is a high-quality structural change.
| # TODO: This leaks the abstraction by accessing backend-specific config. | ||
| # We should find a better way to handle this going forward. | ||
| if isinstance(self.backend, JaxBackend) and self.backend.config.sample_max_num_sequences > 0: | ||
| batchable = batchable[: self.backend.config.sample_max_num_sequences] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the TODO mentions, this is a leaky abstraction. To resolve this, you could move the batch capping logic into the backend.
Here's a potential approach:
-
Add a method to
AbstractBackendto allow backends to cap the batch size:# In tx/tinker/backends/backend.py class AbstractBackend(ABC): ... def cap_sample_batch(self, requests: list) -> list: """Cap the number of sample requests in a batch according to backend constraints.""" return requests
-
Override this in
JaxBackend:# In tx/tinker/backends/jax.py class JaxBackend(AbstractBackend): ... def cap_sample_batch(self, requests: list) -> list: if self.config.sample_max_num_sequences > 0: return requests[:self.config.sample_max_num_sequences] return requests
-
Update
TinkerEngine.find_batchable_sampleto use this new backend method, removing theisinstancecheck:# In tx/tinker/engine.py ... for op in sample_ops: ... if ...: batchable.append(op) batchable = self.backend.cap_sample_batch(batchable) return {str(f.request_id): ... for f in batchable}
This would remove the isinstance check and knowledge of backend-specific configuration from the engine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not necessarily better since then this leaks into every single backend. I would anticipate that most non JaxBackend backends will use #568 anyways. Maybe the right solution going forward will be to remove sampling from the backend, but we will revisit this once we have at least one more backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcmoritz
Yeah I think it's fine, in my maxtext backend I'm using vllm anyway by default.
That's why I originally called it the native backend, because it's the only one that gets to leak abstractions haha.
P.S I will have my maxtext ported over to the new interface in the next week or so.
Btw, I did have to add LORA support to the maxtext implementation, so right now it's dependent on my fork, but I will open a PR to the maxtext repo and will see how it goes.
Summary
(TL;DR: i got maxtext working as a backend with a lot of models and a lot of sharding options etc, i'm splitting it into multiple PR's if we could integrate this I could add my maxtext backend and it would allow a lot of flexibility)
Introduces a clean separation between engine (orchestration) and backend (computation) in TinkerEngine.
New files in
backends/backend.py- AbstractBackend interfacejax.py- JaxBackend implementation (extracted from engine.py)utils.py- Shared utilities (log_timing, pad, pad_batch)New types
PreparedModelPassBatch- Batch data for forward/backward opsPreparedSampleBatch- Batch data for sampling opsTest plan
This is a purely structural refactor - no functional changes to computation logic. Line-by-line comparison confirms identical behavior.
🤖 Generated with Claude Code