Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ use std::sync::Arc;
use crate::aggregates::group_values::multi_group_by::Nulls;
use crate::aggregates::group_values::multi_group_by::{GroupColumn, nulls_equal_to};
use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
use ahash::RandomState;
use arrow::array::{Array as _, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder};
use datafusion_common::Result;
use datafusion_common::hash_utils::{HashValue, combine_hashes};
use itertools::izip;

/// An implementation of [`GroupColumn`] for booleans
Expand Down Expand Up @@ -191,6 +193,60 @@ impl<const NULLABLE: bool> GroupColumn for BooleanGroupValueBuilder<NULLABLE> {

Arc::new(BooleanArray::new(new_builder.finish(), first_n_nulls))
}

fn input_rows_equal(&self, array: &ArrayRef, row_a: usize, row_b: usize) -> bool {
if NULLABLE {
let a_null = array.is_null(row_a);
let b_null = array.is_null(row_b);
if a_null || b_null {
return a_null && b_null;
}
}
let arr = array.as_boolean();
arr.value(row_a) == arr.value(row_b)
}

fn hash_input_row(
&self,
array: &ArrayRef,
row: usize,
random_state: &RandomState,
rehash: bool,
current_hash: u64,
) -> u64 {
if NULLABLE && array.is_null(row) {
return current_hash;
}
let value = array.as_boolean().value(row);
let h = value.hash_one(random_state);
if rehash {
combine_hashes(h, current_hash)
} else {
h
}
}

fn compute_boundaries(&self, array: &ArrayRef, boundaries: &mut [bool]) {
let arr = array.as_boolean();
if NULLABLE && array.null_count() > 0 {
for row in 1..array.len() {
if boundaries[row] {
continue;
}
let prev_null = array.is_null(row - 1);
let curr_null = array.is_null(row);
if prev_null != curr_null || (!prev_null && arr.value(row - 1) != arr.value(row)) {
boundaries[row] = true;
}
}
} else {
for row in 1..array.len() {
if !boundaries[row] && arr.value(row - 1) != arr.value(row) {
boundaries[row] = true;
}
}
}
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use arrow::array::{
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType};
use datafusion_common::utils::proxy::VecAllocExt;
use ahash::RandomState;
use datafusion_common::hash_utils::{HashValue, combine_hashes};
use datafusion_common::{Result, exec_datafusion_err};
use datafusion_physical_expr_common::binary_map::{INITIAL_BUFFER_CAPACITY, OutputType};
use itertools::izip;
Expand Down Expand Up @@ -426,6 +428,95 @@ where
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
}
}

