blob: 3ac195f0b9ecf2780e6d2dfe2d4670f23c8065a0 [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/ranker_model.h"
#include <memory>
#include "base/time/time.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace {
using assist_ranker::RankerModel;
const char kModelURL[] = "https://some.url.net/model";
int64_t InSeconds(const base::Time t) {
return (t - base::Time()).InSeconds();
}
std::unique_ptr<RankerModel> NewModel(const std::string& model_url,
base::Time last_modified,
base::TimeDelta cache_duration) {
std::unique_ptr<RankerModel> model = std::make_unique<RankerModel>();
auto* metadata = model->mutable_proto()->mutable_metadata();
if (!model_url.empty())
metadata->set_source(model_url);
if (!last_modified.is_null())
metadata->set_last_modified_sec(InSeconds(last_modified));
if (!cache_duration.is_zero())
metadata->set_cache_duration_sec(cache_duration.InSeconds());
auto* translate = model->mutable_proto()->mutable_translate();
translate->set_version(1);
auto* logit = translate->mutable_translate_logistic_regression_model();
logit->set_bias(0.1f);
logit->set_accept_ratio_weight(0.2f);
logit->set_decline_ratio_weight(0.3f);
logit->set_ignore_ratio_weight(0.4f);
return model;
}
} // namespace
TEST(RankerModelTest, Serialization) {
base::Time last_modified = base::Time::Now();
base::TimeDelta cache_duration = base::TimeDelta::FromDays(3);
std::unique_ptr<RankerModel> original_model =
NewModel(kModelURL, last_modified, cache_duration);
std::string original_model_str = original_model->SerializeAsString();
std::unique_ptr<RankerModel> serialized_model =
RankerModel::FromString(original_model_str);
std::string serialized_model_str = serialized_model->SerializeAsString();
EXPECT_EQ(serialized_model_str, original_model_str);
EXPECT_EQ(serialized_model->GetSourceURL(), kModelURL);
EXPECT_EQ(serialized_model->proto().metadata().last_modified_sec(),
InSeconds(last_modified));
EXPECT_EQ(serialized_model->proto().metadata().cache_duration_sec(),
cache_duration.InSeconds());
}
TEST(RankerModelTest, IsExpired) {
base::Time today = base::Time::Now();
base::TimeDelta days_15 = base::TimeDelta::FromDays(15);
base::TimeDelta days_30 = base::TimeDelta::FromDays(30);
base::TimeDelta days_60 = base::TimeDelta::FromDays(60);
EXPECT_FALSE(NewModel(kModelURL, today, days_30)->IsExpired());
EXPECT_FALSE(NewModel(kModelURL, today - days_15, days_30)->IsExpired());
EXPECT_TRUE(NewModel(kModelURL, today - days_60, days_30)->IsExpired());
}