Skip to content

Commit

Permalink
JCA implementation of CRYSTALS-Dilithium
Browse files Browse the repository at this point in the history
additional refactoring of lightweight version to deal with draft RFC for private key format for Dilithium
  • Loading branch information
dghgit committed Aug 15, 2022
1 parent abefebb commit 510dd86
Show file tree
Hide file tree
Showing 26 changed files with 1,155 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ public interface BCObjectIdentifiers
public static final ASN1ObjectIdentifier falcon_512 = new ASN1ObjectIdentifier("1.3.9999.3.1"); // falcon.branch("1");
public static final ASN1ObjectIdentifier falcon_1024 = new ASN1ObjectIdentifier("1.3.9999.3.4"); // falcon.branch("2");

/*
* Dilithium
*/
public static final ASN1ObjectIdentifier dilithium = bc_sig.branch("8");

// OpenSSL OIDs
public static ASN1ObjectIdentifier dilithium2 = new ASN1ObjectIdentifier("1.3.6.1.4.1.2.267.7.4.4"); // dilithium.branch("1");
public static ASN1ObjectIdentifier dilithium3 = new ASN1ObjectIdentifier("1.3.6.1.4.1.2.267.7.6.5"); // dilithium.branch("2");
public static ASN1ObjectIdentifier dilithium5 = new ASN1ObjectIdentifier("1.3.6.1.4.1.2.267.7.8.7"); // dilithium.branch("3");

/**
* key_exchange(3) algorithms
*/
Expand Down Expand Up @@ -320,7 +330,7 @@ public interface BCObjectIdentifiers
public static final ASN1ObjectIdentifier ntruhrss701 = pqc_kem_ntru.branch("4");

/**
* NTRU
* Kyber
*/
public static final ASN1ObjectIdentifier pqc_kem_kyber = bc_kem.branch("6");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ else if (this.DilithiumGamma1 == (1 << 19))
{
throw new RuntimeException("Wrong Dilithium Gamma1!");
}

}

public byte[][] generateKeyPair()
Expand Down Expand Up @@ -286,38 +285,26 @@ public byte[][] generateKeyPair()
shake256Digest.update(pk, 0, CryptoPublicKeyBytes);
shake256Digest.doFinal(tr, 0, SeedBytes);

byte[] sk = Packing.packSecretKey(rho, tr, key, t0, s1, s2, this);
byte[][] sk = Packing.packSecretKey(rho, tr, key, t0, s1, s2, this);
// System.out.println("sk engine = ");
// Helper.printByteArray(sk);

return new byte[][]{pk, sk};
return new byte[][]{pk, sk[0], sk[1], sk[2], sk[3], sk[4], sk[5]};
}

public void sign(int signMsglen, byte[] msg, int msglen, byte[] sk)
{
byte[] out = new byte[signMsglen];
System.arraycopy(msg, 0, out, CryptoBytes, msglen);
signSignature(msg, msglen, sk);
}

public byte[] signSignature(byte[] msg, int msglen, byte[] secretKey)
public byte[] signSignature(byte[] msg, int msglen, byte[] rho, byte[] key, byte[] tr, byte[] secretKey)
{
int n;
byte[] outSig = new byte[CryptoBytes + msglen];
byte[] seedBuf = new byte[3 * DilithiumEngine.SeedBytes + 2 * DilithiumEngine.CrhBytes];
byte[] rho, tr, key, mu = new byte[CrhBytes], rhoPrime = new byte[CrhBytes];
byte[] mu = new byte[CrhBytes], rhoPrime = new byte[CrhBytes];
short nonce = 0;
PolyVecL s1 = new PolyVecL(this), y = new PolyVecL(this), z = new PolyVecL(this);
PolyVecK t0 = new PolyVecK(this), s2 = new PolyVecK(this), w1 = new PolyVecK(this), w0 = new PolyVecK(this), h = new PolyVecK(this);
Poly cp = new Poly(this);
PolyVecMatrix aMatrix = new PolyVecMatrix(this);
boolean rej = true;


byte[][] temp = Packing.unpackSecretKey(t0, s1, s2, secretKey, this);
rho = temp[0];
tr = temp[1];
key = temp[2];
Packing.unpackSecretKey(t0, s1, s2, secretKey, this);

// System.out.print("rho = ");
// Helper.printByteArray(rho);
Expand Down Expand Up @@ -452,11 +439,11 @@ public byte[] signSignature(byte[] msg, int msglen, byte[] secretKey)

}

