// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_OMNIBOX_COMPOSEBOX_COMPOSEBOX_QUERY_CONTROLLER_H_
#define COMPONENTS_OMNIBOX_COMPOSEBOX_COMPOSEBOX_QUERY_CONTROLLER_H_
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "base/memory/ref_counted_memory.h"
#include "base/memory/scoped_refptr.h"
#include "base/time/time.h"
#include "base/unguessable_token.h"
#include "components/endpoint_fetcher/endpoint_fetcher.h"
#include "components/lens/contextual_input.h"
#include "components/lens/lens_bitmap_processing.h"
#include "components/lens/lens_overlay_mime_type.h"
#include "components/lens/lens_overlay_request_id_generator.h"
#include "components/omnibox/composebox/composebox_query.mojom.h"
#include "components/search_engines/util.h"
#include "components/variations/variations_client.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "third_party/lens_server_proto/lens_overlay_client_context.pb.h"
#include "third_party/lens_server_proto/lens_overlay_cluster_info.pb.h"
#include "third_party/lens_server_proto/lens_overlay_server.pb.h"
#include "third_party/lens_server_proto/lens_overlay_surface.pb.h"
#include "url/gurl.h"

namespace lens {
class RefCountedLensOverlayClientLogs;
}  // namespace lens

class TemplateURLService;

#if !BUILDFLAG(IS_IOS)
#include "third_party/skia/include/core/SkBitmap.h"
#endif  // !BUILDFLAG(IS_IOS)

enum class QueryControllerState {
  // The initial state, before NotifySessionStarted() is called.
  kOff = 0,
  // The cluster info request is in flight.
  kAwaitingClusterInfoResponse = 1,
  // The cluster info response has been received and is valid.
  kClusterInfoReceived = 2,
  // The cluster info response was not received, or the cluster info has
  // expired.
  kClusterInfoInvalid = 3,
};

namespace version_info {
enum class Channel;
}  // namespace version_info

namespace signin {
class IdentityManager;
}  // namespace signin

// Callback type alias for the OAuth headers created.
using OAuthHeadersCreatedCallback =
    base::OnceCallback<void(std::vector<std::string>)>;
// Callback type alias for the request body proto created.
using FileUploadErrorType = composebox_query::mojom::FileUploadErrorType;
using RequestBodyProtoCreatedCallback =
    base::OnceCallback<void(lens::LensOverlayServerRequest,
                            std::optional<FileUploadErrorType>)>;
// Callback type alias for the upload progress.
using UploadProgressCallback =
    base::RepeatingCallback<void(uint64_t position, uint64_t total)>;
// Callback for when the query controller state changes.
using QueryControllerStateChangedCallback =
    base::RepeatingCallback<void(QueryControllerState state)>;
// Callback for when the file upload status changes.
using FileUploadStatus = composebox_query::mojom::FileUploadStatus;
using FileUploadStatusChangedCallback =
    base::RepeatingCallback<void(std::string file_token,
                                 FileUploadStatus status)>;

// TODO(crbug.com/440427508): Move this class to components/lens and rename it.
class ComposeboxQueryController {
 public:
  // Observer interface for the Page Handler to get updates on file upload
  class FileUploadStatusObserver : public base::CheckedObserver {
   public:
    virtual void OnFileUploadStatusChanged(
        const base::UnguessableToken& file_token,
        lens::MimeType mime_type,
        FileUploadStatus file_upload_status,
        const std::optional<FileUploadErrorType>& error_type) = 0;

   protected:
    ~FileUploadStatusObserver() override = default;
  };

  // Struct containing information about an individual network request.
  // TODO(crbug.com/441351005): Make this struct private and rename it.
  struct UploadRequest {
   public:
    UploadRequest();
    ~UploadRequest();

    // The time the request was sent.
    base::Time start_time;
    // The time the response was received.
    base::Time response_time;
    // The response code of the request. 0 if the response has not been
    // received.
    int response_code = 0;
    // The request body to be sent to the server. Will be set asynchronously
    // after StartFileUploadFlow() is called.
    std::unique_ptr<lens::LensOverlayServerRequest> request_body;
    // The endpoint fetcher used for the request.
    std::unique_ptr<endpoint_fetcher::EndpointFetcher> endpoint_fetcher_;
  };

  // Struct containing file information for a file upload.
  // TODO(crbug.com/441351005): Make this struct private and rename it.
  struct FileInfo {
   public:
    FileInfo();
    ~FileInfo();

