2025-05-07 11:20:40 +08:00

339 lines
13 KiB
C#

#if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
#pragma warning disable
using System;
using System.Collections;
using System.IO;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Asn1.X509;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Parameters;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Security;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.IO;
namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Tls
{
/// <summary>(D)TLS PSK key exchange (RFC 4279).</summary>
public class TlsPskKeyExchange
: AbstractTlsKeyExchange
{
protected TlsPskIdentity mPskIdentity;
protected TlsPskIdentityManager mPskIdentityManager;
protected TlsDHVerifier mDHVerifier;
protected DHParameters mDHParameters;
protected int[] mNamedCurves;
protected byte[] mClientECPointFormats, mServerECPointFormats;
protected byte[] mPskIdentityHint = null;
protected byte[] mPsk = null;
protected DHPrivateKeyParameters mDHAgreePrivateKey = null;
protected DHPublicKeyParameters mDHAgreePublicKey = null;
protected ECPrivateKeyParameters mECAgreePrivateKey = null;
protected ECPublicKeyParameters mECAgreePublicKey = null;
protected AsymmetricKeyParameter mServerPublicKey = null;
protected RsaKeyParameters mRsaServerPublicKey = null;
protected TlsEncryptionCredentials mServerCredentials = null;
protected byte[] mPremasterSecret;
[Obsolete("Use constructor that takes a TlsDHVerifier")]
public TlsPskKeyExchange(int keyExchange, IList supportedSignatureAlgorithms, TlsPskIdentity pskIdentity,
TlsPskIdentityManager pskIdentityManager, DHParameters dhParameters, int[] namedCurves,
byte[] clientECPointFormats, byte[] serverECPointFormats)
: this(keyExchange, supportedSignatureAlgorithms, pskIdentity, pskIdentityManager, new DefaultTlsDHVerifier(),
dhParameters, namedCurves, clientECPointFormats, serverECPointFormats)
{
}
public TlsPskKeyExchange(int keyExchange, IList supportedSignatureAlgorithms, TlsPskIdentity pskIdentity,
TlsPskIdentityManager pskIdentityManager, TlsDHVerifier dhVerifier, DHParameters dhParameters, int[] namedCurves,
byte[] clientECPointFormats, byte[] serverECPointFormats)
: base(keyExchange, supportedSignatureAlgorithms)
{
switch (keyExchange)
{
case KeyExchangeAlgorithm.DHE_PSK:
case KeyExchangeAlgorithm.ECDHE_PSK:
case KeyExchangeAlgorithm.PSK:
case KeyExchangeAlgorithm.RSA_PSK:
break;
default:
throw new InvalidOperationException("unsupported key exchange algorithm");
}
this.mPskIdentity = pskIdentity;
this.mPskIdentityManager = pskIdentityManager;
this.mDHVerifier = dhVerifier;
this.mDHParameters = dhParameters;
this.mNamedCurves = namedCurves;
this.mClientECPointFormats = clientECPointFormats;
this.mServerECPointFormats = serverECPointFormats;
}
public override void SkipServerCredentials()
{
if (mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
throw new TlsFatalAlert(AlertDescription.unexpected_message);
}
public override void ProcessServerCredentials(TlsCredentials serverCredentials)
{
if (!(serverCredentials is TlsEncryptionCredentials))
throw new TlsFatalAlert(AlertDescription.internal_error);
ProcessServerCertificate(serverCredentials.Certificate);
this.mServerCredentials = (TlsEncryptionCredentials)serverCredentials;
}
public override byte[] GenerateServerKeyExchange()
{
this.mPskIdentityHint = mPskIdentityManager.GetHint();
if (this.mPskIdentityHint == null && !RequiresServerKeyExchange)
return null;
MemoryStream buf = new MemoryStream();
if (this.mPskIdentityHint == null)
{
TlsUtilities.WriteOpaque16(TlsUtilities.EmptyBytes, buf);
}
else
{
TlsUtilities.WriteOpaque16(this.mPskIdentityHint, buf);
}
if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
{
if (this.mDHParameters == null)
throw new TlsFatalAlert(AlertDescription.internal_error);
this.mDHAgreePrivateKey = TlsDHUtilities.GenerateEphemeralServerKeyExchange(mContext.SecureRandom,
this.mDHParameters, buf);
}
else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
{
this.mECAgreePrivateKey = TlsEccUtilities.GenerateEphemeralServerKeyExchange(mContext.SecureRandom,
mNamedCurves, mClientECPointFormats, buf);
}
return buf.ToArray();
}
public override void ProcessServerCertificate(Certificate serverCertificate)
{
if (mKeyExchange != KeyExchangeAlgorithm.RSA_PSK)
throw new TlsFatalAlert(AlertDescription.unexpected_message);
if (serverCertificate.IsEmpty)
throw new TlsFatalAlert(AlertDescription.bad_certificate);
X509CertificateStructure x509Cert = serverCertificate.GetCertificateAt(0);
SubjectPublicKeyInfo keyInfo = x509Cert.SubjectPublicKeyInfo;
try
{
this.mServerPublicKey = PublicKeyFactory.CreateKey(keyInfo);
}
catch (Exception e)
{
throw new TlsFatalAlert(AlertDescription.unsupported_certificate, e);
}
// Sanity check the PublicKeyFactory
if (this.mServerPublicKey.IsPrivate)
throw new TlsFatalAlert(AlertDescription.internal_error);
this.mRsaServerPublicKey = ValidateRsaPublicKey((RsaKeyParameters)this.mServerPublicKey);
TlsUtilities.ValidateKeyUsage(x509Cert, KeyUsage.KeyEncipherment);
base.ProcessServerCertificate(serverCertificate);
}
public override bool RequiresServerKeyExchange
{
get
{
switch (mKeyExchange)
{
case KeyExchangeAlgorithm.DHE_PSK:
case KeyExchangeAlgorithm.ECDHE_PSK:
return true;
default:
return false;
}
}
}
public override void ProcessServerKeyExchange(Stream input)
{
this.mPskIdentityHint = TlsUtilities.ReadOpaque16(input);
if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
{
this.mDHParameters = TlsDHUtilities.ReceiveDHParameters(mDHVerifier, input);
this.mDHAgreePublicKey = new DHPublicKeyParameters(TlsDHUtilities.ReadDHParameter(input), mDHParameters);
}
else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
{
ECDomainParameters ecParams = TlsEccUtilities.ReadECParameters(mNamedCurves, mClientECPointFormats, input);
byte[] point = TlsUtilities.ReadOpaque8(input);
this.mECAgreePublicKey = TlsEccUtilities.ValidateECPublicKey(TlsEccUtilities.DeserializeECPublicKey(
mClientECPointFormats, ecParams, point));
}
}
public override void ValidateCertificateRequest(CertificateRequest certificateRequest)
{
throw new TlsFatalAlert(AlertDescription.unexpected_message);
}
public override void ProcessClientCredentials(TlsCredentials clientCredentials)
{
throw new TlsFatalAlert(AlertDescription.internal_error);
}
public override void GenerateClientKeyExchange(Stream output)
{
if (mPskIdentityHint == null)
{
mPskIdentity.SkipIdentityHint();
}
else
{
mPskIdentity.NotifyIdentityHint(mPskIdentityHint);
}
byte[] psk_identity = mPskIdentity.GetPskIdentity();
if (psk_identity == null)
throw new TlsFatalAlert(AlertDescription.internal_error);
this.mPsk = mPskIdentity.GetPsk();
if (mPsk == null)
throw new TlsFatalAlert(AlertDescription.internal_error);
TlsUtilities.WriteOpaque16(psk_identity, output);
mContext.SecurityParameters.pskIdentity = psk_identity;
if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
{
this.mDHAgreePrivateKey = TlsDHUtilities.GenerateEphemeralClientKeyExchange(mContext.SecureRandom,
mDHParameters, output);
}
else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
{
this.mECAgreePrivateKey = TlsEccUtilities.GenerateEphemeralClientKeyExchange(mContext.SecureRandom,
mServerECPointFormats, mECAgreePublicKey.Parameters, output);
}
else if (this.mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
{
this.mPremasterSecret = TlsRsaUtilities.GenerateEncryptedPreMasterSecret(mContext,
this.mRsaServerPublicKey, output);
}
}
public override void ProcessClientKeyExchange(Stream input)
{
byte[] psk_identity = TlsUtilities.ReadOpaque16(input);
this.mPsk = mPskIdentityManager.GetPsk(psk_identity);
if (mPsk == null)
throw new TlsFatalAlert(AlertDescription.unknown_psk_identity);
mContext.SecurityParameters.pskIdentity = psk_identity;
if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
{
this.mDHAgreePublicKey = new DHPublicKeyParameters(TlsDHUtilities.ReadDHParameter(input), mDHParameters);
}
else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
{
byte[] point = TlsUtilities.ReadOpaque8(input);
ECDomainParameters curve_params = this.mECAgreePrivateKey.Parameters;
this.mECAgreePublicKey = TlsEccUtilities.ValidateECPublicKey(TlsEccUtilities.DeserializeECPublicKey(
mServerECPointFormats, curve_params, point));
}
else if (this.mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
{
byte[] encryptedPreMasterSecret;
if (TlsUtilities.IsSsl(mContext))
{
// TODO Do any SSLv3 clients actually include the length?
encryptedPreMasterSecret = Streams.ReadAll(input);
}
else
{
encryptedPreMasterSecret = TlsUtilities.ReadOpaque16(input);
}
this.mPremasterSecret = mServerCredentials.DecryptPreMasterSecret(encryptedPreMasterSecret);
}
}
public override byte[] GeneratePremasterSecret()
{
byte[] other_secret = GenerateOtherSecret(mPsk.Length);
MemoryStream buf = new MemoryStream(4 + other_secret.Length + mPsk.Length);
TlsUtilities.WriteOpaque16(other_secret, buf);
TlsUtilities.WriteOpaque16(mPsk, buf);
Arrays.Fill(mPsk, (byte)0);
this.mPsk = null;
return buf.ToArray();
}
protected virtual byte[] GenerateOtherSecret(int pskLength)
{
if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
{
if (mDHAgreePrivateKey != null)
{
return TlsDHUtilities.CalculateDHBasicAgreement(mDHAgreePublicKey, mDHAgreePrivateKey);
}
throw new TlsFatalAlert(AlertDescription.internal_error);
}
if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
{
if (mECAgreePrivateKey != null)
{
return TlsEccUtilities.CalculateECDHBasicAgreement(mECAgreePublicKey, mECAgreePrivateKey);
}
throw new TlsFatalAlert(AlertDescription.internal_error);
}
if (this.mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
{
return this.mPremasterSecret;
}
return new byte[pskLength];
}
protected virtual RsaKeyParameters ValidateRsaPublicKey(RsaKeyParameters key)
{
// TODO What is the minimum bit length required?
// key.Modulus.BitLength;
if (!key.Exponent.IsProbablePrime(2))
throw new TlsFatalAlert(AlertDescription.illegal_parameter);
return key;
}
}
}
#pragma warning restore
#endif