Skip to content

Commit

Permalink
Use OS-provided RSA OAEP implementation for Android
Browse files Browse the repository at this point in the history
Instead of using the managed OAEP implementatation, allow the OS to handle OAEP with SHA2 algorithms.

Android also supports algorithm identifiers like "RSA/ECB/OAEPwithSHA-256andMGF1Padding" which would make the code simpler, but it makes the hash algorithm used in MGF1 ambiguous. Since the MGF1 algorithm is always the same as the OAEP algorithm, we explicitly use the `OEAPParameterSpec` so that we can control both the MGF1 digest algorithm and the OAEP digest algorithm.
  • Loading branch information
vcsjones authored Jul 6, 2022
1 parent 3cf8f49 commit 907f395
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 207 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ internal enum RsaPadding : int
{
Pkcs1 = 0,
OaepSHA1 = 1,
NoPadding = 2,
OaepSHA256 = 2,
OaepSHA384 = 3,
OaepSHA512 = 4,
}
}
}
Expand Down
118 changes: 26 additions & 92 deletions src/libraries/Common/src/System/Security/Cryptography/RSAAndroid.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public override byte[] Decrypt(byte[] data, RSAEncryptionPadding padding)
ArgumentNullException.ThrowIfNull(data);
ArgumentNullException.ThrowIfNull(padding);

Interop.AndroidCrypto.RsaPadding rsaPadding = GetInteropPadding(padding, out HashAlgorithmName? oaepHashAlgorithmName);
Interop.AndroidCrypto.RsaPadding rsaPadding = GetInteropPadding(padding);
SafeRsaHandle key = GetKey();

