Skip to content
Open
584 changes: 569 additions & 15 deletions c/driver/postgresql/copy/postgres_copy_writer_test.cc

Large diffs are not rendered by default.

281 changes: 225 additions & 56 deletions c/driver/postgresql/copy/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <algorithm>
#include <charconv>
#include <cinttypes>
#include <limits>
Expand Down Expand Up @@ -224,82 +225,140 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
PostgresCopyNumericFieldWriter(int32_t precision, int32_t scale)
: precision_{precision}, scale_{scale} {}

// PostgreSQL NUMERIC Binary Format:
// ===================================
// PostgreSQL stores NUMERIC values in a variable-length binary format:
// - ndigits (int16): Number of base-10000 digits stored
// - weight (int16): Position of the first digit group relative to decimal point
// (weight can be negative for small fractional numbers)
// - sign (int16): kNumericPos (0x0000) or kNumericNeg (0x4000)
// - dscale (int16): Number of decimal digits after the decimal point (display scale)
// - digits[]: Array of int16 values, each 0-9999 (base-10000 representation)
//
// Value calculation: sum(digits[i] * 10000^(weight - i)) * 10^(-dscale)
//
// Example 1: 12300 (from Arrow Decimal value=123, scale=-2)
// - Logical representation: "12300"
// - Grouped in base-10000: [1][2300]
// - ndigits=2, weight=1, sign=0x0000, dscale=0, digits=[1, 2300]
// - Calculation: 1*10000^1 + 2300*10000^0 = 10000 + 2300 = 12300
//
// Example 2: 123.45 (from Arrow Decimal value=12345, scale=2)
// - Logical representation: "123.45"
// - Integer part "123", fractional part "45"
// - Grouped in base-10000: [123][4500] (fractional part right-padded)
// - ndigits=2, weight=0, sign=0x0000, dscale=2, digits=[123, 4500]
// - Calculation: 123*10000^0 + 4500*10000^(-1) = 123 + 0.45 = 123.45
//
// Example 3: 0.00123 (from Arrow Decimal value=123, scale=5)
// - Logical representation: "0.00123"
// - Integer part "0", fractional part "00123"
// - Grouped in base-10000: [123] (leading zeros skipped via negative weight)
// - ndigits=1, weight=-1, sign=0x0000, dscale=5, digits=[123]
// - Calculation: 123*10000^(-1) * 10^0 = 0.0123, but dscale=5 means display as
// 0.00123

ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
struct ArrowDecimal decimal;
ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);

const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : kNumericNeg;

// Number of decimal digits per Postgres digit
constexpr int kDecDigits = 4;
std::vector<int16_t> pg_digits;
int16_t weight = -(scale_ / kDecDigits);
int16_t dscale = scale_;
bool seen_decimal = scale_ == 0;
bool truncating_trailing_zeros = true;

char decimal_string[max_decimal_digits_ + 1];
int digits_remaining = DecimalToString<bitwidth_>(&decimal, decimal_string);
do {
const int start_pos =
digits_remaining < kDecDigits ? 0 : digits_remaining - kDecDigits;
const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
const std::string_view substr{decimal_string + start_pos, len};
int16_t val{};
std::from_chars(substr.data(), substr.data() + substr.size(), val);

if (val == 0) {
if (!seen_decimal && truncating_trailing_zeros) {
dscale -= kDecDigits;
}
} else {
pg_digits.insert(pg_digits.begin(), val);
if (!seen_decimal && truncating_trailing_zeros) {
if (val % 1000 == 0) {
dscale -= 3;
} else if (val % 100 == 0) {
dscale -= 2;
} else if (val % 10 == 0) {
dscale -= 1;
}
}
truncating_trailing_zeros = false;
}
digits_remaining -= kDecDigits;
if (digits_remaining <= 0) {
break;
}
weight++;

if (start_pos <= static_cast<int>(std::strlen(decimal_string)) - scale_) {
seen_decimal = true;
}
} while (true);

int16_t ndigits = pg_digits.size();
int32_t field_size_bytes = sizeof(ndigits) + sizeof(weight) + sizeof(sign) +
// Convert decimal to string and split into integer/fractional parts
// Example transformation for Arrow Decimal(value=12345, scale=2) representing 123.45:
// Input: decimal.value = 12345, scale_ = 2
// After DecimalToString: raw_decimal_string = "12345", original_digits = 5
// After SplitDecimalParts: parts.integer_part = "123"
// parts.fractional_part = "45"
// parts.effective_scale = 2
char raw_decimal_string[max_decimal_digits_ + 1];
int original_digits = DecimalToString<bitwidth_>(&decimal, raw_decimal_string);
DecimalParts parts = SplitDecimalParts(raw_decimal_string, original_digits, scale_);

