blob: 15c9de414240d0be1b893bdf26791c48620e2146 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "fuzz/mem_hash_tree.h"
#include <algorithm>
#include <cassert>
MemHashTree::MemHashTree() : bits_per_level_(0), height_(0) {}
bool MemHashTree::GetLeaf(uint64_t label, fuzz::span<uint8_t> leaf_hash) const {
assert(leaf_hash.size() >= SHA256_DIGEST_SIZE);
auto itr = hash_tree_.find(MaskedLabel(label, 0));
if (itr == hash_tree_.end()) {
std::fill(leaf_hash.begin(), leaf_hash.end(), 0);
return false;
}
std::copy(itr->second.begin(), itr->second.end(), leaf_hash.begin());
return true;
}
size_t MemHashTree::GetPath(uint64_t label,
fuzz::span<uint8_t> path_hashes) const {
uint8_t fan_out = 1 << bits_per_level_;
uint8_t num_siblings = fan_out - 1;
assert(path_hashes.size() >= num_siblings * height_ * SHA256_DIGEST_SIZE);
// num_siblings and child_index_mask have the same value, but were named
// differently to help convey how they are used.
uint64_t child_index_mask = fan_out - 1;
uint64_t shifted_parent_label = label;
uint8_t* dest_itr = path_hashes.begin();
for (uint8_t level = 0; level < height_; ++level) {
uint8_t label_index = shifted_parent_label & child_index_mask;
shifted_parent_label &= ~child_index_mask;
for (uint8_t index = 0; index < fan_out; ++index) {
// Only include hashes for sibling nodes.
if (index == label_index) {
continue;
}
auto src_itr =
hash_tree_.find(MaskedLabel(shifted_parent_label | index, level));
if (src_itr == hash_tree_.end()) {
std::copy(empty_node_hashes_[level].begin(),
empty_node_hashes_[level].end(), dest_itr);
} else {
std::copy(src_itr->second.begin(), src_itr->second.end(), dest_itr);
}
dest_itr += SHA256_DIGEST_SIZE;
}
shifted_parent_label = shifted_parent_label >> bits_per_level_;
}
return dest_itr - path_hashes.begin();
}
void MemHashTree::UpdatePath(uint64_t label,
fuzz::span<const uint8_t> path_hash) {
std::array<uint8_t, SHA256_DIGEST_SIZE> hash;
if (path_hash.empty()) {
std::fill(hash.begin(), hash.end(), 0);
hash_tree_.erase(MaskedLabel(label, 0));
} else {
assert(path_hash.size() == SHA256_DIGEST_SIZE);
std::copy(path_hash.begin(), path_hash.end(), hash.begin());
hash_tree_[MaskedLabel(label, 0)] = hash;
}
uint8_t fan_out = 1 << bits_per_level_;
uint64_t child_index_mask = fan_out - 1;
uint64_t shifted_parent_label = label;
for (int level = 0; level < height_; ++level) {
shifted_parent_label &= ~child_index_mask;
LITE_SHA256_CTX ctx;
DCRYPTO_SHA256_init(&ctx, 1);
int empty_nodes = 0;
for (int index = 0; index < fan_out; ++index) {
auto itr =
hash_tree_.find(MaskedLabel(shifted_parent_label | index, level));
if (itr == hash_tree_.end()) {
HASH_update(&ctx, empty_node_hashes_[level].data(),
empty_node_hashes_[level].size());
++empty_nodes;
} else {
HASH_update(&ctx, itr->second.data(), itr->second.size());
}
}
shifted_parent_label = shifted_parent_label >> bits_per_level_;
const uint8_t* temp = HASH_final(&ctx);
std::copy(temp, temp + SHA256_DIGEST_SIZE, hash.begin());
MaskedLabel node_key(shifted_parent_label, level + 1);
if (empty_nodes == fan_out) {
hash_tree_.erase(node_key);
} else {
hash_tree_[node_key] = hash;
}
}
}
void MemHashTree::Reset() {
bits_per_level_ = 0;
height_ = 0;
empty_node_hashes_.clear();
hash_tree_.clear();
}
void MemHashTree::Reset(uint8_t bits_per_level, uint8_t height) {
bits_per_level_ = bits_per_level;
height_ = height;
hash_tree_.clear();
empty_node_hashes_.resize(height);
std::array<uint8_t, SHA256_DIGEST_SIZE> hash;
std::fill(hash.begin(), hash.end(), 0);
empty_node_hashes_[0] = hash;
uint8_t fan_out = 1 << bits_per_level;
for (int level = 1; level < height; ++level) {
LITE_SHA256_CTX ctx;
DCRYPTO_SHA256_init(&ctx, 1);
for (int index = 0; index < fan_out; ++index) {
HASH_update(&ctx, hash.data(), hash.size());
}
const uint8_t* temp = HASH_final(&ctx);
std::copy(temp, temp + SHA256_DIGEST_SIZE, hash.begin());
empty_node_hashes_[level] = hash;
}
}