Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,28 @@ public override int Encrypt(byte[] input, int offset, int length, byte[] output,
/// <returns>The decrypted plaintext.</returns>
public override byte[] Decrypt(byte[] input, int offset, int length)
{
byte[] output;
var output = new byte[length];

_cipher.Init(forEncryption: false, new ParametersWithIV(_keyParameter, _iv));

var keyStream = new byte[64];
_cipher.ProcessBytes(keyStream, 0, keyStream.Length, keyStream, 0);
_mac.Init(new KeyParameter(keyStream, 0, 32));

if (_aadLength > 0)
{
// If we are in 'AAD mode', then put these bytes through the AAD cipher.

_mac.BlockUpdate(input, offset, length);

Debug.Assert(_aadCipher != null);

_aadCipher.Init(forEncryption: false, new ParametersWithIV(_aadKeyParameter, _iv));

output = new byte[length];
_aadCipher.ProcessBytes(input, offset, length, output, 0);
}
else
{
output = new byte[length];

var bytesWritten = Decrypt(input, offset, length, output, 0);

Debug.Assert(bytesWritten == length);
Expand All @@ -169,7 +174,7 @@ public override byte[] Decrypt(byte[] input, int offset, int length)
/// <param name="input">
/// The input data with below format:
/// <code>
/// [----][----Cipher AAD----(offset)][----Cipher Text----(length)][----TAG----]
/// [----(offset)][----Cipher Text----(length)][----TAG----]
/// </code>
/// </param>
/// <param name="offset">The zero-based offset in <paramref name="input"/> at which to begin decrypting and authenticating.</param>
Expand All @@ -179,16 +184,8 @@ public override byte[] Decrypt(byte[] input, int offset, int length)
/// <returns>The number of plaintext bytes written to <paramref name="output"/>.</returns>
public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset)
{
Debug.Assert(offset >= _aadLength, "The offset must be greater than or equals to aad length");

_cipher.Init(forEncryption: false, new ParametersWithIV(_keyParameter, _iv));

var keyStream = new byte[64];
_cipher.ProcessBytes(keyStream, 0, keyStream.Length, keyStream, 0);
_mac.Init(new KeyParameter(keyStream, 0, 32));

var tag = new byte[TagSize];
_mac.BlockUpdate(input, offset - _aadLength, length + _aadLength);
_mac.BlockUpdate(input, offset, length);
_ = _mac.DoFinal(tag, 0);
if (!Arrays.FixedTimeEquals(TagSize, tag, 0, input, offset + length))
{
Expand Down
135 changes: 73 additions & 62 deletions src/Renci.SshNet/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,23 @@ public sealed class Session : ISession
/// </summary>
private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1);

private readonly byte[] _inboundPacketSequenceBytes = new byte[4];

/// <summary>
/// Gets or sets the incoming packet number.
/// </summary>
private uint InboundPacketSequence
{
get
{
return BinaryPrimitives.ReadUInt32BigEndian(_inboundPacketSequenceBytes);
}
set
{
BinaryPrimitives.WriteUInt32BigEndian(_inboundPacketSequenceBytes, value);
}
}

/// <summary>
/// Holds metadata about session messages.
/// </summary>
Expand All @@ -120,11 +137,6 @@ public sealed class Session : ISession
/// </summary>
private volatile uint _outboundPacketSequence;

/// <summary>
/// Specifies incoming packet number.
/// </summary>
private uint _inboundPacketSequence;

/// <summary>
/// WaitHandle to signal that last service request was accepted.
/// </summary>
Expand Down Expand Up @@ -200,7 +212,6 @@ public sealed class Session : ISession
private Socket _socket;

private ArrayBuffer _receiveBuffer = new(4 * 1024);
private byte[] _plaintextReceiveBuffer = new byte[4 * 1024];

/// <summary>
/// Gets the session semaphore that controls session channels.
Expand Down Expand Up @@ -1213,9 +1224,6 @@ private bool TrySendMessage(Message message)
/// </remarks>
private Message ReceiveMessage(Socket socket)
{
// the length of the packet sequence field in bytes
const int inboundPacketSequenceLength = 4;

// The length of the "packet length" field in bytes
const int packetLengthFieldLength = 4;

Expand Down Expand Up @@ -1272,31 +1280,28 @@ private Message ReceiveMessage(Socket socket)
}
}

var firstBlock = new ArraySegment<byte>(
_receiveBuffer.DangerousGetUnderlyingBuffer(),
_receiveBuffer.ActiveStartOffset,
blockSize);

var plainFirstBlock = firstBlock;

// For ETM or AES-GCM, firstBlock holds the packet length which is
// not encrypted. Otherwise, we decrypt the first "blockSize" bytes.
// (For chacha20-poly1305, this means passing the encrypted packet
// length as AAD).
// For ETM or AES-GCM, the first "blockSize" bytes hold the packet length
// which is not encrypted. Otherwise, we decrypt them.
// (For chacha20-poly1305, this means passing the encrypted packet length
// to its AAD cipher instance - it is the awkward difference between the
// 3-arg and 5-arg Decrypt, and explains why we don't just decrypt these
// bytes in-place).
if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher)
{
_serverCipher.SetSequenceNumber(_inboundPacketSequence);
_serverCipher.SetSequenceNumber(InboundPacketSequence);

if (_serverMac == null || !_serverEtm)
{
plainFirstBlock = new ArraySegment<byte>(_serverCipher.Decrypt(
firstBlock.Array,
firstBlock.Offset,
firstBlock.Count));
var plainFirstBlock = _serverCipher.Decrypt(
_receiveBuffer.DangerousGetUnderlyingBuffer(),
_receiveBuffer.ActiveStartOffset,
blockSize);

plainFirstBlock.CopyTo(_receiveBuffer.ActiveSpan);
}
}

var packetLength = BinaryPrimitives.ReadInt32BigEndian(plainFirstBlock);
var packetLength = BinaryPrimitives.ReadInt32BigEndian(_receiveBuffer.ActiveReadOnlySpan);

// Test packet minimum and maximum boundaries
if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
Expand Down Expand Up @@ -1330,26 +1335,13 @@ private Message ReceiveMessage(Socket socket)
}
}

