| // Copyright (c) 2018 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| #include "components/assist_ranker/nn_classifier.h" |
| #include "base/logging.h" |
| #include "components/assist_ranker/nn_classifier_test_util.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace assist_ranker { |
| namespace nn_classifier { |
| namespace { |
| |
| using ::google::protobuf::RepeatedFieldBackInserter; |
| using ::std::copy; |
| using ::std::vector; |
| |
| TEST(NNClassifierTest, XorTest) { |
| // Creates a NN with a single hidden layer of 5 units that solves XOR. |
| // Creates a DNNClassifier model containing the trained biases and weights. |
| const NNClassifierModel model = CreateXorClassifierModel(); |
| ASSERT_TRUE(Validate(model)); |
| EXPECT_TRUE(CheckInference(model, {0, 0}, {-2.7154054})); |
| EXPECT_TRUE(CheckInference(model, {0, 1}, {2.8271765})); |
| EXPECT_TRUE(CheckInference(model, {1, 0}, {2.6790769})); |
| EXPECT_TRUE(CheckInference(model, {1, 1}, {-3.1652793})); |
| } |
| |
| TEST(NNClassifierTest, ValidateNNClassifierModel) { |
| // Empty model. |
| NNClassifierModel model; |
| EXPECT_FALSE(Validate(model)); |
| |
| // Valid model. |
| model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}}); |
| EXPECT_TRUE(Validate(model)); |
| |
| // Too few hidden layer biases. |
| model = CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}}); |
| EXPECT_FALSE(Validate(model)); |
| |
| // Too few hidden layer weights. |
| model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}}); |
| EXPECT_FALSE(Validate(model)); |
| |
| // Too few logits weights. |
| model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}}); |
| EXPECT_FALSE(Validate(model)); |
| |
| // Logits biases empty. |
| model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}}); |
| EXPECT_FALSE(Validate(model)); |
| } |
| |
| } // namespace |
| } // namespace nn_classifier |
| } // namespace assist_ranker |