Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 150 additions & 74 deletions src/main/java/org/apache/datasketches/count/CountMinSketch.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,30 @@
import org.apache.datasketches.common.Family;
import org.apache.datasketches.common.SketchesArgumentException;
import org.apache.datasketches.common.SketchesException;
import org.apache.datasketches.common.Util;
import org.apache.datasketches.common.positional.PositionalSegment;
import org.apache.datasketches.hash.MurmurHash3;
import org.apache.datasketches.tuple.Util;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.lang.foreign.MemorySegment;
import java.nio.charset.StandardCharsets;
import java.util.Random;

import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED;
import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED;
import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED;


/**
* Java implementation of the CountMin sketch data structure of Cormode and Muthukrishnan.
* This implementation is inspired by and compatible with the datasketches-cpp version by Charlie Dickens.
*
* The CountMin sketch is a probabilistic data structure that provides frequency estimates for items
* in a data stream. It uses multiple hash functions to distribute items across a two-dimensional array,
* providing approximate counts with configurable error bounds.
*
* Reference: http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf
*/
public class CountMinSketch {
private final byte numHashes_;
private final int numBuckets_;
Expand All @@ -39,6 +54,9 @@ public class CountMinSketch {
private final long[] sketchArray_;
private long totalWeight_;

// Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
private static final ThreadLocal<MemorySegment> LONG_SEGMENT =
ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[Long.BYTES]));

private enum Flag {
IS_EMPTY;
Expand All @@ -57,35 +75,64 @@ int mask() {
* @param seed The base hash seed
*/
CountMinSketch(final byte numHashes, final int numBuckets, final long seed) {
numHashes_ = numHashes;
numBuckets_ = numBuckets;
seed_ = seed;
hashSeeds_ = new long[numHashes];
sketchArray_ = new long[numHashes * numBuckets];
totalWeight_ = 0;
// Validate numHashes
if (numHashes <= 0) {
throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes);
}

// Validate numBuckets with clear mathematical justification
if (numBuckets <= 0) {
throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets);
}
if (numBuckets < 3) {
throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1.");
throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " +
"With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets));
}

// Check for potential overflow in array size calculation
// Use long arithmetic to detect overflow before casting
final long totalSize = (long) numHashes * (long) numBuckets;
if (totalSize > Integer.MAX_VALUE) {
throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets
+ " = " + totalSize + " > " + Integer.MAX_VALUE);
}

// This check is to ensure later compatibility with a Java implementation whose maximum size can only
// be 2^31-1. We check only against 2^30 for simplicity.
if (numBuckets * numHashes >= 1 << 30) {
throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" +
"Try reducing either the number of buckets or the number of hash functions.");
if (totalSize >= (1L << 30)) {
throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets
+ " = " + totalSize + " elements (~" + String.format("%d", totalSize * Long.BYTES / (1024 * 1024 * 1024)) + " GB). "
+ "Consider reducing numHashes or numBuckets.");
}

Random rand = new Random(seed);
numHashes_ = numHashes;
numBuckets_ = numBuckets;
seed_ = seed;
hashSeeds_ = new long[numHashes];
sketchArray_ = new long[(int) totalSize];
totalWeight_ = 0;

final Random rand = new Random(seed);
for (int i = 0; i < numHashes; i++) {
hashSeeds_[i] = rand.nextLong();
}
}

private long[] getHashes(byte[] item) {
long[] updateLocations = new long[numHashes_];
/**
* Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness.
*/
private static byte[] longToBytes(final long value) {
final MemorySegment segment = LONG_SEGMENT.get();
segment.set(JAVA_LONG_UNALIGNED, 0, value);
return segment.toArray(JAVA_BYTE);
}


private long[] getHashes(final byte[] item) {
final long[] updateLocations = new long[numHashes_];

for (int i = 0; i < numHashes_; i++) {
long[] index = MurmurHash3.hash(item, hashSeeds_[i]);
final long[] index = MurmurHash3.hash(item, hashSeeds_[i]);
updateLocations[i] = i * (long)numBuckets_ + Math.floorMod(index[0], numBuckets_);
}

Expand Down Expand Up @@ -145,11 +192,11 @@ public double getRelativeError() {
* @param confidence The desired confidence level between 0 and 1.
* @return Suggested number of hash functions.
*/
public static byte suggestNumHashes(double confidence) {
public static byte suggestNumHashes(final double confidence) {
if (confidence < 0 || confidence > 1) {
throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive).");
}
int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence)));
final int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence)));
return (byte) Math.min(value, 127);
}

