Skip to content

Commit

Permalink
Removed completed TODOS, made tree dumps return String rather than pr…
Browse files Browse the repository at this point in the history
…inting
  • Loading branch information
royb committed Jan 26, 2024
1 parent 2ea0147 commit ed2132d
Show file tree
Hide file tree
Showing 17 changed files with 58 additions and 158 deletions.
11 changes: 2 additions & 9 deletions mls/src/main/java/org/bouncycastle/mls/KeyScheduleEpoch.java
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,6 @@ public static Secret pskSecret(MlsCipherSuite suite, List<PSKWithSecret> psks)
joiner_secret
*/
// public static JoinSecrets forMember(CipherSuite suite, Secret initSecret, Secret commitSecret, List<PSKWithSecret> psks, byte[] context) throws IOException
// {
// Secret preJoinerSecret = Secret.extract(suite, initSecret, commitSecret);
// Secret joinerSecret = preJoinerSecret.expandWithLabel(suite, "joiner", context, suite.getKDF().getHashLength());
// return new JoinSecrets(suite, joinerSecret, psks);
// }
// todo change to
public static JoinSecrets forMember(MlsCipherSuite suite, Secret initSecret, Secret commitSecret, Secret pskSecret, byte[] context)
throws IOException
{
Expand Down Expand Up @@ -329,13 +322,13 @@ public static KeyScheduleEpoch forCreator(MlsCipherSuite suite, SecureRandom rng
public static KeyScheduleEpoch forExternalJoiner(MlsCipherSuite suite, TreeSize treeSize, ExternalInitParams externalInitParams, Secret commitSecret, List<PSKWithSecret> psks, byte[] context)
throws IOException, IllegalAccessException
{
return JoinSecrets.forMember(suite, externalInitParams.initSecret, commitSecret, JoinSecrets.pskSecret(suite, psks), context).complete(treeSize, context); //TODO external: pskSecret OR new byte [0]
return JoinSecrets.forMember(suite, externalInitParams.initSecret, commitSecret, JoinSecrets.pskSecret(suite, psks), context).complete(treeSize, context);
}

public JoinSecrets startCommit(Secret commitSecret, List<PSKWithSecret> psks, byte[] context)
throws IOException
{
return JoinSecrets.forMember(suite, initSecret, commitSecret, JoinSecrets.pskSecret(suite, psks), context);//TODO: pskSecret OR new byte [0]
return JoinSecrets.forMember(suite, initSecret, commitSecret, JoinSecrets.pskSecret(suite, psks), context);
}

/*
Expand Down
4 changes: 0 additions & 4 deletions mls/src/main/java/org/bouncycastle/mls/TreeKEM/LeafIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ public class LeafIndex
implements MLSInputStream.Readable, MLSOutputStream.Writable
{
protected int value;


//TODO: make a setter for value
//TODO: make this an int not a long?
public int value()
{
return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ public boolean verifyExtensionSupport(List<Extension> extensions)

public boolean verifyLifetime()
{
//TODO: check
if (leaf_node_source != LeafNodeSource.KEY_PACKAGE)
{
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.bouncycastle.mls.codec.UpdatePath;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;
import org.bouncycastle.util.Strings;
import org.bouncycastle.util.encoders.Hex;

import static org.bouncycastle.mls.TreeKEM.Utils.removeLeaves;
Expand Down Expand Up @@ -68,7 +69,7 @@ public static TreeKEMPrivateKey create(TreeKEMPublicKey pub, LeafIndex from, Sec
throws Exception
{
TreeKEMPrivateKey priv = new TreeKEMPrivateKey(pub.suite, from);
priv.implant(pub, new NodeIndex(from), leafSecret);//todo check
priv.implant(pub, new NodeIndex(from), leafSecret);
return priv;
}

Expand All @@ -87,45 +88,44 @@ public static TreeKEMPrivateKey joiner(TreeKEMPublicKey pub, LeafIndex index, As
return priv;
}

public void dump()
public String dump()
throws IOException
{
StringBuilder sb = new StringBuilder();
for (NodeIndex node :
pathSecrets.keySet())
{
setPrivateKey(node, true);
}

// -DM System.out.println
System.out.println("Tree (priv)");
// -DM System.out.println
System.out.println(" Index: " + (new NodeIndex(index)).value());
sb.append("Tree (priv)").append(Strings.lineSeparator());
sb.append(" Index: ").append((new NodeIndex(index)).value()).append(Strings.lineSeparator());

// -DM System.out.println
System.out.println(" Secrets: ");
sb.append(" Secrets: ").append(Strings.lineSeparator());
for (NodeIndex n : pathSecrets.keySet())
{
Secret pathSecret = pathSecrets.get(n);
Secret nodeSecret = pathSecret.deriveSecret(suite, "node");
AsymmetricCipherKeyPair sk = suite.getHPKE().deriveKeyPair(nodeSecret.value());

// -DM System.out.println
// -DM Hex.toHexString
// -DM Hex.toHexString
System.out.println(" " + n.value()
+ " => " + Hex.toHexString(pathSecret.value(), 0, 4)
+ " => " + Hex.toHexString(suite.getHPKE().serializePublicKey(sk.getPublic()), 0, 4));
sb.append(" ").append(n.value())
.append(" => ").append(Hex.toHexString(pathSecret.value(), 0, 4))
.append(" => ").append(Hex.toHexString(suite.getHPKE().serializePublicKey(sk.getPublic()), 0, 4))
.append(Strings.lineSeparator());
}

// -DM System.out.println
System.out.println(" Cached key pairs: ");
sb.append(" Cached key pairs: ").append(Strings.lineSeparator());
for (NodeIndex n : privateKeyCache.keySet())
{
AsymmetricCipherKeyPair sk = privateKeyCache.get(n);
// -DM System.out.println
// -DM Hex.toHexString
System.out.println(" " + n.value() + " => " + Hex.toHexString(suite.getHPKE().serializePublicKey(sk.getPublic()), 0, 4));
sb.append(" ").append(n.value()).append(" => ")
.append(Hex.toHexString(suite.getHPKE().serializePublicKey(sk.getPublic()), 0, 4))
.append(Strings.lineSeparator());
}
return sb.toString();
}

public void truncate(TreeSize size)
Expand Down Expand Up @@ -253,7 +253,6 @@ public final boolean consistent(TreeKEMPublicKey other)
}
byte[] pub = optNode.getPublicKey();
AsymmetricCipherKeyPair priv = privateKeyCache.get(key);
// todo maybe i have to initilize the public keys for testing
if (!Arrays.equals(pub, suite.getHPKE().serializePublicKey(priv.getPublic())))
{
return false;
Expand Down Expand Up @@ -310,7 +309,6 @@ private void implant(TreeKEMPublicKey pub, NodeIndex start, Secret pathSecret)

public Secret getSharedPathSecret(LeafIndex to)
{
//TODO: make a triplet class
NodeIndex n = index.commonAncestor(to);
if (!pathSecrets.containsKey(n))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;
import org.bouncycastle.mls.protocol.Group;
import org.bouncycastle.util.Strings;
import org.bouncycastle.util.encoders.Hex;

import static org.bouncycastle.mls.TreeKEM.Utils.removeLeaves;
Expand Down Expand Up @@ -139,80 +140,68 @@ public void setSuite(MlsCipherSuite suite)
this.suite = suite;
}

public void dumpHashes()
public String dumpHashes()
{
StringBuilder sb = new StringBuilder();
for (NodeIndex n : hashes.keySet())
{
// -DM System.out.println
System.out.print(n.value() + " : ");
// -DM System.out.println
sb.append(n.value()).append(" : ");
// -DM Hex.toHexString
System.out.println(Hex.toHexString(hashes.get(n)));
sb.append(Hex.toHexString(hashes.get(n))).append(Strings.lineSeparator());
}
return sb.toString();
}

public void dump()
public String dump()
{
// -DM System.out.println
System.out.println("Tree:");
StringBuilder sb = new StringBuilder();
sb.append("Tree:").append(Strings.lineSeparator());
for (int i = 0; i < size.width(); i++)
{
NodeIndex index = new NodeIndex(i);
// -DM System.out.println
System.out.printf(" %03d : ", i);
sb.append(String.format(" %03d : ", i));
if (!nodeAt(index).isBlank())
{
byte[] pkRm = nodeAt(index).node.getPublicKey();
// -DM System.out.println
// -DM Hex.toHexString
System.out.print(Hex.toHexString(pkRm, 0, 4));
sb.append(Hex.toHexString(pkRm, 0, 4));
}
else
{
// -DM System.out.println
System.out.print(" ");
sb.append(" ");
}

// -DM System.out.println
System.out.print(" | ");
sb.append(" | ");
for (int j = 0; j < index.level(); j++)
{
// -DM System.out.println
System.out.print(" ");
sb.append(" ");
}

if (!nodeAt(index).isBlank())
{
// -DM System.out.println
System.out.print("X");
sb.append("X");

if (!index.isLeaf())
{
ParentNode parent = nodeAt(index).getParentNode();
// -DM System.out.println
System.out.print(" [");
sb.append(" [");
for (LeafIndex u : parent.unmerged_leaves)
{
// -DM System.out.println
System.out.print(u.value + ", ");
sb.append(u.value).append( ", ");
}
// -DM System.out.println
System.out.print("]");
sb.append("]");
}
}
else
{
// -DM System.out.println
System.out.print("_");
sb.append("_");
}
// -DM System.out.println
System.out.println();
sb.append(Strings.lineSeparator());
}
// -DM System.out.println
System.out.println("nodeCount: " + nodes.size());
sb.append("nodeCount: ").append(nodes.size()).append(Strings.lineSeparator());
return sb.toString();
}

//TODO: include leaf node options
public TreeKEMPrivateKey update(LeafIndex from, Secret leafSecret, byte[] groupId, byte[] sigPriv, Group.LeafNodeOptions options)
throws Exception
{
Expand Down Expand Up @@ -608,7 +597,6 @@ public void truncate()
// Remove the right subtree until the tree is of minimal size
while (size.leafCount() / 2 > index.value)
{
//TODO: better way of clearing from index to end
nodes.subList(nodes.size() / 2, nodes.size()).clear();
size = TreeSize.forLeaves(size.leafCount() / 2);
}
Expand Down Expand Up @@ -744,7 +732,6 @@ private byte[] originalTreeHash(NodeIndex index, List<LeafIndex> parentExcept)
throws IOException
{
List<LeafIndex> except = new ArrayList<LeafIndex>();
//TODO: check adding sequence
for (LeafIndex i : parentExcept)
{
NodeIndex n = new NodeIndex(i);
Expand Down Expand Up @@ -796,22 +783,8 @@ private byte[] originalTreeHash(NodeIndex index, List<LeafIndex> parentExcept)
{
parentHashInput.parentNode = nodeAt(index).getParentNode();

//TODO: check which one runs faster (assuming the latter)
List<LeafIndex> unmergedOriginal = new ArrayList<LeafIndex>(parentHashInput.parentNode.unmerged_leaves);
parentHashInput.parentNode.unmerged_leaves.removeAll(except);
// int end = parentHashInput.parentNode.unmerged_leaves.size();
// for (LeafIndex leaf: parentHashInput.parentNode.unmerged_leaves)
// {
// if (except.contains(leaf))
// {
// end--;
// }
// else
// {
// break;
// }
// }
// parentHashInput.parentNode.unmerged_leaves = parentHashInput.parentNode.unmerged_leaves.subList(0, end);

hash = suite.hash(MLSOutputStream.encode(TreeHashInput.forParentNode(parentHashInput)));

Expand Down
18 changes: 0 additions & 18 deletions mls/src/main/java/org/bouncycastle/mls/client/MLSClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,18 @@ private static String getCallerMethodName()

private static <T> void catchWrap(Function f, StreamObserver<T> observer)
{
// System.out.println("Executing function: " + getCallerMethodName());
try
{
f.run();
}
catch (Exception e)
{
// System.out.println(e.getMessage());
// for (StackTraceElement elm : e.getStackTrace())
// {
// System.out.println(elm);
// }
// System.exit(0);
observer.onError(Status.INTERNAL.withDescription(e.getMessage()).asException());
}
}

private <T> void stateWrap(FunctionWithState f, MessageOrBuilder request, StreamObserver<T> observer)
{
// System.out.println("Executing function: " + getCallerMethodName());

int stateID = (int)request.getField(request.getDescriptorForType().findFieldByName("state_id"));
CachedGroup group = loadGroup(stateID);
if (group == null)
Expand All @@ -157,11 +148,6 @@ private <T> void stateWrap(FunctionWithState f, MessageOrBuilder request, Stream
}
catch (Exception e)
{
// System.out.println(e.getMessage());
// for (StackTraceElement elm : e.getStackTrace())
// {
// System.out.println(elm);
// }
observer.onError(Status.INTERNAL.withDescription(e.getMessage()).asException());
}
}
Expand Down Expand Up @@ -281,8 +267,6 @@ private KeyPackageWithSecrets newKeyPackage(MlsCipherSuite suite, byte[] identit
suite.serializeSignaturePrivateKey(sigKeyPair.getPrivate())
);
return new KeyPackageWithSecrets(initKeyPair, encryptionKeyPair, sigKeyPair, kp);
//TODO: cache transactions?
// return key package as byte string/array?
}

private LeafIndex findMember(TreeKEMPublicKey tree, byte[] id)
Expand Down Expand Up @@ -661,7 +645,6 @@ private void externalJoinImpl(MlsClient.ExternalJoinRequest request, StreamObser
}
if (ratchetTree != null)
{
//TODO: check if it should be a deep copy
outTree = TreeKEMPublicKey.clone(ratchetTree);
}
else if (outTree == null)
Expand Down Expand Up @@ -1219,7 +1202,6 @@ private void commitImpl(CachedGroup entry, MlsClient.CommitRequest request, Stre
gwm.message.wireFormat = WireFormat.mls_welcome;
byte[] welcomeBytes = MLSOutputStream.encode(gwm.message);

//TODO: check if entry.group should/shouldn't be replaced by commit()
int nextID = storeGroup(gwm.group, entry.encryptHandshake);

entry.pendingCommit = commitBytes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,26 @@ public byte[] getInterimTranscriptHashInput()
return MLSOutputStream.encode(new InterimTranscriptHashInput(auth.confirmation_tag));
}

public AuthenticatedContent(WireFormat wireFormat, FramedContent content, FramedContentAuthData auth)
public AuthenticatedContent(WireFormat wireFormat, FramedContent content, FramedContentAuthData auth) throws Exception
{
this.wireFormat = wireFormat;
this.content = content;
this.auth = auth;

//TODO move elsewhere?
if (auth.contentType == ContentType.COMMIT && auth.confirmation_tag == null)
{
//TODO
// throw new MissingConfirmationTag()
throw new Exception("missing confirmation tag");
}

if (auth.contentType == ContentType.APPLICATION)
{
if (wireFormat != WireFormat.mls_private_message)
{
// throw new UnencryptedApplicationMessage()
throw new Exception("Unencrypted application message");
}
else if (content.sender.senderType != SenderType.MEMBER)
{
// throw new Exception("sender must be a member")
throw new Exception("sender must be a member");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ public List<Short> getExtensions()

public Capabilities()
{
//TODO: make default to support all
versions = Arrays.asList(DEFAULT_SUPPORTED_VERSIONS);
cipherSuites = new ArrayList<Short>();
for (short suite : DEFAULT_SUPPORTED_CIPHERSUITES)
Expand Down
Loading

0 comments on commit ed2132d

Please sign in to comment.