Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7c838a7
Enable separable compilation to solve link issue
horizon-blue Dec 3, 2025
b3ed035
Smoke test to make sure that forward function runs
horizon-blue Dec 3, 2025
c80bc5a
Parametrized forward rendering smoke test to check all variants of re…
horizon-blue Dec 4, 2025
73fb4b9
Include an example of getting image back as JAX array
horizon-blue Dec 4, 2025
2e65e7a
Try out GenMetaballs' forward kernel & compare with FMB in Jupyter
horizon-blue Dec 5, 2025
b6656a0
add jupyter to development env
mugamma Dec 6, 2025
aeb5319
update readme
mugamma Dec 6, 2025
39c6786
initial version of debug notebook
mugamma Dec 6, 2025
f0cb720
clean up and indexing fixes
mugamma Dec 7, 2025
285c873
update tests
mugamma Dec 7, 2025
90acd90
Fix iterator stopping criteria
horizon-blue Dec 7, 2025
d1cf1ca
Reorder camera fields to make x appears first
horizon-blue Dec 7, 2025
84bfdf7
use cholesky decomp of precision
mugamma Dec 7, 2025
0cd8a2c
fix confidence computation
mugamma Dec 7, 2025
a46defd
add negativity check in depth computation
mugamma Dec 7, 2025
8032856
update demo notebook
mugamma Dec 8, 2025
bacec1d
pull arijit's code + get rid of transform3d
mugamma Dec 8, 2025
bcaa92b
fix coordinate issue
mugamma Dec 8, 2025
02aacd7
fixed forward
arijit-dasgupta Dec 9, 2025
debc17c
fixed forward
arijit-dasgupta Dec 9, 2025
bf38ea8
Disable failing tests for now
horizon-blue Dec 9, 2025
0c7e429
Clean notebook's output cells
horizon-blue Dec 9, 2025
fd945fd
Merge branch 'master' into xiaoyan/forward
horizon-blue Dec 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
**/__pycache__/**
**/.ipynb_checkpoints/**
build/
/data/
/output/
Expand Down
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ nanobind_add_module(
# Link the core library to the bindings module
target_link_libraries(_genmetaballs_bindings PRIVATE genmetaballs_core)

# Enable CUDA separable compilation for device code linking
set_target_properties(_genmetaballs_bindings PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
)

# Install the extension into the Python package directory
install(TARGETS _genmetaballs_bindings LIBRARY DESTINATION genmetaballs)

Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ pixi install

### Development Setup

For development:
For development, make sure Mesa is installed and then set up hooks.

```bash
sudo apt install mesa-common-dev
pixi install
pixi run dev-setup
pixi run dev-setup # set-up hooks
```

The `dev-setup` task sets up [pre-commit](https://pre-commit.com/) git hooks:
Expand Down
68 changes: 58 additions & 10 deletions genmetaballs/src/cuda/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include "core/camera.cuh"
#include "core/confidence.cuh"
#include "core/fmb.cuh"
#include "core/forward.cuh"
#include "core/geometry.cuh"
#include "core/getter.cuh"
#include "core/image.cuh"
#include "core/intersector.cuh"
#include "core/utils.cuh"
Expand All @@ -25,6 +27,8 @@ template <MemoryLocation location>
void bind_image_view(nb::module_& m, const char* name);
template <MemoryLocation location>
void bind_fmb_scene(nb::module_& m, const char* name);
template <typename Blender, typename Confidence>
void bind_render_fmbs(nb::module_& m, const char* name);

NB_MODULE(_genmetaballs_bindings, m) {

Expand All @@ -41,7 +45,7 @@ NB_MODULE(_genmetaballs_bindings, m) {
[](const ZeroParameterConfidence& c) { return nb::str("ZeroParameterConfidence()"); });

nb::class_<TwoParameterConfidence>(confidence, "TwoParameterConfidence")
.def(nb::init<float, float>())
.def(nb::init<float, float>(), nb::arg("beta4"), nb::arg("beta5"))
.def_ro("beta4", &TwoParameterConfidence::beta4)
.def_ro("beta5", &TwoParameterConfidence::beta5)
.def("get_confidence", &TwoParameterConfidence::get_confidence, nb::arg("sumexpd"),
Expand All @@ -67,10 +71,26 @@ NB_MODULE(_genmetaballs_bindings, m) {
.def("cov_inv_apply", &FMB::cov_inv_apply,
"apply the inverse covariance matrix to the given vector", nb::arg("vec"))
.def("quadratic_form", &FMB::quadratic_form,
"Evaluate the associated quadratic form at the given vector", nb::arg("vec"));
"Evaluate the associated quadratic form at the given vector", nb::arg("vec"))
.def("__repr__", [](const FMB& self) {
return nb::str("FMB(pose={}, extent={})").format(self.get_pose(), self.get_extent());
});
bind_fmb_scene<MemoryLocation::HOST>(fmb, "CPUFMBScene");
bind_fmb_scene<MemoryLocation::DEVICE>(fmb, "GPUFMBScene");

/*
* Forward (rendering) module bindings
*/
nb::module_ forward = m.def_submodule("forward", "Forward rendering of FMBs");
bind_render_fmbs<FourParameterBlender, ZeroParameterConfidence>(
forward, "render_fmbs_four_param_zero_confidence");
bind_render_fmbs<ThreeParameterBlender, TwoParameterConfidence>(
forward, "render_fmbs_three_param_two_confidence");
bind_render_fmbs<ThreeParameterBlender, ZeroParameterConfidence>(
forward, "render_fmbs_three_param_zero_confidence");
bind_render_fmbs<FourParameterBlender, TwoParameterConfidence>(
forward, "render_fmbs_four_param_two_confidence");

