blob: 237b83ad2633be6113d21dfef503b935d40b27ae [file] [log] [blame]
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/safe_browsing/client_side_model_loader.h"
#include <stdint.h>
#include <map>
#include <memory>
#include <string>
#include "base/callback.h"
#include "base/macros.h"
#include "base/metrics/field_trial.h"
#include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/time/time.h"
#include "chrome/common/safe_browsing/client_model.pb.h"
#include "components/safe_browsing/proto/csd.pb.h"
#include "components/variations/variations_associated_data.h"
#include "content/public/test/test_browser_thread_bundle.h"
#include "net/http/http_status_code.h"
#include "net/url_request/test_url_fetcher_factory.h"
#include "net/url_request/url_request_status.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"
using ::testing::Invoke;
using ::testing::Mock;
using ::testing::StrictMock;
using ::testing::_;
namespace safe_browsing {
namespace {
class MockModelLoader : public ModelLoader {
public:
explicit MockModelLoader(base::Closure update_renderers_callback,
const std::string model_name)
: ModelLoader(update_renderers_callback, model_name) {}
~MockModelLoader() override {}
MOCK_METHOD1(ScheduleFetch, void(int64_t));
MOCK_METHOD2(EndFetch, void(ClientModelStatus, base::TimeDelta));
private:
DISALLOW_COPY_AND_ASSIGN(MockModelLoader);
};
} // namespace
class ModelLoaderTest : public testing::Test {
protected:
ModelLoaderTest()
: factory_(new net::FakeURLFetcherFactory(nullptr)),
field_trials_(new base::FieldTrialList(nullptr)) {}
void SetUp() override {
variations::testing::ClearAllVariationIDs();
variations::testing::ClearAllVariationParams();
}
// Set up the finch experiment to control the model number
// used in the model URL. This clears all existing state.
void SetFinchModelNumber(int model_number) {
// Destroy the existing FieldTrialList before creating a new one to avoid
// a DCHECK.
field_trials_.reset();
field_trials_.reset(new base::FieldTrialList(nullptr));
variations::testing::ClearAllVariationIDs();
variations::testing::ClearAllVariationParams();
const std::string group_name = "ModelFoo"; // Not used in CSD code.
ASSERT_TRUE(base::FieldTrialList::CreateFieldTrial(
ModelLoader::kClientModelFinchExperiment, group_name));
std::map<std::string, std::string> params;
params[ModelLoader::kClientModelFinchParam] =
base::IntToString(model_number);
ASSERT_TRUE(variations::AssociateVariationParams(
ModelLoader::kClientModelFinchExperiment, group_name, params));
}
// Set the URL for future SetModelFetchResponse() calls.
void SetModelUrl(const ModelLoader& loader) { model_url_ = loader.url_; }
void SetModelFetchResponse(std::string response_data,
net::HttpStatusCode response_code,
net::URLRequestStatus::Status status) {
CHECK(model_url_.is_valid());
factory_->SetFakeResponse(model_url_, response_data, response_code, status);
}
private:
content::TestBrowserThreadBundle thread_bundle_;
std::unique_ptr<net::FakeURLFetcherFactory> factory_;
std::unique_ptr<base::FieldTrialList> field_trials_;
GURL model_url_;
};
ACTION_P(InvokeClosure, closure) {
closure.Run();
}
// Test the reponse to many variations of model responses.
TEST_F(ModelLoaderTest, FetchModelTest) {
StrictMock<MockModelLoader> loader(base::Closure(), "top_model.pb");
SetModelUrl(loader);
// The model fetch failed.
{
base::RunLoop loop;
SetModelFetchResponse("blamodel", net::HTTP_INTERNAL_SERVER_ERROR,
net::URLRequestStatus::FAILED);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_FETCH_FAILED, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Empty model file.
{
base::RunLoop loop;
SetModelFetchResponse(std::string(), net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_EMPTY, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Model is too large.
{
base::RunLoop loop;
SetModelFetchResponse(std::string(ModelLoader::kMaxModelSizeBytes + 1, 'x'),
net::HTTP_OK, net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_TOO_LARGE, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Unable to parse the model file.
{
base::RunLoop loop;
SetModelFetchResponse("Invalid model file", net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_PARSE_ERROR, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Model that is missing some required fields (missing the version field).
ClientSideModel model;
model.set_max_words_per_term(4);
{
base::RunLoop loop;
SetModelFetchResponse(model.SerializePartialAsString(), net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_MISSING_FIELDS, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Model that points to hashes that don't exist.
model.set_version(10);
model.add_hashes("bla");
model.add_page_term(1); // Should be 0 instead of 1.
{
base::RunLoop loop;
SetModelFetchResponse(model.SerializePartialAsString(), net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_BAD_HASH_IDS, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
model.set_page_term(0, 0);
// Model version number is wrong.
model.set_version(-1);
{
base::RunLoop loop;
SetModelFetchResponse(model.SerializeAsString(), net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_INVALID_VERSION_NUMBER, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Normal model.
model.set_version(10);
{
base::RunLoop loop;
SetModelFetchResponse(model.SerializeAsString(), net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_SUCCESS, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Model version number is decreasing. Set the model version number of the
// model that is currently loaded in the loader object to 11.
loader.model_.reset(new ClientSideModel(model));
loader.model_->set_version(11);
{
base::RunLoop loop;
SetModelFetchResponse(model.SerializeAsString(), net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_INVALID_VERSION_NUMBER, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Model version hasn't changed since the last reload.
loader.model_->set_version(10);
{
base::RunLoop loop;
SetModelFetchResponse(model.SerializeAsString(), net::HTTP_OK,
net::URLRequestStatus::SUCCESS);
EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_NOT_CHANGED, _))
.WillOnce(InvokeClosure(loop.QuitClosure()));
loader.StartFetch();
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
}
// Test that a successful reponse will update the renderers
TEST_F(ModelLoaderTest, UpdateRenderersTest) {
// Use runloop for convenient callback detection.
base::RunLoop loop;
StrictMock<MockModelLoader> loader(loop.QuitClosure(), "top_model.pb");
EXPECT_CALL(loader, ScheduleFetch(_));
loader.ModelLoader::EndFetch(ModelLoader::MODEL_SUCCESS, base::TimeDelta());
loop.Run();
Mock::VerifyAndClearExpectations(&loader);
}
// Test that a one fetch schedules another fetch.
TEST_F(ModelLoaderTest, RescheduleFetchTest) {
StrictMock<MockModelLoader> loader(base::Closure(), "top_model.pb");
// Zero max_age. Uses default.
base::TimeDelta max_age;
EXPECT_CALL(loader, ScheduleFetch(ModelLoader::kClientModelFetchIntervalMs));
loader.ModelLoader::EndFetch(ModelLoader::MODEL_NOT_CHANGED, max_age);
Mock::VerifyAndClearExpectations(&loader);
// Non-zero max_age from header.
max_age = base::TimeDelta::FromMinutes(42);
EXPECT_CALL(loader, ScheduleFetch((max_age + base::TimeDelta::FromMinutes(1))
.InMilliseconds()));
loader.ModelLoader::EndFetch(ModelLoader::MODEL_NOT_CHANGED, max_age);
Mock::VerifyAndClearExpectations(&loader);
// Non-zero max_age, but failed load should use default interval.
max_age = base::TimeDelta::FromMinutes(42);
EXPECT_CALL(loader, ScheduleFetch(ModelLoader::kClientModelFetchIntervalMs));
loader.ModelLoader::EndFetch(ModelLoader::MODEL_FETCH_FAILED, max_age);
Mock::VerifyAndClearExpectations(&loader);
}
// Test that Finch params control the model names.
TEST_F(ModelLoaderTest, ModelNamesTest) {
// Test the name-templating.
EXPECT_EQ(ModelLoader::FillInModelName(true, 3),
"client_model_v5_ext_variation_3.pb");
EXPECT_EQ(ModelLoader::FillInModelName(false, 5),
"client_model_v5_variation_5.pb");
// No Finch setup. Should default to 0.
std::unique_ptr<ModelLoader> loader;
loader.reset(new ModelLoader(base::Closure(), NULL,
false /* is_extended_reporting */));
EXPECT_EQ(loader->name(), "client_model_v5_variation_0.pb");
EXPECT_EQ(loader->url_.spec(),
"https://ssl.gstatic.com/safebrowsing/csd/"
"client_model_v5_variation_0.pb");
// Model 1, no extended reporting.
SetFinchModelNumber(1);
loader.reset(new ModelLoader(base::Closure(), NULL, false));
EXPECT_EQ(loader->name(), "client_model_v5_variation_1.pb");
// Model 2, with extended reporting.
SetFinchModelNumber(2);
loader.reset(new ModelLoader(base::Closure(), NULL, true));
EXPECT_EQ(loader->name(), "client_model_v5_ext_variation_2.pb");
}
TEST_F(ModelLoaderTest, ModelHasValidHashIds) {
ClientSideModel model;
EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
model.add_hashes("bla");
EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
model.add_page_term(0);
EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
model.add_page_term(-1);
EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model));
model.set_page_term(1, 1);
EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model));
model.set_page_term(1, 0);
EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
// Test bad rules.
model.add_hashes("blu");
ClientSideModel::Rule* rule = model.add_rule();
rule->add_feature(0);
rule->add_feature(1);
rule->set_weight(0.1f);
EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
rule = model.add_rule();
rule->add_feature(0);
rule->add_feature(1);
rule->add_feature(-1);
rule->set_weight(0.2f);
EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model));
rule->set_feature(2, 2);
EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model));
rule->set_feature(2, 1);
EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
}
} // namespace safe_browsing