blob: 56497cd8c7cb5cb27a4ddba4fea0ddafcecbf590 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OMNIBOX_COMPOSEBOX_COMPOSEBOX_QUERY_CONTROLLER_H_
#define COMPONENTS_OMNIBOX_COMPOSEBOX_COMPOSEBOX_QUERY_CONTROLLER_H_
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "base/memory/ref_counted_memory.h"
#include "base/memory/scoped_refptr.h"
#include "base/time/time.h"
#include "base/unguessable_token.h"
#include "components/endpoint_fetcher/endpoint_fetcher.h"
#include "components/lens/lens_overlay_mime_type.h"
#include "components/lens/lens_overlay_request_id_generator.h"
#include "components/omnibox/composebox/composebox_query.mojom.h"
#include "components/search_engines/util.h"
#include "components/variations/variations_client.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "third_party/lens_server_proto/lens_overlay_client_context.pb.h"
#include "third_party/lens_server_proto/lens_overlay_cluster_info.pb.h"
#include "third_party/lens_server_proto/lens_overlay_server.pb.h"
#include "third_party/lens_server_proto/lens_overlay_surface.pb.h"
#include "url/gurl.h"
namespace lens {
class RefCountedLensOverlayClientLogs;
} // namespace lens
class TemplateURLService;
#if !BUILDFLAG(IS_IOS)
#include "third_party/skia/include/core/SkBitmap.h"
#endif // !BUILDFLAG(IS_IOS)
enum class QueryControllerState {
// The initial state, before NotifySessionStarted() is called.
kOff = 0,
// The cluster info request is in flight.
kAwaitingClusterInfoResponse = 1,
// The cluster info response has been received and is valid.
kClusterInfoReceived = 2,
// The cluster info response was not received, or the cluster info has
// expired.
kClusterInfoInvalid = 3,
};
namespace version_info {
enum class Channel;
} // namespace version_info
namespace signin {
class IdentityManager;
} // namespace signin
namespace composebox {
// Image encoding options for an uploaded image.
struct ImageEncodingOptions {
bool enable_webp_encoding;
int max_size;
int max_height;
int max_width;
int compression_quality;
};
} // namespace composebox
// Callback type alias for the OAuth headers created.
using OAuthHeadersCreatedCallback =
base::OnceCallback<void(std::vector<std::string>)>;
// Callback type alias for the request body proto created.
using FileUploadErrorType = composebox_query::mojom::FileUploadErrorType;
using RequestBodyProtoCreatedCallback =
base::OnceCallback<void(lens::LensOverlayServerRequest,
std::optional<FileUploadErrorType>)>;
// Callback type alias for the upload progress.
using UploadProgressCallback =
base::RepeatingCallback<void(uint64_t position, uint64_t total)>;
// Callback for when the query controller state changes.
using QueryControllerStateChangedCallback =
base::RepeatingCallback<void(QueryControllerState state)>;
// Callback for when the file upload status changes.
using FileUploadStatus = composebox_query::mojom::FileUploadStatus;
using FileUploadStatusChangedCallback =
base::RepeatingCallback<void(std::string file_token,
FileUploadStatus status)>;
class ComposeboxQueryController {
public:
// Observer interface for the Page Handler to get updates on file upload
class FileUploadStatusObserver : public base::CheckedObserver {
public:
virtual void OnFileUploadStatusChanged(
const base::UnguessableToken& file_token,
lens::MimeType mime_type,
FileUploadStatus file_upload_status,
const std::optional<FileUploadErrorType>& error_type) = 0;
protected:
~FileUploadStatusObserver() override = default;
};
// Struct containing file information for a file upload.
struct FileInfo {
public:
FileInfo();
~FileInfo();
// Name of the selected file.
std::string file_name;
// Size in bytes of the file.
uint64_t file_size_bytes = 0;
// The time the file was selected.
base::Time webui_selection_time;
// Client-side unique identifier generated by UI. Used as the key in the
// `active_files_` map.
base::UnguessableToken file_token_;
// The mime type of the file.
lens::MimeType mime_type_;
// Gets the file upload status.
FileUploadStatus GetFileUploadStatus() const { return upload_status_; }
// Gets the file upload error type.
FileUploadErrorType GetFileUploadErrorType() const {
return upload_error_type_;
}
// Gets the server response code.
int GetResponseCode() const { return response_code_; }
// Gets a pointer to the request ID for this request for testing.
lens::LensOverlayRequestId* GetRequestIdForTesting() {
return request_id_.get();
}
private:
friend class ComposeboxQueryController;
friend class ComposeboxQueryControllerIOS;
// Default to kNotUploaded, until UploadFile() is called.
// Do not modify this field directly, use UpdateFileUploadStatus() instead.
FileUploadStatus upload_status_ = FileUploadStatus::kNotUploaded;
// The error type if the upload failed.
FileUploadErrorType upload_error_type_ = FileUploadErrorType::kUnknown;
// The request ID for this request. Set by StartFileUploadFlow().
std::unique_ptr<lens::LensOverlayRequestId> request_id_;
// When browser started the network request for the file upload.
base::Time upload_network_request_start_time_;
// When Lens server response was received.
base::Time server_response_time_;
// The network response code.
int response_code_ = 0;
// The request to be sent to the server. Will be set asynchronously after
// StartFileUploadFlow() is called.
std::unique_ptr<lens::LensOverlayServerRequest> request_body_;
// The headers to attach to the request. Will be set asynchronously after
// StartFileUploadFlow() is called.
std::unique_ptr<std::vector<std::string>> request_headers_;
// The access token fetcher used for getting OAuth for the file upload
// request. Will be discarded after the OAuth headers are created.
std::unique_ptr<signin::PrimaryAccountAccessTokenFetcher>
file_upload_access_token_fetcher_;
// The endpoint fetcher used for the file upload request.
std::unique_ptr<endpoint_fetcher::EndpointFetcher>
file_upload_endpoint_fetcher_;
};
ComposeboxQueryController(
signin::IdentityManager* identity_manager,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
version_info::Channel channel,
std::string locale,
TemplateURLService* template_url_service,
variations::VariationsClient* variations_client,
bool send_lns_surface);
virtual ~ComposeboxQueryController();
// Session management. Virtual for testing.
virtual void NotifySessionStarted();
virtual void NotifySessionAbandoned();
// Called when a query has been submitted. `query_start_time` is the time
// that the user clicked the submit button.
GURL CreateAimUrl(const std::string& query_text, base::Time query_start_time);
// Observer management.
void AddObserver(FileUploadStatusObserver* obs);
void RemoveObserver(FileUploadStatusObserver* obs);
// Triggers upload of the file with data and stores the file info in the
// internal map. Call after setting the file info fields. Virtual for testing.
virtual void StartFileUploadFlow(
std::unique_ptr<FileInfo> file_info,
scoped_refptr<base::RefCountedBytes> file_data,
std::optional<composebox::ImageEncodingOptions> image_options);
// Removes file from file cache.
virtual bool DeleteFile(const base::UnguessableToken& file_token);
// Clear entire file cache.
virtual void ClearFiles();
int num_files_in_request() { return num_files_in_request_; }
// Return the file from `active_files_` map or nullptr if not found.
virtual FileInfo* GetFileInfo(const base::UnguessableToken& file_token);
protected:
// Creates the request body proto for an image and calls the callback with the
// request.
static void CreateFileUploadRequestProtoWithImageDataAndContinue(
lens::LensOverlayRequestId request_id,
lens::LensOverlayClientContext client_context,
scoped_refptr<lens::RefCountedLensOverlayClientLogs> client_logs,
RequestBodyProtoCreatedCallback callback,
lens::ImageData image_data);
// Creates the request body proto for an image and calls the callback with the
// request.
virtual void CreateImageUploadRequest(
const base::UnguessableToken& file_token,
scoped_refptr<base::RefCountedBytes> file_data,
std::optional<composebox::ImageEncodingOptions> options,
RequestBodyProtoCreatedCallback callback);
// Returns the EndpointFetcher to use with the given params. Protected to
// allow overriding in tests to mock server responses.
virtual std::unique_ptr<endpoint_fetcher::EndpointFetcher>
CreateEndpointFetcher(std::string request_string,
const GURL& fetch_url,
endpoint_fetcher::HttpMethod http_method,
base::TimeDelta timeout,
const std::vector<std::string>& request_headers,
const std::vector<std::string>& cors_exempt_headers,
UploadProgressCallback upload_progress_callback);
// Creates the client context for Lens requests. Protected to allow access
// from tests.
lens::LensOverlayClientContext CreateClientContext() const;
// Clears all cluster info.
void ClearClusterInfo();
// Resets the request cluster info state. Protected to allow tests to
// override. `session_id` is used to determine if the async timer is
// from an old, invalid session.
virtual void ResetRequestClusterInfoState(int session_id);
// The internal state of the query controller. Protected to allow tests to
// access the state. Do not modify this state directly, use
// SetQueryControllerState() instead.
QueryControllerState query_controller_state_ = QueryControllerState::kOff;
// Callback for when the query controller state changes. Protected to allow
// tests to set the callback.
QueryControllerStateChangedCallback
on_query_controller_state_changed_callback_;
// The map of active files, keyed by the file token.
// Protected to allow tests to access the files.
std::map<base::UnguessableToken, std::unique_ptr<FileInfo>> active_files_;
// Task runner used to create the file upload request proto asynchronously.
scoped_refptr<base::TaskRunner> create_request_task_runner_;
private:
// Fetches the OAuth headers and calls the callback with the headers. If the
// OAuth cannot be retrieved (like if the user is not logged in), the callback
// will be called with an empty vector. Returns the access token fetcher
// making the request so it can be kept alive.
std::unique_ptr<signin::PrimaryAccountAccessTokenFetcher>
CreateOAuthHeadersAndContinue(OAuthHeadersCreatedCallback callback);
// Gets an OAuth token for the cluster info request and proceeds with sending
// a LensOverlayServerClusterInfoRequest to get the cluster info.
void FetchClusterInfo();
// Asynchronous handler for when the fetch cluster info request headers are
// ready. Creates the endpoint fetcher and sends the cluster info request.
void SendClusterInfoNetworkRequest(std::vector<std::string> request_headers);
// Handles the response from the cluster info request.
void HandleClusterInfoResponse(
std::unique_ptr<endpoint_fetcher::EndpointResponse> response);
// Sets the query controller state and notifies the query controller state
// changed callback if it has changed.
void SetQueryControllerState(QueryControllerState new_state);
// Updates the file upload status and notifies the file upload status
// observers with an optional error type if the upload failed.
void UpdateFileUploadStatus(const base::UnguessableToken& file_token,
FileUploadStatus status,
std::optional<FileUploadErrorType> error_type);
#if !BUILDFLAG(IS_IOS)
// Handler for when the image from an image file upload is decoded. Creates
// the request body proto and calls the callback with the request.
void ProcessDecodedImageAndContinue(
lens::LensOverlayRequestId request_id,
const composebox::ImageEncodingOptions& options,
RequestBodyProtoCreatedCallback callback,
const SkBitmap& bitmap);
#endif // !BUILDFLAG(IS_IOS)
// Creates the request body proto and calls the callback with the request.
void CreateFileUploadRequestBodyAndContinue(
const base::UnguessableToken& file_token,
scoped_refptr<base::RefCountedBytes> file_data,
std::optional<composebox::ImageEncodingOptions> options,
RequestBodyProtoCreatedCallback callback);
// Asynchronous handler for when the file upload request body is ready.
void OnUploadFileRequestBodyReady(
const base::UnguessableToken& file_token,
lens::LensOverlayServerRequest request,
std::optional<FileUploadErrorType> error_type);
// Asynchronous handler for when the file upload request headers are ready.
void OnUploadFileRequestHeadersReady(const base::UnguessableToken& file_token,
std::vector<std::string> headers);
// Sends the file upload request if the request body, headers, and cluster
// info are ready.
void MaybeSendFileUploadNetworkRequest(
const base::UnguessableToken& file_token);
// Creates the endpoint fetcher and sends the file upload network request.
void SendFileUploadNetworkRequest(FileInfo* file_infon);
// Handles the response from the file upload request.
void HandleFileUploadResponse(
const base::UnguessableToken& file_token,
std::unique_ptr<endpoint_fetcher::EndpointResponse> response);
// The last received cluster info.
std::optional<lens::LensOverlayClusterInfo> cluster_info_ = std::nullopt;
// The endpoint fetcher used for the cluster info request.
std::unique_ptr<endpoint_fetcher::EndpointFetcher>
cluster_info_endpoint_fetcher_;
// The access token fetcher used for getting OAuth for the cluster info
// request. Will be discarded after the OAuth headers are created.
std::unique_ptr<signin::PrimaryAccountAccessTokenFetcher>
cluster_info_access_token_fetcher_;
// Unowned IdentityManager for fetching access tokens. Could be null for
// incognito profiles.
const raw_ptr<signin::IdentityManager> identity_manager_;
// The url loader factory to use for Lens network requests.
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// The channel to use for Lens network requests.
version_info::Channel channel_;
// The locale used for creating the client context.
std::string locale_;
// The request id generator for this query flow instance.
lens::LensOverlayRequestIdGenerator request_id_generator_;
// The observer list, managed via AddObserver() and RemoveObserver().
base::ObserverList<FileUploadStatusObserver> observers_;
// Owned by the Profile, and thus guaranteed to outlive this instance.
const raw_ptr<TemplateURLService> template_url_service_;
// Owned by the Profile, and thus guaranteed to outlive this instance.
const raw_ptr<variations::VariationsClient> variations_client_;
// Whether or not to send the lns_surface parameter.
// TODO(crbug.com/430070871): Remove this once the server supports the
// `lns_surface` parameter.
bool send_lns_surface_ = false;
// The session counter, incremented when the session is stopped. This is used
// to determine if the session is active when handling cluster info
// expiration.
int session_id_ = 0;
// The number of files that are sent in the AIM request.
int num_files_in_request_ = 0;
base::WeakPtrFactory<ComposeboxQueryController> weak_ptr_factory_{this};
};
#endif // COMPONENTS_OMNIBOX_COMPOSEBOX_COMPOSEBOX_QUERY_CONTROLLER_H_