/*
* Geometry module bindings
*/
Expand Down Expand Up @@ -99,9 +119,21 @@ NB_MODULE(_genmetaballs_bindings, m) {
.def(nb::init<>())
.def_static("from_quat", &Rotation::from_quat, "Create rotation from quaternion",
nb::arg("x"), nb::arg("y"), nb::arg("z"), nb::arg("w"))
.def_prop_ro(
"quat",
[](const Rotation& self) {
auto quat = self.get_quat();
return std::tuple{quat.x, quat.y, quat.z, quat.w};
},
"Get quaternion components as (x, y, z, w)")
.def("apply", &Rotation::apply, "Apply rotation to vector", nb::arg("vec"))
.def("compose", &Rotation::compose, "Compose with another rotation", nb::arg("rot"))
.def("inv", &Rotation::inv, "Inverse rotation");
.def("inv", &Rotation::inv, "Inverse rotation")
.def("__repr__", [](const Rotation& self) {
auto quat = self.get_quat();
return nb::str("Rotation(x={}, y={}, z={}, w={})")
.format(quat.x, quat.y, quat.z, quat.w);
});

nb::class_<Pose>(geometry, "Pose")
.def(nb::init<>())
Expand All @@ -112,15 +144,17 @@ NB_MODULE(_genmetaballs_bindings, m) {
.def_prop_ro("tran", &Pose::get_tran, "get the translation component")
.def("apply", &Pose::apply, "Apply pose to vector", nb::arg("vec"))
.def("compose", &Pose::compose, "Compose with another pose", nb::arg("pose"))
.def("inv", &Pose::inv, "Inverse pose");

.def("inv", &Pose::inv, "Inverse pose")
.def("__repr__", [](const Pose& self) {
return nb::str("Pose(rot={}, tran={})").format(self.get_rot(), self.get_tran());
});
/*
* Camera module bindings
*/
nb::module_ camera = m.def_submodule("camera", "Camera intrinsics and extrinsics");
nb::class_<Intrinsics>(camera, "Intrinsics")
.def(nb::init<uint32_t, uint32_t, float, float, float, float>(), nb::arg("height"),
nb::arg("width"), nb::arg("fx"), nb::arg("fy"), nb::arg("cx"), nb::arg("cy"))
.def(nb::init<uint32_t, uint32_t, float, float, float, float>(), nb::arg("width"),
nb::arg("height"), nb::arg("fx"), nb::arg("fy"), nb::arg("cx"), nb::arg("cy"))
.def_ro("height", &Intrinsics::height)
.def_ro("width", &Intrinsics::width)
.def_ro("fx", &Intrinsics::fx)
Expand All @@ -129,7 +163,11 @@ NB_MODULE(_genmetaballs_bindings, m) {
.def_ro("cy", &Intrinsics::cy)
.def("get_ray_direction", &Intrinsics::get_ray_direction,
"Get the direction of the ray going through pixel (px, py) in camera frame",
nb::arg("px"), nb::arg("py"));
nb::arg("px"), nb::arg("py"))
.def("__repr__", [](const Intrinsics& self) {
return nb::str("Intrinsics(width={}, height={}, fx={}, fy={}, cx={}, cy={})")
.format(self.width, self.height, self.fx, self.fy, self.cx, self.cy);
});

/*
* Image module bindings
Expand Down Expand Up @@ -163,7 +201,8 @@ NB_MODULE(_genmetaballs_bindings, m) {
// blender submodule
nb::module_ blender = m.def_submodule("blender");
nb::class_<FourParameterBlender>(blender, "FourParameterBlender")
.def(nb::init<float, float, float, float>())
.def(nb::init<float, float, float, float>(), nb::arg("beta1"), nb::arg("beta2"),
nb::arg("beta3"), nb::arg("eta"))
.def_ro("beta1", &FourParameterBlender::beta1)
.def_ro("beta2", &FourParameterBlender::beta2)
.def_ro("beta3", &FourParameterBlender::beta3)
Expand All @@ -176,7 +215,7 @@ NB_MODULE(_genmetaballs_bindings, m) {
});

nb::class_<ThreeParameterBlender>(blender, "ThreeParameterBlender")
.def(nb::init<float, float, float>())
.def(nb::init<float, float, float>(), nb::arg("beta1"), nb::arg("beta2"), nb::arg("eta"))
.def_ro("beta1", &ThreeParameterBlender::beta1)
.def_ro("beta2", &ThreeParameterBlender::beta2)
.def_ro("eta", &ThreeParameterBlender::eta)
Expand Down Expand Up @@ -273,3 +312,12 @@ void bind_fmb_scene(nb::module_& m, const char* name) {
return nb::str("{}(size={})").format(name, scene.size());
});
}

template <typename Blender, typename Confidence>
void bind_render_fmbs(nb::module_& m, const char* name) {
m.def(name,
&render_fmbs<AllGetter<MemoryLocation::DEVICE>, LinearIntersector, Blender, Confidence>,
"Render the given FMB scene into the provided image view", nb::arg("fmbs"),
nb::arg("blender"), nb::arg("confidence"), nb::arg("intr"), nb::arg("extr"),
nb::arg("img"));
}
4 changes: 2 additions & 2 deletions genmetaballs/src/cuda/core/blender.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct ThreeParameterBlender {
float beta2;
float eta;

CUDA_CALLABLE __forceinline__ float blend(float t, float d) const {
return expf((beta1 * d) - ((beta2 / eta) * t));
CUDA_CALLABLE __forceinline__ float blend(float tmp, float d) const {
return expf((beta1 * tmp) - ((beta2 / eta) * d));
}
};
3 changes: 2 additions & 1 deletion genmetaballs/src/cuda/core/camera.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ CUDA_CALLABLE PixelCoordRange::Iterator& PixelCoordRange::Iterator::operator++()
}

CUDA_CALLABLE bool PixelCoordRange::Sentinel::operator==(const Iterator& it) const {
return it.py >= py_end;
// stop if we reach the end of rows, or if the range is empty
return it.py >= py_end || it.px_start >= it.px_end || it.py_start >= py_end;
}

CUDA_CALLABLE PixelCoordRange::Iterator PixelCoordRange::begin() const {
Expand Down
4 changes: 2 additions & 2 deletions genmetaballs/src/cuda/core/camera.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include "utils.cuh"

struct Intrinsics {
uint32_t height; // in x direction
uint32_t width; // in y direction
uint32_t width; // in x direction
uint32_t height; // in y direction
float fx;
float fy;
float cx;
Expand Down
10 changes: 9 additions & 1 deletion genmetaballs/src/cuda/core/fmb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@ CUDA_CALLABLE __forceinline__ Vec3D vecdiv(const Vec3D u, const Vec3D v) {
return {u.x / v.x, u.y / v.y, u.z / v.z};
}

// CUDA_CALLABLE Vec3D FMB::cov_inv_apply(const Vec3D vec) const {
// const auto rot = pose_.get_rot();
// return rot.inv().apply(vecdiv(rot.apply(vec), extent_));
// }

CUDA_CALLABLE Vec3D FMB::cov_inv_apply(const Vec3D vec) const {
const auto rot = pose_.get_rot();
return rot.inv().apply(vecdiv(rot.apply(vec), extent_));
// Wanted to add more infor here
// Basically the order of the operation has bee swapper to look something like this:
// R @ diag(1/extent) @ R^T @ vec however, i dont think this fixes everything
return rot.apply(vecdiv(rot.inv().apply(vec), extent_));
}

CUDA_CALLABLE float FMB::quadratic_form(const Vec3D vec) const {
Expand Down
8 changes: 4 additions & 4 deletions genmetaballs/src/cuda/core/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ CUDA_CALLABLE PixelCoordRange get_pixel_coords(const dim3 thread_idx, const dim3
const dim3 block_dim, const dim3 grid_dim,
const Intrinsics& intr) {
// compute the number of pixels each thread should process
const auto num_pixels_x = int_ceil_div(intr.height, grid_dim.x * block_dim.x);
const auto num_pixels_y = int_ceil_div(intr.width, grid_dim.y * block_dim.y);
const auto num_pixels_x = int_ceil_div(intr.width, grid_dim.x * block_dim.x);
const auto num_pixels_y = int_ceil_div(intr.height, grid_dim.y * block_dim.y);
const auto start_x = (block_idx.x * block_dim.x + thread_idx.x) * num_pixels_x;
const auto start_y = (block_idx.y * block_dim.y + thread_idx.y) * num_pixels_y;
return PixelCoordRange{.px_start = start_x,
.px_end = min(start_x + num_pixels_x, intr.height),
.px_end = min(start_x + num_pixels_x, intr.width),
.py_start = start_y,
.py_end = min(start_y + num_pixels_y, intr.width)};
.py_end = min(start_y + num_pixels_y, intr.height)};
}
45 changes: 26 additions & 19 deletions genmetaballs/src/cuda/core/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,38 @@ CUDA_CALLABLE PixelCoordRange get_pixel_coords(const dim3 thread_idx, const dim3
const Intrinsics& intr);

template <typename Getter, typename Intersector, typename Blender, typename Confidence>
__global__ void render_kernel(const Getter fmb_getter, const Blender blender,
Confidence const* confidence, Intrinsics const intr, Pose const* extr,
ImageView<MemoryLocation::DEVICE> img) {
__global__ void render_kernel(const FMBScene<MemoryLocation::DEVICE>& fmbs, const Blender& blender,
const Confidence& confidence, const Intrinsics& intr,
const Pose& extr, ImageView<MemoryLocation::DEVICE> img) {
auto pixel_coords = get_pixel_coords(threadIdx, blockIdx, blockDim, gridDim, intr);
auto fmb_getter = Getter(fmbs, extr);

for (const auto& [px, py] : pixel_coords) {
float w0 = 0.0f, tf = 0.0f, sumexpd = 0.0f;
for (const auto [px, py] : pixel_coords) {
float depth_denom = 0.0f, depth_numer = 0.0f, conf_tmp = 0.0f;
auto ray = intr.get_ray_direction(px, py);
for (const auto& fmb : fmb_getter->get_metaballs(ray)) {
const auto& [t, d] = Intersector::intersect(fmb, ray, extr);
auto w = blender->blend(t, d, fmb, ray);
sumexpd += exp(d); // numerically unstable. use logsumexp
tf += t;
w0 += w;
for (const auto& [fmb, lambda] : fmb_getter.get_metaballs(ray)) {
// d: intersection point along the ray
// q: square of Mahalanobis distance at intersection point
const auto& [d, q] = Intersector::intersect(fmb, ray, extr);
auto tmp = -0.5f * q + lambda;
// the next check is needed to match the reference implementation
// even though it is not in the paper.
auto w_tilde = d > 0 ? blender.blend(tmp, d) : 1e-20f;
conf_tmp += exp(tmp); // numerically unstable. use logsumexp
depth_numer += d * w_tilde;
depth_denom += w_tilde;
}
img.confidence[px][py] = confidence->get_confidence(sumexpd);
img.depth[px][py] = tf / w0;
// the indexing is done this way because the underlying array2ds use
// ij indexing, whereas the pixels uses xy indexing
img.confidence[intr.height - py - 1][px] = confidence.get_confidence(conf_tmp);
img.depth[intr.height - py - 1][px] = depth_numer / depth_denom;
}
}

template <typename Getter, typename Intersector, typename Blender, typename Confidence>
void render_fmbs(const FMBScene<MemoryLocation::DEVICE>& fmbs, const Intrinsics& intr,
const Pose& extr) {
// initialize the fmb_getter
auto fmb_getter = Getter(fmbs, extr);
auto& kernel = render_kernel<Getter, Intersector, Blender, Confidence>;
kernel<<<NUM_BLOCKS, THREADS_PER_BLOCK>>>(fmb_getter, fmbs, intr, extr);
void render_fmbs(const FMBScene<MemoryLocation::DEVICE>& fmbs, const Blender& blender,
const Confidence& confidence, const Intrinsics& intr, const Pose& extr,
ImageView<MemoryLocation::DEVICE> img) {
render_kernel<Getter, Intersector, Blender, Confidence>
<<<NUM_BLOCKS, THREADS_PER_BLOCK>>>(fmbs, blender, confidence, intr, extr, img);
}
4 changes: 4 additions & 0 deletions genmetaballs/src/cuda/core/geometry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ CUDA_CALLABLE Rotation Rotation::from_quat(float x, float y, float z, float w) {
return Rotation{{x / modulus, y / modulus, z / modulus, w / modulus}};
}

CUDA_CALLABLE const float4& Rotation::get_quat() const {
return unit_quat_;
}

CUDA_CALLABLE Vec3D Rotation::apply(const Vec3D vec) const {
// v' = q * v * q^(-1) for unit quaternions
// where q^(-1) = (-x, -y, -z, w)
Expand Down
2 changes: 2 additions & 0 deletions genmetaballs/src/cuda/core/geometry.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public:

static CUDA_CALLABLE Rotation from_quat(float x, float y, float z, float w);

CUDA_CALLABLE const float4& get_quat() const;

CUDA_CALLABLE Vec3D apply(const Vec3D vec) const;

CUDA_CALLABLE Rotation compose(const Rotation& rot) const;
Expand Down
5 changes: 3 additions & 2 deletions genmetaballs/src/cuda/core/intersector.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "fmb.cuh"
#include "geometry.cuh"
#include "utils.cuh"

// implement equation (6) in the paper
class LinearIntersector {
Expand All @@ -16,7 +17,7 @@ public:
const auto v = cam_pose.get_rot().apply(ray);
const auto cov_inv_v = fmb.cov_inv_apply(v);
const auto cam_tran = cam_pose.get_tran();
const auto t = dot(fmb.get_mean() - cam_tran, cov_inv_v) / dot(v, cov_inv_v);
return {t, fmb.quadratic_form(cam_tran + t * v)};
const auto d = dot(fmb.get_mean() - cam_tran, cov_inv_v) / dot(v, cov_inv_v);
return {d, fmb.quadratic_form(cam_tran + d * v)};
}
};
Loading