Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9490ff8
add normal matvec and memory profiler
zitongzhan Dec 13, 2025
9c90aca
print peak cuda allocation
zitongzhan Dec 23, 2025
6256e79
add warp memory pool report
zitongzhan Dec 28, 2025
3a5ce9b
use `A._get_Jt` when matrix_free_normal
zitongzhan Dec 28, 2025
0064146
add back schur by warp's matmul
zitongzhan Dec 28, 2025
acd1b3c
safely import cudss
zitongzhan Jan 14, 2026
91c8ade
Add future plans section to README
zitongzhan Dec 20, 2025
19774c3
add normal matvec and memory profiler
zitongzhan Dec 13, 2025
4ca9c86
print peak cuda allocation
zitongzhan Dec 23, 2025
b71f1a3
add warp memory pool report
zitongzhan Dec 28, 2025
d678867
use `A._get_Jt` when matrix_free_normal
zitongzhan Dec 28, 2025
d127b88
add back schur by warp's matmul
zitongzhan Dec 28, 2025
fa9ab70
Merge branch 'schur-matmul' of github.com:zitongzhan/bae_private into…
zitongzhan Jan 26, 2026
6619808
Merge remote-tracking branch 'upstream/release' into schur-matmul
SEOKWOOPARK Apr 13, 2026
3e4761d
Preventing TrustRegion from accepting diverging steps
SEOKWOOPARK Apr 20, 2026
5d9e2b2
fix(optimizer/LM): Remove redundant solver calls so matrix_free_norma…
SEOKWOOPARK Apr 29, 2026
e34bea2
feat(optim/Schur): Add Matrix-Free path and matrix_free_normal branch
SEOKWOOPARK Apr 29, 2026
5f4f093
Resolving conflict with release branch in README
SEOKWOOPARK Apr 29, 2026
f64d00b
Version up to 0.2.1
SEOKWOOPARK Apr 29, 2026
40798f1
Fix deprecated function in Warp
SEOKWOOPARK May 23, 2026
165104d
Replace Warp with Triton kernels and adjust corresponding codes
SEOKWOOPARK May 23, 2026
b305f81
Remove codes relevant to Chunk
SEOKWOOPARK May 23, 2026
3a97f9e
Merge branch 'release' into memory-issue-swp
SEOKWOOPARK May 24, 2026
a0b4b8b
Remove ba_helpers.py
SEOKWOOPARK May 24, 2026
f46fb74
Fix a conflict in ba_example.py
SEOKWOOPARK May 24, 2026
48ad787
Potential fix for pull request finding 'Variable defined multiple times'
zitongzhan May 24, 2026
8cc6eb3
Potential fix for pull request finding 'Unused local variable'
zitongzhan May 24, 2026
074b931
minimize diff
zitongzhan May 24, 2026
4746522
restore pysolvers
zitongzhan May 24, 2026
7f3ea3d
revert import shuffle
zitongzhan May 24, 2026
d3e24d9
restore LM
zitongzhan May 24, 2026
04908d9
fix import order ba example
zitongzhan May 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tmp_debug/*
task.md
*.pickle
*.png
.warp_cache/*
.DS_Store
tmp/*
examples/module/pgo/data/*
Expand Down
132 changes: 127 additions & 5 deletions ba_example.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,46 @@
from time import perf_counter
from datetime import datetime
from pathlib import Path

import pypose as pp
import torch
import torch.nn as nn
import warp as wp
from pypose.autograd.function import psjac

from datapipes.bal_loader import get_problem
from bae.optim import LM
from bae.optim.optimizer import Schur
from bae.optim.triton_kernel import sparse_bsr_mv
from bae.utils.pysolvers import PCG

TARGET_DATASET = "trafalgar"
TARGET_PROBLEM = "problem-257-65132-pre"
# other options:
# TARGET_DATASET = "ladybug"
# TARGET_PROBLEM = "problem-1723-156502-pre"
# TARGET_DATASET = "dubrovnik"
# TARGET_PROBLEM = "problem-356-226730-pre"
# TARGET_DATASET = "final"
# TARGET_PROBLEM = "problem-13682-4456117-pre"
# TARGET_DATASET = "venice"
# TARGET_PROBLEM = "problem-1778-993923-pre"

DEVICE = "cuda"
OPTIMIZE_INTRINSICS = True
NUM_CAMERA_PARAMS = 10 if OPTIMIZE_INTRINSICS else 7
REPORT_WARP_MEMPOOL = True


def _format_bytes(num_bytes: int) -> str:
sign = "-" if num_bytes < 0 else ""
size = float(abs(num_bytes))
units = ["B", "KiB", "MiB", "GiB", "TiB"]
for unit in units:
if size < 1024.0 or unit == units[-1]:
break
size /= 1024.0
if unit == "B":
return f"{sign}{int(size)} {unit}"
return f"{sign}{size:.2f} {unit}"


@psjac
Expand Down Expand Up @@ -54,7 +75,51 @@ def least_square_error(camera_params, points, cidx, pidx, observes):
return torch.sum(loss**2, dim=-1).mean()


class TrustRegion(pp.optim.strategy.TrustRegion):
def update(self, pg, last, loss, J, D, R, *args, **kwargs):
Jwp = kwargs.get("Jwp")
if Jwp is not None:
J = Jwp

JD = None
for i in range(len(D)):
if Jwp is not None:
JD_i = sparse_bsr_mv(J[i], D[i].flatten().contiguous()).flatten()
else:
JD_i = J[i] @ D[i].flatten()
JD = JD_i if JD is None else JD + JD_i

JD = JD[..., None]
denom = -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze()

if loss >= last or denom <= 0:
quality = -1.0
else:
quality = (last - loss) / denom

pg['radius'] = 1.0 / pg['damping']
if quality > pg['high']:
pg['radius'] = pg['up'] * pg['radius']
pg['down'] = self.down
elif quality > pg['low']:
pg['radius'] = pg['radius']
pg['down'] = self.down
else:
pg['radius'] = pg['radius'] * pg['down']
pg['down'] = pg['down'] * pg['factor']
pg['down'] = max(self.min, min(pg['down'], self.max))
pg['radius'] = max(self.min, min(pg['radius'], self.max))
pg['damping'] = 1.0 / pg['radius']


def main():
file_name = f"{TARGET_DATASET}.{TARGET_PROBLEM}"
cuda_device = torch.device(DEVICE) if DEVICE.startswith("cuda") else None
memory_snapshot_path = None
warp_device = None
warp_mempool_start_current = None
warp_mempool_start_high = None

dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET)
print(f"Fetched {TARGET_PROBLEM} from {TARGET_DATASET}")

Expand All @@ -69,13 +134,37 @@ def main():
"pidx": dataset["point_index_of_observations"],
}

if cuda_device is not None and torch.cuda.is_available():
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_dir = Path("memory_traces")
snapshot_dir.mkdir(exist_ok=True)
memory_snapshot_path = snapshot_dir / f"{file_name}_cuda_memory_{timestamp}.pickle"
torch.cuda.memory._record_memory_history(
enabled="all",
context="all",
stacks="python",
device=cuda_device,
clear_history=True,
)

if REPORT_WARP_MEMPOOL and DEVICE.startswith("cuda"):
try:
if wp.is_cuda_available():
warp_device = wp.get_device("cuda:0" if DEVICE == "cuda" else DEVICE)
if not wp.is_mempool_enabled(warp_device):
wp.set_mempool_enabled(warp_device, True)
warp_mempool_start_current = wp.get_mempool_used_mem_current(warp_device)
warp_mempool_start_high = wp.get_mempool_used_mem_high(warp_device)
except Exception as e:
print(f"Warning: failed to query Warp mempool stats: {e}")

model = Residual(
dataset["camera_params"][:, :NUM_CAMERA_PARAMS].clone(),
dataset["points_3d"].clone(),
).to(DEVICE)
strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4)
strategy = TrustRegion(up=2.0, down=0.5**4)
solver = PCG(tol=1e-4, maxiter=250)
optimizer = LM(model, strategy=strategy, solver=solver, reject=30)
optimizer = Schur(model, strategy=strategy, solver=solver, reject=30, matrix_free_normal=True)

print('Loss:', least_square_error(
model.pose,
Expand All @@ -87,15 +176,48 @@ def main():

print("Initial loss", optimizer.model.loss(input, None).item())

if cuda_device is not None and torch.cuda.is_available():
torch.cuda.synchronize(cuda_device)
torch.cuda.reset_peak_memory_stats(cuda_device)

start = perf_counter()
for idx in range(20):
loss = optimizer.step(input)
print("Iteration", idx, "loss", loss.item(), "time", perf_counter() - start)

torch.cuda.synchronize()
if cuda_device is not None and torch.cuda.is_available():
torch.cuda.synchronize(cuda_device)
end = perf_counter()
print("Time", end - start)

if memory_snapshot_path:
torch.cuda.synchronize(cuda_device)
torch.cuda.memory._dump_snapshot(str(memory_snapshot_path))
print(f"CUDA memory snapshot saved to {memory_snapshot_path}")

if cuda_device is not None and torch.cuda.is_available():
peak_allocated = torch.cuda.max_memory_allocated(cuda_device)
try:
peak_reserved = torch.cuda.max_memory_reserved(cuda_device)
except AttributeError:
peak_reserved = torch.cuda.max_memory_cached(cuda_device)
print(f"Peak CUDA memory allocated: {_format_bytes(peak_allocated)}")
print(f"Peak CUDA memory reserved: {_format_bytes(peak_reserved)}")

if warp_device is not None and warp_mempool_start_current is not None:
try:
warp_current = wp.get_mempool_used_mem_current(warp_device)
warp_high = wp.get_mempool_used_mem_high(warp_device)
print(f"Warp CUDA mempool current: {_format_bytes(warp_current)} "
f"(Δ {_format_bytes(warp_current - warp_mempool_start_current)})"
)
print(
f"Warp CUDA mempool high-water: {_format_bytes(warp_high)} "
f"(Δ {_format_bytes(warp_high - warp_mempool_start_high)})"
)
except Exception as e:
print(f"Warning: failed to query Warp mempool stats: {e}")

print('Ending loss:', least_square_error(
model.pose,
model.points,
Expand Down
Loading
Loading