    // Name of the selected file.
    std::string file_name;

    // Size in bytes of the file.
    uint64_t file_size_bytes = 0;

    // The time the file was selected.
    base::Time webui_selection_time;

    // Client-side unique identifier generated by UI. Used as the key in the
    // `active_files_` map.
    base::UnguessableToken file_token_;

    // The mime type of the file.
    lens::MimeType mime_type_;

    // Gets the file upload status.
    FileUploadStatus GetFileUploadStatus() const { return upload_status_; }

    // Gets the file upload error type.
    FileUploadErrorType GetFileUploadErrorType() const {
      return upload_error_type_;
    }

    // Gets a pointer to the request ID for this request for testing.
    lens::LensOverlayRequestId* GetRequestIdForTesting() {
      return request_id_.get();
    }

   private:
    friend class ComposeboxQueryController;
    friend class ComposeboxQueryControllerIOS;

    // Default to kNotUploaded, until UploadFile() is called.
    // Do not modify this field directly, use UpdateFileUploadStatus() instead.
    FileUploadStatus upload_status_ = FileUploadStatus::kNotUploaded;

    // The error type if the upload failed.
    FileUploadErrorType upload_error_type_ = FileUploadErrorType::kUnknown;

    // The request ID for this request. Set by StartFileUploadFlow().
    std::unique_ptr<lens::LensOverlayRequestId> request_id_;

    // The headers to attach to the request. Will be set asynchronously after
    // StartFileUploadFlow() is called.
    std::unique_ptr<std::vector<std::string>> request_headers_;

    // The access token fetcher used for getting OAuth for the file upload
    // request. Will be discarded after the OAuth headers are created.
    std::unique_ptr<signin::PrimaryAccountAccessTokenFetcher>
        file_upload_access_token_fetcher_;

    // The upload requests.
    std::vector<std::unique_ptr<UploadRequest>> upload_requests_;

    // The number of outstanding network requests. This is set in
    // StartFileUploadFlow() and decremented when successful network responses
    // are received.
    size_t num_outstanding_network_requests_ = 0;
  };

  ComposeboxQueryController(
      signin::IdentityManager* identity_manager,
      scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
      version_info::Channel channel,
      std::string locale,
      TemplateURLService* template_url_service,
      variations::VariationsClient* variations_client,
      bool send_lns_surface,
      bool enable_multi_context_input_flow);
  virtual ~ComposeboxQueryController();

  // Session management. Virtual for testing.
  virtual void NotifySessionStarted();
  virtual void NotifySessionAbandoned();

  std::unique_ptr<lens::LensOverlayRequestId> GetNextRequestId(
      lens::RequestIdUpdateMode update_mode,
      lens::MimeType mime_type,
      lens::LensOverlayRequestId_MediaType media_type);

  // Called when a query has been submitted. `query_start_time` is the time
  // that the user clicked the submit button.
  GURL CreateAimUrl(const std::string& query_text,
                    base::Time query_start_time,
                    std::map<std::string, std::string> additional_params = {});

  // Observer management.
  void AddObserver(FileUploadStatusObserver* obs);
  void RemoveObserver(FileUploadStatusObserver* obs);

  // Triggers upload of the file with data and stores the file info in the
  // internal map. Call after setting the file info fields. Virtual for testing.
  // TODO(crbug.com/441161325): Rename this method to reference
  // "contextual inputs" instead of "files".
  virtual void StartFileUploadFlow(
      const base::UnguessableToken& file_token,
      std::unique_ptr<lens::ContextualInputData> contextual_input_data,
      std::optional<lens::ImageEncodingOptions> image_options);

  // Removes file from file cache.
  virtual bool DeleteFile(const base::UnguessableToken& file_token);

  // Clear entire file cache.
  virtual void ClearFiles();

  // Clears the suggest inputs.
  virtual void ClearSuggestInputs();

  int num_files_in_request() { return num_files_in_request_; }

  // Return the file from `active_files_` map or nullptr if not found.
  virtual FileInfo* GetFileInfo(const base::UnguessableToken& file_token);

  const lens::proto::LensOverlaySuggestInputs& suggest_inputs() const {
    return suggest_inputs_;
  }

