Skip to content

Commit a9a07a2

Browse files
authored
bindings: match tcp EOF behavior (#4323)
1 parent d4c3ebb commit a9a07a2

File tree

3 files changed

+102
-6
lines changed

3 files changed

+102
-6
lines changed

bindings/rust/s2n-tls-tokio/tests/common/mod.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,25 @@ where
9595
);
9696
Ok((client?, server?))
9797
}
98+
99+
pub async fn get_tls_streams<A: Builder, B: Builder>(
100+
server_builder: A,
101+
client_builder: B,
102+
) -> Result<
103+
(
104+
TlsStream<TcpStream, A::Output>,
105+
TlsStream<TcpStream, B::Output>,
106+
),
107+
Box<dyn std::error::Error>,
108+
>
109+
where
110+
<A as Builder>::Output: Unpin,
111+
<B as Builder>::Output: Unpin,
112+
{
113+
let (server_stream, client_stream) = get_streams().await?;
114+
let connector = TlsConnector::new(client_builder);
115+
let acceptor = TlsAcceptor::new(server_builder);
116+
let (client_tls, server_tls) =
117+
run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
118+
Ok((server_tls, client_tls))
119+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
5+
6+
pub mod common;
7+
8+
async fn assert_read_from_closed<S>(mut reader: S, writer: S)
9+
where
10+
S: AsyncRead + AsyncWrite + Unpin,
11+
{
12+
drop(writer);
13+
let result = reader.read_u8().await;
14+
assert!(result.is_err());
15+
let error = result.unwrap_err();
16+
assert!(error.kind() == std::io::ErrorKind::UnexpectedEof);
17+
}
18+
19+
#[tokio::test]
20+
async fn match_tcp_read_from_closed() -> Result<(), Box<dyn std::error::Error>> {
21+
let (tcp_server, tcp_client) = common::get_streams().await?;
22+
assert_read_from_closed(tcp_server, tcp_client).await;
23+
24+
let (tls13_server, tls13_client) = common::get_tls_streams(
25+
common::server_config()?.build()?,
26+
common::client_config()?.build()?,
27+
)
28+
.await?;
29+
assert_read_from_closed(tls13_server, tls13_client).await;
30+
31+
let (tls12_server, tls12_client) = common::get_tls_streams(
32+
common::server_config_tls12()?.build()?,
33+
common::client_config_tls12()?.build()?,
34+
)
35+
.await?;
36+
assert_read_from_closed(tls12_server, tls12_client).await;
37+
Result::Ok(())
38+
}
39+
40+
async fn assert_write_to_closed<S>(reader: S, mut writer: S)
41+
where
42+
S: AsyncRead + AsyncWrite + Unpin,
43+
{
44+
drop(reader);
45+
let result = writer.write_u8(0).await;
46+
assert!(result.is_ok());
47+
}
48+
49+
#[tokio::test]
50+
async fn match_tcp_write_to_closed() -> Result<(), Box<dyn std::error::Error>> {
51+
let (tcp_server, tcp_client) = common::get_streams().await?;
52+
assert_write_to_closed(tcp_server, tcp_client).await;
53+
54+
let (tls13_server, tls13_client) = common::get_tls_streams(
55+
common::server_config()?.build()?,
56+
common::client_config()?.build()?,
57+
)
58+
.await?;
59+
assert_write_to_closed(tls13_server, tls13_client).await;
60+
61+
let (tls12_server, tls12_client) = common::get_tls_streams(
62+
common::server_config_tls12()?.build()?,
63+
common::client_config_tls12()?.build()?,
64+
)
65+
.await?;
66+
assert_write_to_closed(tls12_server, tls12_client).await;
67+
Result::Ok(())
68+
}

bindings/rust/s2n-tls/src/error.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,19 @@ impl TryFrom<std::io::Error> for Error {
306306

307307
impl From<Error> for std::io::Error {
308308
fn from(input: Error) -> Self {
309-
if let Context::Code(_, errno) = input.0 {
310-
if ErrorType::IOError == input.kind() {
311-
let bare = std::io::Error::from_raw_os_error(errno.0);
312-
return std::io::Error::new(bare.kind(), input);
309+
let kind = match input.kind() {
310+
ErrorType::IOError => {
311+
if let Context::Code(_, errno) = input.0 {
312+
let bare = std::io::Error::from_raw_os_error(errno.0);
313+
bare.kind()
314+
} else {
315+
std::io::ErrorKind::Other
316+
}
313317
}
314-
}
315-
std::io::Error::new(std::io::ErrorKind::Other, input)
318+
ErrorType::ConnectionClosed => std::io::ErrorKind::UnexpectedEof,
319+
_ => std::io::ErrorKind::Other,
320+
};
321+
std::io::Error::new(kind, input)
316322
}
317323
}
318324

0 commit comments

Comments
 (0)