int rsaSize = Interop.AndroidCrypto.RsaSize(key);
Expand All @@ -89,7 +89,7 @@ public override byte[] Decrypt(byte[] data, RSAEncryptionPadding padding)
{
destination = new Span<byte>(buf, 0, rsaSize);

if (!TryDecrypt(key, data, destination, rsaPadding, oaepHashAlgorithmName, out int bytesWritten))
if (!TryDecrypt(key, data, destination, rsaPadding, out int bytesWritten))
{
Debug.Fail($"{nameof(TryDecrypt)} should not return false for RSA_size buffer");
throw new CryptographicException();
Expand All @@ -112,7 +112,7 @@ public override bool TryDecrypt(
{
ArgumentNullException.ThrowIfNull(padding);

Interop.AndroidCrypto.RsaPadding rsaPadding = GetInteropPadding(padding, out HashAlgorithmName? oaepHashAlgorithmName);
Interop.AndroidCrypto.RsaPadding rsaPadding = GetInteropPadding(padding);
SafeRsaHandle key = GetKey();

int keySizeBytes = Interop.AndroidCrypto.RsaSize(key);
Expand All @@ -135,7 +135,7 @@ public override bool TryDecrypt(
tmp = rent;
}

bool ret = TryDecrypt(key, data, tmp, rsaPadding, oaepHashAlgorithmName, out bytesWritten);
bool ret = TryDecrypt(key, data, tmp, rsaPadding, out bytesWritten);

if (ret)
{
Expand Down Expand Up @@ -163,23 +163,16 @@ public override bool TryDecrypt(
return ret;
}

return TryDecrypt(key, data, destination, rsaPadding, oaepHashAlgorithmName, out bytesWritten);
return TryDecrypt(key, data, destination, rsaPadding, out bytesWritten);
}

private static bool TryDecrypt(
SafeRsaHandle key,
ReadOnlySpan<byte> data,
Span<byte> destination,
Interop.AndroidCrypto.RsaPadding rsaPadding,
HashAlgorithmName? oaepHashAlgorithmName,
out int bytesWritten)
{
// If rsaPadding is PKCS1 or OAEP-SHA1 then no depadding method should be present.
// If rsaPadding is NoPadding then a depadding method should be present.
Debug.Assert(
(rsaPadding == Interop.AndroidCrypto.RsaPadding.NoPadding) ==
(oaepHashAlgorithmName != null));

// Caller should have already checked this.
Debug.Assert(!key.IsInvalid);

Expand All @@ -196,54 +189,18 @@ private static bool TryDecrypt(
return false;
}

Span<byte> decryptBuf = destination;
byte[]? paddingBuf = null;

if (oaepHashAlgorithmName != null)
{
paddingBuf = CryptoPool.Rent(rsaSize);
decryptBuf = paddingBuf;
}

try
{
int returnValue = Interop.AndroidCrypto.RsaPrivateDecrypt(data.Length, data, decryptBuf, key, rsaPadding);
CheckReturn(returnValue);

if (oaepHashAlgorithmName != null)
{
return RsaPaddingProcessor.DepadOaep(oaepHashAlgorithmName.Value, paddingBuf, destination, out bytesWritten);
}
else
{
// If the padding mode is RSA_NO_PADDING then the size of the decrypted block
// will be RSA_size. If any padding was used, then some amount (determined by the padding algorithm)
// will have been reduced, and only returnValue bytes were part of the decrypted
// body. Either way, we can just use returnValue, but some additional bytes may have been overwritten
// in the destination span.
bytesWritten = returnValue;
}

return true;
}
finally
{
if (paddingBuf != null)
{
// DecryptBuf is paddingBuf if paddingBuf is not null, erase it before returning it.
// If paddingBuf IS null then decryptBuf was destination, and shouldn't be cleared.
CryptographicOperations.ZeroMemory(decryptBuf);
CryptoPool.Return(paddingBuf, clearSize: 0);
}
}
int returnValue = Interop.AndroidCrypto.RsaPrivateDecrypt(data.Length, data, destination, key, rsaPadding);
CheckReturn(returnValue);
bytesWritten = returnValue;
return true;
}

public override byte[] Encrypt(byte[] data, RSAEncryptionPadding padding)
{
ArgumentNullException.ThrowIfNull(data);
ArgumentNullException.ThrowIfNull(padding);

Interop.AndroidCrypto.RsaPadding rsaPadding = GetInteropPadding(padding, out HashAlgorithmName? oaepHashAlgorithmName);
Interop.AndroidCrypto.RsaPadding rsaPadding = GetInteropPadding(padding);
SafeRsaHandle key = GetKey();

byte[] buf = new byte[Interop.AndroidCrypto.RsaSize(key)];
Expand All @@ -253,7 +210,6 @@ public override byte[] Encrypt(byte[] data, RSAEncryptionPadding padding)
data,
buf,
rsaPadding,
oaepHashAlgorithmName,
out int bytesWritten);

if (!encrypted || bytesWritten != buf.Length)
Expand All @@ -269,18 +225,17 @@ public override bool TryEncrypt(ReadOnlySpan<byte> data, Span<byte> destination,
{
ArgumentNullException.ThrowIfNull(padding);

Interop.AndroidCrypto.RsaPadding rsaPadding = GetInteropPadding(padding, out HashAlgorithmName? oaepHashAlgorithmName);
Interop.AndroidCrypto.RsaPadding rsaPadding = GetInteropPadding(padding);
SafeRsaHandle key = GetKey();

return TryEncrypt(key, data, destination, rsaPadding, oaepHashAlgorithmName, out bytesWritten);
return TryEncrypt(key, data, destination, rsaPadding, out bytesWritten);
}

private static bool TryEncrypt(
SafeRsaHandle key,
ReadOnlySpan<byte> data,
Span<byte> destination,
Interop.AndroidCrypto.RsaPadding rsaPadding,
HashAlgorithmName? oaepHashAlgorithmName,
out int bytesWritten)
{
int rsaSize = Interop.AndroidCrypto.RsaSize(key);
Expand All @@ -291,60 +246,39 @@ private static bool TryEncrypt(
return false;
}

int returnValue;

if (oaepHashAlgorithmName != null)
{
Debug.Assert(rsaPadding == Interop.AndroidCrypto.RsaPadding.NoPadding);
byte[] rented = CryptoPool.Rent(rsaSize);
Span<byte> tmp = new Span<byte>(rented, 0, rsaSize);

try
{
RsaPaddingProcessor.PadOaep(oaepHashAlgorithmName.Value, data, tmp);
returnValue = Interop.AndroidCrypto.RsaPublicEncrypt(tmp.Length, tmp, destination, key, rsaPadding);
}
finally
{
CryptographicOperations.ZeroMemory(tmp);
CryptoPool.Return(rented, clearSize: 0);
}
}
else
{
Debug.Assert(rsaPadding != Interop.AndroidCrypto.RsaPadding.NoPadding);

returnValue = Interop.AndroidCrypto.RsaPublicEncrypt(data.Length, data, destination, key, rsaPadding);
}

int returnValue = Interop.AndroidCrypto.RsaPublicEncrypt(data.Length, data, destination, key, rsaPadding);
CheckReturn(returnValue);

bytesWritten = returnValue;
Debug.Assert(returnValue == rsaSize, $"{returnValue} != {rsaSize}");
return true;

}

private static Interop.AndroidCrypto.RsaPadding GetInteropPadding(
RSAEncryptionPadding padding,
out HashAlgorithmName? oaepHashAlgorithmName)
private static Interop.AndroidCrypto.RsaPadding GetInteropPadding(RSAEncryptionPadding padding)
{
if (padding == RSAEncryptionPadding.Pkcs1)
{
oaepHashAlgorithmName = null;
return Interop.AndroidCrypto.RsaPadding.Pkcs1;
}

if (padding == RSAEncryptionPadding.OaepSHA1)
{
oaepHashAlgorithmName = null;
return Interop.AndroidCrypto.RsaPadding.OaepSHA1;
}

if (padding.Mode == RSAEncryptionPaddingMode.Oaep)
if (padding == RSAEncryptionPadding.OaepSHA256)
{
return Interop.AndroidCrypto.RsaPadding.OaepSHA256;
}

if (padding == RSAEncryptionPadding.OaepSHA384)
{
return Interop.AndroidCrypto.RsaPadding.OaepSHA384;
}

if (padding == RSAEncryptionPadding.OaepSHA512)
{
oaepHashAlgorithmName = padding.OaepHashAlgorithm;
return Interop.AndroidCrypto.RsaPadding.NoPadding;
return Interop.AndroidCrypto.RsaPadding.OaepSHA512;
}

throw PaddingModeNotSupported();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,105 +256,6 @@ internal static void PadOaep(
}
}

internal static bool DepadOaep(
HashAlgorithmName hashAlgorithmName,
ReadOnlySpan<byte> source,
Span<byte> destination,
out int bytesWritten)
{
int hLen = HashLength(hashAlgorithmName);

// https://tools.ietf.org/html/rfc3447#section-7.1.2
using (IncrementalHash hasher = IncrementalHash.CreateHash(hashAlgorithmName))
{
Debug.Assert(hasher.HashLengthInBytes == hLen);

Span<byte> lHash = stackalloc byte[hLen];

if (!hasher.TryGetHashAndReset(lHash, out int hLen2) || hLen2 != hLen)
{
Debug.Fail("TryGetHashAndReset failed with exact-size destination");
throw new CryptographicException();
}

int y = source[0];
ReadOnlySpan<byte> maskedSeed = source.Slice(1, hLen);
ReadOnlySpan<byte> maskedDB = source.Slice(1 + hLen);

Span<byte> seed = stackalloc byte[hLen];
// seedMask = MGF(maskedDB, hLen)
Mgf1(hasher, maskedDB, seed);

// seed = seedMask XOR maskedSeed
Xor(seed, maskedSeed);

byte[] tmp = CryptoPool.Rent(source.Length);

try
{
Span<byte> dbMask = new Span<byte>(tmp, 0, maskedDB.Length);
// dbMask = MGF(seed, k - hLen - 1)
Mgf1(hasher, seed, dbMask);

// DB = dbMask XOR maskedDB
Xor(dbMask, maskedDB);

ReadOnlySpan<byte> lHashPrime = dbMask.Slice(0, hLen);

int separatorPos = int.MaxValue;

for (int i = dbMask.Length - 1; i >= hLen; i--)
{
// if dbMask[i] is 1, val is 0. otherwise val is [01,FF]
byte dbMinus1 = (byte)(dbMask[i] - 1);
int val = dbMinus1;

// if val is 0: FFFFFFFF & FFFFFFFF => FFFFFFFF
// if val is any other byte value, val-1 will be in the range 00000000 to 000000FE,
// and so the high bit will not be set.
val = (~val & (val - 1)) >> 31;

// if val is 0: separator = (0 & i) | (~0 & separator) => separator
// else: separator = (~0 & i) | (0 & separator) => i
//
// Net result: non-branching "if (dbMask[i] == 1) separatorPos = i;"
separatorPos = (val & i) | (~val & separatorPos);
}

bool lHashMatches = CryptographicOperations.FixedTimeEquals(lHash, lHashPrime);
bool yIsZero = y == 0;
bool separatorMadeSense = separatorPos < dbMask.Length;

// This intentionally uses non-short-circuiting operations to hide the timing
// differential between the three failure cases
bool shouldContinue = lHashMatches & yIsZero & separatorMadeSense;

if (!shouldContinue)
{
throw new CryptographicException(SR.Cryptography_OAEP_Decryption_Failed);
}

Span<byte> message = dbMask.Slice(separatorPos + 1);

if (message.Length <= destination.Length)
{
message.CopyTo(destination);
bytesWritten = message.Length;
return true;
}
else
{
bytesWritten = 0;
return false;
}
}
finally
{
CryptoPool.Return(tmp, source.Length);
}
}
}

internal static void EncodePss(HashAlgorithmName hashAlgorithmName, ReadOnlySpan<byte> mHash, Span<byte> destination, int keySize)
{
int hLen = HashLength(hashAlgorithmName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ jmethodID g_sslCtxGetDefaultSslParamsMethod;
jclass g_GCMParameterSpecClass;
jmethodID g_GCMParameterSpecCtor;

// java/security/spec/MGF1ParameterSpec
jclass g_MGF1ParameterSpecClass;
jfieldID g_MGF1ParameterSpec_SHA1Field;
jfieldID g_MGF1ParameterSpec_SHA256Field;
jfieldID g_MGF1ParameterSpec_SHA384Field;
jfieldID g_MGF1ParameterSpec_SHA512Field;

// javax/crypto/spec/OAEPParameterSpec
jclass g_OAEPParameterSpecClass;
jmethodID g_OAEPParameterSpecCtor;

// javax/crypto/spec/PSource$PSpecified
jclass g_PSourcePSpecifiedClass;
jfieldID g_PSourcePSpecified_DefaultField;

// java/security/interfaces/DSAKey
jclass g_DSAKeyClass;

Expand Down Expand Up @@ -692,6 +707,18 @@ JNI_OnLoad(JavaVM *vm, void *reserved)
g_GCMParameterSpecClass = GetClassGRef(env, "javax/crypto/spec/GCMParameterSpec");
g_GCMParameterSpecCtor = GetMethod(env, false, g_GCMParameterSpecClass, "<init>", "(I[B)V");

g_MGF1ParameterSpecClass = GetClassGRef(env, "java/security/spec/MGF1ParameterSpec");
g_MGF1ParameterSpec_SHA1Field = GetField(env, true, g_MGF1ParameterSpecClass, "SHA1", "Ljava/security/spec/MGF1ParameterSpec;");
g_MGF1ParameterSpec_SHA256Field = GetField(env, true, g_MGF1ParameterSpecClass, "SHA256", "Ljava/security/spec/MGF1ParameterSpec;");
g_MGF1ParameterSpec_SHA384Field = GetField(env, true, g_MGF1ParameterSpecClass, "SHA384", "Ljava/security/spec/MGF1ParameterSpec;");
g_MGF1ParameterSpec_SHA512Field = GetField(env, true, g_MGF1ParameterSpecClass, "SHA512", "Ljava/security/spec/MGF1ParameterSpec;");

g_OAEPParameterSpecClass = GetClassGRef(env, "javax/crypto/spec/OAEPParameterSpec");
g_OAEPParameterSpecCtor = GetMethod(env, false, g_OAEPParameterSpecClass, "<init>", "(Ljava/lang/String;Ljava/lang/String;Ljava/security/spec/AlgorithmParameterSpec;Ljavax/crypto/spec/PSource;)V");

g_PSourcePSpecifiedClass = GetClassGRef(env, "javax/crypto/spec/PSource$PSpecified");
g_PSourcePSpecified_DefaultField = GetField(env, true, g_PSourcePSpecifiedClass, "DEFAULT", "Ljavax/crypto/spec/PSource$PSpecified;");

g_bigNumClass = GetClassGRef(env, "java/math/BigInteger");
g_bigNumCtor = GetMethod(env, false, g_bigNumClass, "<init>", "([B)V");
g_bigNumCtorWithSign = GetMethod(env, false, g_bigNumClass, "<init>", "(I[B)V");
Expand Down
Loading

0 comments on commit 907f395

Please sign in to comment.