 protected:
  // Creates the request body proto for an image and calls the callback with the
  // request.
  static void CreateFileUploadRequestProtoWithImageDataAndContinue(
      lens::LensOverlayRequestId request_id,
      lens::LensOverlayClientContext client_context,
      scoped_refptr<lens::RefCountedLensOverlayClientLogs> client_logs,
      RequestBodyProtoCreatedCallback callback,
      lens::ImageData image_data);

  // Creates the request body proto for an image and calls the callback with the
  // request.
  virtual void CreateImageUploadRequest(
      const base::UnguessableToken& file_token,
      const std::vector<uint8_t>& image_data,
      std::optional<lens::ImageEncodingOptions> options,
      RequestBodyProtoCreatedCallback callback);

  // Returns the EndpointFetcher to use with the given params. Protected to
  // allow overriding in tests to mock server responses.
  virtual std::unique_ptr<endpoint_fetcher::EndpointFetcher>
  CreateEndpointFetcher(std::string request_string,
                        const GURL& fetch_url,
                        endpoint_fetcher::HttpMethod http_method,
                        base::TimeDelta timeout,
                        const std::vector<std::string>& request_headers,
                        const std::vector<std::string>& cors_exempt_headers,
                        UploadProgressCallback upload_progress_callback);

  // Creates the client context for Lens requests. Protected to allow access
  // from tests.
  lens::LensOverlayClientContext CreateClientContext() const;

  // Clears all cluster info.
  void ClearClusterInfo();

  // Resets the request cluster info state. Protected to allow tests to
  // override. `session_id` is used to determine if the async timer is
  // from an old, invalid session.
  virtual void ResetRequestClusterInfoState(int session_id);

  // The internal state of the query controller. Protected to allow tests to
  // access the state. Do not modify this state directly, use
  // SetQueryControllerState() instead.
  QueryControllerState query_controller_state_ = QueryControllerState::kOff;

  // Callback for when the query controller state changes. Protected to allow
  // tests to set the callback.
  QueryControllerStateChangedCallback
      on_query_controller_state_changed_callback_;

  // The map of active files, keyed by the file token.
  // Protected to allow tests to access the files.
  std::map<base::UnguessableToken, std::unique_ptr<FileInfo>> active_files_;

  // Task runner used to create the file upload request proto asynchronously.
  scoped_refptr<base::TaskRunner> create_request_task_runner_;

 private:
  // Fetches the OAuth headers and calls the callback with the headers. If the
  // OAuth cannot be retrieved (like if the user is not logged in), the callback
  // will be called with an empty vector. Returns the access token fetcher
  // making the request so it can be kept alive.
  std::unique_ptr<signin::PrimaryAccountAccessTokenFetcher>
  CreateOAuthHeadersAndContinue(OAuthHeadersCreatedCallback callback);

  // Gets an OAuth token for the cluster info request and proceeds with sending
  // a LensOverlayServerClusterInfoRequest to get the cluster info.
  void FetchClusterInfo();

  // Asynchronous handler for when the fetch cluster info request headers are
  // ready. Creates the endpoint fetcher and sends the cluster info request.
  void SendClusterInfoNetworkRequest(std::vector<std::string> request_headers);

  // Handles the response from the cluster info request.
  void HandleClusterInfoResponse(
      std::unique_ptr<endpoint_fetcher::EndpointResponse> response);

  // Sets the query controller state and notifies the query controller state
  // changed callback if it has changed.
  void SetQueryControllerState(QueryControllerState new_state);

  // Updates the file upload status and notifies the file upload status
  // observers with an optional error type if the upload failed.
  void UpdateFileUploadStatus(const base::UnguessableToken& file_token,
                              FileUploadStatus status,
                              std::optional<FileUploadErrorType> error_type);

#if !BUILDFLAG(IS_IOS)
  // Handler for when the image from an image file upload is decoded. Creates
  // the request body proto and calls the callback with the request.
  void ProcessDecodedImageAndContinue(lens::LensOverlayRequestId request_id,
                                      const lens::ImageEncodingOptions& options,
                                      RequestBodyProtoCreatedCallback callback,
                                      const SkBitmap& bitmap);
#endif  // !BUILDFLAG(IS_IOS)

  // Creates the request body protos for the file and viewport upload requests
  // and calls the callbacks with the request.
  void CreateUploadRequestBodiesAndContinue(
      const base::UnguessableToken& file_token,
      std::unique_ptr<lens::ContextualInputData> contextual_input_data,
      std::optional<lens::ImageEncodingOptions> options);