public byte[] sign(byte[] msg, int mlen, byte[] secretKey)
public byte[] sign(byte[] msg, int mlen, byte[] rho, byte[] key, byte[] tr, byte[] secretKey)
{
byte[] signedMessage = new byte[CryptoBytes];

System.arraycopy(signSignature(msg, mlen, secretKey), 0, signedMessage, 0, CryptoBytes);
System.arraycopy(signSignature(msg, mlen, rho, key, tr, secretKey), 0, signedMessage, 0, CryptoBytes);
return signedMessage;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private AsymmetricCipherKeyPair genKeyPair()
// Helper.printByteArray(keyPair[0]);

DilithiumPublicKeyParameters pubKey = new DilithiumPublicKeyParameters(dilithiumParams, keyPair[0]);
DilithiumPrivateKeyParameters privKey = new DilithiumPrivateKeyParameters(dilithiumParams, keyPair[1]);
DilithiumPrivateKeyParameters privKey = new DilithiumPrivateKeyParameters(dilithiumParams, keyPair[1], keyPair[2], keyPair[3], keyPair[4], keyPair[5], keyPair[6]);

return new AsymmetricCipherKeyPair(pubKey, privKey);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,26 @@

public class DilithiumParameters
{
public static final DilithiumParameters dilithium2 = new DilithiumParameters(2);
public static final DilithiumParameters dilithium3 = new DilithiumParameters(3);
public static final DilithiumParameters dilithium5 = new DilithiumParameters(5);
public static final DilithiumParameters dilithium2 = new DilithiumParameters("dilithium2", 2);
public static final DilithiumParameters dilithium3 = new DilithiumParameters("dilithium3", 3);
public static final DilithiumParameters dilithium5 = new DilithiumParameters("dilithium5", 5);

private final int k;
private final String name;

private DilithiumParameters(int k)
private DilithiumParameters(String name, int k)
{
this.name = name;
this.k = k;
}

DilithiumEngine getEngine(SecureRandom random)
{
return new DilithiumEngine(k, random);
}

public String getName()
{
return name;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,61 @@
public class DilithiumPrivateKeyParameters
extends DilithiumKeyParameters
{
private final byte[] privateKey;
protected final byte[] rho;
protected final byte[] k;
protected final byte[] tr;
private final byte[] s1;
private final byte[] s2;
private final byte[] t0;

public byte[] getPrivateKey()
{
return Arrays.clone(privateKey);
return getEncoded();
}

public DilithiumPrivateKeyParameters(DilithiumParameters params, byte[] privateKey)
public DilithiumPrivateKeyParameters(DilithiumParameters params, byte[] rho, byte[] K, byte[] tr, byte[] s1, byte[] s2, byte[] t0)
{
super(true, params);
this.privateKey = Arrays.clone(privateKey);
this.rho = Arrays.clone(rho);
this.k = Arrays.clone(K);
this.tr = Arrays.clone(tr);
this.s1 = Arrays.clone(s1);
this.s2 = Arrays.clone(s2);
this.t0 = Arrays.clone(t0);
}

public byte[] getRho()
{
return Arrays.clone(rho);
}

public byte[] getK()
{
return Arrays.clone(k);
}

public byte[] getTr()
{
return Arrays.clone(tr);
}

public byte[] getS1()
{
return Arrays.clone(s1);
}

public byte[] getS2()
{
return Arrays.clone(s2);
}

public byte[] getT0()
{
return Arrays.clone(t0);
}

public byte[] getEncoded()
{
return Arrays.clone(privateKey);
return Arrays.concatenate(new byte[][] { rho, k, tr, s1, s2, t0 });
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ public byte[] generateSignature(byte[] message)
{
DilithiumEngine engine = privKey.getParameters().getEngine(random);


return engine.sign(message, message.length, privKey.getPrivateKey());
// TODO: finish unpacking secret key
return engine.sign(message, message.length, privKey.rho, privKey.k, privKey.tr, privKey.getPrivateKey());
}

public boolean verifySignature(byte[] message, byte[] signature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,35 @@ static byte[] unpackPublicKey(PolyVecK t1, byte[] publicKey, DilithiumEngine eng
return Arrays.copyOfRange(publicKey, 0, DilithiumEngine.SeedBytes);
}

static byte[] packSecretKey(byte[] rho, byte[] tr, byte[] key, PolyVecK t0, PolyVecL s1, PolyVecK s2, DilithiumEngine engine)
static byte[][] packSecretKey(byte[] rho, byte[] tr, byte[] key, PolyVecK t0, PolyVecL s1, PolyVecK s2, DilithiumEngine engine)
{
byte[] out = new byte[engine.getCryptoSecretKeyBytes()];
int i, end = 0;
System.arraycopy(rho, 0, out, 0, DilithiumEngine.SeedBytes);
end += DilithiumEngine.SeedBytes;
byte[][] out = new byte[6][];

System.arraycopy(key, 0, out, end, DilithiumEngine.SeedBytes);
end += DilithiumEngine.SeedBytes;
out[0] = new byte[DilithiumEngine.SeedBytes];
System.arraycopy(rho, 0, out[0], 0, DilithiumEngine.SeedBytes);

System.arraycopy(tr, 0, out, end, DilithiumEngine.SeedBytes);
end += DilithiumEngine.SeedBytes;
out[1] = new byte[DilithiumEngine.SeedBytes];
System.arraycopy(key, 0, out[1], 0, DilithiumEngine.SeedBytes);

for (i = 0; i < engine.getDilithiumL(); ++i)
out[2] = new byte[DilithiumEngine.SeedBytes];
System.arraycopy(tr, 0, out[2], 0, DilithiumEngine.SeedBytes);

out[3] = new byte[engine.getDilithiumL() * engine.getDilithiumPolyEtaPackedBytes()];
for (int i = 0; i < engine.getDilithiumL(); ++i)
{
System.arraycopy(s1.getVectorIndex(i).polyEtaPack(), 0, out, end + i * engine.getDilithiumPolyEtaPackedBytes(), engine.getDilithiumPolyEtaPackedBytes());
System.arraycopy(s1.getVectorIndex(i).polyEtaPack(), 0, out[3], i * engine.getDilithiumPolyEtaPackedBytes(), engine.getDilithiumPolyEtaPackedBytes());
}
end += engine.getDilithiumL() * engine.getDilithiumPolyEtaPackedBytes();

for (i = 0; i < engine.getDilithiumK(); ++i)
out[4] = new byte[engine.getDilithiumK() * engine.getDilithiumPolyEtaPackedBytes()];
for (int i = 0; i < engine.getDilithiumK(); ++i)
{
System.arraycopy(s2.getVectorIndex(i).polyEtaPack(), 0, out, end + i * engine.getDilithiumPolyEtaPackedBytes(), engine.getDilithiumPolyEtaPackedBytes());
System.arraycopy(s2.getVectorIndex(i).polyEtaPack(), 0, out[4], i * engine.getDilithiumPolyEtaPackedBytes(), engine.getDilithiumPolyEtaPackedBytes());
}
end += engine.getDilithiumK() * engine.getDilithiumPolyEtaPackedBytes();

for (i = 0; i < engine.getDilithiumK(); ++i)
out[5] = new byte[engine.getDilithiumK() * DilithiumEngine.DilithiumPolyT0PackedBytes];
for (int i = 0; i < engine.getDilithiumK(); ++i)
{
System.arraycopy(t0.getVectorIndex(i).polyt0Pack(), 0, out, end + i * DilithiumEngine.DilithiumPolyT0PackedBytes, DilithiumEngine.DilithiumPolyT0PackedBytes);
System.arraycopy(t0.getVectorIndex(i).polyt0Pack(), 0, out[5], i * DilithiumEngine.DilithiumPolyT0PackedBytes, DilithiumEngine.DilithiumPolyT0PackedBytes);
}
return out;
}
Expand All @@ -69,39 +70,33 @@ static byte[] packSecretKey(byte[] rho, byte[] tr, byte[] key, PolyVecK t0, Poly
* @return Byte matrix where byte[0] = rho, byte[1] = tr, byte[2] = key
*/

static byte[][] unpackSecretKey(PolyVecK t0, PolyVecL s1, PolyVecK s2, byte[] secretKey, DilithiumEngine engine)
static void unpackSecretKey(PolyVecK t0, PolyVecL s1, PolyVecK s2, byte[] secretKey, DilithiumEngine engine)
{
int i, end;
byte[][] out = new byte[3][];
out[0] = new byte[DilithiumEngine.SeedBytes]; // rho
out[1] = new byte[DilithiumEngine.SeedBytes]; // tr
out[2] = new byte[DilithiumEngine.SeedBytes]; // key

// secretkey = {rho | key | tr} ...
System.arraycopy(secretKey, 0, out[0], 0, DilithiumEngine.SeedBytes);
System.arraycopy(secretKey, DilithiumEngine.SeedBytes * 2, out[1], 0, DilithiumEngine.SeedBytes);
System.arraycopy(secretKey, DilithiumEngine.SeedBytes, out[2], 0, DilithiumEngine.SeedBytes);

end = 3 * DilithiumEngine.SeedBytes;

for (i = 0; i < engine.getDilithiumL(); ++i)
{
// TODO: reduce copying
s1.getVectorIndex(i).polyEtaUnpack(Arrays.copyOfRange(secretKey, end + i * engine.getDilithiumPolyEtaPackedBytes(), end + (i + 1) * engine.getDilithiumPolyEtaPackedBytes()));
}

end += engine.getDilithiumL() * engine.getDilithiumPolyEtaPackedBytes();

for (i = 0; i < engine.getDilithiumK(); ++i)
{
// TODO: reduce copying
s2.getVectorIndex(i).polyEtaUnpack(Arrays.copyOfRange(secretKey, end + i * engine.getDilithiumPolyEtaPackedBytes(), end + (i + 1) * engine.getDilithiumPolyEtaPackedBytes()));
}

end += engine.getDilithiumK() * engine.getDilithiumPolyEtaPackedBytes();

for (i = 0; i < engine.getDilithiumK(); ++i)
{
// TODO: reduce copying
t0.getVectorIndex(i).polyt0Unpack(Arrays.copyOfRange(secretKey, end + i * DilithiumEngine.DilithiumPolyT0PackedBytes, end + (i + 1) * DilithiumEngine.DilithiumPolyT0PackedBytes));
}

return out;
}

static byte[] packSignature(byte[] c, PolyVecL z, PolyVecK h, DilithiumEngine engine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.ASN1Sequence;
import org.bouncycastle.asn1.bc.BCObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
Expand All @@ -23,6 +24,8 @@
import org.bouncycastle.pqc.asn1.XMSSPrivateKey;
import org.bouncycastle.pqc.crypto.cmce.CMCEParameters;
import org.bouncycastle.pqc.crypto.cmce.CMCEPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.crystals.dilithium.DilithiumParameters;
import org.bouncycastle.pqc.crypto.crystals.dilithium.DilithiumPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.crystals.kyber.KyberParameters;
import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.falcon.FalconParameters;
Expand Down Expand Up @@ -195,6 +198,21 @@ else if (algOID.on(BCObjectIdentifiers.pqc_kem_kyber))

return new KyberPrivateKeyParameters(spParams, keyEnc);
}
else if (algOID.equals(BCObjectIdentifiers.dilithium2)
|| algOID.equals(BCObjectIdentifiers.dilithium3) || algOID.equals(BCObjectIdentifiers.dilithium5))
{
ASN1Sequence keyEnc = ASN1Sequence.getInstance(keyInfo.parsePrivateKey());

DilithiumParameters spParams = Utils.dilithiumParamsLookup(keyInfo.getPrivateKeyAlgorithm().getAlgorithm());

return new DilithiumPrivateKeyParameters(spParams,
ASN1BitString.getInstance(keyEnc.getObjectAt(0)).getOctets(),
ASN1BitString.getInstance(keyEnc.getObjectAt(1)).getOctets(),
ASN1BitString.getInstance(keyEnc.getObjectAt(2)).getOctets(),
ASN1BitString.getInstance(keyEnc.getObjectAt(3)).getOctets(),
ASN1BitString.getInstance(keyEnc.getObjectAt(4)).getOctets(),
ASN1BitString.getInstance(keyEnc.getObjectAt(5)).getOctets());
}
else if (algOID.equals(BCObjectIdentifiers.falcon_512) || algOID.equals(BCObjectIdentifiers.falcon_1024))
{
byte[] keyEnc = ASN1OctetString.getInstance(keyInfo.parsePrivateKey()).getOctets();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import java.io.IOException;

import org.bouncycastle.asn1.ASN1EncodableVector;
import org.bouncycastle.asn1.ASN1Set;
import org.bouncycastle.asn1.DERBitString;
import org.bouncycastle.asn1.DEROctetString;
import org.bouncycastle.asn1.DERSequence;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
Expand All @@ -18,6 +21,7 @@
import org.bouncycastle.pqc.asn1.XMSSMTPrivateKey;
import org.bouncycastle.pqc.asn1.XMSSPrivateKey;
import org.bouncycastle.pqc.crypto.cmce.CMCEPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.crystals.dilithium.DilithiumPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.falcon.FalconPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.frodo.FrodoPrivateKeyParameters;
Expand Down Expand Up @@ -244,6 +248,23 @@ else if (privateKey instanceof KyberPrivateKeyParameters)

return new PrivateKeyInfo(algorithmIdentifier, new DEROctetString(encoding), attributes);
}
else if (privateKey instanceof DilithiumPrivateKeyParameters)
{
DilithiumPrivateKeyParameters params = (DilithiumPrivateKeyParameters)privateKey;

ASN1EncodableVector v = new ASN1EncodableVector();

v.add(new DERBitString(params.getRho()));
v.add(new DERBitString(params.getK()));
v.add(new DERBitString(params.getTr()));
v.add(new DERBitString(params.getS1()));
v.add(new DERBitString(params.getS2()));
v.add(new DERBitString(params.getT0()));

AlgorithmIdentifier algorithmIdentifier = new AlgorithmIdentifier(Utils.dilithiumOidLookup(params.getParameters()));

return new PrivateKeyInfo(algorithmIdentifier, new DERSequence(v), attributes);
}
else
{
throw new IOException("key parameters not recognized");
Expand Down
Loading

0 comments on commit 510dd86

Please sign in to comment.