blob: b208b125aa00f6f3a5b02f0a1cd49b35291830a2 [file] [log] [blame]
// Copyright (c) 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/nn_classifier_test_util.h"
#include "components/assist_ranker/nn_classifier.h"
namespace assist_ranker {
namespace nn_classifier {
namespace {
using ::google::protobuf::RepeatedFieldBackInserter;
using ::std::copy;
using ::std::vector;
void CreateLayer(const vector<float>& biases,
const vector<vector<float>>& weights,
NNLayer* layer) {
copy(biases.begin(), biases.end(),
RepeatedFieldBackInserter(layer->mutable_biases()->mutable_values()));
for (const auto& w : weights) {
auto* p = layer->add_weights();
copy(w.begin(), w.end(), RepeatedFieldBackInserter(p->mutable_values()));
}
}
} // namespace
NNClassifierModel CreateModel(const vector<float>& hidden_biases,
const vector<vector<float>>& hidden_weights,
const vector<float>& logits_biases,
const vector<vector<float>>& logits_weights) {
NNClassifierModel model;
CreateLayer(hidden_biases, hidden_weights, model.mutable_hidden_layer());
CreateLayer(logits_biases, logits_weights, model.mutable_logits_layer());
return model;
}
bool CheckInference(const NNClassifierModel& model,
const vector<float>& input,
const vector<float>& expected_scores) {
const vector<float> scores = Inference(model, input);
if (scores.size() != expected_scores.size())
return false;
for (size_t i = 0; i < scores.size(); ++i) {
if (abs(scores[i] - expected_scores[i]) > 1e-05)
return false;
}
return true;
}
NNClassifierModel CreateXorClassifierModel() {
return CreateModel(
// Hidden biases.
{{-0.45737201, 0.2009858, 1.02393341, -1.72199488, -0.54427308}},
// Hidden weights.
{{2.21626472, -0.08185583, -0.7542417, 1.97279537, 0.62363654},
{-1.71283901, 2.0275352, -1.14731216, 1.56915629, 0.49627137}},
// Logits biases.
{-1.27781141},
// Logits weights.
{{2.8636384}, {1.84202337}, {-1.76555872}, {-2.96390629}, {-1.00649774}});
}
} // namespace nn_classifier
} // namespace assist_ranker