Skip to content

Commit

Permalink
Fixed Group so that it properly handles commits
Browse files Browse the repository at this point in the history
Passes all passive client tests
  • Loading branch information
royb committed Sep 1, 2023
1 parent 6d7b81a commit 1aca8dd
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 107 deletions.
18 changes: 18 additions & 0 deletions core/src/main/java/org/bouncycastle/crypto/hpke/HPKE.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ public HPKE(byte mode, short kemId, short kdfId, short aeadId)
}
}

public int getEncSize()
{
switch (kemId)
{
case HPKE.kem_P256_SHA256:
return 65;
case HPKE.kem_P384_SHA348:
return 97;
case HPKE.kem_P521_SHA512:
return 133;
case HPKE.kem_X25519_SHA256:
return 32;
case HPKE.kem_X448_SHA512:
return 56;
default:
throw new IllegalArgumentException("invalid kem id");
}
}
public short getAeadId()
{
return aeadId;
Expand Down
53 changes: 53 additions & 0 deletions mls/src/main/java/org/bouncycastle/mls/KeyScheduleEpoch.java
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,59 @@ public KeyScheduleEpoch(CipherSuite suite, TreeSize treeSize, Secret epochSecret
this.encryptionSecret = epochSecret.deriveSecret(suite, "encryption");
this.groupKeySet = new GroupKeySet(suite, treeSize, encryptionSecret);
}
// public KeyScheduleEpoch copy()
// {
// return new KeyScheduleEpoch(suite, groupKeySet.secretTree.treeSize, )
// this.suite = suite;
// this.initSecret = epochSecret.deriveSecret(suite, "init");
// this.senderDataSecret = epochSecret.deriveSecret(suite, "sender data");
// this.exporterSecret = epochSecret.deriveSecret(suite, "exporter");
// this.confirmationKey = epochSecret.deriveSecret(suite, "confirm");
// this.membershipKey = epochSecret.deriveSecret(suite, "membership");
// this.resumptionPSK = epochSecret.deriveSecret(suite, "resumption");
// this.epochAuthenticator = epochSecret.deriveSecret(suite, "authentication");
//
// this.externalSecret = epochSecret.deriveSecret(suite, "external");
// this.externalKeyPair = suite.getHPKE().deriveKeyPair(externalSecret.value());
//
// this.encryptionSecret = epochSecret.deriveSecret(suite, "encryption");
// this.groupKeySet = new GroupKeySet(suite, treeSize, encryptionSecret);
// }


public KeyScheduleEpoch(CipherSuite suite, TreeSize treeSize, Secret joinerSecret, Secret pskSecret, byte[] context) throws IOException, IllegalAccessException {
this.suite = suite;

Secret memSecret = Secret.extract(suite, joinerSecret, pskSecret);
Secret epochSecret = memSecret.expandWithLabel(suite, "epoch", context, suite.getKDF().getHashLength());

this.senderDataSecret = epochSecret.deriveSecret(suite, "sender data");
this.encryptionSecret = epochSecret.deriveSecret(suite, "encryption");
this.exporterSecret = epochSecret.deriveSecret(suite, "exporter");
this.epochAuthenticator = epochSecret.deriveSecret(suite, "authentication");
this.externalSecret = epochSecret.deriveSecret(suite, "external");
this.confirmationKey = epochSecret.deriveSecret(suite, "confirm");
this.membershipKey = epochSecret.deriveSecret(suite, "membership");
this.resumptionPSK = epochSecret.deriveSecret(suite, "resumption");
this.initSecret = epochSecret.deriveSecret(suite, "init");

this.externalKeyPair = suite.getHPKE().deriveKeyPair(externalSecret.value());

this.groupKeySet = new GroupKeySet(suite, treeSize, encryptionSecret);
}
public KeyScheduleEpoch nextG(TreeSize treeSize, byte[] externalInit, Secret commitSecret, List<PSKWithSecret> psks, byte[] context) throws IOException, IllegalAccessException {
Secret currentSecret = initSecret;
if (externalInit != null)
{
currentSecret = new Secret(externalInit);
}
Secret preJoinerSecret = Secret.extract(suite, currentSecret, commitSecret);
Secret joinerSecret = preJoinerSecret.expandWithLabel(suite, "joiner", context, suite.getKDF().getHashLength());

return new KeyScheduleEpoch(this.suite, treeSize, joinerSecret, JoinSecrets.pskSecret(suite, psks), context);


}

public KeyScheduleEpoch next(TreeSize treeSize, byte[] externalInit, Secret commitSecret, List<PSKWithSecret> psks, byte[] context) throws IOException, IllegalAccessException {
Secret currInitSecret = initSecret;
Expand Down
17 changes: 16 additions & 1 deletion mls/src/main/java/org/bouncycastle/mls/TranscriptHash.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.bouncycastle.mls;

import org.bouncycastle.mls.codec.AuthenticatedContent;
import org.bouncycastle.mls.codec.MLSOutputStream;
import org.bouncycastle.mls.crypto.CipherSuite;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.encoders.Hex;

import java.io.IOException;

Expand All @@ -22,6 +24,16 @@ public TranscriptHash(CipherSuite suite)
{
this.suite = suite;
}
public TranscriptHash(CipherSuite suite, byte[] confirmed, byte[] interim)
{
this.suite = suite;
this.confirmed = confirmed.clone();
this.interim = interim.clone();
}
public TranscriptHash copy()
{
return new TranscriptHash(suite, confirmed, interim);
}

public void update(AuthenticatedContent auth) throws IOException
{
Expand All @@ -38,10 +50,13 @@ private void updateInterim(AuthenticatedContent auth) throws IOException
{
byte[] transcript = Arrays.concatenate(confirmed, auth.getInterimTranscriptHashInput());
interim = suite.hash(transcript);

}
public void updateInterim(byte[] confirmationTag) throws IOException
{
byte[] transcript = Arrays.concatenate(confirmed, confirmationTag);
MLSOutputStream stream = new MLSOutputStream();
stream.writeOpaque(confirmationTag);
byte[] transcript = Arrays.concatenate(confirmed, stream.toByteArray());
interim = suite.hash(transcript);
}
}
7 changes: 6 additions & 1 deletion mls/src/main/java/org/bouncycastle/mls/TreeKEM/LeafNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public class LeafNode
/* SignWithLabel(., "LeafNodeTBS", LeafNodeTBS) */
public byte[] signature; // not in TBS

public CredentialType getCredentialType()
{
return credential.credentialType;
}

public LeafNode()
{
}
Expand Down Expand Up @@ -142,7 +147,7 @@ public boolean verifyLifetime()

public boolean verify(CipherSuite suite, byte[] tbs) throws IOException
{
if (credential.credentialType == CredentialType.x509)
if (getCredentialType() == CredentialType.x509)
{
//TODO: get credential and check if it's signature scheme matches the cipher suite signature scheme
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public static TreeKEMPrivateKey joiner(TreeKEMPublicKey pub, LeafIndex index, As
NodeIndex intersect, Secret pathSecret) throws IOException
{
TreeKEMPrivateKey priv = new TreeKEMPrivateKey(pub.suite, index);

priv.privateKeyCache.put(new NodeIndex(index), leafPriv);

if (pathSecret != null)
Expand All @@ -74,7 +75,7 @@ public void dump() throws IOException
for (NodeIndex node :
pathSecrets.keySet())
{
setPrivateKey(node);
setPrivateKey(node, true);
}

System.out.println("Tree (priv)");
Expand Down Expand Up @@ -179,7 +180,7 @@ public void decap(LeafIndex from, TreeKEMPublicKey pub, byte[] context, UpdatePa
}

// decrypt and implant
AsymmetricCipherKeyPair priv = getPrivateKey(res.get(resi));
AsymmetricCipherKeyPair priv = setPrivateKey(res.get(resi), false);
HPKECiphertext ct = path.nodes.get(dpi).encrypted_path_secret.get(resi);

Secret pathSecret = new Secret(suite.decryptWithLabel(
Expand All @@ -203,7 +204,7 @@ private boolean havePrivateKey(NodeIndex n)
return pathSecrets.containsKey(n) || privateKeyCache.containsKey(n);
}

public boolean consistent(TreeKEMPublicKey other) throws IOException
public final boolean consistent(TreeKEMPublicKey other) throws IOException
{
if (suite.getSuiteId() != other.suite.getSuiteId())
{
Expand All @@ -212,7 +213,7 @@ public boolean consistent(TreeKEMPublicKey other) throws IOException

for (NodeIndex node : pathSecrets.keySet())
{
setPrivateKey(node);
setPrivateKey(node, true);
}

for (NodeIndex key : privateKeyCache.keySet())
Expand All @@ -233,17 +234,16 @@ public boolean consistent(TreeKEMPublicKey other) throws IOException
return true;
}

private AsymmetricCipherKeyPair setPrivateKey(NodeIndex n) throws IOException
protected AsymmetricCipherKeyPair setPrivateKey(NodeIndex n, boolean isConst) throws IOException
{
AsymmetricCipherKeyPair priv = getPrivateKey(n);
if (priv != null)
if (priv != null && !isConst)
{
//TODO: Why is this adding more than what we want???
privateKeyCache.put(n, priv);
}
return priv;
}
protected AsymmetricCipherKeyPair getPrivateKey(NodeIndex n) throws IOException
private AsymmetricCipherKeyPair getPrivateKey(NodeIndex n) throws IOException
{
if (privateKeyCache.containsKey(n))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public TreeKEMPrivateKey update(LeafIndex from, Secret leafSecret, byte[] groupI
for (NodeIndex n: dp.parents)
{
Secret pathSecret = priv.pathSecrets.get(n);
AsymmetricCipherKeyPair nodePriv = priv.getPrivateKey(n);
AsymmetricCipherKeyPair nodePriv = priv.setPrivateKey(n, false);

pathNodes.add(new UpdatePathNode(suite.getHPKE().serializePublicKey(nodePriv.getPublic()), new ArrayList<>()));
}
Expand All @@ -190,7 +190,7 @@ public TreeKEMPrivateKey update(LeafIndex from, Secret leafSecret, byte[] groupI
ph0 = ph[0];
}

byte[] leafPub = suite.getHPKE().serializePublicKey(priv.getPrivateKey(new NodeIndex(from)).getPublic());
byte[] leafPub = suite.getHPKE().serializePublicKey(priv.setPrivateKey(new NodeIndex(from), false).getPublic());
LeafNode newLeaf = leafNode.getLeafNode().forCommit(suite, groupId, from, leafPub, ph0, sigPriv);

// Merge the changes into the tree
Expand All @@ -206,11 +206,11 @@ public UpdatePath encap(TreeKEMPrivateKey priv, byte[] context, List<LeafIndex>
for (int i = 0; i < dp.parents.size(); i++)
{
NodeIndex n = dp.parents.get(i);
List<NodeIndex> res = dp.resolutions.get(i);
List<NodeIndex> res = (List<NodeIndex>) dp.resolutions.get(i).clone();
removeLeaves(res, except);

Secret pathSecret = priv.pathSecrets.get(n);
AsymmetricCipherKeyPair nodePriv = priv.getPrivateKey(n);
AsymmetricCipherKeyPair nodePriv = priv.setPrivateKey(n, false);

List<HPKECiphertext> cts = new ArrayList<>();
for (NodeIndex nr: res)
Expand Down Expand Up @@ -408,7 +408,7 @@ public LeafIndex addLeaf(LeafNode leaf)
List<NodeIndex> dp = index.directPath(size);

// Update the unmerged list
for (NodeIndex n : index.directPath(size))
for (NodeIndex n : dp)
{
if (nodeAt(n).node == null)
{
Expand Down
Loading

0 comments on commit 1aca8dd

Please sign in to comment.