blob: cd8d71107920c1ef73d658cee7fa3425802f0e91 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/history_embeddings/ml_intent_classifier.h"
#include <memory>
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/memory/weak_ptr.h"
#include "base/task/sequenced_task_runner.h"
#include "components/history_embeddings/history_embeddings_features.h"
#include "components/history_embeddings/intent_classifier.h"
#include "components/optimization_guide/core/model_execution/feature_keys.h"
#include "components/optimization_guide/core/model_execution/on_device_capability.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/features/history_query_intent.pb.h"
#include "components/optimization_guide/proto/history_query_intent_model_metadata.pb.h"
#include "components/optimization_guide/proto/model_quality_metadata.pb.h"
namespace history_embeddings {
namespace {
using ::optimization_guide::ModelBasedCapabilityKey;
using ::optimization_guide::SessionConfigParams;
using Session = ::optimization_guide::OnDeviceSession;
using ::optimization_guide::ParsedAnyMetadata;
using ::optimization_guide::proto::HistoryQueryIntentModelMetadata;
using ::optimization_guide::proto::HistoryQueryIntentRequest;
using ::optimization_guide::proto::HistoryQueryIntentResponse;
} // namespace
// State for an intent classification.
class MlIntentClassifier::Execution final {
public:
Execution() = default;
void Execute(OnDeviceCapability* model_executor,
std::string query,
ComputeQueryIntentCallback callback) {
session_ = model_executor->StartSession(
ModelBasedCapabilityKey::kHistoryQueryIntent, SessionConfigParams{});
if (!session_) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback),
ComputeIntentStatus::MODEL_UNAVAILABLE, false));
return;
}
const auto any = session_->GetOnDeviceFeatureMetadata();
model_metadata_ = ParsedAnyMetadata<HistoryQueryIntentModelMetadata>(any);
callback_ = std::move(callback);
HistoryQueryIntentRequest request;
request.set_text(std::move(query));
if (EnableMlIntentClassifierScore()) {
session_->AddContext(request);
session_->Score(GetTokenToScore(),
base::BindOnce(&Execution::OnScoreResult,
weak_ptr_factory_.GetWeakPtr()));
} else {
session_->ExecuteModel(
request, base::BindRepeating(&Execution::OnExecutionResult,
weak_ptr_factory_.GetWeakPtr()));
}
}
private:
void OnExecutionResult(
optimization_guide::OptimizationGuideModelStreamingExecutionResult
result) {
if (!result.response.has_value()) {
Finish(ComputeIntentStatus::EXECUTION_FAILURE, false);
return;
}
if (!result.response->is_complete) {
return;
}
auto response = ParsedAnyMetadata<HistoryQueryIntentResponse>(
result.response->response);
if (!response) {
Finish(ComputeIntentStatus::EXECUTION_FAILURE, false);
return;
}
Finish(ComputeIntentStatus::SUCCESS, response->is_answer_seeking());
}
void OnScoreResult(std::optional<float> score) {
if (!score.has_value()) {
Finish(ComputeIntentStatus::EXECUTION_FAILURE, false);
return;
}
bool is_answer_seeking = *score > GetIntentScoreThreshold();
Finish(ComputeIntentStatus::SUCCESS, is_answer_seeking);
}
void Finish(ComputeIntentStatus status, bool is_query_intent) {
weak_ptr_factory_.InvalidateWeakPtrs();
base::SequencedTaskRunner::GetCurrentDefault()->DeleteSoon(
FROM_HERE, std::move(session_));
std::move(callback_).Run(status, is_query_intent);
}
bool EnableMlIntentClassifierScore() {
return GetFeatureParameters().enable_ml_intent_classifier_score &&
model_metadata_ && !model_metadata_->score_token().empty();
}
std::string GetTokenToScore() {
CHECK(model_metadata_);
return model_metadata_->score_token();
}
float GetIntentScoreThreshold() {
CHECK(model_metadata_);
return model_metadata_->score_threshold();
}
ComputeQueryIntentCallback callback_;
std::unique_ptr<Session> session_;
std::optional<HistoryQueryIntentModelMetadata> model_metadata_;
base::WeakPtrFactory<Execution> weak_ptr_factory_{this};
};
MlIntentClassifier::MlIntentClassifier(OnDeviceCapability* model_executor)
: model_executor_(model_executor) {}
MlIntentClassifier::~MlIntentClassifier() = default;
int64_t MlIntentClassifier::GetModelVersion() {
// This can be replaced with the real implementation.
return 0;
}
void MlIntentClassifier::ComputeQueryIntent(
std::string query,
ComputeQueryIntentCallback callback) {
execution_ = std::make_unique<Execution>();
execution_->Execute(model_executor_, std::move(query), std::move(callback));
}
} // namespace history_embeddings