Expand All @@ -158,7 +205,7 @@ public static byte suggestNumHashes(double confidence) {
* @param relativeError The desired relative error.
* @return Suggested number of buckets.
*/
public static int suggestNumBuckets(double relativeError) {
public static int suggestNumBuckets(final double relativeError) {
if (relativeError < 0.) {
throw new SketchesException("Relative error must be at least 0.");
}
Expand All @@ -171,8 +218,7 @@ public static int suggestNumBuckets(double relativeError) {
* @param weight The weight of the item.
*/
public void update(final long item, final long weight) {
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
update(longByte, weight);
update(longToBytes(item), weight);
}

/**
Expand All @@ -199,8 +245,8 @@ public void update(final byte[] item, final long weight) {
}

totalWeight_ += weight > 0 ? weight : -weight;
long[] hashLocations = getHashes(item);
for (long h : hashLocations) {
final long[] hashLocations = getHashes(item);
for (final long h : hashLocations) {
sketchArray_[(int) h] += weight;
}
}
Expand All @@ -211,8 +257,7 @@ public void update(final byte[] item, final long weight) {
* @return Estimated frequency.
*/
public long getEstimate(final long item) {
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
return getEstimate(longByte);
return getEstimate(longToBytes(item));
}

/**
Expand All @@ -239,10 +284,11 @@ public long getEstimate(final byte[] item) {
return 0;
}

long[] hashLocations = getHashes(item);
final long[] hashLocations = getHashes(item);
long res = sketchArray_[(int) hashLocations[0]];
for (long h : hashLocations) {
res = Math.min(res, sketchArray_[(int) h]);
// Start from index 1 to avoid processing first element twice
for (int i = 1; i < hashLocations.length; i++) {
res = Math.min(res, sketchArray_[(int) hashLocations[i]]);
}

return res;
Expand All @@ -254,8 +300,7 @@ public long getEstimate(final byte[] item) {
* @return Upper bound of estimated frequency.
*/
public long getUpperBound(final long item) {
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
return getUpperBound(longByte);
return getUpperBound(longToBytes(item));
}

/**
Expand All @@ -268,8 +313,8 @@ public long getUpperBound(final String item) {
return 0;
}

byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
return getUpperBound(strByte);
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
return getUpperBound(strByte);
}

/**
Expand All @@ -291,8 +336,7 @@ public long getUpperBound(final byte[] item) {
* @return Lower bound of estimated frequency.
*/
public long getLowerBound(final long item) {
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
return getLowerBound(longByte);
return getLowerBound(longToBytes(item));
}

/**
Expand All @@ -305,7 +349,7 @@ public long getLowerBound(final String item) {
return 0;
}

byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
return getLowerBound(strByte);
}

Expand All @@ -327,8 +371,8 @@ public void merge(final CountMinSketch other) {
throw new SketchesException("Cannot merge a sketch with itself");
}

boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() &&
getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_();
final boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_()
&& getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_();

if (!acceptableConfig) {
throw new SketchesException("Incompatible sketch configuration.");
Expand All @@ -342,39 +386,56 @@ public void merge(final CountMinSketch other) {
}

/**
* Serializes the sketch into the provided ByteBuffer.
* @param buf The ByteBuffer to write into.
* Returns the serialized size in bytes.
*/
private int getSerializedSizeBytes() {
final int preambleBytes = Family.COUNTMIN.getMinPreLongs() * Long.BYTES;
if (isEmpty()) {
return preambleBytes;
}
return preambleBytes + Long.BYTES + (sketchArray_.length * Long.BYTES);
}


/**
* Returns the sketch as a byte array.
*/
public void serialize(ByteArrayOutputStream buf) {
public byte[] toByteArray() {
final int serializedSizeBytes = getSerializedSizeBytes();
final byte[] bytes = new byte[serializedSizeBytes];
final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(bytes));

// Long 0
final int preambleLongs = Family.COUNTMIN.getMinPreLongs();
buf.write((byte) preambleLongs);
posSeg.setByte((byte) preambleLongs);
final int serialVersion = 1;
buf.write((byte) serialVersion);
posSeg.setByte((byte) serialVersion);
final int familyId = Family.COUNTMIN.getID();
buf.write((byte) familyId);
posSeg.setByte((byte) familyId);
final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0;
buf.write((byte)flagsByte);
posSeg.setByte((byte) flagsByte);
final int NULL_32 = 0;
buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array());
posSeg.setInt(NULL_32);

// Long 1
buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array());
buf.write(numHashes_);
short hashSeed = Util.computeSeedHash(seed_);
buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array());
posSeg.setInt(numBuckets_);
posSeg.setByte(numHashes_);
final short hashSeed = Util.computeSeedHash(seed_);
posSeg.setShort(hashSeed);
final byte NULL_8 = 0;
buf.write(NULL_8);
posSeg.setByte(NULL_8);

if (isEmpty()) {
return;
return bytes;
}

final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array();
buf.writeBytes(totWeightByte);
posSeg.setLong(totalWeight_);

for (long w: sketchArray_) {
buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array());
for (final long w: sketchArray_) {
posSeg.setLong(w);
}

return bytes;
}

/**
Expand All @@ -384,36 +445,51 @@ public void serialize(ByteArrayOutputStream buf) {
* @return The deserialized CountMinSketch.
*/
public static CountMinSketch deserialize(final byte[] b, final long seed) {
ByteBuffer buf = ByteBuffer.allocate(b.length);
buf.put(b);
buf.flip();

final byte preambleLongs = buf.get();
final byte serialVersion = buf.get();
final byte familyId = buf.get();
final byte flagsByte = buf.get();
final int NULL_32 = buf.getInt();
final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(b));

final byte preambleLongs = posSeg.getByte();
final byte serialVersion = posSeg.getByte();
final byte familyId = posSeg.getByte();
final byte flagsByte = posSeg.getByte();
posSeg.getInt(); // skip NULL_32

// Validate serialization format
final int expectedPreambleLongs = Family.COUNTMIN.getMinPreLongs();
if (preambleLongs != expectedPreambleLongs) {
throw new SketchesArgumentException("Preamble longs mismatch: expected " + expectedPreambleLongs
+ ", actual " + preambleLongs);
}
final int expectedSerialVersion = 1;
if (serialVersion != expectedSerialVersion) {
throw new SketchesArgumentException("Serial version mismatch: expected " + expectedSerialVersion
+ ", actual " + serialVersion);
}
final int expectedFamilyId = Family.COUNTMIN.getID();
if (familyId != expectedFamilyId) {
throw new SketchesArgumentException("Family ID mismatch: expected " + expectedFamilyId
+ ", actual " + familyId);
}

final int numBuckets = buf.getInt();
final byte numHashes = buf.get();
final short seedHash = buf.getShort();
final byte NULL_8 = buf.get();
final int numBuckets = posSeg.getInt();
final byte numHashes = posSeg.getByte();
final short seedHash = posSeg.getShort();
posSeg.getByte(); // skip NULL_8

if (seedHash != Util.computeSeedHash(seed)) {
throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
+ String.valueOf(Util.computeSeedHash(seed)));
throw new SketchesArgumentException("Incompatible seed hashes: " + seedHash + ", "
+ Util.computeSeedHash(seed));
}

CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
final CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
final boolean empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0;
if (empty) {
return cms;
}
long w = buf.getLong();
final long w = posSeg.getLong();
cms.totalWeight_ = w;

for (int i = 0; i < cms.sketchArray_.length; i++) {
cms.sketchArray_[i] = buf.getLong();
cms.sketchArray_[i] = posSeg.getLong();
}

return cms;
Expand Down
Loading