Skip to content

Commit

Permalink
Made CipherSuites static and added interfaces for it. Patched grease …
Browse files Browse the repository at this point in the history
…removal.
  • Loading branch information
royb committed Jan 19, 2024
1 parent fc9b93a commit 70dc775
Show file tree
Hide file tree
Showing 31 changed files with 874 additions and 830 deletions.
6 changes: 3 additions & 3 deletions mls/src/main/java/org/bouncycastle/mls/GroupKeySet.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import org.bouncycastle.mls.TreeKEM.LeafIndex;
import org.bouncycastle.mls.TreeKEM.NodeIndex;
import org.bouncycastle.mls.codec.ContentType;
import org.bouncycastle.mls.crypto.CipherSuite;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;

import java.io.IOException;
Expand All @@ -14,7 +14,7 @@
import java.util.Map;

public class GroupKeySet {
final CipherSuite suite;
final MlsCipherSuite suite;
final int secretSize;
// We store a commitment to the encryption secret that was used to create this structure, so that we can compare
// for purposes of equivalence checking without violating forward secrecy.
Expand All @@ -25,7 +25,7 @@ public class GroupKeySet {
Map<LeafIndex, HashRatchet> applicationRatchets;


public GroupKeySet(CipherSuite suite, TreeSize treeSize, Secret encryptionSecret) throws IOException, IllegalAccessException {
public GroupKeySet(MlsCipherSuite suite, TreeSize treeSize, Secret encryptionSecret) throws IOException, IllegalAccessException {
this.suite = suite;
this.secretSize = suite.getKDF().getHashLength();
this.encryptionSecretCommit = encryptionSecret.deriveSecret(suite, "commitment");
Expand Down
92 changes: 22 additions & 70 deletions mls/src/main/java/org/bouncycastle/mls/KeyScheduleEpoch.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import org.bouncycastle.crypto.hpke.HPKEContextWithEncapsulation;
import org.bouncycastle.crypto.params.AsymmetricKeyParameter;
import org.bouncycastle.mls.codec.MLSOutputStream;
import org.bouncycastle.mls.crypto.CipherSuite;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;
import org.bouncycastle.mls.codec.PreSharedKeyID;
import org.bouncycastle.util.Arrays;
Expand Down Expand Up @@ -39,7 +39,7 @@ public byte[] receiveExternalInit(byte[] kemOut) throws IOException

public static class JoinSecrets {
// Cached values
private final CipherSuite suite;
private final MlsCipherSuite suite;
// Public values
public final Secret joinerSecret;

Expand Down Expand Up @@ -83,7 +83,7 @@ public void writeTo(MLSOutputStream stream) throws IOException {
V V
psk_[n-1] --> Extract --> ExpandWithLabel --> Extract = psk_secret_[n]
*/
public static Secret pskSecret(CipherSuite suite, List<PSKWithSecret> psks) throws IOException {
public static Secret pskSecret(MlsCipherSuite suite, List<PSKWithSecret> psks) throws IOException {
Secret pskSecret = Secret.zero(suite);
if (psks == null || psks.isEmpty()) {
return pskSecret;
Expand Down Expand Up @@ -127,7 +127,7 @@ public static Secret pskSecret(CipherSuite suite, List<PSKWithSecret> psks) thro
// return new JoinSecrets(suite, joinerSecret, psks);
// }
// todo change to
public static JoinSecrets forMember(CipherSuite suite, Secret initSecret, Secret commitSecret, Secret pskSecret, byte[] context) throws IOException {
public static JoinSecrets forMember(MlsCipherSuite suite, Secret initSecret, Secret commitSecret, Secret pskSecret, byte[] context) throws IOException {
Secret preJoinerSecret = Secret.extract(suite, initSecret, commitSecret);
Secret joinerSecret = preJoinerSecret.expandWithLabel(suite,"joiner", context, suite.getKDF().getHashLength());
return new JoinSecrets(suite, joinerSecret, pskSecret);
Expand All @@ -151,7 +151,7 @@ public static JoinSecrets forMember(CipherSuite suite, Secret initSecret, Secret
V
epoch_secret
*/
public JoinSecrets(CipherSuite suite, Secret joinerSecret, List<PSKWithSecret> psks) throws IOException {
public JoinSecrets(MlsCipherSuite suite, Secret joinerSecret, List<PSKWithSecret> psks) throws IOException {
this.suite = suite;
this.joinerSecret = joinerSecret;
this.memberSecret = Secret.extract(suite, joinerSecret, pskSecret(suite, psks));
Expand All @@ -161,7 +161,7 @@ public JoinSecrets(CipherSuite suite, Secret joinerSecret, List<PSKWithSecret> p
this.welcomeKey = welcomeSecret.expand(suite, "key", suite.getAEAD().getKeySize());
this.welcomeNonce = welcomeSecret.expand(suite, "nonce", suite.getAEAD().getNonceSize());
}
public JoinSecrets(CipherSuite suite, Secret joinerSecret, Secret pskSecret) throws IOException {
public JoinSecrets(MlsCipherSuite suite, Secret joinerSecret, Secret pskSecret) throws IOException {
this.suite = suite;
this.joinerSecret = joinerSecret;
this.memberSecret = Secret.extract(suite, joinerSecret, pskSecret);
Expand Down Expand Up @@ -193,7 +193,7 @@ public static class ExternalInitParams {
public byte[] kemOutput;
public Secret initSecret;

public ExternalInitParams(CipherSuite suite, AsymmetricKeyParameter externalPub) {
public ExternalInitParams(MlsCipherSuite suite, AsymmetricKeyParameter externalPub) {
final byte[] exportContext = "MLS 1.0 external init secret".getBytes(StandardCharsets.UTF_8);
final int L = suite.getKDF().getHashLength();

Expand All @@ -215,7 +215,7 @@ public byte[] getKEMOutput() {



final CipherSuite suite;
final MlsCipherSuite suite;

// Secrets derived from the epoch secret
public final Secret initSecret;
Expand Down Expand Up @@ -248,7 +248,7 @@ public GroupKeySet getEncryptionKeys(TreeSize size) throws IOException, IllegalA
return new GroupKeySet(suite, size, encryptionSecret);
}

public static KeyGeneration senderDataKeys(CipherSuite suite, byte[] senderDataSecretBytes, byte[] ciphertext) throws IOException
public static KeyGeneration senderDataKeys(MlsCipherSuite suite, byte[] senderDataSecretBytes, byte[] ciphertext) throws IOException
{
Secret senderDataSecret = new Secret(senderDataSecretBytes);
int sampleSize = suite.getKDF().getHashLength();
Expand All @@ -261,20 +261,13 @@ public static KeyGeneration senderDataKeys(CipherSuite suite, byte[] senderDataS
}


public static Secret welcomeSecret(CipherSuite suite, byte[] joinerSecret, List<PSKWithSecret> psk) throws IOException
public static Secret welcomeSecret(MlsCipherSuite suite, byte[] joinerSecret, List<PSKWithSecret> psk) throws IOException
{
Secret pskSecret = JoinSecrets.pskSecret(suite, psk);
Secret extract = new Secret(suite.getKDF().extract(joinerSecret, pskSecret.value()));
return extract.deriveSecret(suite, "welcome");
}
public static KeyScheduleEpoch forCreatorTEST(CipherSuite suite, byte[] groupContext, byte[] initSecret) throws IOException, IllegalAccessException
{
JoinSecrets joinerSecret = JoinSecrets.forMember(suite, new Secret(initSecret), Secret.zero(suite), new Secret(new byte[0]), groupContext);
TreeSize size = TreeSize.forLeaves(1);
return joinerSecret.complete(size, groupContext);
// return KeyScheduleEpoch.joiner(suite, joinerSecret.joinerSecret.value(), new ArrayList<>(), groupContext);
}
public static KeyScheduleEpoch forCreator(CipherSuite suite, byte[] groupContext) throws IOException, IllegalAccessException
public static KeyScheduleEpoch forCreator(MlsCipherSuite suite, byte[] groupContext) throws IOException, IllegalAccessException
{
SecureRandom random = new SecureRandom();
byte[] initSecret = new byte[suite.getKDF().getHashLength()];
Expand All @@ -285,24 +278,19 @@ public static KeyScheduleEpoch forCreator(CipherSuite suite, byte[] groupContext
// return joinerSecret.complete(size, groupContext);
return KeyScheduleEpoch.joiner(suite, joinerSecret.joinerSecret.value(), new ArrayList<>(), groupContext);
}
public static KeyScheduleEpoch forCreator(CipherSuite suite) throws IOException, IllegalAccessException {
public static KeyScheduleEpoch forCreator(MlsCipherSuite suite) throws IOException, IllegalAccessException {
SecureRandom rng = new SecureRandom();
return forCreator(suite, rng);
}
public static KeyScheduleEpoch forCreator(CipherSuite suite, SecureRandom rng)
public static KeyScheduleEpoch forCreator(MlsCipherSuite suite, SecureRandom rng)
throws IOException, IllegalAccessException {
byte[] epochSecret = new byte[suite.getKDF().getHashLength()];
rng.nextBytes(epochSecret);
TreeSize treeSize = TreeSize.forLeaves(1);
return new KeyScheduleEpoch(suite, treeSize, new Secret(epochSecret));
}
// public static KeyScheduleEpoch forCreator(CipherSuite suite, byte[] epochSecret)
// throws IOException, IllegalAccessException {
// TreeSize treeSize = TreeSize.forLeaves(1);
// return new KeyScheduleEpoch(suite, treeSize, new Secret(epochSecret));
// }

public static KeyScheduleEpoch forExternalJoiner(CipherSuite suite, TreeSize treeSize, ExternalInitParams externalInitParams, Secret commitSecret, List<PSKWithSecret> psks, byte[] context) throws IOException, IllegalAccessException {
public static KeyScheduleEpoch forExternalJoiner(MlsCipherSuite suite, TreeSize treeSize, ExternalInitParams externalInitParams, Secret commitSecret, List<PSKWithSecret> psks, byte[] context) throws IOException, IllegalAccessException {
return JoinSecrets.forMember(suite, externalInitParams.initSecret, commitSecret, JoinSecrets.pskSecret(suite, psks), context).complete(treeSize, context); //TODO external: pskSecret OR new byte [0]
}

Expand Down Expand Up @@ -330,7 +318,7 @@ public byte[] confirmationTag(byte[] confirmedTranscriptHash)
return suite.getKDF().extract(confirmationKey.value(), confirmedTranscriptHash);
}

public KeyScheduleEpoch(CipherSuite suite, Secret initSecret, Secret senderDataSecret, Secret exporterSecret, Secret confirmationKey, Secret membershipKey, Secret resumptionPSK, Secret epochAuthenticator, Secret encryptionSecret, Secret externalSecret, AsymmetricCipherKeyPair externalKeyPair, GroupKeySet groupKeySet, Secret joinerSecret)
public KeyScheduleEpoch(MlsCipherSuite suite, Secret initSecret, Secret senderDataSecret, Secret exporterSecret, Secret confirmationKey, Secret membershipKey, Secret resumptionPSK, Secret epochAuthenticator, Secret encryptionSecret, Secret externalSecret, AsymmetricCipherKeyPair externalKeyPair, GroupKeySet groupKeySet, Secret joinerSecret)
{
this.suite = suite;
this.initSecret = new Secret(initSecret.value());
Expand All @@ -352,7 +340,7 @@ public KeyScheduleEpoch copy()
return new KeyScheduleEpoch(suite, initSecret, senderDataSecret, exporterSecret, confirmationKey, membershipKey, resumptionPSK, epochAuthenticator, encryptionSecret, externalSecret, externalKeyPair, groupKeySet, joinerSecret);
}

public static KeyScheduleEpoch joiner(CipherSuite suite, byte[] joinerSecret, List<PSKWithSecret> psks, byte[] context) throws IOException, IllegalAccessException
public static KeyScheduleEpoch joiner(MlsCipherSuite suite, byte[] joinerSecret, List<PSKWithSecret> psks, byte[] context) throws IOException, IllegalAccessException
{
TreeSize size = TreeSize.forLeaves(1);
JoinSecrets joinSecrets = new JoinSecrets(suite, new Secret(joinerSecret), psks);
Expand All @@ -361,7 +349,7 @@ public static KeyScheduleEpoch joiner(CipherSuite suite, byte[] joinerSecret, Li
}

// ONLY USED BY EXTERNAL JOINER
public KeyScheduleEpoch(CipherSuite suite) throws IOException, IllegalAccessException
public KeyScheduleEpoch(MlsCipherSuite suite) throws IOException, IllegalAccessException
{
this.suite = suite;
this.initSecret = new Secret(new byte[0]);
Expand All @@ -376,7 +364,7 @@ public KeyScheduleEpoch(CipherSuite suite) throws IOException, IllegalAccessExce
this.encryptionSecret = new Secret(new byte[0]);
this.groupKeySet = null;
}
public KeyScheduleEpoch(CipherSuite suite, TreeSize treeSize, Secret epochSecret) throws IOException, IllegalAccessException {
public KeyScheduleEpoch(MlsCipherSuite suite, TreeSize treeSize, Secret epochSecret) throws IOException, IllegalAccessException {
this.suite = suite;
this.initSecret = epochSecret.deriveSecret(suite, "init");
this.senderDataSecret = epochSecret.deriveSecret(suite, "sender data");
Expand All @@ -392,27 +380,8 @@ public KeyScheduleEpoch(CipherSuite suite, TreeSize treeSize, Secret epochSecret
this.encryptionSecret = epochSecret.deriveSecret(suite, "encryption");
this.groupKeySet = new GroupKeySet(suite, treeSize, encryptionSecret);
}
// public KeyScheduleEpoch copy()
// {
// return new KeyScheduleEpoch(suite, groupKeySet.secretTree.treeSize, )
// this.suite = suite;
// this.initSecret = epochSecret.deriveSecret(suite, "init");
// this.senderDataSecret = epochSecret.deriveSecret(suite, "sender data");
// this.exporterSecret = epochSecret.deriveSecret(suite, "exporter");
// this.confirmationKey = epochSecret.deriveSecret(suite, "confirm");
// this.membershipKey = epochSecret.deriveSecret(suite, "membership");
// this.resumptionPSK = epochSecret.deriveSecret(suite, "resumption");
// this.epochAuthenticator = epochSecret.deriveSecret(suite, "authentication");
//
// this.externalSecret = epochSecret.deriveSecret(suite, "external");
// this.externalKeyPair = suite.getHPKE().deriveKeyPair(externalSecret.value());
//
// this.encryptionSecret = epochSecret.deriveSecret(suite, "encryption");
// this.groupKeySet = new GroupKeySet(suite, treeSize, encryptionSecret);
// }


public KeyScheduleEpoch(CipherSuite suite, TreeSize treeSize, Secret joinerSecret, Secret pskSecret, byte[] context) throws IOException, IllegalAccessException {

public KeyScheduleEpoch(MlsCipherSuite suite, TreeSize treeSize, Secret joinerSecret, Secret pskSecret, byte[] context) throws IOException, IllegalAccessException {
this.suite = suite;

Secret memSecret = Secret.extract(suite, joinerSecret, pskSecret);
Expand All @@ -434,27 +403,10 @@ public KeyScheduleEpoch(CipherSuite suite, TreeSize treeSize, Secret joinerSecre
}


public KeyScheduleEpoch next(TreeSize treeSize, byte[] externalInit, Secret commitSecret, List<PSKWithSecret> psks, byte[] context) throws IOException, IllegalAccessException {

/*
Secret currentSecret = initSecret;
if (externalInit != null)
{
currentSecret = new Secret(externalInit);
}
Secret preJoinerSecret = Secret.extract(suite, currentSecret, commitSecret);
Secret joinerSecret = preJoinerSecret.expandWithLabel(suite, "joiner", context, suite.getKDF().getHashLength());
return new KeyScheduleEpoch(this.suite, treeSize, joinerSecret, JoinSecrets.pskSecret(suite, psks), context);
*/
public KeyScheduleEpoch next(TreeSize treeSize, byte[] externalInit, Secret commitSecret, List<PSKWithSecret> psks, byte[] context) throws IOException, IllegalAccessException
{
Secret currInitSecret = initSecret;
if (externalInit != null) {
// final byte[] exportContext = "MLS 1.0 external init secret".getBytes(StandardCharsets.UTF_8);
// final int L = suite.getKDF().getHashLength();
// HPKEContext ctx = suite.getHPKE().setupBaseR(externalInit, externalKeyPair, new byte[0]);
// currInitSecret = new Secret(ctx.export(exportContext, L));
currInitSecret = new Secret(externalInit);
}

Expand Down
10 changes: 5 additions & 5 deletions mls/src/main/java/org/bouncycastle/mls/TranscriptHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import org.bouncycastle.mls.codec.AuthenticatedContent;
import org.bouncycastle.mls.codec.MLSOutputStream;
import org.bouncycastle.mls.crypto.CipherSuite;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.util.Arrays;

import java.io.IOException;

public class TranscriptHash
{

private CipherSuite suite;
private MlsCipherSuite suite;
byte[] confirmed;
byte[] interim;

Expand All @@ -29,20 +29,20 @@ public void setInterim(byte[] interim)
this.interim = interim;
}

public TranscriptHash(CipherSuite suite)
public TranscriptHash(MlsCipherSuite suite)
{
this.suite = suite;
confirmed = new byte[0];
}

public TranscriptHash(CipherSuite suite, byte[] confirmed, byte[] interim)
public TranscriptHash(MlsCipherSuite suite, byte[] confirmed, byte[] interim)
{
this.suite = suite;
this.confirmed = confirmed;
this.interim = interim;
}

static public TranscriptHash fromConfirmationTag(CipherSuite suite, byte[] confirmed, byte[] confirmationTag) throws IOException
static public TranscriptHash fromConfirmationTag(MlsCipherSuite suite, byte[] confirmed, byte[] confirmationTag) throws IOException
{
TranscriptHash out = new TranscriptHash(suite, confirmed.clone(), new byte[0]);
out.updateInterim(confirmationTag);
Expand Down
16 changes: 8 additions & 8 deletions mls/src/main/java/org/bouncycastle/mls/TreeKEM/LeafNode.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package org.bouncycastle.mls.TreeKEM;

import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.protocol.Group;
import org.bouncycastle.mls.codec.Capabilities;
import org.bouncycastle.mls.codec.Credential;
import org.bouncycastle.mls.codec.CredentialType;
import org.bouncycastle.mls.codec.Extension;
import org.bouncycastle.mls.codec.MLSInputStream;
import org.bouncycastle.mls.codec.MLSOutputStream;
import org.bouncycastle.mls.crypto.CipherSuite;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -17,7 +17,7 @@
public class LeafNode
implements MLSInputStream.Readable, MLSOutputStream.Writable
{
CipherSuite suite;
MlsCipherSuite suite;
byte[] encryption_key;
byte[] signature_key;
Credential credential;
Expand All @@ -42,7 +42,7 @@ public CredentialType getCredentialType()
return credential.getCredentialType();
}

public CipherSuite getSuite()
public MlsCipherSuite getSuite()
{
return suite;
}
Expand All @@ -68,7 +68,7 @@ public List<Extension> getExtensions()
}

public LeafNode(
CipherSuite suite,
MlsCipherSuite suite,
byte[] encryption_key,
byte[] signature_key,
Credential credential,
Expand Down Expand Up @@ -201,7 +201,7 @@ public boolean verifyLifetime()
return lifeTime.verify();
}

public boolean verify(CipherSuite suite, byte[] tbs) throws IOException
public boolean verify(MlsCipherSuite suite, byte[] tbs) throws IOException
{
if (getCredentialType() == CredentialType.x509)
{
Expand All @@ -211,7 +211,7 @@ public boolean verify(CipherSuite suite, byte[] tbs) throws IOException
return suite.verifyWithLabel(signature_key, "LeafNodeTBS", tbs, signature);
}

public LeafNode forCommit(CipherSuite suite, byte[] groupId, LeafIndex leafIndex, byte[] encKeyIn, byte[] parentHash, Group.LeafNodeOptions options, byte[] sigPriv) throws Exception
public LeafNode forCommit(MlsCipherSuite suite, byte[] groupId, LeafIndex leafIndex, byte[] encKeyIn, byte[] parentHash, Group.LeafNodeOptions options, byte[] sigPriv) throws Exception
{
LeafNode clone = copyWithOptions(encKeyIn, options);
clone.leaf_node_source = LeafNodeSource.COMMIT;
Expand All @@ -221,7 +221,7 @@ public LeafNode forCommit(CipherSuite suite, byte[] groupId, LeafIndex leafIndex

return clone;
}
public LeafNode forUpdate(CipherSuite suite, byte[] groupId, LeafIndex leafIndex, byte[] encKeyIn, Group.LeafNodeOptions options, byte[] sigPriv) throws Exception
public LeafNode forUpdate(MlsCipherSuite suite, byte[] groupId, LeafIndex leafIndex, byte[] encKeyIn, Group.LeafNodeOptions options, byte[] sigPriv) throws Exception
{
LeafNode clone = copyWithOptions(encKeyIn, options);
clone.leaf_node_source = LeafNodeSource.UPDATE;
Expand All @@ -231,7 +231,7 @@ public LeafNode forUpdate(CipherSuite suite, byte[] groupId, LeafIndex leafIndex
return clone;
}

private void sign(CipherSuite suite, byte[] sigPriv, byte[] tbs) throws Exception
private void sign(MlsCipherSuite suite, byte[] sigPriv, byte[] tbs) throws Exception
{
byte[] sigPub = suite.serializeSignaturePublicKey(suite.deserializeSignaturePrivateKey(sigPriv).getPublic());
if (!Arrays.equals(sigPub, signature_key))
Expand Down
Loading

0 comments on commit 70dc775

Please sign in to comment.