blob: 3661431e04708887c69854ef1c67005c4d6d0541 [file] [log] [blame]
// Copyright (c) 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/quantized_nn_classifier.h"
#include "base/logging.h"
#include "components/assist_ranker/nn_classifier.h"
#include "components/assist_ranker/nn_classifier_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace assist_ranker {
namespace quantized_nn_classifier {
namespace {
using ::google::protobuf::RepeatedFieldBackInserter;
using ::google::protobuf::RepeatedPtrField;
using ::std::copy;
using ::std::vector;
void CreateLayer(const vector<int>& biases,
const vector<vector<int>>& weights,
float low,
float high,
QuantizedNNLayer* layer) {
layer->set_biases(std::string(biases.begin(), biases.end()));
for (const auto& i : weights) {
layer->mutable_weights()->Add(std::string(i.begin(), i.end()));
}
layer->set_low(low);
layer->set_high(high);
}
// Creates a QuantizedDNNClassifierModel proto using a trained set of biases and
// weights.
QuantizedNNClassifierModel CreateModel(
const vector<int>& hidden_biases,
const vector<vector<int>>& hidden_weights,
const vector<int>& logits_biases,
const vector<vector<int>>& logits_weights,
float low,
float high) {
QuantizedNNClassifierModel model;
CreateLayer(hidden_biases, hidden_weights, low, high,
model.mutable_hidden_layer());
CreateLayer(logits_biases, logits_weights, low, high,
model.mutable_logits_layer());
return model;
}
TEST(QuantizedNNClassifierTest, Dequantize) {
const QuantizedNNClassifierModel quantized = CreateModel(
// Hidden biases.
{{8, 16, 32}},
// Hidden weights.
{{2, 4, 6}, {10, 4, 8}},
// Logits biases.
{2},
// Logits weights.
{{4}, {2}, {6}},
// Low.
0,
// High.
128);
ASSERT_TRUE(Validate(quantized));
const NNClassifierModel model = Dequantize(quantized);
const NNClassifierModel expected = nn_classifier::CreateModel(
// Hidden biases.
{{4, 8, 16}},
// Hidden weights.
{{1, 2, 3}, {5, 2, 4}},
// Logits biases.
{1},
// Logits weights.
{{2}, {1}, {3}});
EXPECT_EQ(model.SerializeAsString(), expected.SerializeAsString());
}
TEST(QuantizedNNClassifierTest, XorTest) {
// Creates a NN with a single hidden layer of 5 units that solves XOR.
// Creates a QuantizedDNNClassifier model containing the trained biases and
// weights.
const QuantizedNNClassifierModel quantized = CreateModel(
// Hidden biases.
{{110, 139, 175, 55, 106}},
// Hidden weights.
{{228, 127, 97, 217, 158}, {55, 219, 80, 199, 152}},
// Logits biases.
{74},
// Logits weights.
{{255}, {211}, {53}, {0}, {86}},
// Low.
-2.96390629,
// High.
2.8636384);
ASSERT_TRUE(Validate(quantized));
const NNClassifierModel model = Dequantize(quantized);
ASSERT_TRUE(nn_classifier::Validate(model));
EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 0}, {-2.7032}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 1}, {2.80681}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 0}, {2.64435}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 1}, {-3.17825}));
}
TEST(QuantizedNNClassifierTest, ValidateQuantizedNNClassifierModel) {
// Empty model.
QuantizedNNClassifierModel model;
EXPECT_FALSE(Validate(model));
// Valid model.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
0, 1);
EXPECT_TRUE(Validate(model));
// Hidden bias incorrect size.
model =
CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
// Hidden weight vector incorrect size.
model =
CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
// Logits weights incorrect size.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
// Empty logits bias.
model =
CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
// Low / high incorrect.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
1, 0);
EXPECT_FALSE(Validate(model));
}
} // namespace
} // namespace quantized_nn_classifier
} // namespace assist_ranker