blob: c080eb34adccd67a8a9a9e8c41388ce6b88bc9a9 [file] [log] [blame]
// Copyright (c) 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/nn_classifier.h"
#include "base/logging.h"
#include "components/assist_ranker/nn_classifier_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace assist_ranker {
namespace nn_classifier {
namespace {
using ::google::protobuf::RepeatedFieldBackInserter;
using ::std::copy;
using ::std::vector;
TEST(NNClassifierTest, XorTest) {
// Creates a NN with a single hidden layer of 5 units that solves XOR.
// Creates a DNNClassifier model containing the trained biases and weights.
const NNClassifierModel model = CreateXorClassifierModel();
ASSERT_TRUE(Validate(model));
EXPECT_TRUE(CheckInference(model, {0, 0}, {-2.7154054}));
EXPECT_TRUE(CheckInference(model, {0, 1}, {2.8271765}));
EXPECT_TRUE(CheckInference(model, {1, 0}, {2.6790769}));
EXPECT_TRUE(CheckInference(model, {1, 1}, {-3.1652793}));
}
TEST(NNClassifierTest, ValidateNNClassifierModel) {
// Empty model.
NNClassifierModel model;
EXPECT_FALSE(Validate(model));
// Valid model.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}});
EXPECT_TRUE(Validate(model));
// Too few hidden layer biases.
model = CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}});
EXPECT_FALSE(Validate(model));
// Too few hidden layer weights.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}});
EXPECT_FALSE(Validate(model));
// Too few logits weights.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}});
EXPECT_FALSE(Validate(model));
// Logits biases empty.
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}});
EXPECT_FALSE(Validate(model));
}
} // namespace
} // namespace nn_classifier
} // namespace assist_ranker