Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions src/ipc/metadata/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ export class Message<T extends MessageHeader = any> {
const bodyLength: bigint = _message.bodyLength()!;
const version: MetadataVersion = _message.version();
const headerType: MessageHeader = _message.headerType();
const message = new Message(bodyLength, version, headerType);
const metadata = decodeMessageCustomMetadata(_message);
const message = new Message(bodyLength, version, headerType, undefined, metadata);
message._createHeader = decodeMessageHeader(_message, headerType);
return message;
}
Expand All @@ -98,22 +99,35 @@ export class Message<T extends MessageHeader = any> {
} else if (message.isDictionaryBatch()) {
headerOffset = DictionaryBatch.encode(b, message.header() as DictionaryBatch);
}

// Encode custom metadata if present (must be done before startMessage)
const customMetadataOffset = !(message.metadata && message.metadata.size > 0) ? -1 :
_Message.createCustomMetadataVector(b, [...message.metadata].map(([k, v]) => {
const key = b.createString(`${k}`);
const val = b.createString(`${v}`);
_KeyValue.startKeyValue(b);
_KeyValue.addKey(b, key);
_KeyValue.addValue(b, val);
return _KeyValue.endKeyValue(b);
}));

_Message.startMessage(b);
_Message.addVersion(b, MetadataVersion.V5);
_Message.addHeader(b, headerOffset);
_Message.addHeaderType(b, message.headerType);
_Message.addBodyLength(b, BigInt(message.bodyLength));
if (customMetadataOffset !== -1) { _Message.addCustomMetadata(b, customMetadataOffset); }
_Message.finishMessageBuffer(b, _Message.endMessage(b));
return b.asUint8Array();
}