  // Callback that takes the image request body proto and adds the pdf page
  // index to it.
  void AddPageIndexToImageUploadRequestAndContinue(
      std::optional<size_t> pdf_page_index,
      RequestBodyProtoCreatedCallback callback,
      lens::LensOverlayServerRequest request,
      std::optional<FileUploadErrorType> error_type);

  // Asynchronous handler for when an upload request body is ready.
  void OnUploadRequestBodyReady(const base::UnguessableToken& file_token,
                                size_t request_index,
                                lens::LensOverlayServerRequest request,
                                std::optional<FileUploadErrorType> error_type);

  // Asynchronous handler for when the request headers for uploading file and
  // viewport data are ready.
  void OnUploadRequestHeadersReady(const base::UnguessableToken& file_token,
                                   std::vector<std::string> headers);

  // Sends the upload request if the request body, headers, and cluster
  // info are ready.
  void MaybeSendUploadNetworkRequest(const base::UnguessableToken& file_token,
                                     size_t request_index);

  // Creates the endpoint fetcher and sends the upload network request.
  void SendUploadNetworkRequest(FileInfo* file_info, size_t request_index);

  // Callback for when an upload endpoint fetcher is created, storing it
  // updating the file info state.
  void OnUploadEndpointFetcherCreated(
      const base::UnguessableToken& file_token,
      size_t request_index,
      std::unique_ptr<endpoint_fetcher::EndpointFetcher> endpoint_fetcher);

  // Handles the response from an upload request.
  void HandleUploadResponse(
      const base::UnguessableToken& file_token,
      size_t request_index,
      std::unique_ptr<endpoint_fetcher::EndpointResponse> response);

  // Performs the fetch request.
  void PerformFetchRequest(
      lens::LensOverlayServerRequest* request,
      std::vector<std::string>* request_headers,
      base::TimeDelta timeout,
      base::OnceCallback<
          void(std::unique_ptr<endpoint_fetcher::EndpointFetcher>)>
          fetcher_created_callback,
      endpoint_fetcher::EndpointFetcherCallback response_received_callback,
      UploadProgressCallback upload_progress_callback = base::NullCallback());

  // The last received cluster info.
  std::optional<lens::LensOverlayClusterInfo> cluster_info_ = std::nullopt;

  // The endpoint fetcher used for the cluster info request.
  std::unique_ptr<endpoint_fetcher::EndpointFetcher>
      cluster_info_endpoint_fetcher_;

  // The access token fetcher used for getting OAuth for the cluster info
  // request. Will be discarded after the OAuth headers are created.
  std::unique_ptr<signin::PrimaryAccountAccessTokenFetcher>
      cluster_info_access_token_fetcher_;

  // Unowned IdentityManager for fetching access tokens. Could be null for
  // incognito profiles.
  const raw_ptr<signin::IdentityManager> identity_manager_;

  // The url loader factory to use for Lens network requests.
  scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;

  // The channel to use for Lens network requests.
  version_info::Channel channel_;

  // The locale used for creating the client context.
  std::string locale_;

  // The request id generator for this query flow instance.
  lens::LensOverlayRequestIdGenerator request_id_generator_;

  // The observer list, managed via AddObserver() and RemoveObserver().
  base::ObserverList<FileUploadStatusObserver> observers_;

  // Owned by the Profile, and thus guaranteed to outlive this instance.
  const raw_ptr<TemplateURLService> template_url_service_;

  // Owned by the Profile, and thus guaranteed to outlive this instance.
  const raw_ptr<variations::VariationsClient> variations_client_;

  // Whether or not to send the lns_surface parameter.
  // TODO(crbug.com/430070871): Remove this once the server supports the
  // `lns_surface` parameter.
  bool send_lns_surface_ = false;

  // Whether or not to use the multiple-input id request generation flow.
  bool enable_multi_context_input_flow_ = false;

  lens::proto::LensOverlaySuggestInputs suggest_inputs_;

  // The session counter, incremented when the session is stopped. This is used
  // to determine if the session is active when handling cluster info
  // expiration.
  int session_id_ = 0;

  // The number of files that are sent in the AIM request.
  int num_files_in_request_ = 0;

  base::WeakPtrFactory<ComposeboxQueryController> weak_ptr_factory_{this};
};

#endif  // COMPONENTS_OMNIBOX_COMPOSEBOX_COMPOSEBOX_QUERY_CONTROLLER_H_