// Group into PostgreSQL base-10000 representation
// After GroupIntegerDigits: int_digits = [123], weight = 0
// (groups "123" right-to-left: "123" → 123, only 1 group so weight = 0)
auto [int_digits, weight] = GroupIntegerDigits(parts.integer_part);

// After GroupFractionalDigits: frac_digits = [4500], final_weight = 0
// (groups "45" left-to-right with right-padding: "45" → "4500" → 4500)
auto [frac_digits, final_weight] =
GroupFractionalDigits(parts.fractional_part, weight, !parts.integer_part.empty());

// Combine digit arrays
// After combining: all_digits = [123, 4500]
std::vector<int16_t> all_digits = int_digits;
all_digits.insert(all_digits.end(), frac_digits.begin(), frac_digits.end());

// Calculate display scale by counting trailing zeros in the DECIMAL STRING
// For our example: frac_part="45" has 0 trailing zeros, effective_scale=2
// So dscale = 2 - 0 = 2 (2 fractional digits to display)
int trailing_zeros = 0;
for (int j = parts.fractional_part.length() - 1;
j >= 0 && parts.fractional_part[j] == '0'; j--) {
trailing_zeros++;
}
int16_t dscale = std::max<int16_t>(0, parts.effective_scale - trailing_zeros);

// Optimize: remove trailing zero digit groups from fractional part
int n_int_digit_groups = int_digits.size();
while (static_cast<int>(all_digits.size()) > n_int_digit_groups &&
all_digits.back() == 0) {
all_digits.pop_back();
}

// Handle zero special case
if (all_digits.empty()) {
final_weight = 0;
dscale = 0;
} else if (static_cast<int>(all_digits.size()) <= n_int_digit_groups) {
// All fractional digits were removed
dscale = 0;
}

if (dscale < 0) dscale = 0;

// Write PostgreSQL NUMERIC binary format to buffer
// Final values for our example: ndigits = 2
// final_weight = 0
// sign = 0x0000
// dscale = 2
// digits = [123, 4500]
// Binary output represents: 123 * 10000^0 + 4500 * 10000^(-1) = 123 + 0.45 = 123.45
int16_t ndigits = all_digits.size();
int32_t field_size_bytes = sizeof(ndigits) + sizeof(final_weight) + sizeof(sign) +
sizeof(dscale) + ndigits * sizeof(int16_t);

NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, final_weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));

const size_t pg_digit_bytes = sizeof(int16_t) * pg_digits.size();
const size_t pg_digit_bytes = sizeof(int16_t) * all_digits.size();
NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, pg_digit_bytes));
for (auto pg_digit : pg_digits) {
for (auto pg_digit : all_digits) {
WriteUnsafe<int16_t>(buffer, pg_digit);
}

return ADBC_STATUS_OK;
}

private:
// returns the length of the string
// Helper struct for organizing data flow between functions
struct DecimalParts {
std::string integer_part; // e.g., "12300" or "123"
std::string fractional_part; // e.g., "45" or "00123"
int effective_scale; // Scale after handling negative values
};

// Helper function implementations for decimal-to-PostgreSQL NUMERIC conversion

