Skip to content

Conversation

@OhadRubin
Copy link
Contributor

@OhadRubin OhadRubin commented Dec 17, 2025

Summary

(TL;DR: i got maxtext working as a backend with a lot of models and a lot of sharding options etc, i'm splitting it into multiple PR's if we could integrate this I could add my maxtext backend and it would allow a lot of flexibility)
Introduces a clean separation between engine (orchestration) and backend (computation) in TinkerEngine.

  • Engine handles: DB operations, request validation, data extraction, file I/O, orchestration
  • Backend handles: Model state, JAX/Flax computation, gradient accumulation, optimizer updates

New files in backends/

  • backend.py - AbstractBackend interface
  • jax.py - JaxBackend implementation (extracted from engine.py)
  • utils.py - Shared utilities (log_timing, pad, pad_batch)

New types

  • PreparedModelPassBatch - Batch data for forward/backward ops
  • PreparedSampleBatch - Batch data for sampling ops

Test plan

  • Verify forward/backward batch processing produces identical results
  • Verify sampling produces identical results
  • Verify checkpoint save/load works correctly
  • Verify optimizer step applies gradients correctly

This is a purely structural refactor - no functional changes to computation logic. Line-by-line comparison confirms identical behavior.

🤖 Generated with Claude Code

Introduces a clean separation between engine (orchestration) and backend (computation):

**New files in `backends/`:**
- `backend.py`: AbstractBackend interface defining the contract
- `native.py`: NativeBackend implementation (extracted from engine.py)
- `utils.py`: Shared utilities (log_timing, pad, pad_batch)
- `__init__.py`: Module exports

**Engine responsibilities (engine.py):**
- Database operations (futures, checkpoints)
- Request validation (`_filter_valid_requests`)
- Data extraction from requests (`_prepare_model_pass_batch`, `_prepare_sample_batch`)
- File I/O (checkpoint download/upload)
- Orchestration of batch processing

**Backend responsibilities (native.py):**
- Model initialization and state management
- JAX/Flax computation (forward, backward, gradient accumulation)
- Optimizer creation and updates
- Checkpoint data extraction/insertion

**New types in types.py:**
- `PreparedModelPassBatch`: Batch data for forward/backward ops
- `PreparedSampleBatch`: Batch data for sampling ops

This is a purely structural refactor - no functional changes to computation logic.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@OhadRubin OhadRubin changed the title Refactor TinkerEngine to use backend architecture [tx] Refactor TinkerEngine to use backend architecture Dec 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring of the TinkerEngine to use a backend architecture. The clean separation between the engine for orchestration and the backend for computation greatly improves the code's structure, clarity, and maintainability. The new AbstractBackend interface is well-defined, and the NativeBackend correctly encapsulates the JAX/Flax computation logic. The changes in TinkerEngine make it much cleaner and easier to follow. I've identified one potential high-severity issue regarding a division by zero and a couple of medium-severity suggestions for improving code clarity.

- Added `metrics` property in `TinkerEngine` for backward compatibility with backend metrics.
- Introduced `configure_adapter` method in `AbstractBackend` to streamline LoRA adapter configuration.
- Updated `NativeBackend` to implement the new `configure_adapter` method, replacing the previous `update_adapter_config` call.

These changes improve the modularity and maintainability of the codebase while ensuring compatibility with existing metrics functionality.
OhadRubin and others added 3 commits December 20, 2025 07:04
- Replace create_optimizer + configure_adapter with register_model
- Move optimizer storage from engine to backend (self.optimizers dict)
- Remove optimizer params from checkpoint methods
- Add min_seq_len param to round_up_seq_len calls

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Refactor backend to manage optimizers internally
added use_eager_sharding and nan-prevention
@OhadRubin
Copy link
Contributor Author

OhadRubin commented Dec 20, 2025

