blob: 6423595515f3c83b1a675e325c22a70a800061bb [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/language_detection/core/ngram_hash.h"
#include <stdint.h>
#include <string>
#include "base/compiler_specific.h"
#include "base/containers/span.h"
#include "components/language_detection/core/ngram_hash_ops_utils.h"
#include "third_party/flatbuffers/src/include/flatbuffers/flexbuffers.h"
#include "third_party/flatbuffers/src/include/flatbuffers/util.h"
#include "third_party/smhasher/src/src/MurmurHash2.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tflite/src/tensorflow/lite/string_util.h"
namespace language_detection {
namespace {
using ::flexbuffers::GetRoot;
using ::flexbuffers::Map;
using ::flexbuffers::TypedVector;
using ::tflite::GetString;
using ::tflite::StringRef;
constexpr int kInputMessage = 0;
constexpr int kOutputLabel = 0;
constexpr int kDefaultMaxSplits = 128;
// This op takes in a string, finds the character ngrams for it and then
// maps each of these ngrams to an index using the specified vocabulary sizes.
// Input(s):
// - input: Input string.
// - seeds: Seed for the random number generator.
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
// be interpreted as generating unigrams, bigrams, and trigrams.
// - vocab_sizes: Size of the vocabulary for each of the ngram features
// respectively. The op would generate vocab ids to be less than or equal to
// the vocab size. The index 0 implies an invalid ngram.
// - max_splits: Maximum number of tokens in the output. If this is unset, the
// limit is `kDefaultMaxSplits`.
// - lower_case_input: If this is set to true, the input string would be
// lower-cased before any processing.
// Output(s):
// - output: A tensor of size [number of ngrams, number of tokens + 2],
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
// Helper class used for pre-processing the input.
class NGramHashParams {
public:
NGramHashParams(const uint64_t seed,
const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes,
int max_splits,
bool lower_case_input)
: seed_(seed),
ngram_lengths_(ngram_lengths),
vocab_sizes_(vocab_sizes),
max_splits_(max_splits),
lower_case_input_(lower_case_input) {}
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
TfLiteContext* context) {
if (input_t->bytes == 0) {
context->ReportError(context, "Empty input not supported.");
return kTfLiteError;
}
// Do sanity checks on the input.
if (ngram_lengths_.empty()) {
context->ReportError(context, "`ngram_lengths` must be non-empty.");
return kTfLiteError;
}
if (vocab_sizes_.empty()) {
context->ReportError(context, "`vocab_sizes` must be non-empty.");
return kTfLiteError;
}
if (ngram_lengths_.size() != vocab_sizes_.size()) {
context->ReportError(
context,
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
return kTfLiteError;
}
if (max_splits_ <= 0) {
context->ReportError(context, "`max_splits` must be > 0.");
return kTfLiteError;
}
// Obtain and tokenize the input.
StringRef input_ref = GetString(input_t, /*string_index=*/0);
tokenized_output_ =
Tokenize({input_ref.str, input_ref.len}, max_splits_,
/*exclude_nonalphaspace_tokens=*/true, lower_case_input_);
return kTfLiteOk;
}
uint64_t GetSeed() const { return seed_; }
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
int GetNumNGrams() const { return ngram_lengths_.size(); }
const std::vector<int>& GetNGramLengths() const { return ngram_lengths_; }
const std::vector<int>& GetVocabSizes() const { return vocab_sizes_; }
const TokenizedOutput& GetTokenizedOutput() const {
return tokenized_output_;
}
TokenizedOutput tokenized_output_;
private:
const uint64_t seed_;
std::vector<int> ngram_lengths_;
std::vector<int> vocab_sizes_;
const int max_splits_;
const bool lower_case_input_;
};
// Convert the TypedVector into a regular std::vector.
std::vector<int> GetIntVector(TypedVector typed_vec) {
std::vector<int> vec(typed_vec.size());
for (size_t j = 0; j < typed_vec.size(); j++) {
vec[j] = typed_vec[j].AsInt32();
}
return vec;
}
void GetNGramHashIndices(NGramHashParams* params, base::span<int32_t> data) {
const int max_unicode_length = params->GetNumTokens();
const auto ngram_lengths = params->GetNGramLengths();
const auto vocab_sizes = params->GetVocabSizes();
const auto& tokenized_output = params->GetTokenizedOutput();
const auto seed = params->GetSeed();
// Compute for each ngram.
for (size_t ngram = 0; ngram < ngram_lengths.size(); ngram++) {
const int vocab_size = vocab_sizes[ngram];
const int ngram_length = ngram_lengths[ngram];
// Compute for each token within the input.
for (size_t start = 0; start < tokenized_output.tokens.size(); start++) {
// Compute the number of bytes for the ngram starting at the given
// token.
int num_bytes = 0;
for (size_t i = start;
i < tokenized_output.tokens.size() && i < (start + ngram_length);
i++) {
num_bytes += tokenized_output.tokens[i].second;
}
// Compute the hash for the ngram starting at the token.
//
// TODO(crbug.com/40601370): Murmur2 has only 2 remaining uses in
// Chrome. Migrate to a different hash that's more widely used in future
// versions of the mode and also supports 32/64 bit platforms
// seamlessly. Anything over num_bytes = 7 can overflow on 32-bit. By
// limiting to 7, this may truncate the last byte of the input and result
// in a slightly different hash but impact should be minimal.
const auto str_hash = MurmurHash64A(
&tokenized_output.str[tokenized_output.tokens[start].first],
std::min(num_bytes, 7), seed);
// Map the hash to an index in the vocab.
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
}
}
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const Map& m = GetRoot(buffer_t, length).AsMap();
const uint64_t seed = m["seed"].AsInt64();
const std::vector<int> ngram_lengths =
GetIntVector(m["ngram_lengths"].AsTypedVector());
const std::vector<int> vocab_sizes =
GetIntVector(m["vocab_sizes"].AsTypedVector());
const int max_splits =
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
const bool lowercase_input =
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
lowercase_input);
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<NGramHashParams*>(buffer);
}
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
tflite::SetTensorToDynamic(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
TF_LITE_ENSURE_OK(
context, params->PreprocessInput(
tflite::GetInput(context, node, kInputMessage), context));
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
if (tflite::IsDynamicTensor(output)) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
// SAFETY: Without span accessors, we can only trust the
// result of `TfLiteIntArrayCreate()`.
auto data = UNSAFE_BUFFERS(base::span(output_size->data, 3u));
data[0] = 1;
data[1] = params->GetNumNGrams();
data[2] = params->GetNumTokens();
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
} else {
context->ReportError(context, "Output must by dynamic.");
return kTfLiteError;
}
if (output->type == kTfLiteInt32) {
// SAFETY: Without span accessors, we can only trust the size
// provided.
base::span<int32_t> data =
UNSAFE_BUFFERS(base::span(tflite::GetTensorData<int32_t>(output),
output->bytes / sizeof(int32_t)));
GetNGramHashIndices(params, data);
} else {
context->ReportError(context, "Output type must be Int32.");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace
TfLiteRegistration* Register_NGRAM_HASH() {
static TfLiteRegistration r = {Init, Free, Resize, Eval};
return &r;
}
} // namespace language_detection