blob: 225fde873587045c5c9533664a261d72cdd8d045 [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include <utility>
#include <memory>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/command_line.h"
#include "base/files/file_util.h"
#include "base/files/important_file_writer.h"
#include "base/macros.h"
#include "base/metrics/histogram_macros.h"
#include "base/sequenced_task_runner.h"
#include "base/strings/string_util.h"
#include "base/task/post_task.h"
#include "base/task_runner_util.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_url_fetcher.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace assist_ranker {
namespace {
// The minimum duration, in minutes, between download attempts.
constexpr int kMinRetryDelayMins = 3;
// Suffixes for the various histograms produced by the backend.
const char kWriteTimerHistogram[] = ".Timer.WriteModel";
const char kReadTimerHistogram[] = ".Timer.ReadModel";
const char kDownloadTimerHistogram[] = ".Timer.DownloadModel";
const char kParsetimerHistogram[] = ".Timer.ParseModel";
const char kModelStatusHistogram[] = ".Model.Status";
// Helper function to UMA log a timer histograms.
void RecordTimerHistogram(const std::string& name, base::TimeDelta duration) {
base::HistogramBase* counter = base::Histogram::FactoryTimeGet(
name, base::TimeDelta::FromMilliseconds(10),
base::TimeDelta::FromMilliseconds(200000), 100,
base::HistogramBase::kUmaTargetedHistogramFlag);
DCHECK(counter);
counter->AddTime(duration);
}
// A helper class to produce a scoped timer histogram that supports using a
// non-static-const name.
class MyScopedHistogramTimer {
public:
MyScopedHistogramTimer(const base::StringPiece& name)
: name_(name.begin(), name.end()), start_(base::TimeTicks::Now()) {}
~MyScopedHistogramTimer() {
RecordTimerHistogram(name_, base::TimeTicks::Now() - start_);
}
private:
const std::string name_;
const base::TimeTicks start_;
DISALLOW_COPY_AND_ASSIGN(MyScopedHistogramTimer);
};
std::string LoadFromFile(const base::FilePath& model_path) {
DCHECK(!model_path.empty());
DVLOG(2) << "Reading data from: " << model_path.value();
std::string data;
if (!base::ReadFileToString(model_path, &data) || data.empty()) {
DVLOG(2) << "Failed to read data from: " << model_path.value();
data.clear();
}
return data;
}
void SaveToFile(const GURL& model_url,
const base::FilePath& model_path,
const std::string& model_data,
const std::string& uma_prefix) {
DVLOG(2) << "Saving model from '" << model_url << "'' to '"
<< model_path.value() << "'.";
MyScopedHistogramTimer timer(uma_prefix + kWriteTimerHistogram);
base::ImportantFileWriter::WriteFileAtomically(model_path, model_data);
}
} // namespace
RankerModelLoaderImpl::RankerModelLoaderImpl(
ValidateModelCallback validate_model_cb,
OnModelAvailableCallback on_model_available_cb,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
base::FilePath model_path,
GURL model_url,
std::string uma_prefix)
: background_task_runner_(base::CreateSequencedTaskRunnerWithTraits(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT,
base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})),
validate_model_cb_(std::move(validate_model_cb)),
on_model_available_cb_(std::move(on_model_available_cb)),
url_loader_factory_(std::move(url_loader_factory)),
model_path_(std::move(model_path)),
model_url_(std::move(model_url)),
uma_prefix_(std::move(uma_prefix)),
url_fetcher_(std::make_unique<RankerURLFetcher>()),
weak_ptr_factory_(this) {}
RankerModelLoaderImpl::~RankerModelLoaderImpl() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
void RankerModelLoaderImpl::NotifyOfRankerActivity() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
switch (state_) {
case LoaderState::NOT_STARTED:
if (!model_path_.empty()) {
StartLoadFromFile();
break;
}
// There was no configured model path. Switch the state to IDLE and
// fall through to consider the URL.
state_ = LoaderState::IDLE;
FALLTHROUGH;
case LoaderState::IDLE:
if (model_url_.is_valid()) {
StartLoadFromURL();
break;
}
// There was no configured model URL. Switch the state to FINISHED and
// fall through.
state_ = LoaderState::FINISHED;
FALLTHROUGH;
case LoaderState::FINISHED:
case LoaderState::LOADING_FROM_FILE:
case LoaderState::LOADING_FROM_URL:
// Nothing to do.
break;
}
}
void RankerModelLoaderImpl::StartLoadFromFile() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_EQ(state_, LoaderState::NOT_STARTED);
DCHECK(!model_path_.empty());
state_ = LoaderState::LOADING_FROM_FILE;
load_start_time_ = base::TimeTicks::Now();
base::PostTaskAndReplyWithResult(
background_task_runner_.get(), FROM_HERE,
base::BindOnce(&LoadFromFile, model_path_),
base::BindOnce(&RankerModelLoaderImpl::OnFileLoaded,
weak_ptr_factory_.GetWeakPtr()));
}
void RankerModelLoaderImpl::OnFileLoaded(const std::string& data) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_EQ(state_, LoaderState::LOADING_FROM_FILE);
// Record the duration of the download.
RecordTimerHistogram(uma_prefix_ + kReadTimerHistogram,
base::TimeTicks::Now() - load_start_time_);
// Empty data means |model_path| wasn't successfully read. Otherwise,
// parse and validate the model.
std::unique_ptr<RankerModel> model;
if (data.empty()) {
ReportModelStatus(RankerModelStatus::LOAD_FROM_CACHE_FAILED);
} else {
model = CreateAndValidateModel(data);
}
// If |model| is nullptr, then data is empty or the parse failed. Transition
// to IDLE, from which URL download can be attempted.
if (!model) {
state_ = LoaderState::IDLE;
} else {
// The model is valid. The client is willing/able to use it. Keep track
// of where it originated and whether or not is has expired.
std::string url_spec = model->GetSourceURL();
bool is_expired = model->IsExpired();
bool is_finished = url_spec == model_url_.spec() && !is_expired;
DVLOG(2) << (is_expired ? "Expired m" : "M") << "odel in '"
<< model_path_.value() << "' was originally downloaded from '"
<< url_spec << "'.";
// If the cached model came from currently configured |model_url_| and has
// not expired, transition to FINISHED, as there is no need for a model
// download; otherwise, transition to IDLE.
state_ = is_finished ? LoaderState::FINISHED : LoaderState::IDLE;
// Transfer the model to the client.
on_model_available_cb_.Run(std::move(model));
}
// Notify the state machine. This will immediately kick off a download if
// one is required, instead of waiting for the next organic detection of
// ranker activity.
NotifyOfRankerActivity();
}
void RankerModelLoaderImpl::StartLoadFromURL() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_EQ(state_, LoaderState::IDLE);
DCHECK(model_url_.is_valid());
// Do nothing if the download attempts should be throttled.
if (base::TimeTicks::Now() < next_earliest_download_time_) {
DVLOG(2) << "Last download attempt was too recent.";
ReportModelStatus(RankerModelStatus::DOWNLOAD_THROTTLED);
return;
}
// Kick off the next download attempt and reset the time of the next earliest
// allowable download attempt.
DVLOG(2) << "Downloading model from: " << model_url_;
state_ = LoaderState::LOADING_FROM_URL;
load_start_time_ = base::TimeTicks::Now();
next_earliest_download_time_ =
load_start_time_ + base::TimeDelta::FromMinutes(kMinRetryDelayMins);
bool request_started =
url_fetcher_->Request(model_url_,
base::BindOnce(&RankerModelLoaderImpl::OnURLFetched,
weak_ptr_factory_.GetWeakPtr()),
url_loader_factory_.get());
// |url_fetcher_| maintains a request retry counter. If all allowed attempts
// have already been exhausted, then the loader is finished and has abandoned
// loading the model.
if (!request_started) {
DVLOG(2) << "Model download abandoned.";
ReportModelStatus(RankerModelStatus::MODEL_LOADING_ABANDONED);
state_ = LoaderState::FINISHED;
}
}
void RankerModelLoaderImpl::OnURLFetched(bool success,
const std::string& data) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_EQ(state_, LoaderState::LOADING_FROM_URL);
// Record the duration of the download.
RecordTimerHistogram(uma_prefix_ + kDownloadTimerHistogram,
base::TimeTicks::Now() - load_start_time_);
// On request failure, transition back to IDLE. The loader will retry, or
// enforce the max download attempts, later.
if (!success || data.empty()) {
DVLOG(2) << "Download from '" << model_url_ << "'' failed.";
ReportModelStatus(RankerModelStatus::DOWNLOAD_FAILED);
state_ = LoaderState::IDLE;
return;
}
// Attempt to loads the model. If this fails, transition back to IDLE. The
// loader will retry, or enfore the max download attempts, later.
auto model = CreateAndValidateModel(data);
if (!model) {
DVLOG(2) << "Model from '" << model_url_ << "'' not valid.";
state_ = LoaderState::IDLE;
return;
}
// The model is valid. Update the metadata to track the source URL and
// download timestamp.
auto* metadata = model->mutable_proto()->mutable_metadata();
metadata->set_source(model_url_.spec());
metadata->set_last_modified_sec(
(base::Time::Now() - base::Time()).InSeconds());
// Cache the model to model_path_, in the background.
if (!model_path_.empty()) {
background_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&SaveToFile, model_url_, model_path_,
model->SerializeAsString(), uma_prefix_));
}
// The loader is finished. Transfer the model to the client.
state_ = LoaderState::FINISHED;
on_model_available_cb_.Run(std::move(model));
}
std::unique_ptr<RankerModel> RankerModelLoaderImpl::CreateAndValidateModel(
const std::string& data) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
MyScopedHistogramTimer timer(uma_prefix_ + kParsetimerHistogram);
auto model = RankerModel::FromString(data);
if (ReportModelStatus(model ? validate_model_cb_.Run(*model)
: RankerModelStatus::PARSE_FAILED) !=
RankerModelStatus::OK) {
return nullptr;
}
return model;
}
RankerModelStatus RankerModelLoaderImpl::ReportModelStatus(
RankerModelStatus model_status) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
base::HistogramBase* histogram = base::LinearHistogram::FactoryGet(
uma_prefix_ + kModelStatusHistogram, 1,
static_cast<int>(RankerModelStatus::MAX),
static_cast<int>(RankerModelStatus::MAX) + 1,
base::HistogramBase::kUmaTargetedHistogramFlag);
if (histogram)
histogram->Add(static_cast<int>(model_status));
return model_status;
}
} // namespace assist_ranker