@pcmoritz can you take a look? you said you wanted to add pytorch compat at some point, I got another backend in a another PR with maxtext working with a lot of advanced sharding options (got 100k sequence length working on a TPU v5p-8 with 30B-A3B).
(piglatin.py works with this btw, haven't tested other things yet)

@pcmoritz pcmoritz added the tx label Dec 22, 2025
@pcmoritz
Copy link
Collaborator

@pcmoritz TODO: Port FSDP changes back into the native backend

@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring that introduces a backend architecture to TinkerEngine. The separation of concerns between the engine (orchestration) and the backend (computation) is clear and greatly improves the codebase's structure and maintainability. The new AbstractBackend interface is well-defined, and the NativeBackend correctly encapsulates the JAX/Flax computation logic.

I've identified a critical issue where request validation was inadvertently dropped for batch processing methods, which could lead to crashes. I've also found a medium-severity issue regarding a potential division-by-zero in gradient calculation.

Overall, this is an excellent structural improvement. Once the identified issues are addressed, this will be a great addition to the codebase.

pcmoritz and others added 3 commits December 26, 2025 07:44
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring of TinkerEngine. It introduces a clean separation between the engine (orchestration) and a new backend layer (computation) by creating an AbstractBackend and a NativeBackend implementation. This change greatly improves the modularity and maintainability of the codebase, making it easier to add new backends in the future as mentioned in the PR description. The logic has been correctly moved from the engine to the native backend, and the engine's role is now clearly focused on orchestration tasks like DB operations and request handling. The introduction of PreparedModelPassBatch and PreparedSampleBatch types for communication between the engine and backend is a good design choice. While described as a purely structural refactor, I noticed it also includes a subtle but important bug fix in the loss calculation to prevent division by zero, which is a great improvement. I've added a couple of suggestions for further minor performance and robustness enhancements. Overall, this is excellent work that significantly improves the architecture.

pcmoritz and others added 2 commits December 27, 2025 01:16
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed structural refactoring by separating the TinkerEngine's orchestration logic from the computation logic, which is now handled by a new backend architecture. The introduction of AbstractBackend and NativeBackend creates a clean separation of concerns, making the system more modular and extensible. The related test and benchmark files have been updated correctly to reflect these architectural changes. My review comments focus on improving the robustness of the new NativeBackend by making a utility function safe against division-by-zero and replacing assert statements with proper ValueError exceptions for input validation. Overall, this is a high-quality refactoring that greatly improves the codebase's structure.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Dec 27, 2025

@OhadRubin Thanks a lot for your PR! I did some refactoring, mostly making the AbstractBackend more canonical by removing the jax dependency, and streamlining the methods to line up very closely with the tinker specification (in particular, I removed the adapter_index handling from the engine, since different backends might handle it differently). This way will give more flexibility to the backend implementation, e.g. we would like to add a backend that is implemented using skyrl-train as part of the ongoing SkyRL tinkerification effort. If backends want to share this code, we can do so via a utilities file?

Do you want to have a look at the changes and let me know if that works for you (e.g. if you can implement #788 with this API) or if you have any suggestions for improvement? Thanks again a lot for the contribution, this will also help making the jax backend multi-node, which I'm planning to work on soon :)

@OhadRubin
Copy link
Contributor Author

@pcmoritz Looks good! I'll align my draft PR to it.

A few questions:

  • Config standardization - my maxtext draft has maxtext_config_str with some duplicate args (lora rank, sharding). Worth thinking about how to unify?

  • register_modelcreate_model, but unregister_model removed without lora zeroing. Intentional? If not - bug! Related: should the backend have a has_capacity interface and manage its own slots, or should the engine handle this? My use case is single machine with max_loras=1 - I had LRU eviction so I don't need to restart when switching models.

  • For multiple backends on different machines - do you see tx adding a layer between engine and backend to track which node has which model?

@pcmoritz
Copy link
Collaborator

pcmoritz commented Dec 27, 2025

Thanks for the feedback!

  • For the configuration, let's introduce a --backend flag (which is a string and selects the backend, maybe we can have "jax" which is the current native backend, "maxtext" which is the backend you implemented and "skyrl(-train)" which will be the SkyRL-train backend) and also a --backend-config flag which is the config for the backend and up to the backend to interpret, but I feel like the best would be to standardize on JSON and then each backend defines a pydantic type to parse and validate the config. I can augment the current PR with that and also rename the "native" backend to "jax".

  • For register / create model and cleanup: Cleanup is currently not implemented, but can be done since we have the health check pings from the client and when the client is not active any more, we can destroy the model. I think in a follow up PR we can add a delete_model or destroy_model function to the backend that does this. It is better to do this as a follow up PR though since it is better to keep this one mostly a refactoring PR.

  • For multiple models / multiple nodes: What I had in mind is there will be a 1:1 correspondence between engines and base models (e.g. each engine hosts a single base model) and each engine can be multi-node to shard the model. We will need to implement the possibility to connect the API server with multiple engines (this will require code changes), and there will need to be a way to orchestrate multiple engines e.g. with K8s or Ray (that one will mostly require some helper scripts and documentation).

Let me know if you have thoughts about these things, in the meantime I'll go ahead and rename the native backend to jax and introduce the --backend and --backend-config flag :)

@pcmoritz
Copy link
Collaborator

I pushed a first version of the refactor, it will need a little more refinement though, I'll get to that in a few hours :)

@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a well-executed refactoring that introduces a clean separation between the TinkerEngine and a new backend architecture. The introduction of AbstractBackend and the JaxBackend implementation successfully decouples orchestration from computation, which will make the system more modular and extensible. The configuration has also been nicely refactored to move backend-specific settings into a backend_config dictionary.

I have a couple of suggestions for further improvement regarding adapter index allocation and a leaky abstraction in the engine, but overall this is a high-quality structural change.

Comment on lines +306 to +309
# TODO: This leaks the abstraction by accessing backend-specific config.
# We should find a better way to handle this going forward.
if isinstance(self.backend, JaxBackend) and self.backend.config.sample_max_num_sequences > 0:
batchable = batchable[: self.backend.config.sample_max_num_sequences]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

As the TODO mentions, this is a leaky abstraction. To resolve this, you could move the batch capping logic into the backend.

Here's a potential approach:

  1. Add a method to AbstractBackend to allow backends to cap the batch size:

    # In tx/tinker/backends/backend.py
    class AbstractBackend(ABC):
        ...
        def cap_sample_batch(self, requests: list) -> list:
            """Cap the number of sample requests in a batch according to backend constraints."""
            return requests
  2. Override this in JaxBackend:

    # In tx/tinker/backends/jax.py
    class JaxBackend(AbstractBackend):
        ...
        def cap_sample_batch(self, requests: list) -> list:
            if self.config.sample_max_num_sequences > 0:
                return requests[:self.config.sample_max_num_sequences]
            return requests
  3. Update TinkerEngine.find_batchable_sample to use this new backend method, removing the isinstance check:

    # In tx/tinker/engine.py
    ...
    for op in sample_ops:
        ...
        if ...:
            batchable.append(op)
    
    batchable = self.backend.cap_sample_batch(batchable)
    
    return {str(f.request_id): ... for f in batchable}

This would remove the isinstance check and knowledge of backend-specific configuration from the engine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessarily better since then this leaks into every single backend. I would anticipate that most non JaxBackend backends will use #568 anyways. Maybe the right solution going forward will be to remove sampling from the backend, but we will revisit this once we have at least one more backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcmoritz
Yeah I think it's fine, in my maxtext backend I'm using vllm anyway by default.
That's why I originally called it the native backend, because it's the only one that gets to leak abstractions haha.
P.S I will have my maxtext ported over to the new interface in the next week or so.
Btw, I did have to add LORA support to the maxtext implementation, so right now it's dependent on my fork, but I will open a PR to the maxtext repo and will see how it goes.

@pcmoritz pcmoritz merged commit 351524c into NovaSky-AI:main Dec 29, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants