2025-05-07 11:20:40 +08:00

439 lines
15 KiB
C#

#if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
#pragma warning disable
using System;
using System.Collections;
using System.IO;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities;
namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Tls
{
internal class DtlsReliableHandshake
{
private const int MaxReceiveAhead = 16;
private const int MessageHeaderLength = 12;
private readonly DtlsRecordLayer mRecordLayer;
private TlsHandshakeHash mHandshakeHash;
private IDictionary mCurrentInboundFlight = BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateHashtable();
private IDictionary mPreviousInboundFlight = null;
private IList mOutboundFlight = BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateArrayList();
private bool mSending = true;
private int mMessageSeq = 0, mNextReceiveSeq = 0;
internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport)
{
this.mRecordLayer = transport;
this.mHandshakeHash = new DeferredHash();
this.mHandshakeHash.Init(context);
}
internal void NotifyHelloComplete()
{
this.mHandshakeHash = mHandshakeHash.NotifyPrfDetermined();
}
internal TlsHandshakeHash HandshakeHash
{
get { return mHandshakeHash; }
}
internal TlsHandshakeHash PrepareToFinish()
{
TlsHandshakeHash result = mHandshakeHash;
this.mHandshakeHash = mHandshakeHash.StopTracking();
return result;
}
internal void SendMessage(byte msg_type, byte[] body)
{
TlsUtilities.CheckUint24(body.Length);
if (!mSending)
{
CheckInboundFlight();
mSending = true;
mOutboundFlight.Clear();
}
Message message = new Message(mMessageSeq++, msg_type, body);
mOutboundFlight.Add(message);
WriteMessage(message);
UpdateHandshakeMessagesDigest(message);
}
internal byte[] ReceiveMessageBody(byte msg_type)
{
Message message = ReceiveMessage();
if (message.Type != msg_type)
throw new TlsFatalAlert(AlertDescription.unexpected_message);
return message.Body;
}
internal Message ReceiveMessage()
{
if (mSending)
{
mSending = false;
PrepareInboundFlight(BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateHashtable());
}
byte[] buf = null;
// TODO Check the conditions under which we should reset this
int readTimeoutMillis = 1000;
for (;;)
{
try
{
for (;;)
{
Message pending = GetPendingMessage();
if (pending != null)
return pending;
int receiveLimit = mRecordLayer.GetReceiveLimit();
if (buf == null || buf.Length < receiveLimit)
{
buf = new byte[receiveLimit];
}
int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
if (received < 0)
break;
bool resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
if (resentOutbound)
{
readTimeoutMillis = BackOff(readTimeoutMillis);
}
}
}
catch (IOException e)
{
// NOTE: Assume this is a timeout for the moment
}
ResendOutboundFlight();
readTimeoutMillis = BackOff(readTimeoutMillis);
}
}
internal void Finish()
{
DtlsHandshakeRetransmit retransmit = null;
if (!mSending)
{
CheckInboundFlight();
}
else
{
PrepareInboundFlight(null);
if (mPreviousInboundFlight != null)
{
/*
* RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
* when in the FINISHED state, the node that transmits the last flight (the server in an
* ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
* of the peer's last flight with a retransmit of the last flight.
*/
retransmit = new Retransmit(this);
}
}
mRecordLayer.HandshakeSuccessful(retransmit);
}
internal void ResetHandshakeMessagesDigest()
{
mHandshakeHash.Reset();
}
private int BackOff(int timeoutMillis)
{
/*
* TODO[DTLS] implementations SHOULD back off handshake packet size during the
* retransmit backoff.
*/
return System.Math.Min(timeoutMillis * 2, 60000);
}
/**
* Check that there are no "extra" messages left in the current inbound flight
*/
private void CheckInboundFlight()
{
foreach (int key in mCurrentInboundFlight.Keys)
{
if (key >= mNextReceiveSeq)
{
// TODO Should this be considered an error?
}
}
}
private Message GetPendingMessage()
{
DtlsReassembler next = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq];
if (next != null)
{
byte[] body = next.GetBodyIfComplete();
if (body != null)
{
mPreviousInboundFlight = null;
return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, next.MsgType, body));
}
}
return null;
}
private void PrepareInboundFlight(IDictionary nextFlight)
{
ResetAll(mCurrentInboundFlight);
mPreviousInboundFlight = mCurrentInboundFlight;
mCurrentInboundFlight = nextFlight;
}
private bool ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
{
bool checkPreviousFlight = false;
while (len >= MessageHeaderLength)
{
int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
int message_length = fragment_length + MessageHeaderLength;
if (len < message_length)
{
// NOTE: Truncated message - ignore it
break;
}
int length = TlsUtilities.ReadUint24(buf, off + 1);
int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
if (fragment_offset + fragment_length > length)
{
// NOTE: Malformed fragment - ignore it and the rest of the record
break;
}
/*
* NOTE: This very simple epoch check will only work until we want to support
* renegotiation (and we're not likely to do that anyway).
*/
byte msg_type = TlsUtilities.ReadUint8(buf, off + 0);
int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
if (epoch != expectedEpoch)
{
break;
}
int message_seq = TlsUtilities.ReadUint16(buf, off + 4);
if (message_seq >= (mNextReceiveSeq + windowSize))
{
// NOTE: Too far ahead - ignore
}
else if (message_seq >= mNextReceiveSeq)
{
DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[message_seq];
if (reassembler == null)
{
reassembler = new DtlsReassembler(msg_type, length);
mCurrentInboundFlight[message_seq] = reassembler;
}
reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
fragment_length);
}
else if (mPreviousInboundFlight != null)
{
/*
* NOTE: If we receive the previous flight of incoming messages in full again,
* retransmit our last flight
*/
DtlsReassembler reassembler = (DtlsReassembler)mPreviousInboundFlight[message_seq];
if (reassembler != null)
{
reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
fragment_length);
checkPreviousFlight = true;
}
}
off += message_length;
len -= message_length;
}
bool result = checkPreviousFlight && CheckAll(mPreviousInboundFlight);
if (result)
{
ResendOutboundFlight();
ResetAll(mPreviousInboundFlight);
}
return result;
}
private void ResendOutboundFlight()
{
mRecordLayer.ResetWriteEpoch();
for (int i = 0; i < mOutboundFlight.Count; ++i)
{
WriteMessage((Message)mOutboundFlight[i]);
}
}
private Message UpdateHandshakeMessagesDigest(Message message)
{
if (message.Type != HandshakeType.hello_request)
{
byte[] body = message.Body;
byte[] buf = new byte[MessageHeaderLength];
TlsUtilities.WriteUint8(message.Type, buf, 0);
TlsUtilities.WriteUint24(body.Length, buf, 1);
TlsUtilities.WriteUint16(message.Seq, buf, 4);
TlsUtilities.WriteUint24(0, buf, 6);
TlsUtilities.WriteUint24(body.Length, buf, 9);
mHandshakeHash.BlockUpdate(buf, 0, buf.Length);
mHandshakeHash.BlockUpdate(body, 0, body.Length);
}
return message;
}
private void WriteMessage(Message message)
{
int sendLimit = mRecordLayer.GetSendLimit();
int fragmentLimit = sendLimit - MessageHeaderLength;
// TODO Support a higher minimum fragment size?
if (fragmentLimit < 1)
{
// TODO Should we be throwing an exception here?
throw new TlsFatalAlert(AlertDescription.internal_error);
}
int length = message.Body.Length;
// NOTE: Must still send a fragment if body is empty
int fragment_offset = 0;
do
{
int fragment_length = System.Math.Min(length - fragment_offset, fragmentLimit);
WriteHandshakeFragment(message, fragment_offset, fragment_length);
fragment_offset += fragment_length;
}
while (fragment_offset < length);
}
private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
{
RecordLayerBuffer fragment = new RecordLayerBuffer(MessageHeaderLength + fragment_length);
TlsUtilities.WriteUint8(message.Type, fragment);
TlsUtilities.WriteUint24(message.Body.Length, fragment);
TlsUtilities.WriteUint16(message.Seq, fragment);
TlsUtilities.WriteUint24(fragment_offset, fragment);
TlsUtilities.WriteUint24(fragment_length, fragment);
fragment.Write(message.Body, fragment_offset, fragment_length);
fragment.SendToRecordLayer(mRecordLayer);
}
private static bool CheckAll(IDictionary inboundFlight)
{
foreach (DtlsReassembler r in inboundFlight.Values)
{
if (r.GetBodyIfComplete() == null)
{
return false;
}
}
return true;
}
private static void ResetAll(IDictionary inboundFlight)
{
foreach (DtlsReassembler r in inboundFlight.Values)
{
r.Reset();
}
}
internal class Message
{
private readonly int mMessageSeq;
private readonly byte mMsgType;
private readonly byte[] mBody;
internal Message(int message_seq, byte msg_type, byte[] body)
{
this.mMessageSeq = message_seq;
this.mMsgType = msg_type;
this.mBody = body;
}
public int Seq
{
get { return mMessageSeq; }
}
public byte Type
{
get { return mMsgType; }
}
public byte[] Body
{
get { return mBody; }
}
}
internal class RecordLayerBuffer
: MemoryStream
{
internal RecordLayerBuffer(int size)
: base(size)
{
}
internal void SendToRecordLayer(DtlsRecordLayer recordLayer)
{
#if PORTABLE || NETFX_CORE
byte[] buf = ToArray();
int bufLen = buf.Length;
#else
byte[] buf = GetBuffer();
int bufLen = (int)Length;
#endif
recordLayer.Send(buf, 0, bufLen);
BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.Dispose(this);
}
}
internal class Retransmit
: DtlsHandshakeRetransmit
{
private readonly DtlsReliableHandshake mOuter;
internal Retransmit(DtlsReliableHandshake outer)
{
this.mOuter = outer;
}
public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len)
{
mOuter.ProcessRecord(0, epoch, buf, off, len);
}
}
}
}
#pragma warning restore
#endif