@@ -9,51 +9,66 @@ use std::{
99 sync:: atomic:: { AtomicBool , Ordering } ,
1010} ;
1111
12+ use brass_aphid_wire_decryption:: decryption:: stream_decrypter:: StreamDecrypter ;
1213use byteorder:: { BigEndian , ReadBytesExt } ;
1314
15+ use crate :: Mode ;
16+
1417pub type LocalDataBuffer = RefCell < VecDeque < u8 > > ;
1518
1619#[ derive( Debug , Default ) ]
1720pub struct TestPairIO {
1821 /// a data buffer that the server writes to and the client reads from
19- pub server_tx_stream : Rc < LocalDataBuffer > ,
22+ pub server_tx_stream : LocalDataBuffer ,
2023 /// a data buffer that the client writes to and the server reads from
21- pub client_tx_stream : Rc < LocalDataBuffer > ,
22-
23- pub recording : Rc < AtomicBool > ,
24- pub client_tx_transcript : Rc < RefCell < Vec < u8 > > > ,
25- pub server_tx_transcript : Rc < RefCell < Vec < u8 > > > ,
24+ pub client_tx_stream : LocalDataBuffer ,
25+
26+ /// indicates whether all client/server writes should be stored to the
27+ /// transcript fields
28+ pub recording : AtomicBool ,
29+ pub client_tx_transcript : RefCell < Vec < u8 > > ,
30+ pub server_tx_transcript : RefCell < Vec < u8 > > ,
31+ /// [`Self::enable_decryption`] will initialize the stream decrypter, which
32+ /// allows tests to make assertions on the decrypted TLS transcript.
33+ ///
34+ /// This is especially useful for TLS 1.3 where much of the handshake is encrypted.
35+ pub decrypter : RefCell < Option < StreamDecrypter > > ,
2636}
2737
2838impl TestPairIO {
29- pub fn client_view ( & self ) -> ViewIO {
39+ pub fn client_view ( self : & Rc < Self > ) -> ViewIO {
3040 ViewIO {
31- send_ctx : self . client_tx_stream . clone ( ) ,
32- recv_ctx : self . server_tx_stream . clone ( ) ,
33- recording : self . recording . clone ( ) ,
34- send_transcript : self . client_tx_transcript . clone ( ) ,
41+ identity : Mode :: Client ,
42+ io : Rc :: clone ( self ) ,
3543 }
3644 }
3745
38- pub fn server_view ( & self ) -> ViewIO {
46+ pub fn server_view ( self : & Rc < Self > ) -> ViewIO {
3947 ViewIO {
40- send_ctx : self . server_tx_stream . clone ( ) ,
41- recv_ctx : self . client_tx_stream . clone ( ) ,
42- recording : self . recording . clone ( ) ,
43- send_transcript : self . server_tx_transcript . clone ( ) ,
48+ identity : Mode :: Server ,
49+ io : Rc :: clone ( self ) ,
4450 }
4551 }
4652
47- pub fn enable_recording ( & mut self ) {
53+ pub fn enable_recording ( & self ) {
4854 self . recording . store ( true , Ordering :: Relaxed ) ;
4955 }
5056
57+ /// Note: this is only available for TLS 1.3
58+ pub fn enable_decryption (
59+ & self ,
60+ keys : brass_aphid_wire_decryption:: decryption:: key_manager:: KeyManager ,
61+ ) {
62+ let stream_decrypter = StreamDecrypter :: new ( keys) ;
63+ * self . decrypter . borrow_mut ( ) = Some ( stream_decrypter) ;
64+ }
65+
5166 pub fn client_record_sizes ( & self ) -> Vec < u16 > {
52- Self :: record_sizes ( self . client_tx_transcript . as_ref ( ) . borrow ( ) . as_slice ( ) ) . unwrap ( )
67+ Self :: record_sizes ( self . client_tx_transcript . borrow ( ) . as_slice ( ) ) . unwrap ( )
5368 }
5469
5570 pub fn server_record_sizes ( & self ) -> Vec < u16 > {
56- Self :: record_sizes ( self . server_tx_transcript . as_ref ( ) . borrow ( ) . as_slice ( ) ) . unwrap ( )
71+ Self :: record_sizes ( self . server_tx_transcript . borrow ( ) . as_slice ( ) ) . unwrap ( )
5772 }
5873
5974 /// Return a list of the record sizes contained in `buffer`.
@@ -80,20 +95,37 @@ impl TestPairIO {
8095///
8196/// This view is client/server specific, and notably implements the read and write
8297/// traits.
83- ///
84- // This struct is used by Openssl and Rustls which both rely on a "stream" abstraction
85- // which implements read and write. This is not used by s2n-tls, which relies on
86- // lower level callbacks.
8798pub struct ViewIO {
88- pub send_ctx : Rc < LocalDataBuffer > ,
89- pub recv_ctx : Rc < LocalDataBuffer > ,
90- pub recording : Rc < AtomicBool > ,
91- pub send_transcript : Rc < RefCell < Vec < u8 > > > ,
99+ pub identity : Mode ,
100+ pub io : Rc < TestPairIO > ,
101+ }
102+
103+ impl ViewIO {
104+ fn recv_ctx ( & self ) -> & LocalDataBuffer {
105+ match self . identity {
106+ Mode :: Client => & self . io . server_tx_stream ,
107+ Mode :: Server => & self . io . client_tx_stream ,
108+ }
109+ }
110+
111+ fn send_ctx ( & self ) -> & LocalDataBuffer {
112+ match self . identity {
113+ Mode :: Client => & self . io . client_tx_stream ,
114+ Mode :: Server => & self . io . server_tx_stream ,
115+ }
116+ }
117+
118+ fn send_transcript ( & self ) -> & RefCell < Vec < u8 > > {
119+ match self . identity {
120+ Mode :: Client => & self . io . client_tx_transcript ,
121+ Mode :: Server => & self . io . server_tx_transcript ,
122+ }
123+ }
92124}
93125
94126impl std:: io:: Read for ViewIO {
95127 fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
96- let res = self . recv_ctx . borrow_mut ( ) . read ( buf) ;
128+ let res = self . recv_ctx ( ) . borrow_mut ( ) . read ( buf) ;
97129 if let Ok ( 0 ) = res {
98130 // We are "faking" a TcpStream, where a read of length 0 indicates
99131 // EoF. That is incorrect for this scenario. Instead we return WouldBlock
@@ -107,15 +139,29 @@ impl std::io::Read for ViewIO {
107139
108140impl std:: io:: Write for ViewIO {
109141 fn write ( & mut self , buf : & [ u8 ] ) -> std:: io:: Result < usize > {
110- let write_result = self . send_ctx . borrow_mut ( ) . write ( buf) ;
111-
112- if self . recording . load ( Ordering :: Relaxed ) {
113- if let Ok ( written) = write_result {
114- self . send_transcript
142+ let write_result = self . send_ctx ( ) . borrow_mut ( ) . write ( buf) ;
143+
144+ // if we successfully wrote data, we need to record it in the various test
145+ // utilities.
146+ if let Ok ( written) = write_result {
147+ // recorder
148+ if self . io . recording . load ( Ordering :: Relaxed ) {
149+ self . send_transcript ( )
115150 . borrow_mut ( )
116151 . write_all ( & buf[ 0 ..written] )
117152 . unwrap ( ) ;
118153 }
154+
155+ // decrypter
156+ let mut decrypter = self . io . decrypter . borrow_mut ( ) ;
157+ if let Some ( decrypter) = decrypter. as_mut ( ) {
158+ let wire_mode = match self . identity {
159+ Mode :: Client => brass_aphid_wire_decryption:: decryption:: Mode :: Client ,
160+ Mode :: Server => brass_aphid_wire_decryption:: decryption:: Mode :: Server ,
161+ } ;
162+ decrypter. record_tx ( & buf[ 0 ..written] , wire_mode) ;
163+ decrypter. decrypt_records ( wire_mode) . unwrap ( ) ;
164+ }
119165 }
120166
121167 write_result
0 commit comments