blob: c4510d44432c99572f20d4c467e2a0f9fb43e88c [file] [edit]
#include <sframe/sframe.h>
#include "crypto.h"
#include "header.h"
namespace SFRAME_NAMESPACE {
///
/// Errors
///
unsupported_ciphersuite_error::unsupported_ciphersuite_error()
: std::runtime_error("Unsupported ciphersuite")
{
}
authentication_error::authentication_error()
: std::runtime_error("AEAD authentication failure")
{
}
///
/// KeyRecord
///
static auto
from_ascii(const char* str, size_t len)
{
const auto ptr = reinterpret_cast<const uint8_t*>(str);
return input_bytes(ptr, len);
}
static const auto base_label = from_ascii("SFrame 1.0 Secret ", 18);
static const auto key_label = from_ascii("key ", 4);
static const auto salt_label = from_ascii("salt ", 5);
static owned_bytes<32>
sframe_key_label(CipherSuite suite, KeyID key_id)
{
auto label = owned_bytes<32>(base_label);
label.append(key_label);
label.resize(32);
auto label_data = output_bytes(label);
encode_uint(key_id, label_data.subspan(22).first(8));
encode_uint(static_cast<uint64_t>(suite), label_data.subspan(30));
return label;
}
static owned_bytes<33>
sframe_salt_label(CipherSuite suite, KeyID key_id)
{
auto label = owned_bytes<33>(base_label);
label.append(salt_label);
label.resize(33);
auto label_data = output_bytes(label);
encode_uint(key_id, label_data.last(10).first(8));
encode_uint(static_cast<uint64_t>(suite), label_data.last(2));
return label;
}
KeyRecord
KeyRecord::from_base_key(CipherSuite suite,
KeyID key_id,
KeyUsage usage,
input_bytes base_key)
{
auto key_size = cipher_key_size(suite);
auto nonce_size = cipher_nonce_size(suite);
const auto empty_byte_string = owned_bytes<1>();
const auto key_label = sframe_key_label(suite, key_id);
const auto salt_label = sframe_salt_label(suite, key_id);
auto secret = hkdf_extract(suite, empty_byte_string, base_key);
auto key = hkdf_expand(suite, secret, key_label, key_size);
auto salt = hkdf_expand(suite, secret, salt_label, nonce_size);
return KeyRecord{ key, salt, usage, 0 };
}
///
/// Context
///
Context::Context(CipherSuite suite_in)
: suite(suite_in)
{
}
Context::~Context() = default;
void
Context::add_key(KeyID key_id, KeyUsage usage, input_bytes base_key)
{
keys.emplace(key_id,
KeyRecord::from_base_key(suite, key_id, usage, base_key));
}
static owned_bytes<KeyRecord::max_salt_size>
form_nonce(Counter ctr, input_bytes salt)
{
auto nonce = owned_bytes<KeyRecord::max_salt_size>(salt);
for (size_t i = 0; i < sizeof(ctr); i++) {
nonce[nonce.size() - i - 1] ^= uint8_t(ctr >> (8 * i));
}
return nonce;
}
static constexpr auto max_aad_size =
Header::max_size + Context::max_metadata_size;
static owned_bytes<max_aad_size>
form_aad(const Header& header, input_bytes metadata)
{
if (metadata.size() > Context::max_metadata_size) {
throw buffer_too_small_error("Metadata too large");
}
auto aad = owned_bytes<max_aad_size>(0);
aad.append(header.encoded());
aad.append(metadata);
return aad;
}
output_bytes
Context::protect(KeyID key_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata)
{
auto& key_record = keys.at(key_id);
const auto counter = key_record.counter;
key_record.counter += 1;
const auto header = Header{ key_id, counter };
const auto header_data = header.encoded();
if (ciphertext.size() < header_data.size()) {
throw buffer_too_small_error("Ciphertext too small for SFrame header");
}
std::copy(header_data.begin(), header_data.end(), ciphertext.begin());
auto inner_ciphertext = ciphertext.subspan(header_data.size());
auto final_ciphertext =
Context::protect_inner(header, inner_ciphertext, plaintext, metadata);
return ciphertext.first(header_data.size() + final_ciphertext.size());
}
output_bytes
Context::unprotect(output_bytes plaintext,
input_bytes ciphertext,
input_bytes metadata)
{
const auto header = Header::parse(ciphertext);
const auto inner_ciphertext = ciphertext.subspan(header.size());
return Context::unprotect_inner(
header, plaintext, inner_ciphertext, metadata);
}
output_bytes
Context::protect_inner(const Header& header,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata)
{
if (ciphertext.size() < plaintext.size() + cipher_overhead(suite)) {
throw buffer_too_small_error("Ciphertext too small for cipher overhead");
}
const auto& key_and_salt = keys.at(header.key_id);
const auto aad = form_aad(header, metadata);
const auto nonce = form_nonce(header.counter, key_and_salt.salt);
return seal(suite, key_and_salt.key, nonce, ciphertext, aad, plaintext);
}
output_bytes
Context::unprotect_inner(const Header& header,
output_bytes plaintext,
input_bytes ciphertext,
input_bytes metadata)
{
if (ciphertext.size() < cipher_overhead(suite)) {
throw buffer_too_small_error("Ciphertext too small for cipher overhead");
}
if (plaintext.size() < ciphertext.size() - cipher_overhead(suite)) {
throw buffer_too_small_error("Plaintext too small for decrypted value");
}
const auto& key_and_salt = keys.at(header.key_id);
const auto aad = form_aad(header, metadata);
const auto nonce = form_nonce(header.counter, key_and_salt.salt);
return open(suite, key_and_salt.key, nonce, plaintext, aad, ciphertext);
}
///
/// MLSContext
///
MLSContext::MLSContext(CipherSuite suite_in, size_t epoch_bits_in)
: Context(suite_in)
, epoch_bits(epoch_bits_in)
, epoch_mask((size_t(1) << epoch_bits_in) - 1)
{
epoch_cache.resize(1 << epoch_bits_in);
}
void
MLSContext::add_epoch(EpochID epoch_id, input_bytes sframe_epoch_secret)
{
add_epoch(epoch_id, sframe_epoch_secret, 0);
}
void
MLSContext::add_epoch(EpochID epoch_id,
input_bytes sframe_epoch_secret,
size_t sender_bits)
{
auto epoch_index = epoch_id & epoch_mask;
auto& epoch = epoch_cache[epoch_index];
if (epoch) {
purge_epoch(epoch->full_epoch);
}
epoch.emplace(epoch_id, sframe_epoch_secret, epoch_bits, sender_bits);
}
void
MLSContext::purge_before(EpochID keeper)
{
for (auto& ptr : epoch_cache) {
if (ptr && ptr->full_epoch < keeper) {
purge_epoch(ptr->full_epoch);
ptr.reset();
}
}
}
output_bytes
MLSContext::protect(EpochID epoch_id,
SenderID sender_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata)
{
return protect(epoch_id, sender_id, 0, ciphertext, plaintext, metadata);
}
output_bytes
MLSContext::protect(EpochID epoch_id,
SenderID sender_id,
ContextID context_id,
output_bytes ciphertext,
input_bytes plaintext,
input_bytes metadata)
{
auto key_id = form_key_id(epoch_id, sender_id, context_id);
ensure_key(key_id, KeyUsage::protect);
return Context::protect(key_id, ciphertext, plaintext, metadata);
}
output_bytes
MLSContext::unprotect(output_bytes plaintext,
input_bytes ciphertext,
input_bytes metadata)
{
const auto header = Header::parse(ciphertext);
const auto inner_ciphertext = ciphertext.subspan(header.size());
ensure_key(header.key_id, KeyUsage::unprotect);
return Context::unprotect_inner(
header, plaintext, inner_ciphertext, metadata);
}
MLSContext::EpochKeys::EpochKeys(MLSContext::EpochID full_epoch_in,
input_bytes sframe_epoch_secret_in,
size_t epoch_bits,
size_t sender_bits_in)
: full_epoch(full_epoch_in)
, sframe_epoch_secret(sframe_epoch_secret_in)
, sender_bits(sender_bits_in)
{
static constexpr uint64_t one = 1;
static constexpr size_t key_id_bits = 64;
if (sender_bits > key_id_bits - epoch_bits) {
throw invalid_parameter_error("Sender ID field too large");
}
// XXX(RLB) We use 0 as a signifier that the sender takes the rest of the key
// ID, and context IDs are not allowed. This would be more explicit if we
// used std::optional, but would require more modern C++.
if (sender_bits == 0) {
sender_bits = key_id_bits - epoch_bits;
}
context_bits = key_id_bits - sender_bits - epoch_bits;
max_sender_id = (one << sender_bits) - 1;
max_context_id = (one << context_bits) - 1;
}
owned_bytes<MLSContext::EpochKeys::max_secret_size>
MLSContext::EpochKeys::base_key(CipherSuite ciphersuite,
SenderID sender_id) const
{
const auto hash_size = cipher_digest_size(ciphersuite);
auto enc_sender_id = owned_bytes<8>();
encode_uint(sender_id, enc_sender_id);
return hkdf_expand(
ciphersuite, sframe_epoch_secret, enc_sender_id, hash_size);
}
void
MLSContext::purge_epoch(EpochID epoch_id)
{
const auto drop_bits = epoch_id & epoch_bits;
keys.erase_if_key(
[&](const auto& epoch) { return (epoch & epoch_bits) == drop_bits; });
}
KeyID
MLSContext::form_key_id(EpochID epoch_id,
SenderID sender_id,
ContextID context_id) const
{
auto epoch_index = epoch_id & epoch_mask;
auto& epoch = epoch_cache[epoch_index];
if (!epoch) {
throw invalid_parameter_error(
"Unknown epoch. epoch_index: " + std::to_string(epoch_index) +
", sender_id:" + std::to_string(sender_id));
}
if (sender_id > epoch->max_sender_id) {
throw invalid_parameter_error("Sender ID overflow");
}
if (context_id > epoch->max_context_id) {
throw invalid_parameter_error("Context ID overflow");
}
auto sender_part = uint64_t(sender_id) << epoch_bits;
auto context_part = uint64_t(0);
if (epoch->context_bits > 0) {
context_part = uint64_t(context_id) << (epoch_bits + epoch->sender_bits);
}
return KeyID(context_part | sender_part | epoch_index);
}
void
MLSContext::ensure_key(KeyID key_id, KeyUsage usage)
{
// If the required key already exists, we are done
const auto epoch_index = key_id & epoch_mask;
auto& epoch = epoch_cache[epoch_index];
if (!epoch) {
throw invalid_parameter_error("Unknown epoch: " +
std::to_string(epoch_index));
}
if (keys.contains(key_id)) {
return;
}
// Otherwise, derive a key and implant it
const auto sender_id = key_id >> epoch_bits;
Context::add_key(key_id, usage, epoch->base_key(suite, sender_id));
return;
}
} // namespace SFRAME_NAMESPACE