Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions c++/src/BlockBuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
*/

#include "BlockBuffer.hh"
#include "Utils.hh"
#include "orc/OrcFile.hh"
#include "orc/Writer.hh"

#include <algorithm>
#include <stdexcept>

namespace orc {

Expand Down Expand Up @@ -51,10 +53,19 @@ namespace orc {
if (currentSize_ < currentCapacity_) {
Block emptyBlock(blocks_[currentSize_ / blockSize_] + currentSize_ % blockSize_,
blockSize_ - currentSize_ % blockSize_);
currentSize_ = (currentSize_ / blockSize_ + 1) * blockSize_;
uint64_t nextBlockNumber = currentSize_ / blockSize_ + 1;
uint64_t nextSize = 0;
if (multiplyWithOverflow(nextBlockNumber, blockSize_, &nextSize)) {
throw std::length_error("Block buffer size overflow");
}
currentSize_ = nextSize;
return emptyBlock;
} else {
resize(currentSize_ + blockSize_);
uint64_t nextSize = 0;
if (addWithOverflow(currentSize_, blockSize_, &nextSize)) {
throw std::length_error("Block buffer size overflow");
}
resize(nextSize);
return Block(blocks_.back(), blockSize_);
}
}
Expand All @@ -70,10 +81,19 @@ namespace orc {

void BlockBuffer::reserve(uint64_t newCapacity) {
while (currentCapacity_ < newCapacity) {
uint64_t nextCapacity = 0;
if (addWithOverflow(currentCapacity_, blockSize_, &nextCapacity)) {
throw std::length_error("Block buffer capacity overflow");
}
char* newBlockPtr = memoryPool_.malloc(blockSize_);
if (newBlockPtr != nullptr) {
blocks_.push_back(newBlockPtr);
currentCapacity_ += blockSize_;
try {
blocks_.push_back(newBlockPtr);
} catch (...) {
memoryPool_.free(newBlockPtr);
throw;
}
currentCapacity_ = nextCapacity;
} else {
break;
}
Expand Down
2 changes: 1 addition & 1 deletion c++/src/BlockBuffer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ namespace orc {
* Get the number of blocks that are fully or partially occupied
*/
uint64_t getBlockNumber() const {
return (currentSize_ + blockSize_ - 1) / blockSize_;
return currentSize_ / blockSize_ + (currentSize_ % blockSize_ == 0 ? 0 : 1);
}

uint64_t size() const {
Expand Down
107 changes: 85 additions & 22 deletions c++/src/ColumnReader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
#include "DictionaryLoader.hh"
#include "RLE.hh"
#include "SchemaEvolution.hh"
#include "Utils.hh"
#include "orc/Exceptions.hh"
#include "orc/Int128.hh"

#include <math.h>
#include <iostream>
#include <string>
#include <type_traits>

namespace orc {
Expand Down Expand Up @@ -127,6 +129,26 @@ namespace orc {
}
}

void addLengthToTotal(uint64_t* total, int64_t length, const char* columnKind) {
if (length < 0) {
throw ParseError(std::string("Negative length in ") + columnKind + " column");
}
uint64_t nextTotal = 0;
if (addWithOverflow(*total, static_cast<uint64_t>(length), &nextTotal) ||
nextTotal > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) {
throw ParseError(std::string("Length overflow in ") + columnKind + " column");
}
*total = nextTotal;
}

void incrementUnionChildCount(int64_t* counts, size_t tag) {
int64_t nextCount = 0;
if (addWithOverflow(counts[tag], static_cast<int64_t>(1), &nextCount)) {
throw ParseError("Union child count overflow");
}
counts[tag] = nextCount;
}

template <typename BatchType>
class BooleanColumnReader : public ColumnReader {
private:
Expand Down Expand Up @@ -428,11 +450,24 @@ namespace orc {
uint64_t numValues) {
numValues = ColumnReader::skip(numValues);

if (static_cast<size_t>(bufferEnd_ - bufferPointer_) >= bytesPerValue_ * numValues) {
bufferPointer_ += bytesPerValue_ * numValues;
if (numValues > static_cast<uint64_t>((std::numeric_limits<size_t>::max)())) {
throw ParseError("Double column skip size overflow");
}
size_t bytesToSkip = 0;
if (multiplyWithOverflow(static_cast<size_t>(bytesPerValue_), static_cast<size_t>(numValues),
&bytesToSkip)) {
throw ParseError("Double column skip size overflow");
}
if (bytesToSkip == 0) {
return numValues;
}

size_t bufferedBytes =
bufferPointer_ == nullptr ? 0 : static_cast<size_t>(bufferEnd_ - bufferPointer_);
if (bufferedBytes >= bytesToSkip) {
bufferPointer_ += bytesToSkip;
} else {
size_t sizeToSkip =
bytesPerValue_ * numValues - static_cast<size_t>(bufferEnd_ - bufferPointer_);
size_t sizeToSkip = bytesToSkip - bufferedBytes;
const size_t cap = static_cast<size_t>(std::numeric_limits<int>::max());
while (sizeToSkip != 0) {
size_t step = sizeToSkip > cap ? cap : sizeToSkip;
Expand Down Expand Up @@ -498,7 +533,7 @@ namespace orc {
if (!stream->Next(&chunk, &length)) {
throw ParseError("bad read in readFully");
}
if (posn + length > bufferSize) {
if (length < 0 || length > bufferSize - posn) {
throw ParseError("Corrupt dictionary blob in StringDictionaryColumn");
}
memcpy(buffer + posn, chunk, static_cast<size_t>(length));
Expand Down Expand Up @@ -666,7 +701,12 @@ namespace orc {
while (done < numValues) {
uint64_t step = std::min(BUFFER_SIZE, static_cast<size_t>(numValues - done));
lengthRle_->next(buffer, step, nullptr);
totalBytes += computeSize(buffer, nullptr, step);
size_t stepBytes = computeSize(buffer, nullptr, step);
size_t nextTotalBytes = 0;
if (addWithOverflow(totalBytes, stepBytes, &nextTotalBytes)) {
throw ParseError("String length overflow in StringDirectColumn");
}
totalBytes = nextTotalBytes;
done += step;
}
if (totalBytes <= lastBufferLength_) {
Expand Down Expand Up @@ -694,17 +734,38 @@ namespace orc {
size_t StringDirectColumnReader::computeSize(const int64_t* lengths, const char* notNull,
uint64_t numValues) {
size_t totalLength = 0;
bool hasNegativeLength = false;
bool hasLengthOverflow = false;
auto addLength = [&](int64_t value) {
if (value < 0) {
hasNegativeLength = true;
return;
}
size_t length = static_cast<size_t>(value);
size_t nextTotalLength = 0;
bool overflow = addWithOverflow(totalLength, length, &nextTotalLength);
hasLengthOverflow |= overflow;
if (!overflow) {
totalLength = nextTotalLength;
}
};
if (notNull) {
for (size_t i = 0; i < numValues; ++i) {
if (notNull[i]) {
totalLength += static_cast<size_t>(lengths[i]);
addLength(lengths[i]);
}
}
} else {
for (size_t i = 0; i < numValues; ++i) {
totalLength += static_cast<size_t>(lengths[i]);
addLength(lengths[i]);
}
}
if (hasNegativeLength) {
throw ParseError("Negative string length in StringDirectColumn");
}
if (hasLengthOverflow) {
throw ParseError("String length overflow in StringDirectColumn");
}
return totalLength;
}

Expand All @@ -728,7 +789,7 @@ namespace orc {
size_t bytesBuffered = 0;
byteBatch.blob.resize(totalLength);
char* ptr = byteBatch.blob.data();
while (bytesBuffered + lastBufferLength_ < totalLength) {
while (bytesBuffered < totalLength && lastBufferLength_ < totalLength - bytesBuffered) {
if (lastBuffer_ != nullptr) {
memcpy(ptr + bytesBuffered, lastBuffer_, lastBufferLength_);
}
Expand Down Expand Up @@ -922,7 +983,7 @@ namespace orc {
uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE);
rle_->next(buffer, chunk, nullptr);
for (size_t i = 0; i < chunk; ++i) {
childrenElements += static_cast<size_t>(buffer[i]);
addLengthToTotal(&childrenElements, buffer[i], "List");
}
lengthsRead += chunk;
}
Expand Down Expand Up @@ -954,18 +1015,18 @@ namespace orc {
if (notNull) {
for (size_t i = 0; i < numValues; ++i) {
if (notNull[i]) {
uint64_t tmp = static_cast<uint64_t>(offsets[i]);
int64_t length = offsets[i];
offsets[i] = static_cast<int64_t>(totalChildren);
totalChildren += tmp;
addLengthToTotal(&totalChildren, length, "List");
} else {
offsets[i] = static_cast<int64_t>(totalChildren);
}
}
} else {
for (size_t i = 0; i < numValues; ++i) {
uint64_t tmp = static_cast<uint64_t>(offsets[i]);
int64_t length = offsets[i];
offsets[i] = static_cast<int64_t>(totalChildren);
totalChildren += tmp;
addLengthToTotal(&totalChildren, length, "List");
}
}
offsets[numValues] = static_cast<int64_t>(totalChildren);
Expand Down Expand Up @@ -1050,7 +1111,7 @@ namespace orc {
uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE);
rle_->next(buffer, chunk, nullptr);
for (size_t i = 0; i < chunk; ++i) {
childrenElements += static_cast<size_t>(buffer[i]);
addLengthToTotal(&childrenElements, buffer[i], "Map");
}
lengthsRead += chunk;
}
Expand Down Expand Up @@ -1087,18 +1148,18 @@ namespace orc {
if (notNull) {
for (size_t i = 0; i < numValues; ++i) {
if (notNull[i]) {
uint64_t tmp = static_cast<uint64_t>(offsets[i]);
int64_t length = offsets[i];
offsets[i] = static_cast<int64_t>(totalChildren);
totalChildren += tmp;
addLengthToTotal(&totalChildren, length, "Map");
} else {
offsets[i] = static_cast<int64_t>(totalChildren);
}
}
} else {
for (size_t i = 0; i < numValues; ++i) {
uint64_t tmp = static_cast<uint64_t>(offsets[i]);
int64_t length = offsets[i];
offsets[i] = static_cast<int64_t>(totalChildren);
totalChildren += tmp;
addLengthToTotal(&totalChildren, length, "Map");
}
}
offsets[numValues] = static_cast<int64_t>(totalChildren);
Expand Down Expand Up @@ -1199,7 +1260,7 @@ namespace orc {
uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE);
rle_->next(reinterpret_cast<char*>(buffer), chunk, nullptr);
for (size_t i = 0; i < chunk; ++i) {
counts[getCheckedUnionTag(buffer[i], numChildren_)] += 1;
incrementUnionChildCount(counts, getCheckedUnionTag(buffer[i], numChildren_));
}
lengthsRead += chunk;
}
Expand Down Expand Up @@ -1236,13 +1297,15 @@ namespace orc {
for (size_t i = 0; i < numValues; ++i) {
if (notNull[i]) {
size_t tag = getCheckedUnionTag(tags[i], numChildren_);
offsets[i] = static_cast<uint64_t>(counts[tag]++);
offsets[i] = static_cast<uint64_t>(counts[tag]);
incrementUnionChildCount(counts, tag);
}
}
} else {
for (size_t i = 0; i < numValues; ++i) {
size_t tag = getCheckedUnionTag(tags[i], numChildren_);
offsets[i] = static_cast<uint64_t>(counts[tag]++);
offsets[i] = static_cast<uint64_t>(counts[tag]);
incrementUnionChildCount(counts, tag);
}
}
// read the right number of each child column
Expand Down
24 changes: 19 additions & 5 deletions c++/src/DictionaryLoader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "DictionaryLoader.hh"
#include "RLE.hh"
#include "Utils.hh"

namespace orc {

Expand All @@ -32,7 +33,7 @@ namespace orc {
if (!stream->Next(&chunk, &length)) {
throw ParseError("bad read in readFully");
}
if (posn + length > bufferSize) {
if (length < 0 || length > bufferSize - posn) {
throw ParseError("Corrupt dictionary blob");
}
memcpy(buffer + posn, chunk, static_cast<size_t>(length));
Expand Down Expand Up @@ -64,19 +65,32 @@ namespace orc {
createRleDecoder(std::move(stream), false, rleVersion, pool, stripe.getReaderMetrics());

// Decode dictionary entry lengths
dictionary->dictionaryOffset.resize(dictSize + 1);
uint64_t dictionaryOffsetSize = 0;
if (addWithOverflow(static_cast<uint64_t>(dictSize), static_cast<uint64_t>(1),
&dictionaryOffsetSize)) {
std::stringstream ss;
ss << "Dictionary size overflow for column " << columnId;
throw ParseError(ss.str());
}
dictionary->dictionaryOffset.resize(dictionaryOffsetSize);
int64_t* lengthArray = dictionary->dictionaryOffset.data();
lengthDecoder->next(lengthArray + 1, dictSize, nullptr);
lengthArray[0] = 0;

// Convert lengths to cumulative offsets
for (uint32_t i = 1; i < dictSize + 1; ++i) {
for (uint64_t i = 1; i < dictionaryOffsetSize; ++i) {
if (lengthArray[i] < 0) {
std::stringstream ss;
ss << "Negative dictionary entry length for column " << columnId;
throw ParseError(ss.str());
}
lengthArray[i] += lengthArray[i - 1];
int64_t nextOffset = 0;
if (addWithOverflow(lengthArray[i], lengthArray[i - 1], &nextOffset)) {
std::stringstream ss;
ss << "Dictionary entry length overflow for column " << columnId;
throw ParseError(ss.str());
}
lengthArray[i] = nextOffset;
}

int64_t blobSize = lengthArray[dictSize];
Expand All @@ -97,4 +111,4 @@ namespace orc {
return dictionary;
}

} // namespace orc
} // namespace orc
Loading
Loading