blob: fdde6907e0aa13f71d81b330c037cf0c0b6f512d [file] [edit]
#include <sframe/sframe.h>
#include "crypto.h"
#include "header.h"
#include <algorithm>
#include <array>
#include <iomanip>
#include <iostream>
#include <stdexcept>
#include <tuple>
namespace sframe {
std::ostream&
operator<<(std::ostream& str, const input_bytes data)
{
str.flags(std::ios::hex);
for (const auto& byte : data) {
str << std::setw(2) << std::setfill('0') << int(byte);
}
return str;
}
///
/// Errors
///
unsupported_ciphersuite_error::unsupported_ciphersuite_error()
: std::runtime_error("Unsupported ciphersuite")
{
}
authentication_error::authentication_error()
: std::runtime_error("AEAD authentication failure")
{
}
///
/// ContextBase
///
ContextBase::ContextBase(CipherSuite suite_in)
: suite(suite_in)
{
}
ContextBase::~ContextBase() = default;
void
ContextBase::add_key(KeyID key_id, const bytes& base_key)
{
keys.emplace(key_id, KeyAndSalt::from_base_key(suite, base_key));
}
static bytes
form_nonce(Counter ctr, const bytes& salt)
{
auto nonce = salt;
for (size_t i = 0; i < sizeof(ctr); i++) {
nonce[nonce.size() - i - 1] ^= uint8_t(ctr >> (8 * i));
}
return nonce;
}
output_bytes
ContextBase::protect(const Header& header,
output_bytes ciphertext,
input_bytes plaintext)
{
if (ciphertext.size() < plaintext.size() + overhead(suite)) {
throw buffer_too_small_error("Ciphertext too small for cipher overhead");
}
const auto& key_and_salt = keys.at(header.key_id);
const auto aad = header.encoded();
const auto nonce = form_nonce(header.counter, key_and_salt.salt);
return seal(suite, key_and_salt.key, nonce, ciphertext, aad, plaintext);
}
output_bytes
ContextBase::unprotect(const Header& header,
output_bytes plaintext,
input_bytes ciphertext)
{
if (ciphertext.size() < overhead(suite)) {
throw buffer_too_small_error("Ciphertext too small for cipher overhead");
}
if (plaintext.size() < ciphertext.size() - overhead(suite)) {
throw buffer_too_small_error("Plaintext too small for decrypted value");
}
const auto& key_and_salt = keys.at(header.key_id);
const auto aad = header.encoded();
const auto nonce = form_nonce(header.counter, key_and_salt.salt);
return open(suite, key_and_salt.key, nonce, plaintext, aad, ciphertext);
}
static const bytes sframe_label{
0x53, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x31, 0x30 // "ContextBase10"
};
static const bytes sframe_key_label{ 0x6b, 0x65, 0x79 }; // "key"
static const bytes sframe_salt_label{ 0x73, 0x61, 0x6c, 0x74 }; // "salt"
static const bytes sframe_ctr_label{
// "ContextBase10 AES CM AEAD"
0x53, 0x46, 0x72, 0x61, 0x6d, 0x65, 0x31, 0x30, 0x20, 0x41,
0x45, 0x53, 0x20, 0x43, 0x4d, 0x20, 0x41, 0x45, 0x41, 0x44,
};
static const bytes sframe_enc_label{ 0x65, 0x6e, 0x63 }; // "enc"
static const bytes sframe_auth_label{ 0x61, 0x75, 0x74, 0x68 }; // "auth"
ContextBase::KeyAndSalt
ContextBase::KeyAndSalt::from_base_key(CipherSuite suite, const bytes& base_key)
{
auto key_size = cipher_key_size(suite);
auto nonce_size = cipher_nonce_size(suite);
auto hash_size = cipher_digest_size(suite);
auto secret = hkdf_extract(suite, sframe_label, base_key);
auto key = hkdf_expand(suite, secret, sframe_key_label, key_size);
auto salt = hkdf_expand(suite, secret, sframe_salt_label, nonce_size);
// If using CTR+HMAC, set key = enc_key || auth_key
if (suite == CipherSuite::AES_CM_128_HMAC_SHA256_4 ||
suite == CipherSuite::AES_CM_128_HMAC_SHA256_8) {
secret = hkdf_extract(suite, sframe_ctr_label, key);
auto main_key = key;
auto enc_key = hkdf_expand(suite, secret, sframe_enc_label, key_size);
auto auth_key = hkdf_expand(suite, secret, sframe_auth_label, hash_size);
key = enc_key;
key.insert(key.end(), auth_key.begin(), auth_key.end());
}
return KeyAndSalt{ key, salt, 0 };
}
///
/// Context
///
Context::Context(CipherSuite suite_in)
: ContextBase(suite_in)
{
}
Context::~Context() = default;
void
Context::add_key(KeyID key_id, const bytes& base_key)
{
ContextBase::add_key(key_id, base_key);
counters.emplace(key_id, 0);
}
output_bytes
Context::protect(KeyID key_id, output_bytes ciphertext, input_bytes plaintext)
{
const auto counter = counters.at(key_id);
counters.at(key_id) += 1;
const auto header = Header{ key_id, counter };
const auto aad = header.encoded();
if (ciphertext.size() < aad.size()) {
throw buffer_too_small_error("Ciphertext too small for SFrame header");
}
std::copy(aad.begin(), aad.end(), ciphertext.begin());
auto inner_ciphertext = ciphertext.subspan(aad.size());
auto final_ciphertext =
ContextBase::protect(header, inner_ciphertext, plaintext);
return ciphertext.subspan(0, aad.size() + final_ciphertext.size());
}
output_bytes
Context::unprotect(output_bytes plaintext, input_bytes ciphertext)
{
const auto header = Header::parse(ciphertext);
const auto inner_ciphertext = ciphertext.subspan(header.size());
return ContextBase::unprotect(header, plaintext, inner_ciphertext);
}
///
/// MLSContext
///
MLSContext::MLSContext(CipherSuite suite_in, size_t epoch_bits_in)
: Context(suite_in)
, epoch_bits(epoch_bits_in)
, epoch_mask((size_t(1) << epoch_bits_in) - 1)
, epoch_cache(size_t(1) << epoch_bits_in)
{
std::for_each(epoch_cache.begin(),
epoch_cache.end(),
[&](std::unique_ptr<EpochKeys>& ptr) { ptr.reset(nullptr); });
}
void
MLSContext::add_epoch(EpochID epoch_id, const bytes& sframe_epoch_secret)
{
add_epoch(epoch_id, sframe_epoch_secret, 0);
}
void
MLSContext::add_epoch(EpochID epoch_id,
const bytes& sframe_epoch_secret,
size_t sender_bits)
{
auto epoch_index = epoch_id & epoch_mask;
auto& epoch = epoch_cache.at(epoch_index);
if (epoch) {
purge_epoch(epoch->full_epoch);
}
epoch.reset(
new EpochKeys(epoch_id, sframe_epoch_secret, epoch_bits, sender_bits));
}
void
MLSContext::purge_before(EpochID keeper)
{
for (auto& ptr : epoch_cache) {
if (ptr && ptr->full_epoch < keeper) {
purge_epoch(ptr->full_epoch);
ptr.reset(nullptr);
}
}
}
output_bytes
MLSContext::protect(EpochID epoch_id,
SenderID sender_id,
output_bytes ciphertext,
input_bytes plaintext)
{
return protect(epoch_id, sender_id, 0, ciphertext, plaintext);
}
output_bytes
MLSContext::protect(EpochID epoch_id,
SenderID sender_id,
ContextID context_id,
output_bytes ciphertext,
input_bytes plaintext)
{
auto key_id = form_key_id(epoch_id, sender_id, context_id);
ensure_key(key_id);
return Context::protect(key_id, ciphertext, plaintext);
}
output_bytes
MLSContext::unprotect(output_bytes plaintext, input_bytes ciphertext)
{
const auto header = Header::parse(ciphertext);
const auto inner_ciphertext = ciphertext.subspan(header.size());
ensure_key(header.key_id);
return ContextBase::unprotect(header, plaintext, inner_ciphertext);
}
MLSContext::EpochKeys::EpochKeys(MLSContext::EpochID full_epoch_in,
bytes sframe_epoch_secret_in,
size_t epoch_bits,
size_t sender_bits_in)
: full_epoch(full_epoch_in)
, sframe_epoch_secret(std::move(sframe_epoch_secret_in))
, sender_bits(sender_bits_in)
{
static constexpr uint64_t one = 1;
static constexpr size_t key_id_bits = 64;
if (sender_bits > key_id_bits - epoch_bits) {
throw invalid_parameter_error("Sender ID field too large");
}
// XXX(RLB) We use 0 as a signifier that the sender takes the rest of the key
// ID, and context IDs are not allowed. This would be more explicit if we
// used std::optional, but would require more modern C++.
if (sender_bits == 0) {
sender_bits = key_id_bits - epoch_bits;
}
context_bits = key_id_bits - sender_bits - epoch_bits;
max_sender_id = (one << (sender_bits + 1)) - 1;
max_context_id = (one << (context_bits + 1)) - 1;
}
bytes
MLSContext::EpochKeys::base_key(CipherSuite ciphersuite,
SenderID sender_id) const
{
auto hash_size = cipher_digest_size(ciphersuite);
auto enc_sender_id = bytes(8);
encode_uint(sender_id, enc_sender_id);
return hkdf_expand(
ciphersuite, sframe_epoch_secret, enc_sender_id, hash_size);
}
void
MLSContext::purge_epoch(EpochID epoch_id)
{
const auto drop_bits = epoch_id & epoch_bits;
// Remove keys for this epoch
for (auto i = keys.begin(); i != keys.end();) {
if ((i->first & epoch_bits) == drop_bits) {
i = keys.erase(i);
} else {
++i;
}
}
// Remove counters for this epoch
for (auto i = counters.begin(); i != counters.end();) {
if ((i->first & epoch_bits) == drop_bits) {
i = counters.erase(i);
} else {
++i;
}
}
}
KeyID
MLSContext::form_key_id(EpochID epoch_id,
SenderID sender_id,
ContextID context_id) const
{
auto epoch_index = epoch_id & epoch_mask;
auto& epoch = epoch_cache.at(epoch_index);
if (!epoch) {
throw invalid_parameter_error(
"Unknown epoch. epoch_index: " + std::to_string(epoch_index) +
", sender_id:" + std::to_string(sender_id));
}
if (sender_id > epoch->max_sender_id) {
throw invalid_parameter_error("Sender ID overflow");
}
if (context_id > epoch->max_context_id) {
throw invalid_parameter_error("Context ID overflow");
}
auto sender_part = uint64_t(sender_id) << epoch_bits;
auto context_part = uint64_t(0);
if (epoch->context_bits > 0) {
context_part = uint64_t(context_id) << (epoch_bits + epoch->sender_bits);
}
return KeyID(context_part | sender_part | epoch_index);
}
void
MLSContext::ensure_key(KeyID key_id)
{
// If the required key already exists, we are done
const auto epoch_index = key_id & epoch_mask;
auto& epoch = epoch_cache.at(epoch_index);
if (!epoch) {
throw invalid_parameter_error("Unknown epoch: " +
std::to_string(epoch_index));
}
if (keys.count(key_id) > 0) {
return;
}
// Otherwise, derive a key and implant it
const auto sender_id = key_id >> epoch_bits;
Context::add_key(key_id, epoch->base_key(suite, sender_id));
return;
}
} // namespace sframe