Skip to content

Add Schur completment and its mat-free mode#35

Open
zitongzhan wants to merge 32 commits into
releasefrom
memory-issue-swp
Open

Add Schur completment and its mat-free mode#35
zitongzhan wants to merge 32 commits into
releasefrom
memory-issue-swp

Conversation

@zitongzhan
Copy link
Copy Markdown
Collaborator

This pull request introduces significant improvements to the optimizer infrastructure, focusing on enhanced memory profiling, a new Schur complement optimizer, and better support for matrix-free operations.

Optimizer Enhancements

  • Added a new Schur optimizer class in bae.optim.optimizer, implementing the Schur complement method with support for both standard and matrix-free normal equations, block Jacobi preconditioning, and efficient memory usage.

  • Updated the LM optimizer to support a matrix_free_normal mode, allowing for more efficient computation and memory usage in large-scale problems.

  • Add a custom TrustRegion class that supports Warp, especially for use with the Schur optimizer.

Sparse Matrix and PyOps Improvements

  • Improved sparse matrix operations, including fixes to inv_op for correct tensor creation and a new test block in py_ops.py for diagonal operations on CUDA.

Comment thread bae/sparse/warp_wrappers.py Fixed
Comment thread bae/optim/optimizer.py Fixed
Comment thread bae/sparse/py_ops.py Fixed
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces high-performance Triton kernels for sparse BSR operations, including matrix-vector multiplication, matrix-matrix multiplication, and transposition. It also implements a matrix-free NormalMatVec operator and a new Schur complement-based optimizer to improve the efficiency of bundle adjustment tasks. The bundle adjustment example was updated with CUDA memory snapshotting and Warp mempool reporting. Review feedback highlights a critical issue where in-place diagonal modifications in the LM and Schur optimizers cause damping factors to accumulate incorrectly during step rejections. Additionally, the reviewer recommends removing performance-hindering torch.cuda.empty_cache() calls, addressing potential divisions by zero in the Conjugate Gradient solver, and cleaning up redundant or commented-out code.

Comment thread bae/optim/optimizer.py
diag_scale *= 1.0 + pg['damping']
A.set_damping(diag_scale - 1.0)
else:
diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping']))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The diagonal_op_ function performs an in-place multiplication on the matrix A. Since this is inside the while loop, if a step is rejected and the loop repeats, the damping will be applied cumulatively (e.g., $(1+\lambda_1)(1+\lambda_2)...$) instead of being applied to the original $J^T J$ diagonal. This deviates from the standard Levenberg-Marquardt algorithm and can lead to excessively aggressive damping. Consider cloning the matrix or resetting the diagonal before applying damping in each iteration.

Comment thread bae/optim/optimizer.py
Comment thread bae/optim/optimizer.py
R = R.tensor()
else:
R = R.detach()
torch.cuda.empty_cache()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling torch.cuda.empty_cache() inside the optimization step is generally discouraged as it triggers a GPU synchronization and can significantly degrade performance. If memory management is a concern, it's better to optimize tensor lifecycles or use a dedicated memory pool. If this was added for debugging memory usage, it should be removed before merging.


Ap = matvec(p)
Ap_flat = Ap.reshape(-1)
alpha = (rz / torch.dot(p.reshape(-1), Ap_flat)).item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Potential division by zero if torch.dot(p.reshape(-1), Ap_flat) is zero (e.g., if the matrix is singular or not positive definite). While $J^T J$ is positive semi-definite, numerical issues or zero curvature directions could cause this to be zero. Consider adding a small epsilon or a check for numerical stability.


rz_new = torch.dot(r_flat, z_flat)
beta = (rz_new / rz).item()
p.mul_(beta).add_(z)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Potential division by zero if rz is zero. Although the convergence check at line 692 should ideally terminate the loop if the residual is zero, a safety check for rz before division is recommended to prevent NaN values in case of numerical instability.

Comment thread bae/sparse/warp_wrappers.py Outdated
Comment thread bae/utils/pysolvers.py Outdated
zitongzhan and others added 3 commits May 23, 2026 20:35
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants