blob: 022440b39fdd6e64f0c61d61221180df82f56231 [file] [log] [blame] [edit]
// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_PROVIDER_H_
#define COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_PROVIDER_H_
#include <memory>
#include "base/callback_list.h"
#include "base/memory/raw_ptr.h"
#include "base/sequence_checker.h"
#include "base/task/sequenced_task_runner.h"
#include "components/omnibox/browser/autocomplete_enums.h"
#include "components/omnibox/browser/autocomplete_provider.h"
#include "components/omnibox/browser/autocomplete_provider_client.h"
#include "components/omnibox/browser/on_device_head_model.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
// TODO(crbug.com/40241602): clean up this build flag guard later if possible.
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "components/omnibox/browser/on_device_tail_model_executor.h"
#endif
class AutocompleteProviderListener;
// An asynchronous autocomplete provider which receives input string and tries
// to find the matches in an on device head model. This provider is designed to
// help users get suggestions when they are in poor network.
// By default, all matches provided by this provider will have a relevance no
// greater than 99, such that its matches will not show before any other
// providers; However the relevance can be changed to any arbitrary value by
// Finch when the input is not classified as a URL.
// TODO(crbug.com/40241602): rename this provider to "OnDeviceProvider" since it
// will serve both head and tail suggestions.
class OnDeviceHeadProvider : public AutocompleteProvider {
public:
static OnDeviceHeadProvider* Create(AutocompleteProviderClient* client,
AutocompleteProviderListener* listener);
void Start(const AutocompleteInput& input, bool minimal_changes) override;
void Stop(AutocompleteStopReason stop_reason) override;
void AddProviderInfo(ProvidersInfo* provider_info) const override;
AutocompleteProviderClient* client() { return client_; }
private:
friend class OnDeviceHeadProviderTest;
// A useful data structure to store Autocomplete input and suggestions fetched
// from the on device head model for a search request to the model.
struct OnDeviceHeadProviderParams;
OnDeviceHeadProvider(AutocompleteProviderClient* client,
AutocompleteProviderListener* listener);
~OnDeviceHeadProvider() override;
OnDeviceHeadProvider(const OnDeviceHeadProvider&) = delete;
OnDeviceHeadProvider& operator=(const OnDeviceHeadProvider&) = delete;
bool IsOnDeviceHeadProviderAllowed(const AutocompleteInput& input);
// Helper functions used for asynchronous search to the on device models.
// The Autocomplete input and output from the model will be passed from
// DoSearch to AllSearchDone via the OnDeviceHeadProviderParams object.
// DoSearch: searches the on device models and returns the tops suggestions
// matches the given AutocompleteInput.
void DoSearch(std::unique_ptr<OnDeviceHeadProviderParams> params);
// AllSearchDone: called after all searches are completed, fills |matches_|
// with the suggestions fetched by DoSearch and then calls NotifyListeners.
void AllSearchDone(std::unique_ptr<OnDeviceHeadProviderParams> params);
// Helper function to be called when searches to the head model is done.
void HeadModelSearchDone(std::unique_ptr<OnDeviceHeadProviderParams> params);
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
// Helper function to be called when searches to the tail model is done.
void TailModelSearchDone(
std::unique_ptr<OnDeviceHeadProviderParams> params,
std::vector<OnDeviceTailModelExecutor::Prediction> predictions);
#endif
// Helper functions to read head model filename from the static
// OnDeviceModelUpdateListener instance.
std::string GetOnDeviceHeadModelFilename() const;
// Fetches suggestions matching the params from the given on device head
// model.
static std::unique_ptr<OnDeviceHeadProviderParams>
GetSuggestionsFromHeadModel(
const std::string& model_filename,
const size_t provider_max_matches,
std::unique_ptr<OnDeviceHeadProviderParams> params);
// Determines whether should fetch tail suggestions.
static bool ShouldFetchTailSuggestions(
const OnDeviceHeadProviderParams& params,
const std::string& locale);
raw_ptr<AutocompleteProviderClient> client_;
// The task runner dedicated for on device model operations which is added to
// offload expensive operations out of the UI sequence.
scoped_refptr<base::SequencedTaskRunner> worker_task_runner_;
// Sequence checker that ensure autocomplete request handling will only happen
// on main thread.
SEQUENCE_CHECKER(main_sequence_checker_);
// The request id used to trace current request to the on device models.
// The id will be increased whenever a new request is received from the
// AutocompleteController.
size_t on_device_search_request_id_;
base::WeakPtrFactory<OnDeviceHeadProvider> weak_ptr_factory_{this};
};
#endif // COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_PROVIDER_H_