// Convert decimal to string (absolute value, no sign)
// Returns the length of the string
template <int32_t DEC_WIDTH>
int DecimalToString(struct ArrowDecimal* decimal, char* out) {
int DecimalToString(struct ArrowDecimal* decimal, char* out) const {
constexpr size_t nwords = (DEC_WIDTH == 128) ? 2 : 4;
uint8_t tmp[DEC_WIDTH / 8];
ArrowDecimalGetBytes(decimal, tmp);
Expand All @@ -322,10 +381,9 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
for (size_t i = 0; i < DEC_WIDTH; i++) {
int carry;

carry = (buf[nwords - 1] >= 0x7FFFFFFFFFFFFFFF);
carry = (buf[nwords - 1] > 0x7FFFFFFFFFFFFFFF);
for (size_t j = nwords - 1; j > 0; j--) {
buf[j] =
((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j - 1] >= 0x7FFFFFFFFFFFFFFF);
buf[j] = ((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j - 1] > 0x7FFFFFFFFFFFFFFF);
}
buf[0] = ((buf[0] << 1) & 0xFFFFFFFFFFFFFFFF);

Expand All @@ -350,6 +408,117 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
return ndigits;
}

DecimalParts SplitDecimalParts(const char* decimal_digits, int digit_count,
int scale) const {
// Virtual zeros represent the logical zeros appended for negative scale
// Example: value=123, scale=-2 → "123" with 2 virtual zeros = "12300"
const int virtual_zeros = (scale < 0) ? -scale : 0;
const int effective_scale = (scale < 0) ? 0 : scale;
const int total_logical_digits = digit_count + virtual_zeros;

// Calculate split point
const int n_int_digits = total_logical_digits > effective_scale
? total_logical_digits - effective_scale
: 0;
const int n_frac_digits = total_logical_digits - n_int_digits;

DecimalParts parts;
parts.effective_scale = effective_scale;

// Extract integer part
if (n_int_digits > 0) {
if (n_int_digits <= digit_count) {
// Integer part is within the original digits
parts.integer_part.assign(decimal_digits, n_int_digits);
} else {
// Integer part includes all original digits + virtual zeros
parts.integer_part.assign(decimal_digits, digit_count);
parts.integer_part.append(virtual_zeros, '0');
}
}

// Extract fractional part (only exists if scale > 0)
if (n_int_digits == 0 && total_logical_digits < effective_scale) {
// Small fractional: 0.00123 needs leading zeros
parts.fractional_part.assign(effective_scale - total_logical_digits, '0');
parts.fractional_part.append(decimal_digits, digit_count);
} else if (n_frac_digits > 0 && n_int_digits < digit_count) {
// Fractional part from remaining digits (virtual zeros don't appear in fractional
// part)
parts.fractional_part.assign(decimal_digits + n_int_digits,
digit_count - n_int_digits);
}

return parts;
}

std::pair<std::vector<int16_t>, int16_t> GroupIntegerDigits(
const std::string& int_part) const {
constexpr int kDecDigits = 4;
std::vector<int16_t> digits;

if (int_part.empty()) {
return {digits, -1}; // weight = -1 for pure fractional numbers
}

// Calculate weight: ceil(length / 4) - 1
int16_t weight = (int_part.length() + kDecDigits - 1) / kDecDigits - 1;

// Group right-to-left in chunks of 4
int i = int_part.length();
while (i > 0) {
int chunk_size = std::min(i, kDecDigits);
std::string_view chunk =
std::string_view(int_part).substr(i - chunk_size, chunk_size);

int16_t val{};
std::from_chars(chunk.data(), chunk.data() + chunk.size(), val);

// Skip trailing zeros
if (val != 0 || !digits.empty()) {
digits.insert(digits.begin(), val);
}
i -= chunk_size;
}

return {digits, weight};
}

std::pair<std::vector<int16_t>, int16_t> GroupFractionalDigits(
const std::string& frac_part, int16_t initial_weight, bool has_integer_part) const {
constexpr int kDecDigits = 4;
std::vector<int16_t> digits;
int16_t weight = initial_weight;

if (frac_part.empty()) {
return {digits, weight};
}

bool skip_leading_zeros = !has_integer_part;

// Group left-to-right in chunks of 4, right-padding last chunk
for (size_t i = 0; i < frac_part.length(); i += kDecDigits) {
int chunk_size = std::min(kDecDigits, static_cast<int>(frac_part.length() - i));
std::string chunk_str = frac_part.substr(i, chunk_size);

// Right-pad to 4 digits (e.g., "45" → "4500")
chunk_str.resize(kDecDigits, '0');

int16_t val{};
std::from_chars(chunk_str.data(), chunk_str.data() + chunk_str.size(), val);

if (skip_leading_zeros && val == 0) {
// Skip leading zero groups in fractional part (e.g., 0.0012 → skip "0012")
weight--;
} else {
digits.push_back(val);
skip_leading_zeros = false;
}
}

return {digits, weight};
}

static constexpr uint16_t kNumericPos = 0x0000;
static constexpr uint16_t kNumericNeg = 0x4000;
static constexpr int32_t bitwidth_ = (T == NANOARROW_TYPE_DECIMAL128) ? 128 : 256;
Expand Down
19 changes: 0 additions & 19 deletions c/driver/postgresql/validation/queries/ingest/decimal.toml

This file was deleted.

45 changes: 45 additions & 0 deletions c/driver/postgresql/validation/queries/ingest/decimal.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// part: expected_schema
{
"format": "+s",
"children": [
{
"name": "idx",
"format": "l",
"flags": ["nullable"]
},
{
"name": "value",
"format": "u",
"flags": ["nullable"],
"metadata": {
"ARROW:extension:name": "arrow.opaque",
"ARROW:extension:metadata": "{\"type_name\": \"numeric\", \"vendor_name\": \"PostgreSQL\"}"
}
}
]
}

// part: expected

{"idx": 0, "value": "0"}
{"idx": 1, "value": "123.45"}
{"idx": 2, "value": "-123.45"}
{"idx": 3, "value": "9999999.99"}
{"idx": 4, "value": "-9999999.99"}
Loading
Loading