| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/history_embeddings/ml_answerer.h" |
| |
| #include <algorithm> |
| |
| #include "base/barrier_callback.h" |
| #include "base/memory/scoped_refptr.h" |
| #include "base/strings/stringprintf.h" |
| #include "components/history_embeddings/history_embeddings_features.h" |
| #include "components/optimization_guide/core/model_quality/model_execution_logging_wrappers.h" |
| #include "components/optimization_guide/core/optimization_guide_model_executor.h" |
| #include "components/optimization_guide/core/optimization_guide_util.h" |
| #include "components/optimization_guide/proto/features/history_answer.pb.h" |
| |
| namespace history_embeddings { |
| |
| using ModelExecutionError = optimization_guide:: |
| OptimizationGuideModelExecutionError::ModelExecutionError; |
| using optimization_guide::OptimizationGuideModelExecutionError; |
| using optimization_guide::OptimizationGuideModelStreamingExecutionResult; |
| using optimization_guide::SessionConfigParams; |
| using optimization_guide::proto::Answer; |
| using optimization_guide::proto::HistoryAnswerRequest; |
| using optimization_guide::proto::Passage; |
| |
| namespace { |
| |
| static constexpr std::string kPassageIdToken = "ID"; |
| // Estimated token count of the preamble text in prompt. |
| static constexpr size_t kPreambleTokenBufferSize = 100u; |
| // Estimated token count of overhead text per passage. |
| static constexpr size_t kExtraTokensPerPassage = 10u; |
| |
| std::string GetPassageIdStr(size_t id) { |
| return base::StringPrintf("%04d", static_cast<int>(id)); |
| } |
| |
| float GetMlAnswerScoreThreshold() { |
| return GetFeatureParameters().ml_answerer_min_score; |
| } |
| |
| } // namespace |
| |
| // Helper struct to bundle raw model input (queries/passages) with its metadata. |
| struct MlAnswerer::ModelInput { |
| // The string content of this input. |
| std::string text; |
| // Index 0 is reserved for queries, i.e. this index will be 0 iff. this input |
| // is a query. If the input is a passage, index will contain the index of the |
| // passage in the original passage vector (where lower index means higher |
| // relevance), plus 1 to offset for query. |
| size_t index; |
| // The size of `text` in tokens. |
| uint32_t token_count; |
| }; |
| |
| // Manages sessions for generating an answer for a given query and multiple |
| // URLs. |
| class MlAnswerer::SessionManager { |
| public: |
| using SessionScoreType = std::tuple<int, std::optional<float>>; |
| |
| SessionManager(std::string query, |
| Context context, |
| ComputeAnswerCallback callback, |
| base::WeakPtr<ModelQualityLogsUploaderService> logs_uploader) |
| : query_(std::move(query)), |
| context_(std::move(context)), |
| callback_(std::move(callback)), |
| origin_task_runner_(base::SequencedTaskRunner::GetCurrentDefault()), |
| logs_uploader_(logs_uploader), |
| weak_ptr_factory_(this) {} |
| |
| ~SessionManager() { |
| // Explicitly invalidate weak pointers to prevent callbacks that may be |
| // triggered by the destructor logic. |
| weak_ptr_factory_.InvalidateWeakPtrs(); |
| // Run the existing callback if not called yet with canceled status. |
| if (!callback_.is_null()) { |
| FinishAndResetSessions(AnswererResult( |
| ComputeAnswerStatus::kExecutionCancelled, query_, Answer())); |
| } |
| } |
| |
| // Adds a session that contains query and passage context. |
| // It exists until this manager resets or gets destroyed. |
| void AddSession( |
| std::unique_ptr<OptimizationGuideModelExecutor::Session> session, |
| std::string url) { |
| sessions_.push_back(std::move(session)); |
| urls_.push_back(url); |
| } |
| |
| // Runs speculative decoding by first getting scores for each URL candidate |
| // and continuing decoding with only the highest scored session. |
| void RunSpeculativeDecoding() { |
| const size_t num_sessions = GetNumberOfSessions(); |
| base::OnceCallback<void(const std::vector<SessionScoreType>&)> cb = |
| base::BindOnce(&SessionManager::SortAndDecode, |
| weak_ptr_factory_.GetWeakPtr()); |
| const auto barrier_cb = |
| base::BarrierCallback<SessionScoreType>(num_sessions, std::move(cb)); |
| for (size_t s_index = 0; s_index < num_sessions; s_index++) { |
| VLOG(3) << "Running Score for session " << s_index; |
| sessions_[s_index]->Score( |
| kPassageIdToken, base::BindOnce( |
| [](size_t index, std::optional<float> score) { |
| VLOG(3) << "Score complete for " << index; |
| return std::make_tuple(index, score); |
| }, |
| s_index) |
| .Then(barrier_cb)); |
| } |
| } |
| |
| size_t GetNumberOfSessions() { return sessions_.size(); } |
| |
| base::WeakPtr<MlAnswerer::SessionManager> GetWeakPtr() { |
| return weak_ptr_factory_.GetWeakPtr(); |
| } |
| |
| // Runs callback with result. |
| void FinishCallback(AnswererResult answer_result) { |
| CHECK(!callback_.is_null()); |
| origin_task_runner_->PostTask( |
| FROM_HERE, |
| base::BindOnce(std::move(callback_), std::move(answer_result))); |
| } |
| |
| // Finishes and cleans up sessions. |
| void FinishAndResetSessions(AnswererResult answer_result) { |
| FinishCallback(std::move(answer_result)); |
| |
| // Destroy all existing sessions. |
| VLOG(3) << "Sessions cleared."; |
| sessions_.clear(); |
| urls_.clear(); |
| } |
| |
| // Called when all sessions are started and added. |
| void OnSessionsStarted(std::vector<int> args) { RunSpeculativeDecoding(); } |
| |
| // Called when token counts of the query and all passages of a session are |
| // computed. |
| void OnTokenCountRetrieved(std::unique_ptr<Session> session, |
| const std::string url, |
| base::OnceCallback<void(int)> session_added_cb, |
| std::vector<ModelInput> inputs) { |
| HistoryAnswerRequest request; |
| int token_limit = session->GetTokenLimits().min_context_tokens; |
| // Reserve space for preamble text. |
| int token_count = kPreambleTokenBufferSize; |
| |
| // Sort the inputs according to their indices in the original vector, so |
| // we prioritize passages that are more relevant. |
| std::ranges::sort( |
| inputs.begin(), inputs.end(), |
| [](ModelInput& i1, ModelInput& i2) { return i1.index < i2.index; }); |
| |
| // Add the query to the request. The query will always have index 0. |
| token_count += inputs[0].token_count; |
| request.set_query(inputs[0].text); |
| |
| // Add as many passages as the input window can fit. |
| for (size_t i = 1; i < inputs.size(); ++i) { |
| token_count += (inputs[i].token_count + kExtraTokensPerPassage); |
| if (token_count > token_limit) { |
| break; |
| } |
| |
| auto* passage = request.add_passages(); |
| passage->set_text(inputs[i].text); |
| passage->set_passage_id(GetPassageIdStr(i)); |
| } |
| |
| VLOG(3) << "Running AddContext for query: `" << request.query() << "`"; |
| session->AddContext(request); |
| AddSession(std::move(session), url); |
| std::move(session_added_cb).Run(1); |
| } |
| |
| private: |
| // Callback to be repeatedly called during streaming execution. |
| void StreamingExecutionCallback( |
| size_t session_index, |
| optimization_guide::OptimizationGuideModelStreamingExecutionResult result, |
| std::unique_ptr<optimization_guide::proto::HistoryAnswerLoggingData> |
| logging_data) { |
| auto log_entry = std::make_unique<optimization_guide::ModelQualityLogEntry>( |
| logs_uploader_); |
| log_entry->log_ai_data_request()->set_allocated_history_answer( |
| logging_data.release()); |
| if (!result.response.has_value()) { |
| ComputeAnswerStatus status = ComputeAnswerStatus::kExecutionFailure; |
| auto error = result.response.error().error(); |
| if (error == ModelExecutionError::kFiltered) { |
| status = ComputeAnswerStatus::kFiltered; |
| } |
| FinishCallback(AnswererResult(status, query_, Answer(), |
| std::move(log_entry), "", {})); |
| } else if (result.response->is_complete) { |
| auto response = optimization_guide::ParsedAnyMetadata< |
| optimization_guide::proto::HistoryAnswerResponse>( |
| std::move(result.response).value().response); |
| AnswererResult answerer_result(ComputeAnswerStatus::kSuccess, query_, |
| response->answer(), std::move(log_entry), |
| urls_[session_index], {}); |
| answerer_result.PopulateScrollToTextFragment( |
| context_.url_passages_map[answerer_result.url]); |
| FinishCallback(std::move(answerer_result)); |
| } |
| } |
| |
| // Decodes with the highest scored session. |
| void SortAndDecode(const std::vector<SessionScoreType>& session_scores) { |
| size_t max_index = session_scores.size(); |
| float max_score = 0.0; |
| for (size_t i = 0; i < session_scores.size(); i++) { |
| const std::optional<float> score = std::get<1>(session_scores[i]); |
| if (score.has_value()) { |
| VLOG(3) << "Session: " << std::get<0>(session_scores[i]) |
| << " Score: " << *score; |
| VLOG(3) << "URL: " << urls_[std::get<0>(session_scores[i])]; |
| if (*score > max_score) { |
| max_score = *score; |
| max_index = i; |
| } |
| } |
| } |
| |
| if (max_index == session_scores.size()) { |
| FinishAndResetSessions(AnswererResult{ |
| ComputeAnswerStatus::kExecutionFailure, query_, Answer()}); |
| return; |
| } |
| |
| // Return unanswerable status due to highest score is below the threshold. |
| if (max_score < GetMlAnswerScoreThreshold()) { |
| FinishAndResetSessions( |
| AnswererResult{ComputeAnswerStatus::kUnanswerable, query_, Answer()}); |
| return; |
| } |
| |
| // Continue decoding using the session with the highest score. |
| // Use a dummy request here since both passages and query are already added |
| // to context. |
| if (!sessions_.empty()) { |
| optimization_guide::proto::HistoryAnswerRequest request; |
| const size_t session_index = std::get<0>(session_scores[max_index]); |
| VLOG(3) << "Running ExecuteModel for session " << session_index; |
| optimization_guide::ExecuteModelSessionWithLogging( |
| sessions_[session_index].get(), request, |
| base::BindRepeating(&SessionManager::StreamingExecutionCallback, |
| weak_ptr_factory_.GetWeakPtr(), session_index)); |
| } else { |
| // If sessions are already cleaned up, run callback with canceled status. |
| FinishAndResetSessions(AnswererResult{ |
| ComputeAnswerStatus::kExecutionCancelled, query_, Answer()}); |
| } |
| } |
| |
| std::vector<std::unique_ptr<OptimizationGuideModelExecutor::Session>> |
| sessions_; |
| // URLs associated with sessions by index. |
| std::vector<std::string> urls_; |
| std::string query_; |
| Context context_; |
| ComputeAnswerCallback callback_; |
| const scoped_refptr<base::SequencedTaskRunner> origin_task_runner_; |
| base::WeakPtr<ModelQualityLogsUploaderService> logs_uploader_; |
| base::WeakPtrFactory<SessionManager> weak_ptr_factory_; |
| }; |
| |
| MlAnswerer::MlAnswerer(OptimizationGuideModelExecutor* model_executor, |
| ModelQualityLogsUploaderService* logs_uploader) |
| : model_executor_(model_executor) { |
| if (logs_uploader) { |
| logs_uploader_ = logs_uploader->GetWeakPtr(); |
| } |
| } |
| |
| MlAnswerer::~MlAnswerer() = default; |
| |
| int64_t MlAnswerer::GetModelVersion() { |
| // This can be replaced with the real implementation. |
| return 0; |
| } |
| |
| void MlAnswerer::ComputeAnswer(std::string query, |
| Context context, |
| ComputeAnswerCallback callback) { |
| CHECK(model_executor_); |
| |
| // Assign a new session manager (and destroy the existing one). |
| session_manager_ = std::make_unique<SessionManager>( |
| query, context, std::move(callback), logs_uploader_); |
| |
| const auto sessions_started_callback = base::BarrierCallback<int>( |
| context.url_passages_map.size(), |
| base::BindOnce(&MlAnswerer::SessionManager::OnSessionsStarted, |
| session_manager_->GetWeakPtr())); |
| |
| const SessionConfigParams session_config{ |
| .execution_mode = SessionConfigParams::ExecutionMode::kOnDeviceOnly}; |
| // Start a session for each URL. |
| for (const auto& url_and_passages : context.url_passages_map) { |
| std::unique_ptr<Session> session = model_executor_->StartSession( |
| optimization_guide::ModelBasedCapabilityKey::kHistorySearch, |
| session_config); |
| if (session == nullptr) { |
| session_manager_->FinishAndResetSessions(AnswererResult( |
| ComputeAnswerStatus::kModelUnavailable, query, Answer())); |
| return; |
| } |
| |
| StartAndAddSession(query, url_and_passages.first, url_and_passages.second, |
| std::move(session), sessions_started_callback); |
| } |
| } |
| |
| void MlAnswerer::StartAndAddSession( |
| const std::string& query, |
| const std::string& url, |
| const std::vector<std::string>& passages, |
| std::unique_ptr<Session> session, |
| base::OnceCallback<void(int)> session_started) { |
| Session* raw_session = session.get(); |
| const auto token_count_callback = base::BarrierCallback<ModelInput>( |
| passages.size() + 1, // We need token count for passages + query. |
| base::BindOnce(&MlAnswerer::SessionManager::OnTokenCountRetrieved, |
| session_manager_->GetWeakPtr(), std::move(session), url, |
| std::move(session_started))); |
| |
| const auto make_model_input = [](std::string text, size_t index, |
| std::optional<uint32_t> token_count) { |
| VLOG(3) << "Created model input for " << index; |
| return ModelInput{text, index, token_count.value_or(0)}; |
| }; |
| |
| // Get token count for query, always assign index 0 to query to make a |
| // ModelInput. |
| raw_session->GetSizeInTokens( |
| query, |
| base::BindOnce(make_model_input, query, 0).Then(token_count_callback)); |
| |
| // Get token count for passages, and assign their index + 1 to make |
| // ModelInput, in order to reserve index 0 for query. |
| VLOG(3) << "Running GetSizeInTokens for " << passages.size() << " passages.."; |
| for (size_t i = 0; i < passages.size(); ++i) { |
| raw_session->GetSizeInTokens( |
| passages[i], base::BindOnce(make_model_input, passages[i], i + 1) |
| .Then(token_count_callback)); |
| } |
| } |
| |
| } // namespace history_embeddings |