// Construct buffer for holding the payload and the inbound packet sequence as we need both in order
// to generate the hash.
var plaintextLength = 4 + totalPacketLength - serverMacLength;

if (_plaintextReceiveBuffer.Length < plaintextLength)
{
Array.Resize(ref _plaintextReceiveBuffer, Math.Max(plaintextLength, 2 * _plaintextReceiveBuffer.Length));
}

BinaryPrimitives.WriteUInt32BigEndian(_plaintextReceiveBuffer, _inboundPacketSequence);

plainFirstBlock.AsSpan().CopyTo(_plaintextReceiveBuffer.AsSpan(4));

if (_serverMac != null && _serverEtm)
{
// ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet)

// sequence_number
_ = _serverMac.TransformBlock(
inputBuffer: _plaintextReceiveBuffer,
inputBuffer: _inboundPacketSequenceBytes,
inputOffset: 0,
inputCount: 4,
outputBuffer: null,
Expand Down Expand Up @@ -1377,58 +1369,77 @@ private Message ReceiveMessage(Socket socket)
{
Debug.Assert(numberOfBytesToDecrypt % blockSize == 0);

var decryptBuffer = _receiveBuffer.DangerousGetUnderlyingBuffer();
var decryptOffset = _receiveBuffer.ActiveStartOffset + blockSize;

var numberOfBytesDecrypted = _serverCipher.Decrypt(
input: _receiveBuffer.DangerousGetUnderlyingBuffer(),
offset: _receiveBuffer.ActiveStartOffset + blockSize,
input: decryptBuffer,
offset: decryptOffset,
length: numberOfBytesToDecrypt,
output: _plaintextReceiveBuffer,
outputOffset: 4 + blockSize);
output: decryptBuffer,
outputOffset: decryptOffset);

Debug.Assert(numberOfBytesDecrypted == numberOfBytesToDecrypt);
}
else
{
_receiveBuffer.ActiveReadOnlySpan
.Slice(blockSize, numberOfBytesToDecrypt)
.CopyTo(_plaintextReceiveBuffer.AsSpan(4 + blockSize));
}

if (_serverMac != null && !_serverEtm)
{
// non-ETM mac = MAC(key, sequence_number || unencrypted_packet)

var clientHash = _serverMac.ComputeHash(_plaintextReceiveBuffer, 0, plaintextLength);
// sequence_number
_ = _serverMac.TransformBlock(
inputBuffer: _inboundPacketSequenceBytes,
inputOffset: 0,
inputCount: 4,
outputBuffer: null,
outputOffset: 0);

// unencrypted_packet
_ = _serverMac.TransformBlock(
inputBuffer: _receiveBuffer.DangerousGetUnderlyingBuffer(),
inputOffset: _receiveBuffer.ActiveStartOffset,
inputCount: totalPacketLength - serverMacLength,
outputBuffer: null,
outputOffset: 0);

_ = _serverMac.TransformFinalBlock(Array.Empty<byte>(), 0, 0);

if (!CryptoAbstraction.FixedTimeEquals(clientHash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength)))
if (!CryptoAbstraction.FixedTimeEquals(_serverMac.Hash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength)))
{
throw new SshConnectionException("MAC error", DisconnectReason.MacError);
}
}