fn input_rows_equal(&self, array: &ArrayRef, row_a: usize, row_b: usize) -> bool {
let a_null = array.is_null(row_a);
let b_null = array.is_null(row_b);
if a_null || b_null {
return a_null && b_null;
}
match self.output_type {
OutputType::Binary => {
let arr = array.as_bytes::<GenericBinaryType<O>>();
arr.value(row_a) == arr.value(row_b)
}
OutputType::Utf8 => {
let arr = array.as_bytes::<GenericStringType<O>>();
arr.value(row_a) == arr.value(row_b)
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
}
}

fn hash_input_row(
&self,
array: &ArrayRef,
row: usize,
random_state: &RandomState,
rehash: bool,
current_hash: u64,
) -> u64 {
if array.is_null(row) {
return current_hash;
}
let h = match self.output_type {
OutputType::Binary => {
let arr = array.as_bytes::<GenericBinaryType<O>>();
arr.value(row).hash_one(random_state)
}
OutputType::Utf8 => {
let arr = array.as_bytes::<GenericStringType<O>>();
arr.value(row).hash_one(random_state)
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
};
if rehash {
combine_hashes(h, current_hash)
} else {
h
}
}

fn compute_boundaries(&self, array: &ArrayRef, boundaries: &mut [bool]) {
let has_nulls = array.null_count() > 0;
match self.output_type {
OutputType::Binary => {
let arr = array.as_bytes::<GenericBinaryType<O>>();
for row in 1..array.len() {
if boundaries[row] {
continue;
}
if has_nulls {
let prev_null = array.is_null(row - 1);
let curr_null = array.is_null(row);
if prev_null != curr_null || (!prev_null && arr.value(row - 1) != arr.value(row)) {
boundaries[row] = true;
}
} else if arr.value(row - 1) != arr.value(row) {
boundaries[row] = true;
}
}
}
OutputType::Utf8 => {
let arr = array.as_bytes::<GenericStringType<O>>();
for row in 1..array.len() {
if boundaries[row] {
continue;
}
if has_nulls {
let prev_null = array.is_null(row - 1);
let curr_null = array.is_null(row);
if prev_null != curr_null || (!prev_null && arr.value(row - 1) != arr.value(row)) {
boundaries[row] = true;
}
} else if arr.value(row - 1) != arr.value(row) {
boundaries[row] = true;
}
}
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
}
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ use crate::aggregates::group_values::multi_group_by::{
use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
use arrow::array::{Array, ArrayRef, AsArray, ByteView, GenericByteViewArray, make_view};
use arrow::buffer::{Buffer, ScalarBuffer};
use ahash::RandomState;
use arrow::datatypes::ByteViewType;
use datafusion_common::Result;
use datafusion_common::hash_utils::{HashValue, combine_hashes};
use itertools::izip;
use std::marker::PhantomData;
use std::mem::{replace, size_of};
Expand Down Expand Up @@ -577,6 +579,146 @@ impl<B: ByteViewType> GroupColumn for ByteViewGroupValueBuilder<B> {
fn take_n(&mut self, n: usize) -> ArrayRef {
self.take_n_inner(n)
}

fn input_rows_equal(&self, array: &ArrayRef, row_a: usize, row_b: usize) -> bool {
let a_null = array.is_null(row_a);
let b_null = array.is_null(row_b);
if a_null || b_null {
return a_null && b_null;
}
// Use the raw bytes for comparison (works for both StringView and BinaryView)
// SAFETY: we checked for nulls above
let arr = array.as_byte_view::<B>();
let a_len = arr.views()[row_a] as u32;
let b_len = arr.views()[row_b] as u32;
if a_len != b_len {
return false;
}
if a_len <= 12 {
// Inlined: compare the view payloads directly (first 4 bytes of value are in high bits)
arr.views()[row_a] == arr.views()[row_b]
} else {
// Not inlined: compare prefixes first, then full data
let a_view = ByteView::from(arr.views()[row_a]);
let b_view = ByteView::from(arr.views()[row_b]);
if a_view.prefix != b_view.prefix {
return false;
}
let bufs = arr.data_buffers();
let a_bytes = unsafe {
bufs.get_unchecked(a_view.buffer_index as usize)
.get_unchecked(a_view.offset as usize..(a_view.offset as usize + a_len as usize))
};
let b_bytes = unsafe {
bufs.get_unchecked(b_view.buffer_index as usize)
.get_unchecked(b_view.offset as usize..(b_view.offset as usize + b_len as usize))
};
a_bytes == b_bytes
}
}

fn hash_input_row(
&self,
array: &ArrayRef,
row: usize,
random_state: &RandomState,
rehash: bool,
current_hash: u64,
) -> u64 {
if array.is_null(row) {
return current_hash;
}
// Get raw bytes for hashing (matches create_hashes byte-view path)
let arr = array.as_byte_view::<B>();
let view_len = arr.views()[row] as u32;
let bytes: &[u8] = if view_len <= 12 {
// Inlined view
let view = arr.views()[row];
unsafe {
std::slice::from_raw_parts(
((&view) as *const u128 as *const u8).add(4),
view_len as usize,
)
}
} else {
let view = ByteView::from(arr.views()[row]);
let bufs = arr.data_buffers();
unsafe {
bufs.get_unchecked(view.buffer_index as usize)
.get_unchecked(view.offset as usize..(view.offset as usize + view_len as usize))
}
};
let h = bytes.hash_one(random_state);
if rehash {
combine_hashes(h, current_hash)
} else {
h
}
}

fn compute_boundaries(&self, array: &ArrayRef, boundaries: &mut [bool]) {
let arr = array.as_byte_view::<B>();
let views = arr.views();
let has_nulls = array.null_count() > 0;
let bufs = arr.data_buffers();

for row in 1..array.len() {
if boundaries[row] {
continue;
}

if has_nulls {
let prev_null = array.is_null(row - 1);
let curr_null = array.is_null(row);
if prev_null != curr_null {
boundaries[row] = true;
continue;
}
if prev_null {
// both null => same group, not a boundary
continue;
}
}

// Compare the two views
let prev_view = views[row - 1];
let curr_view = views[row];

let prev_len = prev_view as u32;
let curr_len = curr_view as u32;

if prev_len != curr_len {
boundaries[row] = true;
continue;
}

if prev_len <= 12 {
// Both inlined: compare the full u128 views
if prev_view != curr_view {
boundaries[row] = true;
}
} else {
// Both non-inlined: compare prefixes, then full data if needed
let pv = ByteView::from(prev_view);
let cv = ByteView::from(curr_view);
if pv.prefix != cv.prefix {
boundaries[row] = true;
} else {
let prev_bytes = unsafe {
bufs.get_unchecked(pv.buffer_index as usize)
.get_unchecked(pv.offset as usize..(pv.offset as usize + prev_len as usize))
};
let curr_bytes = unsafe {
bufs.get_unchecked(cv.buffer_index as usize)
.get_unchecked(cv.offset as usize..(cv.offset as usize + curr_len as usize))
};
if prev_bytes != curr_bytes {
boundaries[row] = true;
}
}
}
}
}
}

#[cfg(test)]
Expand Down
Loading
Loading