Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions src/network/host_pairing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ namespace {
constexpr std::size_t CLIENT_CHALLENGE_BYTE_COUNT = 16;
constexpr std::size_t CLIENT_SECRET_BYTE_COUNT = 16;
constexpr int SOCKET_TIMEOUT_MILLISECONDS = 5000;
constexpr int PIN_ENTRY_SOCKET_TIMEOUT_MILLISECONDS = 90000;
constexpr uint16_t DEFAULT_SERVERINFO_HTTP_PORT = 47989;
constexpr uint16_t FALLBACK_SERVERINFO_HTTP_PORT = 47984;
constexpr uint16_t DEFAULT_SERVERINFO_HTTPS_PORT = 47990;
Expand Down Expand Up @@ -217,7 +218,7 @@ namespace {

bool is_timeout_error(int errorCode) {
#if defined(NXDK) || !defined(_WIN32)
return errorCode == ETIMEDOUT;
return errorCode == ETIMEDOUT || errorCode == EWOULDBLOCK || errorCode == EAGAIN;
#else
return errorCode == WSAETIMEDOUT;
#endif
Expand Down Expand Up @@ -249,18 +250,18 @@ namespace {
return true;
}

void set_socket_timeouts(SOCKET socketHandle) {
void set_socket_timeouts(SOCKET socketHandle, int timeoutMilliseconds) {
#if defined(NXDK) || !defined(_WIN32)
timeval timeout {
SOCKET_TIMEOUT_MILLISECONDS / 1000,
(SOCKET_TIMEOUT_MILLISECONDS % 1000) * 1000,
timeoutMilliseconds / 1000,
(timeoutMilliseconds % 1000) * 1000,
};
setsockopt(socketHandle, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<const char *>(&timeout), sizeof(timeout));
setsockopt(socketHandle, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<const char *>(&timeout), sizeof(timeout));
#else
const DWORD timeoutMilliseconds = SOCKET_TIMEOUT_MILLISECONDS;
setsockopt(socketHandle, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<const char *>(&timeoutMilliseconds), sizeof(timeoutMilliseconds));
setsockopt(socketHandle, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<const char *>(&timeoutMilliseconds), sizeof(timeoutMilliseconds));
const DWORD platformTimeoutMilliseconds = static_cast<DWORD>(timeoutMilliseconds);
setsockopt(socketHandle, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<const char *>(&platformTimeoutMilliseconds), sizeof(platformTimeoutMilliseconds));
setsockopt(socketHandle, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast<const char *>(&platformTimeoutMilliseconds), sizeof(platformTimeoutMilliseconds));
#endif
}

Expand Down Expand Up @@ -351,7 +352,8 @@ namespace {
std::string_view expectedTlsCertificatePem,
HttpResponse *response,
std::string *errorMessage,
const std::atomic<bool> *cancelRequested = nullptr
const std::atomic<bool> *cancelRequested = nullptr,
int socketIoTimeoutMilliseconds = SOCKET_TIMEOUT_MILLISECONDS
);

std::string summarize_http_payload_preview(std::string_view text) {
Expand Down Expand Up @@ -1496,18 +1498,25 @@ namespace {
return true;
}

bool finalize_connected_socket(SOCKET socketHandle, std::string *errorMessage) {
bool finalize_connected_socket(SOCKET socketHandle, int socketIoTimeoutMilliseconds, std::string *errorMessage) {
trace_pairing_phase("restoring blocking mode after connect");
if (!set_socket_non_blocking(socketHandle, false, errorMessage)) {
return false;
}

set_socket_timeouts(socketHandle);
set_socket_timeouts(socketHandle, socketIoTimeoutMilliseconds);
trace_pairing_phase("socket connected");
return true;
}

bool connect_socket(const std::string &address, uint16_t port, SocketGuard *socketGuard, std::string *errorMessage, const std::atomic<bool> *cancelRequested = nullptr) {
bool connect_socket(
const std::string &address,
uint16_t port,
SocketGuard *socketGuard,
int socketIoTimeoutMilliseconds,
std::string *errorMessage,
const std::atomic<bool> *cancelRequested = nullptr
) {
if (socketGuard == nullptr) {
return append_error(errorMessage, "Internal pairing error while preparing the host connection");
}
Expand Down Expand Up @@ -1547,7 +1556,7 @@ namespace {
}
}

return finalize_connected_socket(socketGuard->handle, errorMessage);
return finalize_connected_socket(socketGuard->handle, socketIoTimeoutMilliseconds, errorMessage);
}

bool recv_all_plain(SOCKET socketHandle, std::string *response, std::string *errorMessage, const std::atomic<bool> *cancelRequested = nullptr) {
Expand Down Expand Up @@ -1855,7 +1864,8 @@ namespace {
std::string_view expectedTlsCertificatePem,
HttpResponse *response,
std::string *errorMessage,
const std::atomic<bool> *cancelRequested
const std::atomic<bool> *cancelRequested,
int socketIoTimeoutMilliseconds
) {
if (pairing_cancel_requested(cancelRequested)) {
return append_cancelled_pairing_error(errorMessage);
Expand All @@ -1869,6 +1879,7 @@ namespace {
useTls,
tlsClientIdentity,
std::string(expectedTlsCertificatePem),
socketIoTimeoutMilliseconds,
};
network::testing::HostPairingHttpTestResponse testResponse {};
if (std::string testError; !testHandler(testRequest, &testResponse, &testError, cancelRequested)) {
Expand All @@ -1891,7 +1902,7 @@ namespace {

SocketGuard socketGuard;
trace_pairing_phase("http_get: connect_socket");
if (!connect_socket(address, port, &socketGuard, errorMessage, cancelRequested)) {
if (!connect_socket(address, port, &socketGuard, socketIoTimeoutMilliseconds, errorMessage, cancelRequested)) {
return false;
}

Expand Down Expand Up @@ -2149,12 +2160,18 @@ namespace {
return true;
}

bool execute_pairing_phase_request(PairingSessionState *session, const std::string &path, bool useTls, std::string_view expectedTlsCertificatePem = {}) {
bool execute_pairing_phase_request(
PairingSessionState *session,
const std::string &path,
bool useTls,
std::string_view expectedTlsCertificatePem = {},
int socketIoTimeoutMilliseconds = SOCKET_TIMEOUT_MILLISECONDS
) {
if (session == nullptr) {
return false;
}

if (!http_get(session->request.address, useTls ? session->serverInfo.httpsPort : session->serverInfo.httpPort, path, useTls, useTls ? &session->request.identity : nullptr, expectedTlsCertificatePem, &session->response, &session->errorMessage, session->cancelRequested)) {
if (!http_get(session->request.address, useTls ? session->serverInfo.httpsPort : session->serverInfo.httpPort, path, useTls, useTls ? &session->request.identity : nullptr, expectedTlsCertificatePem, &session->response, &session->errorMessage, session->cancelRequested, socketIoTimeoutMilliseconds)) {
return false;
}
return parse_pairing_tag(session->response, "paired", &session->phaseValue, &session->errorMessage);
Expand Down Expand Up @@ -2187,7 +2204,7 @@ namespace {

const std::string phasePath = "/pair?uniqueid=" + session->uniqueId + "&uuid=" + session->requestUuid + "&devicename=" + session->deviceName + "&updateState=1&phrase=getservercert&salt=" + session->saltHex + "&clientcert=" + certHex;
trace_pairing_phase("phase 1 getservercert request");
if (!execute_pairing_phase_request(session, phasePath, false)) {
if (!execute_pairing_phase_request(session, phasePath, false, {}, PIN_ENTRY_SOCKET_TIMEOUT_MILLISECONDS)) {
return false;
}
if (session->phaseValue != "1") {
Expand Down
1 change: 1 addition & 0 deletions src/network/host_pairing.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ namespace network {
bool useTls = false; ///< True when the request would normally use TLS.
const PairingIdentity *tlsClientIdentity = nullptr; ///< Optional client identity attached to TLS requests.
std::string expectedTlsCertificatePem; ///< Optional pinned host certificate expected by the request.
int socketIoTimeoutMilliseconds = 0; ///< Socket read/write timeout that the real transport would apply.
};

/**
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/network/host_pairing_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace {
using network::testing::HostPairingHttpTestRequest;
using network::testing::HostPairingHttpTestResponse;

constexpr int kDefaultPairingSocketTimeoutMilliseconds = 5000;
constexpr int kPinEntrySocketTimeoutMilliseconds = 90000;
constexpr std::string_view kUnpairedClientErrorMessage = "The host reports that this client is no longer paired. Pair the host again.";

class ScopedHostPairingHttpTestHandler {
Expand Down Expand Up @@ -1309,6 +1311,39 @@ namespace {
EXPECT_EQ(result.message, "Pairing failed during phase 1 (getservercert): The host rejected the initial pairing request");
}

TEST(HostPairingTest, PairHostUsesExtendedTimeoutForPinEntryResponse) {
const network::PairingIdentity identity = network::create_pairing_identity();
ASSERT_TRUE(network::is_valid_pairing_identity(identity));

std::size_t callCount = 0U;
ScopedHostPairingHttpTestHandler guard([&callCount](const HostPairingHttpTestRequest &request, HostPairingHttpTestResponse *response, std::string *, const std::atomic<bool> *) {
if (callCount++ == 0U) {
EXPECT_EQ(request.socketIoTimeoutMilliseconds, kDefaultPairingSocketTimeoutMilliseconds);
response->statusCode = 200;
response->body = make_server_info_xml(false, 47989U, 47990U, "Pair Host", "pair-host");
return true;
}

EXPECT_NE(request.pathAndQuery.find("phrase=getservercert"), std::string::npos);
EXPECT_EQ(request.socketIoTimeoutMilliseconds, kPinEntrySocketTimeoutMilliseconds);
response->statusCode = 200;
response->body = make_pair_phase_response("0");
return true;
});

const network::HostPairingResult result = network::pair_host({
test_support::kTestIpv4Addresses[test_support::kIpLivingRoom],
47989U,
"1234",
"MoonlightXboxOG",
identity,
});

EXPECT_EQ(callCount, 2U);
EXPECT_FALSE(result.success);
EXPECT_EQ(result.message, "Pairing failed during phase 1 (getservercert): The host rejected the initial pairing request");
}

TEST(HostPairingTest, PairHostFailsWhenTheChallengeResponseIsTooShort) {
const network::PairingIdentity clientIdentity = network::create_pairing_identity();
const network::PairingIdentity serverIdentity = network::create_pairing_identity();
Expand Down
Loading