blob: a86e3710c63a82a21ef498c5c337be339a396ac2 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/history_embeddings/ml_answerer.h"
#include <algorithm>
#include "base/barrier_callback.h"
#include "base/memory/scoped_refptr.h"
#include "base/strings/stringprintf.h"
#include "components/history_embeddings/history_embeddings_features.h"
#include "components/optimization_guide/core/model_quality/model_execution_logging_wrappers.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/features/history_answer.pb.h"
namespace history_embeddings {
using ModelExecutionError = optimization_guide::
OptimizationGuideModelExecutionError::ModelExecutionError;
using optimization_guide::OptimizationGuideModelExecutionError;
using optimization_guide::OptimizationGuideModelStreamingExecutionResult;
using optimization_guide::SessionConfigParams;
using optimization_guide::proto::Answer;
using optimization_guide::proto::HistoryAnswerRequest;
using optimization_guide::proto::Passage;
namespace {
static constexpr std::string kPassageIdToken = "ID";
// Estimated token count of the preamble text in prompt.
static constexpr size_t kPreambleTokenBufferSize = 100u;
// Estimated token count of overhead text per passage.
static constexpr size_t kExtraTokensPerPassage = 10u;
std::string GetPassageIdStr(size_t id) {
return base::StringPrintf("%04d", static_cast<int>(id));
}
float GetMlAnswerScoreThreshold() {
return GetFeatureParameters().ml_answerer_min_score;
}
} // namespace
// Helper struct to bundle raw model input (queries/passages) with its metadata.
struct MlAnswerer::ModelInput {
// The string content of this input.
std::string text;
// Index 0 is reserved for queries, i.e. this index will be 0 iff. this input
// is a query. If the input is a passage, index will contain the index of the
// passage in the original passage vector (where lower index means higher
// relevance), plus 1 to offset for query.
size_t index;
// The size of `text` in tokens.
uint32_t token_count;
};
// Manages sessions for generating an answer for a given query and multiple
// URLs.
class MlAnswerer::SessionManager {
public:
using SessionScoreType = std::tuple<int, std::optional<float>>;
SessionManager(std::string query,
Context context,
ComputeAnswerCallback callback,
base::WeakPtr<ModelQualityLogsUploaderService> logs_uploader)
: query_(std::move(query)),
context_(std::move(context)),
callback_(std::move(callback)),
origin_task_runner_(base::SequencedTaskRunner::GetCurrentDefault()),
logs_uploader_(logs_uploader),
weak_ptr_factory_(this) {}
~SessionManager() {
// Explicitly invalidate weak pointers to prevent callbacks that may be
// triggered by the destructor logic.
weak_ptr_factory_.InvalidateWeakPtrs();
// Run the existing callback if not called yet with canceled status.
if (!callback_.is_null()) {
FinishAndResetSessions(AnswererResult(
ComputeAnswerStatus::kExecutionCancelled, query_, Answer()));
}
}
// Adds a session that contains query and passage context.
// It exists until this manager resets or gets destroyed.
void AddSession(
std::unique_ptr<OptimizationGuideModelExecutor::Session> session,
std::string url) {
sessions_.push_back(std::move(session));
urls_.push_back(url);
}
// Runs speculative decoding by first getting scores for each URL candidate
// and continuing decoding with only the highest scored session.
void RunSpeculativeDecoding() {
const size_t num_sessions = GetNumberOfSessions();
base::OnceCallback<void(const std::vector<SessionScoreType>&)> cb =
base::BindOnce(&SessionManager::SortAndDecode,
weak_ptr_factory_.GetWeakPtr());
const auto barrier_cb =
base::BarrierCallback<SessionScoreType>(num_sessions, std::move(cb));
for (size_t s_index = 0; s_index < num_sessions; s_index++) {
VLOG(3) << "Running Score for session " << s_index;
sessions_[s_index]->Score(
kPassageIdToken, base::BindOnce(
[](size_t index, std::optional<float> score) {
VLOG(3) << "Score complete for " << index;
return std::make_tuple(index, score);
},
s_index)
.Then(barrier_cb));
}
}
size_t GetNumberOfSessions() { return sessions_.size(); }
base::WeakPtr<MlAnswerer::SessionManager> GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
// Runs callback with result.
void FinishCallback(AnswererResult answer_result) {
CHECK(!callback_.is_null());
origin_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback_), std::move(answer_result)));
}
// Finishes and cleans up sessions.
void FinishAndResetSessions(AnswererResult answer_result) {
FinishCallback(std::move(answer_result));
// Destroy all existing sessions.
VLOG(3) << "Sessions cleared.";
sessions_.clear();
urls_.clear();
}
// Called when all sessions are started and added.
void OnSessionsStarted(std::vector<int> args) { RunSpeculativeDecoding(); }
// Called when token counts of the query and all passages of a session are
// computed.
void OnTokenCountRetrieved(std::unique_ptr<Session> session,
const std::string url,
base::OnceCallback<void(int)> session_added_cb,
std::vector<ModelInput> inputs) {
HistoryAnswerRequest request;
int token_limit = session->GetTokenLimits().min_context_tokens;
// Reserve space for preamble text.
int token_count = kPreambleTokenBufferSize;
// Sort the inputs according to their indices in the original vector, so
// we prioritize passages that are more relevant.
std::ranges::sort(
inputs.begin(), inputs.end(),
[](ModelInput& i1, ModelInput& i2) { return i1.index < i2.index; });
// Add the query to the request. The query will always have index 0.
token_count += inputs[0].token_count;
request.set_query(inputs[0].text);
// Add as many passages as the input window can fit.
for (size_t i = 1; i < inputs.size(); ++i) {
token_count += (inputs[i].token_count + kExtraTokensPerPassage);
if (token_count > token_limit) {
break;
}
auto* passage = request.add_passages();
passage->set_text(inputs[i].text);
passage->set_passage_id(GetPassageIdStr(i));
}
VLOG(3) << "Running AddContext for query: `" << request.query() << "`";
session->AddContext(request);
AddSession(std::move(session), url);
std::move(session_added_cb).Run(1);
}
private:
// Callback to be repeatedly called during streaming execution.
void StreamingExecutionCallback(
size_t session_index,
optimization_guide::OptimizationGuideModelStreamingExecutionResult result,
std::unique_ptr<optimization_guide::proto::HistoryAnswerLoggingData>
logging_data) {
auto log_entry = std::make_unique<optimization_guide::ModelQualityLogEntry>(
logs_uploader_);
log_entry->log_ai_data_request()->set_allocated_history_answer(
logging_data.release());
if (!result.response.has_value()) {
ComputeAnswerStatus status = ComputeAnswerStatus::kExecutionFailure;
auto error = result.response.error().error();
if (error == ModelExecutionError::kFiltered) {
status = ComputeAnswerStatus::kFiltered;
}
FinishCallback(AnswererResult(status, query_, Answer(),
std::move(log_entry), "", {}));
} else if (result.response->is_complete) {
auto response = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::HistoryAnswerResponse>(
std::move(result.response).value().response);
AnswererResult answerer_result(ComputeAnswerStatus::kSuccess, query_,
response->answer(), std::move(log_entry),
urls_[session_index], {});
answerer_result.PopulateScrollToTextFragment(
context_.url_passages_map[answerer_result.url]);
FinishCallback(std::move(answerer_result));
}
}
// Decodes with the highest scored session.
void SortAndDecode(const std::vector<SessionScoreType>& session_scores) {
size_t max_index = session_scores.size();
float max_score = 0.0;
for (size_t i = 0; i < session_scores.size(); i++) {
const std::optional<float> score = std::get<1>(session_scores[i]);
if (score.has_value()) {
VLOG(3) << "Session: " << std::get<0>(session_scores[i])
<< " Score: " << *score;
VLOG(3) << "URL: " << urls_[std::get<0>(session_scores[i])];
if (*score > max_score) {
max_score = *score;
max_index = i;
}
}
}
if (max_index == session_scores.size()) {
FinishAndResetSessions(AnswererResult{
ComputeAnswerStatus::kExecutionFailure, query_, Answer()});
return;
}
// Return unanswerable status due to highest score is below the threshold.
if (max_score < GetMlAnswerScoreThreshold()) {
FinishAndResetSessions(
AnswererResult{ComputeAnswerStatus::kUnanswerable, query_, Answer()});
return;
}
// Continue decoding using the session with the highest score.
// Use a dummy request here since both passages and query are already added
// to context.
if (!sessions_.empty()) {
optimization_guide::proto::HistoryAnswerRequest request;
const size_t session_index = std::get<0>(session_scores[max_index]);
VLOG(3) << "Running ExecuteModel for session " << session_index;
optimization_guide::ExecuteModelSessionWithLogging(
sessions_[session_index].get(), request,
base::BindRepeating(&SessionManager::StreamingExecutionCallback,
weak_ptr_factory_.GetWeakPtr(), session_index));
} else {
// If sessions are already cleaned up, run callback with canceled status.
FinishAndResetSessions(AnswererResult{
ComputeAnswerStatus::kExecutionCancelled, query_, Answer()});
}
}
std::vector<std::unique_ptr<OptimizationGuideModelExecutor::Session>>
sessions_;
// URLs associated with sessions by index.
std::vector<std::string> urls_;
std::string query_;
Context context_;
ComputeAnswerCallback callback_;
const scoped_refptr<base::SequencedTaskRunner> origin_task_runner_;
base::WeakPtr<ModelQualityLogsUploaderService> logs_uploader_;
base::WeakPtrFactory<SessionManager> weak_ptr_factory_;
};
MlAnswerer::MlAnswerer(OptimizationGuideModelExecutor* model_executor,
ModelQualityLogsUploaderService* logs_uploader)
: model_executor_(model_executor) {
if (logs_uploader) {
logs_uploader_ = logs_uploader->GetWeakPtr();
}
}
MlAnswerer::~MlAnswerer() = default;
int64_t MlAnswerer::GetModelVersion() {
// This can be replaced with the real implementation.
return 0;
}
void MlAnswerer::ComputeAnswer(std::string query,
Context context,
ComputeAnswerCallback callback) {
CHECK(model_executor_);
// Assign a new session manager (and destroy the existing one).
session_manager_ = std::make_unique<SessionManager>(
query, context, std::move(callback), logs_uploader_);
const auto sessions_started_callback = base::BarrierCallback<int>(
context.url_passages_map.size(),
base::BindOnce(&MlAnswerer::SessionManager::OnSessionsStarted,
session_manager_->GetWeakPtr()));
const SessionConfigParams session_config{
.execution_mode = SessionConfigParams::ExecutionMode::kOnDeviceOnly};
// Start a session for each URL.
for (const auto& url_and_passages : context.url_passages_map) {
std::unique_ptr<Session> session = model_executor_->StartSession(
optimization_guide::ModelBasedCapabilityKey::kHistorySearch,
session_config);
if (session == nullptr) {
session_manager_->FinishAndResetSessions(AnswererResult(
ComputeAnswerStatus::kModelUnavailable, query, Answer()));
return;
}
StartAndAddSession(query, url_and_passages.first, url_and_passages.second,
std::move(session), sessions_started_callback);
}
}
void MlAnswerer::StartAndAddSession(
const std::string& query,
const std::string& url,
const std::vector<std::string>& passages,
std::unique_ptr<Session> session,
base::OnceCallback<void(int)> session_started) {
Session* raw_session = session.get();
const auto token_count_callback = base::BarrierCallback<ModelInput>(
passages.size() + 1, // We need token count for passages + query.
base::BindOnce(&MlAnswerer::SessionManager::OnTokenCountRetrieved,
session_manager_->GetWeakPtr(), std::move(session), url,
std::move(session_started)));
const auto make_model_input = [](std::string text, size_t index,
std::optional<uint32_t> token_count) {
VLOG(3) << "Created model input for " << index;
return ModelInput{text, index, token_count.value_or(0)};
};
// Get token count for query, always assign index 0 to query to make a
// ModelInput.
raw_session->GetSizeInTokens(
query,
base::BindOnce(make_model_input, query, 0).Then(token_count_callback));
// Get token count for passages, and assign their index + 1 to make
// ModelInput, in order to reserve index 0 for query.
VLOG(3) << "Running GetSizeInTokens for " << passages.size() << " passages..";
for (size_t i = 0; i < passages.size(); ++i) {
raw_session->GetSizeInTokens(
passages[i], base::BindOnce(make_model_input, passages[i], i + 1)
.Then(token_count_callback));
}
}
} // namespace history_embeddings