blob: fdf53520a2c125fdfb51c982d81a5d421a3c9522 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ui/tabs/organization/request_factory.h"
#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include "base/command_line.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/strings/utf_string_conversions.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/tabs/organization/logging_util.h"
#include "chrome/browser/ui/tabs/organization/tab_organization_request.h"
#include "chrome/browser/ui/tabs/organization/tab_organization_session.h"
#include "components/optimization_guide/core/model_quality/feature_type_map.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/features/tab_organization.pb.h"
#include "content/public/browser/web_contents.h"
namespace {
bool CanUseOptimizationGuide(Profile* profile) {
return base::FeatureList::IsEnabled(
optimization_guide::features::kOptimizationGuideModelExecution) &&
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
}
void OnLogResults(Profile* profile,
std::unique_ptr<optimization_guide::ModelQualityLogEntry>
model_quality_log_entry,
const TabOrganizationSession* session) {
if (!model_quality_log_entry) {
return;
}
OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
if (!optimization_guide_keyed_service) {
return;
}
if (!session->request() || !session->request()->response() ||
session->request()->response()->organizations.size() == 0 ||
session->tab_organizations().size() == 0) {
if (model_quality_log_entry) {
optimization_guide_keyed_service->UploadModelQualityLogs(
std::move(model_quality_log_entry));
}
return;
}
optimization_guide::proto::TabOrganizationQuality* quality =
model_quality_log_entry
->quality_data<optimization_guide::TabOrganizationFeatureTypeMap>();
AddSessionDetailsToQuality(quality, session);
optimization_guide_keyed_service->UploadModelQualityLogs(
std::move(model_quality_log_entry));
}
void OnTabOrganizationModelExecutionResult(
Profile* profile,
TabOrganizationRequest::BackendCompletionCallback on_completion,
TabOrganizationRequest::BackendFailureCallback on_failure,
optimization_guide::OptimizationGuideModelExecutionResult result,
std::unique_ptr<optimization_guide::ModelQualityLogEntry> log_entry) {
OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
if (!optimization_guide_keyed_service) {
std::move(on_failure).Run();
return;
}
if (!result.has_value()) {
// TODO(b/322206302): remove this when this is fixed in the ModelQualityLogEntry API
if (log_entry) {
optimization_guide_keyed_service->UploadModelQualityLogs(
std::move(log_entry));
}
std::move(on_failure).Run();
return;
}
auto response = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::TabOrganizationResponse>(result.value());
if (!response) {
if (log_entry) {
optimization_guide_keyed_service->UploadModelQualityLogs(
std::move(log_entry));
}
std::move(on_failure).Run();
return;
}
std::vector<TabOrganizationResponse::Organization> organizations;
for (const auto& tab_organization : response->tab_organizations()) {
std::vector<TabData::TabID> response_tab_ids;
for (const auto& tab : tab_organization.tabs()) {
response_tab_ids.emplace_back(tab.tab_id());
}
organizations.emplace_back(base::UTF8ToUTF16(tab_organization.label()),
std::move(response_tab_ids));
}
const std::string execution_id =
log_entry->log_ai_data_request()->model_execution_info().execution_id();
std::unique_ptr<TabOrganizationResponse> local_response =
std::make_unique<TabOrganizationResponse>(
std::move(organizations), base::UTF8ToUTF16(execution_id),
base::BindOnce(OnLogResults, profile, std::move(log_entry)));
std::move(on_completion).Run(std::move(local_response));
}
void PerformTabOrganizationExecution(
Profile* profile,
const TabOrganizationRequest* request,
TabOrganizationRequest::BackendCompletionCallback on_completion,
TabOrganizationRequest::BackendFailureCallback on_failure) {
if (!CanUseOptimizationGuide(profile)) {
std::move(on_failure).Run();
return;
}
optimization_guide::proto::TabOrganizationRequest tab_organization_request;
for (const std::unique_ptr<TabData>& tab_data : request->tab_datas()) {
if (!tab_data->IsValidForOrganizing()) {
continue;
}
auto* tab = tab_organization_request.add_tabs();
tab->set_tab_id(tab_data->tab_id());
tab->set_title(base::UTF16ToUTF8(tab_data->web_contents()->GetTitle()));
tab->set_url(tab_data->original_url().spec());
}
if (request->base_tab_id().has_value()) {
tab_organization_request.set_active_tab_id(request->base_tab_id().value());
}
OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
optimization_guide_keyed_service->ExecuteModel(
optimization_guide::proto::ModelExecutionFeature::
MODEL_EXECUTION_FEATURE_TAB_ORGANIZATION,
tab_organization_request,
base::BindOnce(OnTabOrganizationModelExecutionResult, profile,
std::move(on_completion), std::move(on_failure)));
}
} // anonymous namespace
TabOrganizationRequestFactory::~TabOrganizationRequestFactory() = default;
// static
std::unique_ptr<TabOrganizationRequestFactory>
TabOrganizationRequestFactory::GetForProfile(Profile* profile) {
if (CanUseOptimizationGuide(profile)) {
return std::make_unique<OptimizationGuideTabOrganizationRequestFactory>();
}
return std::make_unique<TwoTabsRequestFactory>();
}
TwoTabsRequestFactory::~TwoTabsRequestFactory() = default;
std::unique_ptr<TabOrganizationRequest> TwoTabsRequestFactory::CreateRequest(
Profile* profile) {
// for this request strategy only the first 2 tabs will be addedto an
// organization.
TabOrganizationRequest::BackendStartRequest start_request = base::BindOnce(
[](const TabOrganizationRequest* request,
TabOrganizationRequest::BackendCompletionCallback on_completion,
TabOrganizationRequest::BackendFailureCallback on_failure) {
if (request->tab_datas().size() >= 2) {
std::vector<TabData::TabID> response_tab_ids;
std::transform(request->tab_datas().begin(),
request->tab_datas().begin() + 2,
std::back_inserter(response_tab_ids),
[](const std::unique_ptr<TabData>& tab_data) {
return tab_data->tab_id();
});
std::vector<TabOrganizationResponse::Organization> organizations;
organizations.emplace_back(u"Organization",
std::move(response_tab_ids));
std::unique_ptr<TabOrganizationResponse> response =
std::make_unique<TabOrganizationResponse>(
std::move(organizations));
std::move(on_completion).Run(std::move(response));
} else {
std::move(on_failure).Run();
}
});
return std::make_unique<TabOrganizationRequest>(std::move(start_request));
}
OptimizationGuideTabOrganizationRequestFactory::
~OptimizationGuideTabOrganizationRequestFactory() = default;
std::unique_ptr<TabOrganizationRequest>
OptimizationGuideTabOrganizationRequestFactory::CreateRequest(
Profile* profile) {
return std::make_unique<TabOrganizationRequest>(
base::BindOnce(PerformTabOrganizationExecution, profile));
}