_receiveBuffer.Discard(totalPacketLength);

var paddingLength = _plaintextReceiveBuffer[inboundPacketSequenceLength + packetLengthFieldLength];
var paddingLength = _receiveBuffer.ActiveReadOnlySpan[packetLengthFieldLength];

ArraySegment<byte> payload = new(
_plaintextReceiveBuffer,
offset: inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength,
_receiveBuffer.DangerousGetUnderlyingBuffer(),
offset: _receiveBuffer.ActiveStartOffset + packetLengthFieldLength + paddingLengthFieldLength,
count: packetLength - paddingLength - paddingLengthFieldLength);

if (_serverDecompression != null)
{
payload = new(_serverDecompression.Decompress(payload.Array, payload.Offset, payload.Count));
}

_inboundPacketSequence++;
var newInboundPacketSequence = ++InboundPacketSequence;

// The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5
// It ensures the integrity of key exchange process.
if (_inboundPacketSequence == uint.MaxValue && _isInitialKex)
if (newInboundPacketSequence == uint.MaxValue && _isInitialKex)
{
throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed);
}

return LoadMessage(payload.Array, payload.Offset, payload.Count);
var message = LoadMessage(payload.Array, payload.Offset, payload.Count);

// The deserialised message may still reference data in the buffer, so calling Discard
// here might seem misguided. It is OK because Discard does not mutate the buffer
// and it will not be touched again until the next call to ReceiveMessage, which will
// only occur after the message has been fully processed.
_receiveBuffer.Discard(totalPacketLength);

return message;
}

private void TrySendDisconnect(DisconnectReason reasonCode, string message)
Expand Down Expand Up @@ -1545,7 +1556,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)

_logger.LogDebug("[{SessionId}] Enabling strict key exchange extension.", SessionIdHex);

if (_inboundPacketSequence != 1)
if (InboundPacketSequence != 1)
{
throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed);
}
Expand Down Expand Up @@ -1646,7 +1657,7 @@ internal void OnNewKeysReceived(NewKeysMessage message)

if (_isStrictKex)
{
_inboundPacketSequence = 0;
InboundPacketSequence = 0;
}

NewKeysReceived?.Invoke(this, new MessageEventArgs<NewKeysMessage>(message));
Expand Down