/** @nocollapse */
public static from(header: Schema | RecordBatch | DictionaryBatch, bodyLength = 0) {
public static from(header: Schema | RecordBatch | DictionaryBatch, bodyLength = 0, metadata?: Map<string, string>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the second metadata argument is only relevant if we're serializing a RecordBatch message, can we just use its metadata field instead?

Suggested change
public static from(header: Schema | RecordBatch | DictionaryBatch, bodyLength = 0, metadata?: Map<string, string>) {
public static from(header: Schema | RecordBatch | DictionaryBatch, bodyLength = 0) {

if (header instanceof Schema) {
return new Message(0, MetadataVersion.V5, MessageHeader.Schema, header);
}
if (header instanceof RecordBatch) {
return new Message(bodyLength, MetadataVersion.V5, MessageHeader.RecordBatch, header);
return new Message(bodyLength, MetadataVersion.V5, MessageHeader.RecordBatch, header, metadata);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return new Message(bodyLength, MetadataVersion.V5, MessageHeader.RecordBatch, header, metadata);
return new Message(bodyLength, MetadataVersion.V5, MessageHeader.RecordBatch, header, header.metadata);

}
if (header instanceof DictionaryBatch) {
return new Message(bodyLength, MetadataVersion.V5, MessageHeader.DictionaryBatch, header);
Expand All @@ -126,24 +140,27 @@ export class Message<T extends MessageHeader = any> {
protected _bodyLength: number;
protected _version: MetadataVersion;
protected _compression: BodyCompression | null;
protected _metadata: Map<string, string>;
public get type() { return this.headerType; }
public get version() { return this._version; }
public get headerType() { return this._headerType; }
public get compression() { return this._compression; }
public get bodyLength() { return this._bodyLength; }
public get metadata() { return this._metadata; }
declare protected _createHeader: MessageHeaderDecoder;
public header() { return this._createHeader<T>(); }
public isSchema(): this is Message<MessageHeader.Schema> { return this.headerType === MessageHeader.Schema; }
public isRecordBatch(): this is Message<MessageHeader.RecordBatch> { return this.headerType === MessageHeader.RecordBatch; }
public isDictionaryBatch(): this is Message<MessageHeader.DictionaryBatch> { return this.headerType === MessageHeader.DictionaryBatch; }

constructor(bodyLength: bigint | number, version: MetadataVersion, headerType: T, header?: any) {
constructor(bodyLength: bigint | number, version: MetadataVersion, headerType: T, header?: any, metadata?: Map<string, string>) {
this._version = version;
this._headerType = headerType;
this.body = new Uint8Array(0);
this._compression = header?.compression;
header && (this._createHeader = () => header);
this._bodyLength = bigIntToNumber(bodyLength);
this._metadata = metadata || new Map();
}
}

Expand Down Expand Up @@ -468,6 +485,17 @@ function decodeCustomMetadata(parent?: _Schema | _Field | null) {
return data;
}

/** @ignore */
function decodeMessageCustomMetadata(message: _Message) {
const data = new Map<string, string>();
for (let entry, key, i = -1, n = Math.trunc(message.customMetadataLength()); ++i < n;) {
if ((entry = message.customMetadata(i)) && (key = entry.key()) != null) {
data.set(key, entry.value()!);
}
}
return data;
}

/** @ignore */
function decodeIndexType(_type: _Int) {
return new Int(_type.isSigned(), _type.bitWidth() as IntBitWidth);
Expand Down
12 changes: 6 additions & 6 deletions src/ipc/reader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ abstract class RecordBatchReaderImpl<T extends TypeMap = any> implements RecordB
return this;
}

protected _loadRecordBatch(header: metadata.RecordBatch, body: Uint8Array): RecordBatch<T> {
protected _loadRecordBatch(header: metadata.RecordBatch, body: Uint8Array, messageMetadata?: Map<string, string>): RecordBatch<T> {
let children: Data<any>[];
if (header.compression != null) {
const codec = compressionRegistry.get(header.compression.type);
Expand All @@ -379,7 +379,7 @@ abstract class RecordBatchReaderImpl<T extends TypeMap = any> implements RecordB
}

const data = makeData({ type: new Struct(this.schema.fields), length: header.length, children });
return new RecordBatch(this.schema, data);
return new RecordBatch(this.schema, data, messageMetadata);
}

protected _loadDictionaryBatch(header: metadata.DictionaryBatch, body: Uint8Array) {
Expand Down Expand Up @@ -512,7 +512,7 @@ class RecordBatchStreamReaderImpl<T extends TypeMap = any> extends RecordBatchRe
this._recordBatchIndex++;
const header = message.header();
const buffer = reader.readMessageBody(message.bodyLength);
const recordBatch = this._loadRecordBatch(header, buffer);
const recordBatch = this._loadRecordBatch(header, buffer, message.metadata);
return { done: false, value: recordBatch };
} else if (message.isDictionaryBatch()) {
this._dictionaryIndex++;
Expand Down Expand Up @@ -587,7 +587,7 @@ class AsyncRecordBatchStreamReaderImpl<T extends TypeMap = any> extends RecordBa
this._recordBatchIndex++;
const header = message.header();
const buffer = await reader.readMessageBody(message.bodyLength);
const recordBatch = this._loadRecordBatch(header, buffer);
const recordBatch = this._loadRecordBatch(header, buffer, message.metadata);
return { done: false, value: recordBatch };
} else if (message.isDictionaryBatch()) {
this._dictionaryIndex++;
Expand Down Expand Up @@ -640,7 +640,7 @@ class RecordBatchFileReaderImpl<T extends TypeMap = any> extends RecordBatchStre
if (message?.isRecordBatch()) {
const header = message.header();
const buffer = this._reader.readMessageBody(message.bodyLength);
const recordBatch = this._loadRecordBatch(header, buffer);
const recordBatch = this._loadRecordBatch(header, buffer, message.metadata);
return recordBatch;
}
}
Expand Down Expand Up @@ -714,7 +714,7 @@ class AsyncRecordBatchFileReaderImpl<T extends TypeMap = any> extends AsyncRecor
if (message?.isRecordBatch()) {
const header = message.header();
const buffer = await this._reader.readMessageBody(message.bodyLength);
const recordBatch = this._loadRecordBatch(header, buffer);
const recordBatch = this._loadRecordBatch(header, buffer, message.metadata);
return recordBatch;
}
}
Expand Down
31 changes: 27 additions & 4 deletions src/ipc/writer.ts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the Message reads the metadata from the RecordBatch, this can all be simplified.

Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,18 @@ export class RecordBatchWriter<T extends TypeMap = any> extends ReadableInterop<
return this;
}

public write(payload?: Table<T> | RecordBatch<T> | Iterable<RecordBatch<T>> | null) {
/**
* Write a RecordBatch to the stream with optional custom metadata.
* @param payload The RecordBatch, Table, or iterable of RecordBatches to write
* @param customMetadata Optional custom metadata to attach to the message (only used when payload is a single RecordBatch)
*/
public write(payload?: Table<T> | RecordBatch<T> | Iterable<RecordBatch<T>> | null, customMetadata?: Map<string, string>): void;
// Overload for UnderlyingSink compatibility (used by DOM streams)
public write(chunk: RecordBatch<T>, controller: WritableStreamDefaultController): void;
public write(payload?: Table<T> | RecordBatch<T> | Iterable<RecordBatch<T>> | null, customMetadataOrController?: Map<string, string> | WritableStreamDefaultController) {
// Determine if second argument is customMetadata (Map) or controller (WritableStreamDefaultController)
const customMetadata = customMetadataOrController instanceof Map ? customMetadataOrController : undefined;

let schema: Schema<T> | null = null;

if (!this._sink) {
Expand All @@ -207,7 +218,7 @@ export class RecordBatchWriter<T extends TypeMap = any> extends ReadableInterop<

if (payload instanceof RecordBatch) {
if (!(payload instanceof _InternalEmptyPlaceholderRecordBatch)) {
this._writeRecordBatch(payload);
this._writeRecordBatch(payload, customMetadata);
}
} else if (payload instanceof Table) {
this.writeAll(payload.batches);
Expand Down Expand Up @@ -273,10 +284,12 @@ export class RecordBatchWriter<T extends TypeMap = any> extends ReadableInterop<
return nBytes > 0 ? this._write(new Uint8Array(nBytes)) : this;
}

protected _writeRecordBatch(batch: RecordBatch<T>) {
protected _writeRecordBatch(batch: RecordBatch<T>, customMetadata?: Map<string, string>) {
const { byteLength, nodes, bufferRegions, buffers, variadicBufferCounts } = this._assembleRecordBatch(batch);
const recordBatch = new metadata.RecordBatch(batch.numRows, nodes, bufferRegions, this._compression, variadicBufferCounts);
const message = Message.from(recordBatch, byteLength);
// Merge batch.metadata with customMetadata (customMetadata takes precedence)
const mergedMetadata = mergeMetadata(batch.metadata, customMetadata);
const message = Message.from(recordBatch, byteLength, mergedMetadata);
return this
._writeDictionaries(batch)
._writeMessage(message)
Expand Down Expand Up @@ -589,3 +602,13 @@ function recordBatchToJSON(records: RecordBatch) {
'columns': columns
}, null, 2);
}

/** @ignore */
function mergeMetadata(base?: Map<string, string>, override?: Map<string, string>): Map<string, string> | undefined {
if (!base?.size && !override?.size) { return undefined; }
const merged = new Map(base);
if (override) {
for (const [k, v] of override) { merged.set(k, v); }
}
return merged;
}
25 changes: 17 additions & 8 deletions src/recordbatch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ export interface RecordBatch<T extends TypeMap = any> {
export class RecordBatch<T extends TypeMap = any> {

constructor(columns: { [P in keyof T]: Data<T[P]> });
constructor(schema: Schema<T>, data?: Data<Struct<T>>);
constructor(schema: Schema<T>, data?: Data<Struct<T>>, metadata?: Map<string, string>);
constructor(...args: any[]) {
switch (args.length) {
case 3:
case 2: {
[this.schema] = args;
if (!(this.schema instanceof Schema)) {
Expand All @@ -60,7 +61,8 @@ export class RecordBatch<T extends TypeMap = any> {
nullCount: 0,
type: new Struct<T>(this.schema.fields),
children: this.schema.fields.map((f) => makeData({ type: f.type, nullCount: 0 }))
})
}),
this._metadata = new Map()
] = args;
if (!(this.data instanceof Data)) {
throw new TypeError('RecordBatch constructor expects a [Schema, Data] pair.');
Expand All @@ -84,17 +86,24 @@ export class RecordBatch<T extends TypeMap = any> {
const schema = new Schema<T>(fields);
const data = makeData({ type: new Struct<T>(fields), length, children, nullCount: 0 });
[this.schema, this.data] = ensureSameLengthData<T>(schema, data.children as Data<T[keyof T]>[], length);
this._metadata = new Map();
break;
}
default: throw new TypeError('RecordBatch constructor expects an Object mapping names to child Data, or a [Schema, Data] pair.');
}
}

protected _dictionaries?: Map<number, Vector>;
protected _metadata: Map<string, string>;

public readonly schema: Schema<T>;
public readonly data: Data<Struct<T>>;

/**
* Custom metadata for this RecordBatch.
*/
public get metadata() { return this._metadata; }

public get dictionaries() {
return this._dictionaries || (this._dictionaries = collectDictionaries(this.schema.fields, this.data.children));
}
Expand Down Expand Up @@ -188,7 +197,7 @@ export class RecordBatch<T extends TypeMap = any> {
*/
public slice(begin?: number, end?: number): RecordBatch<T> {
const [slice] = new Vector([this.data]).slice(begin, end).data;
return new RecordBatch(this.schema, slice);
return new RecordBatch(this.schema, slice, this._metadata);
}

/**
Expand Down Expand Up @@ -240,7 +249,7 @@ export class RecordBatch<T extends TypeMap = any> {
schema = new Schema(fields, new Map(this.schema.metadata));
data = makeData({ type: new Struct<T>(fields), children });
}
return new RecordBatch(schema, data);
return new RecordBatch(schema, data, this._metadata);
}

/**
Expand All @@ -259,7 +268,7 @@ export class RecordBatch<T extends TypeMap = any> {
children[index] = this.data.children[index] as Data<T[K]>;
}
}
return new RecordBatch(schema, makeData({ type, length: this.numRows, children }));
return new RecordBatch(schema, makeData({ type, length: this.numRows, children }), this._metadata);
}

/**
Expand All @@ -272,7 +281,7 @@ export class RecordBatch<T extends TypeMap = any> {
const schema = this.schema.selectAt<K>(columnIndices);
const children = columnIndices.map((i) => this.data.children[i]).filter(Boolean);
const subset = makeData({ type: new Struct(schema.fields), length: this.numRows, children });
return new RecordBatch<{ [P in keyof K]: K[P] }>(schema, subset);
return new RecordBatch<{ [P in keyof K]: K[P] }>(schema, subset, this._metadata);
}

// Initialize this static property via an IIFE so bundlers don't tree-shake
Expand Down Expand Up @@ -347,9 +356,9 @@ function collectDictionaries(fields: Field[], children: readonly Data[], diction
* @private
*/
export class _InternalEmptyPlaceholderRecordBatch<T extends TypeMap = any> extends RecordBatch<T> {
constructor(schema: Schema<T>) {
constructor(schema: Schema<T>, metadata?: Map<string, string>) {
const children = schema.fields.map((f) => makeData({ type: f.type }));
const data = makeData({ type: new Struct<T>(schema.fields), nullCount: 0, children });
super(schema, data);
super(schema, data, metadata || new Map());
}
}
Binary file added test/data/test_message_metadata.arrow
Binary file not shown.
97 changes: 97 additions & 0 deletions test/unit/ipc/reader/message-metadata-tests.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import { readFileSync } from 'node:fs';
import path from 'node:path';
import { tableFromIPC, RecordBatch } from 'apache-arrow';

// Path to the test file with message-level metadata
// Use process.cwd() since tests are run from project root
const testFilePath = path.resolve(process.cwd(), 'test/data/test_message_metadata.arrow');

describe('RecordBatch message metadata', () => {
const buffer = readFileSync(testFilePath);
const table = tableFromIPC(buffer);

test('should read RecordBatch metadata from IPC file', () => {
expect(table.batches).toHaveLength(3);

for (let i = 0; i < table.batches.length; i++) {
const batch = table.batches[i];
expect(batch).toBeInstanceOf(RecordBatch);
expect(batch.metadata).toBeInstanceOf(Map);
expect(batch.metadata.size).toBeGreaterThan(0);

// Verify specific metadata keys exist
expect(batch.metadata.has('batch_index')).toBe(true);
expect(batch.metadata.has('batch_id')).toBe(true);
expect(batch.metadata.has('producer')).toBe(true);

// Verify batch_index matches the batch position
expect(batch.metadata.get('batch_index')).toBe(String(i));
expect(batch.metadata.get('batch_id')).toBe(`batch_${String(i).padStart(4, '0')}`);
}
});

test('should read unicode metadata values', () => {
const batch = table.batches[0];
expect(batch.metadata.has('unicode_test')).toBe(true);
expect(batch.metadata.get('unicode_test')).toBe('Hello 世界 🌍 مرحبا');
});

test('should handle empty metadata values', () => {
const batch = table.batches[0];
expect(batch.metadata.has('optional_field')).toBe(true);
expect(batch.metadata.get('optional_field')).toBe('');
});

test('should read JSON metadata values', () => {
const batch = table.batches[0];
expect(batch.metadata.has('batch_info_json')).toBe(true);
const jsonStr = batch.metadata.get('batch_info_json')!;
const parsed = JSON.parse(jsonStr);
expect(parsed.batch_number).toBe(0);
expect(parsed.processing_stage).toBe('final');
expect(parsed.tags).toEqual(['validated', 'complete']);
});

describe('metadata preservation', () => {
test('should preserve metadata through slice()', () => {
const batch = table.batches[0];
const sliced = batch.slice(0, 2);
expect(sliced.metadata).toBeInstanceOf(Map);
expect(sliced.metadata.size).toBe(batch.metadata.size);
expect(sliced.metadata.get('batch_index')).toBe(batch.metadata.get('batch_index'));
});

test('should preserve metadata through select()', () => {
const batch = table.batches[0];
const selected = batch.select(['id', 'name']);
expect(selected.metadata).toBeInstanceOf(Map);
expect(selected.metadata.size).toBe(batch.metadata.size);
expect(selected.metadata.get('batch_index')).toBe(batch.metadata.get('batch_index'));
});

test('should preserve metadata through selectAt()', () => {
const batch = table.batches[0];
const selectedAt = batch.selectAt([0, 1]);
expect(selectedAt.metadata).toBeInstanceOf(Map);
expect(selectedAt.metadata.size).toBe(batch.metadata.size);
expect(selectedAt.metadata.get('batch_index')).toBe(batch.metadata.get('batch_index'));
});
});
});
Loading