Skip to content

Preserve input memory location / dtype for NN Descent#1928

Open
jinsolp wants to merge 6 commits intorapidsai:mainfrom
jinsolp:nnd-keep-input-data-mem
Open

Preserve input memory location / dtype for NN Descent#1928
jinsolp wants to merge 6 commits intorapidsai:mainfrom
jinsolp:nnd-keep-input-data-mem

Conversation

@jinsolp
Copy link
Contributor

@jinsolp jinsolp commented Mar 18, 2026

Closes #1901

Previous Code

  • We almost always allocate device side fp16 arrays. This was for...
    • allowing wmma usage
    • allowing data modification for CosineExpanded preprocessing

Current PR Changes

  • No logical changes apart from removing dispatching fp32 input to use fp32 vs fp16 distance computation. This is removed now and will default to using the input type (e.g. keep fp32 as fp32). One exception is when compress_to_fp16=True and input type is fp32. In this case we conver to fp16 to exploit wmma.
  • Reducing redundant memory:
    • We only allocate device side arrays corresponding to input dtype if input is not device-accessible (allocate half types for fp32 if compress_to_fp16=True).
    • Remove preprocessing for CosineExpanded metric (because we don't want to allocate additional device side data arrays) and do the computation inside the calculate_metric function.

Peak memory usage Changes

  • food data (5M x 384) = 7.25GiB

  • sports data (13M x 284) = 18.55GiB

  • notice how for FP32->FP16 Device (meaning data is already on device), previous code allocates a new fp16 array, resulting in more gpu mem usage. This PR ensures that we convert to fp16 on-th-fly (resulting in the overhead in time) instead of allocating new fp16 memory for that.

performance_metrics

Performance Changes

  • Conversion Overhead: On-the-fly conversion introduces negligible overhead.
  • Cosine Metric: Now reads l2 norms inside the calculate_metric function, aligning with access pattern used by the L2 distance metric. Adds minimal overhead (e.g. previously 18.2937s VS 18.7598s for 5Mx384 data)

@jinsolp jinsolp self-assigned this Mar 18, 2026
@jinsolp jinsolp requested review from a team as code owners March 18, 2026 01:54
@jinsolp jinsolp added breaking Introduces a breaking change improvement Improves an existing functionality labels Mar 18, 2026
@jinsolp jinsolp marked this pull request as draft March 18, 2026 01:55
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 18, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jinsolp jinsolp marked this pull request as ready for review March 20, 2026 00:15
@jinsolp jinsolp changed the title [WIP] Preserve input memory location / dtype for NN Descent Preserve input memory location / dtype for NN Descent Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking Introduces a breaking change improvement Improves an existing functionality

Projects

Development

Successfully merging this pull request may close these issues.

Compute distances in NN Descent kernels in native types

1 participant