Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions cpp/celeborn/client/reader/CelebornInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ bool CelebornInputStream::moveToNextChunk() {

if (currReader_->hasNext()) {
currChunk_ = getNextChunk();
return true;
if (currChunk_) {
return true;
}
}
if (currLocationIndex_ < locations_.size()) {
moveToNextReader();
Expand All @@ -169,11 +171,59 @@ bool CelebornInputStream::moveToNextChunk() {

std::unique_ptr<memory::ReadOnlyByteBuffer>
CelebornInputStream::getNextChunk() {
// TODO: support the failure retrying, including excluding the failed
// location, open a reader to read from the location's peer.
auto chunk = currReader_->next();
verifyChunk(chunk);
return std::move(chunk);
while (fetchChunkRetryCnt_ < fetchChunkMaxRetry_) {
try {
if (isExcluded(currReader_->getLocation())) {
CELEBORN_FAIL(
"Fetch data from excluded worker! {}",
Comment thread
afterincomparableyum marked this conversation as resolved.
currReader_->getLocation().hostAndFetchPort());
}
if (!currReader_->hasNext()) {
return nullptr;
}
auto chunk = currReader_->next();
verifyChunk(chunk);
return std::move(chunk);
} catch (const std::exception& e) {
auto failedLocation = currReader_->getLocation();
shuffleClient_->excludeFailedFetchLocation(
failedLocation.hostAndFetchPort(), e);
fetchChunkRetryCnt_++;
currReader_ = nullptr;

if (fetchChunkRetryCnt_ == fetchChunkMaxRetry_) {
CELEBORN_FAIL(
"Fetch chunk failed for {} times for location {}. Error: {}",
fetchChunkRetryCnt_,
failedLocation.hostAndFetchPort(),
e.what());
}

if (failedLocation.hasPeer() && !readSkewPartitionWithoutMapRange_) {
LOG(WARNING) << "Fetch chunk failed " << fetchChunkRetryCnt_ << "/"
<< fetchChunkMaxRetry_ << " times for location "
<< failedLocation.hostAndFetchPort()
<< ", change to peer. Error: " << e.what();
// fetchChunkRetryCnt_ % 2 == 0 means both replicas have been tried,
// so sleep before next try.
if (fetchChunkRetryCnt_ % 2 == 0) {
std::this_thread::sleep_for(retryWait_);
}
currReader_ = createReaderWithRetry(*failedLocation.getPeer());
} else {
LOG(WARNING) << "Fetch chunk failed " << fetchChunkRetryCnt_ << "/"
<< fetchChunkMaxRetry_ << " times for location "
<< failedLocation.hostAndFetchPort()
<< ". Error: " << e.what();
std::this_thread::sleep_for(retryWait_);
// TODO: Pass checkpoint metadata when supported to skip
// already-read chunks, improving retry performance.
currReader_ = createReaderWithRetry(failedLocation);
}
}
}

CELEBORN_FAIL("Fetch chunk failed!");
}

void CelebornInputStream::verifyChunk(
Expand Down Expand Up @@ -204,7 +254,9 @@ void CelebornInputStream::moveToNextReader() {
currLocationIndex_++;
if (currReader_->hasNext()) {
currChunk_ = getNextChunk();
return;
if (currChunk_) {
return;
}
}
moveToNextReader();
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/celeborn/client/reader/WorkerPartitionReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ WorkerPartitionReader::~WorkerPartitionReader() {
client_->sendRpcRequestWithoutResponse(request);
}

const protocol::PartitionLocation& WorkerPartitionReader::getLocation() const {
return location_;
}

bool WorkerPartitionReader::hasNext() {
return toConsumeChunkId_ < streamHandler_->numChunks;
}
Expand Down
8 changes: 8 additions & 0 deletions cpp/celeborn/client/reader/WorkerPartitionReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class PartitionReader {
virtual bool hasNext() = 0;

virtual std::unique_ptr<memory::ReadOnlyByteBuffer> next() = 0;

virtual const protocol::PartitionLocation& getLocation() const = 0;
};

class WorkerPartitionReader
Expand All @@ -51,6 +53,8 @@ class WorkerPartitionReader

std::unique_ptr<memory::ReadOnlyByteBuffer> next() override;

const protocol::PartitionLocation& getLocation() const override;

private:
// Disable creating the object directly to make sure that
// std::enable_shared_from_this works properly.
Expand Down Expand Up @@ -88,6 +92,10 @@ class WorkerPartitionReader
static constexpr auto kDefaultConsumeIter = std::chrono::milliseconds(500);

// TODO: add other params, such as fetchChunkRetryCnt, fetchChunkMaxRetries
// TODO: add TEST_CLIENT_FETCH_FAILURE support (matching Java's testFetch
// flag) to enable integration testing of getNextChunk() retry logic, as
// done by Java's ReadWriteTestWithFailures. This will likely be done once
// full C++ write support is achieved.
};
} // namespace client
} // namespace celeborn
244 changes: 244 additions & 0 deletions cpp/celeborn/client/tests/CelebornInputStreamRetryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using namespace celeborn::client;
using namespace celeborn::network;
using namespace celeborn::protocol;
using namespace celeborn::conf;
using namespace celeborn::memory;

namespace {
using MS = std::chrono::milliseconds;
Expand Down Expand Up @@ -180,6 +181,117 @@ std::shared_ptr<CelebornConf> makeTestConf(bool replicateEnabled = true) {
CelebornConf::kClientFetchExcludeWorkerOnFailureEnabled, "true");
return conf;
}
// Creates a valid PbStreamHandler RpcResponse for WorkerPartitionReader
// construction. Each copy is safe to use independently (RpcResponse clones
// the body on copy).
RpcResponse makeStreamHandlerResponse(int numChunks = 1) {
PbStreamHandler pb;
pb.set_streamid(100);
pb.set_numchunks(numChunks);
for (int i = 0; i < numChunks; i++) {
pb.add_chunkoffsets(i);
}
pb.set_fullpath("test-fullpath");
TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
return RpcResponse(1111, transportMessage.toReadOnlyByteBuffer());
}

// Creates a chunk buffer
std::unique_ptr<ReadOnlyByteBuffer> makeChunkBuffer(
int mapId,
int attemptId,
int batchId,
const std::string& payload) {
const size_t totalSize = 4 * sizeof(int) + payload.size();
auto buffer = ByteBuffer::createWriteOnly(totalSize, false);
buffer->writeLE<int>(mapId);
buffer->writeLE<int>(attemptId);
buffer->writeLE<int>(batchId);
buffer->writeLE<int>(static_cast<int>(payload.size()));
buffer->writeFromString(payload);
return ByteBuffer::toReadOnly(std::move(buffer));
}

// A TransportClient whose sendRpcRequestSync always returns a valid
// stream handler, and whose fetchChunkAsync walks through a pre-configured
// sequence of success/failure behaviors. Used to exercise the getNextChunk()
// retry loop independently of reader-creation failures.
class SequencedMockTransportClient : public TransportClient {
public:
SequencedMockTransportClient()
: TransportClient(nullptr, nullptr, MS(100)),
streamHandlerResponse_(makeStreamHandlerResponse()) {}

RpcResponse sendRpcRequestSync(const RpcRequest& request, Timeout timeout)
override {
return streamHandlerResponse_;
}

void sendRpcRequestWithoutResponse(const RpcRequest& request) override {}

void fetchChunkAsync(
const StreamChunkSlice& streamChunkSlice,
const RpcRequest& request,
FetchChunkSuccessCallback onSuccess,
FetchChunkFailureCallback onFailure) override {
auto idx = fetchCallIdx_++;
if (idx < fetchBehaviors_.size()) {
fetchBehaviors_[idx](streamChunkSlice, onSuccess, onFailure);
}
}

using FetchBehavior = std::function<void(
const StreamChunkSlice&,
FetchChunkSuccessCallback,
FetchChunkFailureCallback)>;

void addFetchSuccess(std::unique_ptr<ReadOnlyByteBuffer> chunk) {
auto iobuf = std::shared_ptr<folly::IOBuf>(chunk->getData());
fetchBehaviors_.push_back([iobuf](
const StreamChunkSlice& slice,
FetchChunkSuccessCallback onSuccess,
FetchChunkFailureCallback) {
onSuccess(slice, ByteBuffer::createReadOnly(iobuf->clone(), false));
});
}

void addFetchFailure(const std::string& errorMessage) {
fetchBehaviors_.push_back([errorMessage](
const StreamChunkSlice& slice,
FetchChunkSuccessCallback,
FetchChunkFailureCallback onFailure) {
onFailure(slice, std::make_unique<std::runtime_error>(errorMessage));
});
}

private:
RpcResponse streamHandlerResponse_;
std::vector<FetchBehavior> fetchBehaviors_;
size_t fetchCallIdx_{0};
};

class SequencedMockClientFactory : public TransportClientFactory {
public:
explicit SequencedMockClientFactory(
std::shared_ptr<SequencedMockTransportClient> client)
: TransportClientFactory(std::make_shared<CelebornConf>()),
client_(std::move(client)) {}

std::shared_ptr<TransportClient> createClient(
const std::string& host,
uint16_t port) override {
hosts_.push_back(host);
return client_;
}

const std::vector<std::string>& hosts() const {
return hosts_;
}

private:
std::shared_ptr<SequencedMockTransportClient> client_;
std::vector<std::string> hosts_;
};
} // namespace

// Verifies that createReaderWithRetry exhausts all retries and throws.
Expand Down Expand Up @@ -408,4 +520,136 @@ TEST(CelebornInputStreamRetryTest, replicationDoublesMaxRetries) {
// With maxRetriesForEachReplica=2 and replication enabled,
// fetchChunkMaxRetry = 2 * 2 = 4 total attempts
EXPECT_EQ(factory->hosts().size(), 4u);
}

// getNextChunk() retry tests
// These tests exercise the retry loop inside getNextChunk(), which is
// triggered when a successfully-created reader's next() call fails during
// chunk fetching.

// Verifies that when a chunk fetch fails on the primary, getNextChunk()
// switches to the peer replica and successfully reads data on retry.
TEST(CelebornInputStreamRetryTest, fetchChunkRetrySucceedsWithPeerSwitch) {
auto mockClient = std::make_shared<SequencedMockTransportClient>();
mockClient->addFetchFailure("chunk fetch failed");
const std::string payload = "hello";
mockClient->addFetchSuccess(makeChunkBuffer(0, 0, 0, payload));

auto factory = std::make_shared<SequencedMockClientFactory>(mockClient);
auto conf = makeTestConf(true);
auto excludedWorkers =
std::make_shared<CelebornInputStream::FetchExcludedWorkers>();
StubShuffleClient shuffleClient(conf, excludedWorkers);

auto location = makeLocationWithPeer();
std::vector<std::shared_ptr<const PartitionLocation>> locations;
locations.push_back(std::move(location));
std::vector<int> attempts = {0};

CelebornInputStream stream(
"test-shuffle-key",
conf,
factory,
std::move(locations),
attempts,
0,
0,
100,
false,
excludedWorkers,
&shuffleClient);

std::vector<uint8_t> buffer(payload.size());
int bytesRead = stream.read(buffer.data(), 0, payload.size());
EXPECT_EQ(bytesRead, payload.size());
EXPECT_EQ(std::string(buffer.begin(), buffer.end()), payload);

auto& hosts = factory->hosts();
ASSERT_GE(hosts.size(), 2u);
EXPECT_EQ(hosts[0], "primary-host");
EXPECT_EQ(hosts[1], "replica-host");
}

// Verifies that getNextChunk() throws after exhausting all chunk-fetch retries.
TEST(CelebornInputStreamRetryTest, fetchChunkRetryExhaustsAllRetries) {
auto mockClient = std::make_shared<SequencedMockTransportClient>();
for (int i = 0; i < 4; i++) {
mockClient->addFetchFailure("chunk fetch failed");
}

auto factory = std::make_shared<SequencedMockClientFactory>(mockClient);
auto conf = makeTestConf(true);
auto excludedWorkers =
std::make_shared<CelebornInputStream::FetchExcludedWorkers>();
StubShuffleClient shuffleClient(conf, excludedWorkers);

auto location = makeLocationWithPeer();
std::vector<std::shared_ptr<const PartitionLocation>> locations;
locations.push_back(std::move(location));
std::vector<int> attempts = {0};

EXPECT_THROW(
CelebornInputStream(
"test-shuffle-key",
conf,
factory,
std::move(locations),
attempts,
0,
0,
100,
false,
excludedWorkers,
&shuffleClient),
std::exception);

auto& hosts = factory->hosts();
EXPECT_EQ(hosts.size(), 4u);
EXPECT_EQ(hosts[0], "primary-host");
for (size_t i = 1; i < hosts.size(); i++) {
EXPECT_EQ(hosts[i], "replica-host");
}
}

// Verifies that without a peer, getNextChunk() retries the same location
// and succeeds on the second attempt.
TEST(CelebornInputStreamRetryTest, fetchChunkRetryNoPeerRetriesSameLocation) {
auto mockClient = std::make_shared<SequencedMockTransportClient>();
mockClient->addFetchFailure("chunk fetch failed");
const std::string payload = "world";
mockClient->addFetchSuccess(makeChunkBuffer(0, 0, 0, payload));

auto factory = std::make_shared<SequencedMockClientFactory>(mockClient);
auto conf = makeTestConf(false);
auto excludedWorkers =
std::make_shared<CelebornInputStream::FetchExcludedWorkers>();
StubShuffleClient shuffleClient(conf, excludedWorkers);

auto location = makeLocationWithoutPeer();
std::vector<std::shared_ptr<const PartitionLocation>> locations;
locations.push_back(std::move(location));
std::vector<int> attempts = {0};

CelebornInputStream stream(
"test-shuffle-key",
conf,
factory,
std::move(locations),
attempts,
0,
0,
100,
false,
excludedWorkers,
&shuffleClient);

std::vector<uint8_t> buffer(payload.size());
int bytesRead = stream.read(buffer.data(), 0, payload.size());
EXPECT_EQ(bytesRead, payload.size());
EXPECT_EQ(std::string(buffer.begin(), buffer.end()), payload);

for (const auto& host : factory->hosts()) {
EXPECT_EQ(host, "solo-host");
}
EXPECT_EQ(factory->hosts().size(), 2u);
}
Loading
Loading