Skip to content

Commit

Permalink
More integration fixes (#365)
Browse files Browse the repository at this point in the history
* Properly handle one-member trees without the one member in leaf 0

* Move uniqueness checking to tree operations

* Rename WireFormat enum values to match RFC

* clang-format

* clang-tidy

---------

Co-authored-by: Richard Barnes <richbarn@cisco.com>
  • Loading branch information
bifurcation and Richard Barnes committed Sep 9, 2023
1 parent 572be18 commit a684f2b
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 103 deletions.
12 changes: 6 additions & 6 deletions include/mls/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ struct GroupContext;
enum struct WireFormat : uint16_t
{
reserved = 0,
mls_plaintext = 1,
mls_ciphertext = 2,
mls_public_message = 1,
mls_private_message = 2,
mls_welcome = 3,
mls_group_info = 4,
mls_key_package = 5,
Expand Down Expand Up @@ -653,8 +653,8 @@ struct MLSMessage
WireFormat wire_format() const;

MLSMessage() = default;
MLSMessage(PublicMessage mls_plaintext);
MLSMessage(PrivateMessage mls_ciphertext);
MLSMessage(PublicMessage public_message);
MLSMessage(PrivateMessage private_message);
MLSMessage(Welcome welcome);
MLSMessage(GroupInfo group_info);
MLSMessage(KeyPackage key_package);
Expand Down Expand Up @@ -718,10 +718,10 @@ TLS_VARIANT_MAP(MLS_NAMESPACE::SenderType,

TLS_VARIANT_MAP(MLS_NAMESPACE::WireFormat,
MLS_NAMESPACE::PublicMessage,
mls_plaintext)
mls_public_message)
TLS_VARIANT_MAP(MLS_NAMESPACE::WireFormat,
MLS_NAMESPACE::PrivateMessage,
mls_ciphertext)
mls_private_message)
TLS_VARIANT_MAP(MLS_NAMESPACE::WireFormat, MLS_NAMESPACE::Welcome, mls_welcome)
TLS_VARIANT_MAP(MLS_NAMESPACE::WireFormat,
MLS_NAMESPACE::GroupInfo,
Expand Down
40 changes: 40 additions & 0 deletions include/mls/treekem.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct TreeKEMPublicKey
TreeKEMPublicKey& operator=(const TreeKEMPublicKey& other) = default;
TreeKEMPublicKey& operator=(TreeKEMPublicKey&& other) = default;

LeafIndex allocate_leaf();
LeafIndex add_leaf(const LeafNode& leaf);
void update_leaf(LeafIndex index, const LeafNode& leaf);
void blank_path(LeafIndex index);
Expand All @@ -149,6 +150,40 @@ struct TreeKEMPublicKey
std::optional<LeafNode> leaf_node(LeafIndex index) const;
std::vector<NodeIndex> resolve(NodeIndex index) const;

template<typename UnaryPredicate>
bool all_leaves(const UnaryPredicate& pred) const
{
for (LeafIndex i{ 0 }; i < size; i.val++) {
const auto& node = node_at(i);
if (node.blank()) {
continue;
}

if (!pred(i, node.leaf_node())) {
return false;
}
}

return true;
}

template<typename UnaryPredicate>
bool any_leaf(const UnaryPredicate& pred) const
{
for (LeafIndex i{ 0 }; i < size; i.val++) {
const auto& node = node_at(i);
if (node.blank()) {
continue;
}

if (pred(i, node.leaf_node())) {
return true;
}
}

return false;
}

using FilteredDirectPath =
std::vector<std::tuple<NodeIndex, std::vector<NodeIndex>>>;
FilteredDirectPath filtered_direct_path(NodeIndex index) const;
Expand Down Expand Up @@ -188,6 +223,11 @@ struct TreeKEMPublicKey
NodeIndex parent,
NodeIndex sibling) const;

bool exists_in_tree(const HPKEPublicKey& key,
std::optional<LeafIndex> except) const;
bool exists_in_tree(const SignaturePublicKey& key,
std::optional<LeafIndex> except) const;

OptionalNode blank_node;

friend struct TreeKEMPrivateKey;
Expand Down
41 changes: 27 additions & 14 deletions lib/mls_vectors/src/mls_vectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ MessageProtectionTestVector::protect_pub(
auto content =
GroupContent{ group_id, epoch, sender, authenticated_data, raw_content };

auto auth_content = AuthenticatedContent::sign(WireFormat::mls_plaintext,
auto auth_content = AuthenticatedContent::sign(WireFormat::mls_public_message,
content,
cipher_suite,
signature_priv,
Expand All @@ -853,11 +853,12 @@ MessageProtectionTestVector::protect_priv(
auto content =
GroupContent{ group_id, epoch, sender, authenticated_data, raw_content };

auto auth_content = AuthenticatedContent::sign(WireFormat::mls_ciphertext,
content,
cipher_suite,
signature_priv,
group_context());
auto auth_content =
AuthenticatedContent::sign(WireFormat::mls_private_message,
content,
cipher_suite,
signature_priv,
group_context());
if (content.content_type() == ContentType::commit) {
auto confirmation_tag = prg.secret("confirmation_tag");
auth_content.set_confirmation_tag(confirmation_tag);
Expand Down Expand Up @@ -973,7 +974,7 @@ TranscriptTestVector::TranscriptTestVector(CipherSuite suite)
auto leaf_index = LeafIndex{ 0 };

authenticated_content = AuthenticatedContent::sign(
WireFormat::mls_plaintext,
WireFormat::mls_public_message,
GroupContent{
group_id, epoch, { MemberSender{ leaf_index } }, {}, Commit{} },
suite,
Expand Down Expand Up @@ -1812,11 +1813,15 @@ MessagesTestVector::MessagesTestVector()

auto version = ProtocolVersion::mls10;
auto hpke_priv = prg.hpke_key("hpke_priv");
auto hpke_priv_2 = prg.hpke_key("hpke_priv_2");
auto hpke_pub = hpke_priv.public_key;
auto hpke_pub_2 = hpke_priv_2.public_key;
auto hpke_ct =
HPKECiphertext{ prg.secret("kem_output"), prg.secret("ciphertext") };
auto sig_priv = prg.signature_key("signature_priv");
auto sig_priv_2 = prg.signature_key("signature_priv_2");
auto sig_pub = sig_priv.public_key;
auto sig_pub_2 = sig_priv_2.public_key;

// KeyPackage and extensions
auto cred = Credential::basic(user_id);
Expand All @@ -1828,6 +1833,14 @@ MessagesTestVector::MessagesTestVector()
Lifetime::create_default(),
ext_list,
sig_priv };
auto leaf_node_2 = LeafNode{ suite,
hpke_pub_2,
sig_pub_2,
cred,
Capabilities::create_default(),
Lifetime::create_default(),
ext_list,
sig_priv_2 };
auto key_package_obj = KeyPackage{ suite, hpke_pub, leaf_node, {}, sig_priv };

auto leaf_node_update =
Expand All @@ -1839,7 +1852,7 @@ MessagesTestVector::MessagesTestVector()

auto tree = TreeKEMPublicKey{ suite };
tree.add_leaf(leaf_node);
tree.add_leaf(leaf_node);
tree.add_leaf(leaf_node_2);
auto ratchet_tree_obj = RatchetTreeExtension{ tree };

// Welcome and its substituents
Expand Down Expand Up @@ -1886,7 +1899,7 @@ MessagesTestVector::MessagesTestVector()
auto membership_key = prg.secret("membership_key");

auto content_auth_proposal = AuthenticatedContent::sign(
WireFormat::mls_plaintext,
WireFormat::mls_public_message,
{ group_id, epoch, sender, {}, Proposal{ remove } },
suite,
sig_priv,
Expand All @@ -1895,7 +1908,7 @@ MessagesTestVector::MessagesTestVector()
content_auth_proposal, suite, membership_key, group_context);

auto content_auth_commit =
AuthenticatedContent::sign(WireFormat::mls_plaintext,
AuthenticatedContent::sign(WireFormat::mls_public_message,
{ group_id, epoch, sender, {}, commit_obj },
suite,
sig_priv,
Expand All @@ -1906,7 +1919,7 @@ MessagesTestVector::MessagesTestVector()

// PrivateMessage
auto content_auth_application_obj = AuthenticatedContent::sign(
WireFormat::mls_ciphertext,
WireFormat::mls_private_message,
{ group_id, epoch, sender, {}, ApplicationData{} },
suite,
sig_priv,
Expand Down Expand Up @@ -1982,15 +1995,15 @@ MessagesTestVector::verify() const
VERIFY_TLS_RTT_VAL("Public(Proposal)",
MLSMessage,
public_message_proposal,
require_format(WireFormat::mls_plaintext));
require_format(WireFormat::mls_public_message));
VERIFY_TLS_RTT_VAL("Public(Commit)",
MLSMessage,
public_message_commit,
require_format(WireFormat::mls_plaintext));
require_format(WireFormat::mls_public_message));
VERIFY_TLS_RTT_VAL("PrivateMessage",
MLSMessage,
private_message,
require_format(WireFormat::mls_ciphertext));
require_format(WireFormat::mls_private_message));

return std::nullopt;
}
Expand Down
24 changes: 12 additions & 12 deletions src/messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ AuthenticatedContent::sign(WireFormat wire_format,
const SignaturePrivateKey& sig_priv,
const std::optional<GroupContext>& context)
{
if (wire_format == WireFormat::mls_plaintext &&
if (wire_format == WireFormat::mls_public_message &&
content.content_type() == ContentType::application) {
throw InvalidParameterError(
"Application data cannot be sent as PublicMessage");
Expand All @@ -369,7 +369,7 @@ AuthenticatedContent::verify(CipherSuite suite,
const SignaturePublicKey& sig_pub,
const std::optional<GroupContext>& context) const
{
if (wire_format == WireFormat::mls_plaintext &&
if (wire_format == WireFormat::mls_public_message &&
content.content_type() == ContentType::application) {
return false;
}
Expand Down Expand Up @@ -545,7 +545,7 @@ PublicMessage::unprotect(CipherSuite suite,
}

return AuthenticatedContent{
WireFormat::mls_plaintext,
WireFormat::mls_public_message,
content,
auth,
};
Expand All @@ -561,7 +561,7 @@ AuthenticatedContent
PublicMessage::authenticated_content() const
{
auto auth_content = AuthenticatedContent{};
auth_content.wire_format = WireFormat::mls_plaintext;
auth_content.wire_format = WireFormat::mls_public_message;
auth_content.content = content;
auth_content.auth = auth;
return auth_content;
Expand All @@ -571,7 +571,7 @@ PublicMessage::PublicMessage(AuthenticatedContent content_auth)
: content(std::move(content_auth.content))
, auth(std::move(content_auth.auth))
{
if (content_auth.wire_format != WireFormat::mls_plaintext) {
if (content_auth.wire_format != WireFormat::mls_public_message) {
throw InvalidParameterError("Wire format mismatch (not mls_plaintext)");
}
}
Expand All @@ -590,7 +590,7 @@ PublicMessage::membership_mac(CipherSuite suite,
const std::optional<GroupContext>& context) const
{
auto tbm = tls::marshal(GroupContentTBM{
{ WireFormat::mls_plaintext, content, context },
{ WireFormat::mls_public_message, content, context },
auth,
});

Expand Down Expand Up @@ -813,7 +813,7 @@ PrivateMessage::unprotect(CipherSuite suite,
unmarshal_ciphertext_content(opt::get(content_pt), content, auth);

return AuthenticatedContent{
WireFormat::mls_ciphertext,
WireFormat::mls_private_message,
std::move(content),
std::move(auth),
};
Expand Down Expand Up @@ -851,13 +851,13 @@ MLSMessage::wire_format() const
return tls::variant<WireFormat>::type(message);
}

MLSMessage::MLSMessage(PublicMessage mls_plaintext)
: message(std::move(mls_plaintext))
MLSMessage::MLSMessage(PublicMessage public_message)
: message(std::move(public_message))
{
}

MLSMessage::MLSMessage(PrivateMessage mls_ciphertext)
: message(std::move(mls_ciphertext))
MLSMessage::MLSMessage(PrivateMessage private_message)
: message(std::move(private_message))
{
}

Expand Down Expand Up @@ -906,7 +906,7 @@ external_proposal(CipherSuite suite,
{ /* no authenticated data */ },
{ proposal } };
auto content_auth = AuthenticatedContent::sign(
WireFormat::mls_plaintext, std::move(content), suite, sig_priv, {});
WireFormat::mls_public_message, std::move(content), suite, sig_priv, {});

return PublicMessage::protect(std::move(content_auth), suite, {}, {});
}
Expand Down
4 changes: 2 additions & 2 deletions src/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ Session::Inner::import_handshake(const bytes& encoded) const
auto msg = tls::get<MLSMessage>(encoded);

switch (msg.wire_format()) {
case WireFormat::mls_plaintext:
case WireFormat::mls_public_message:
if (encrypt_handshake) {
throw ProtocolError("Handshake not encrypted as required");
}

return msg;

case WireFormat::mls_ciphertext: {
case WireFormat::mls_private_message: {
if (!encrypt_handshake) {
throw ProtocolError("Unexpected handshake encryption");
}
Expand Down
Loading

0 comments on commit a684f2b

Please sign in to comment.