blob: 34e397dd6bced2e5334f88edcb332a78c0616829 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ui/tabs/organization/request_factory.h"
#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include "base/command_line.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/strings/utf_string_conversions.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/tabs/organization/logging_util.h"
#include "chrome/browser/ui/tabs/organization/tab_organization_request.h"
#include "chrome/browser/ui/tabs/organization/tab_organization_session.h"
#include "chrome/browser/ui/tabs/tab_group_model.h"
#include "chrome/browser/ui/tabs/tab_strip_model.h"
#include "chrome/browser/ui/ui_features.h"
#include "chrome/browser/ui/webui/tab_search/tab_search.mojom.h"
#include "chrome/browser/ui/webui/tab_search/tab_search_prefs.h"
#include "components/optimization_guide/core/model_quality/model_execution_logging_wrappers.h"
#include "components/optimization_guide/core/model_quality/model_quality_log_entry.h"
#include "components/optimization_guide/core/model_quality/model_quality_logs_uploader_service.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/features/tab_organization.pb.h"
#include "components/prefs/pref_service.h"
#include "components/tab_groups/tab_group_id.h"
#include "content/public/browser/web_contents.h"
namespace {
bool CanUseOptimizationGuide(Profile* profile) {
return base::FeatureList::IsEnabled(
optimization_guide::features::kOptimizationGuideModelExecution) &&
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
}
void OnLogResults(
Profile* profile,
std::unique_ptr<optimization_guide::ModelQualityLogEntry> log_entry,
const TabOrganizationSession* session) {
if (log_entry && session->request() && session->request()->response() &&
session->request()->response()->organizations.size() > 0 &&
session->tab_organizations().size() > 0) {
optimization_guide::proto::TabOrganizationQuality* quality =
log_entry->log_ai_data_request()
->mutable_tab_organization()
->mutable_quality();
AddSessionDetailsToQuality(quality, session);
}
optimization_guide::ModelQualityLogEntry::Upload(std::move(log_entry));
}
void OnTabOrganizationModelExecutionResult(
Profile* profile,
TabOrganizationRequest::BackendCompletionCallback on_completion,
TabOrganizationRequest::BackendFailureCallback on_failure,
optimization_guide::OptimizationGuideModelExecutionResult result,
std::unique_ptr<optimization_guide::proto::TabOrganizationLoggingData>
logging_data) {
OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
optimization_guide::ModelQualityLogsUploaderService* logs_uploader =
optimization_guide_keyed_service->GetModelQualityLogsUploaderService();
auto log_entry = std::make_unique<optimization_guide::ModelQualityLogEntry>(
logs_uploader ? logs_uploader->GetWeakPtr() : nullptr);
*log_entry->log_ai_data_request()->mutable_tab_organization() = *logging_data;
if (!result.response.has_value()) {
std::move(on_failure).Run();
return;
}
auto response = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::TabOrganizationResponse>(
result.response.value());
if (!response) {
std::move(on_failure).Run();
return;
}
std::vector<TabOrganizationResponse::Organization> organizations;
for (const auto& tab_group : response->tab_groups()) {
std::vector<TabData::TabID> response_tab_ids;
for (const auto& tab : tab_group.tabs()) {
response_tab_ids.emplace_back(tab.tab_id());
}
std::optional<tab_groups::TabGroupId> group_id;
const std::optional<base::Token> group_id_token =
base::Token::FromString(tab_group.group_id());
if (group_id_token.has_value()) {
group_id = std::make_optional(
tab_groups::TabGroupId::FromRawToken(group_id_token.value()));
}
organizations.emplace_back(base::UTF8ToUTF16(tab_group.label()),
std::move(response_tab_ids), group_id);
}
const std::string execution_id = log_entry->log_ai_data_request()
->tab_organization()
.model_execution_info()
.execution_id();
std::unique_ptr<TabOrganizationResponse> local_response =
std::make_unique<TabOrganizationResponse>(
std::move(organizations), base::UTF8ToUTF16(execution_id),
base::BindOnce(OnLogResults, profile, std::move(log_entry)));
std::move(on_completion).Run(std::move(local_response));
}
void PerformTabOrganizationExecution(
Profile* profile,
const TabOrganizationRequest* request,
TabOrganizationRequest::BackendCompletionCallback on_completion,
TabOrganizationRequest::BackendFailureCallback on_failure) {
if (!CanUseOptimizationGuide(profile)) {
std::move(on_failure).Run();
return;
}
optimization_guide::proto::TabOrganizationRequest tab_organization_request;
int valid_tabs = 0;
for (const std::unique_ptr<TabData>& tab_data : request->tab_datas()) {
if (!tab_data->IsValidForOrganizing()) {
continue;
}
valid_tabs++;
auto* tab = tab_organization_request.add_tabs();
tab->set_tab_id(tab_data->tab_id());
tab->set_title(
base::UTF16ToUTF8(tab_data->tab()->GetContents()->GetTitle()));
tab->set_url(tab_data->original_url().spec());
}
// When the user only has one valid tab, and it cannot be added to existing
// groups, complete without running the model to show the "No groups found"
// error state.
bool should_request_organization = valid_tabs > 1;
if (valid_tabs == 1) {
const auto* tab_group_model =
request->tab_datas()[0]->original_tab_strip_model()->group_model();
should_request_organization =
tab_group_model && !tab_group_model->ListTabGroups().empty();
}
if (!should_request_organization) {
std::vector<TabOrganizationResponse::Organization> organizations;
std::unique_ptr<TabOrganizationResponse> response =
std::make_unique<TabOrganizationResponse>(std::move(organizations));
std::move(on_completion).Run(std::move(response));
return;
}
for (const std::unique_ptr<GroupData>& group_data : request->group_datas()) {
auto* group = tab_organization_request.add_pre_existing_tab_groups();
group->set_group_id(group_data->id.ToString());
group->set_label(base::UTF16ToUTF8(group_data->label));
for (const std::unique_ptr<TabData>& tab_data : group_data->tabs) {
auto* tab = group->add_tabs();
tab->set_tab_id(tab_data->tab_id());
tab->set_title(
base::UTF16ToUTF8(tab_data->tab()->GetContents()->GetTitle()));
tab->set_url(tab_data->original_url().spec());
}
}
if (request->base_tab_id().has_value()) {
tab_organization_request.set_active_tab_id(request->base_tab_id().value());
}
if (base::FeatureList::IsEnabled(features::kTabOrganizationModelStrategy)) {
const int32_t strategy_int = profile->GetPrefs()->GetInteger(
tab_search_prefs::kTabOrganizationModelStrategy);
auto strategy =
static_cast<tab_search::mojom::TabOrganizationModelStrategy>(
strategy_int);
switch (strategy) {
case tab_search::mojom::TabOrganizationModelStrategy::kTopic:
tab_organization_request.set_model_strategy(
optimization_guide::proto::
TabOrganizationRequest_TabOrganizationModelStrategy_STRATEGY_UNSPECIFIED);
break;
case tab_search::mojom::TabOrganizationModelStrategy::kTask:
tab_organization_request.set_model_strategy(
optimization_guide::proto::
TabOrganizationRequest_TabOrganizationModelStrategy_STRATEGY_TASK_BASED);
break;
case tab_search::mojom::TabOrganizationModelStrategy::kDomain:
tab_organization_request.set_model_strategy(
optimization_guide::proto::
TabOrganizationRequest_TabOrganizationModelStrategy_STRATEGY_DOMAIN_BASED);
break;
default:
tab_organization_request.set_model_strategy(
optimization_guide::proto::
TabOrganizationRequest_TabOrganizationModelStrategy_STRATEGY_UNSPECIFIED);
break;
}
}
if (base::FeatureList::IsEnabled(features::kTabOrganizationUserInstruction)) {
if (request->user_instruction().has_value()) {
tab_organization_request.set_user_command(
request->user_instruction().value());
}
}
tab_organization_request.set_allow_reorganizing_existing_groups(true);
OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
ExecuteModelWithLogging(
optimization_guide_keyed_service,
optimization_guide::ModelBasedCapabilityKey::kTabOrganization,
tab_organization_request, /*execution_timeout=*/std::nullopt,
base::BindOnce(OnTabOrganizationModelExecutionResult, profile,
std::move(on_completion), std::move(on_failure)));
}
} // anonymous namespace
TabOrganizationRequestFactory::~TabOrganizationRequestFactory() = default;
// static
std::unique_ptr<TabOrganizationRequestFactory>
TabOrganizationRequestFactory::GetForProfile(Profile* profile) {
if (CanUseOptimizationGuide(profile)) {
return std::make_unique<OptimizationGuideTabOrganizationRequestFactory>();
}
return std::make_unique<TwoTabsRequestFactory>();
}
TwoTabsRequestFactory::~TwoTabsRequestFactory() = default;
std::unique_ptr<TabOrganizationRequest> TwoTabsRequestFactory::CreateRequest(
Profile* profile) {
// for this request strategy only the first 2 tabs will be addedto an
// organization.
TabOrganizationRequest::BackendStartRequest start_request = base::BindOnce(
[](const TabOrganizationRequest* request,
TabOrganizationRequest::BackendCompletionCallback on_completion,
TabOrganizationRequest::BackendFailureCallback on_failure) {
if (request->tab_datas().size() >= 2) {
std::vector<TabData::TabID> response_tab_ids;
std::transform(request->tab_datas().begin(),
request->tab_datas().begin() + 2,
std::back_inserter(response_tab_ids),
[](const std::unique_ptr<TabData>& tab_data) {
return tab_data->tab_id();
});
std::vector<TabOrganizationResponse::Organization> organizations;
organizations.emplace_back(u"Organization",
std::move(response_tab_ids));
std::unique_ptr<TabOrganizationResponse> response =
std::make_unique<TabOrganizationResponse>(
std::move(organizations));
std::move(on_completion).Run(std::move(response));
} else {
std::move(on_failure).Run();
}
});
return std::make_unique<TabOrganizationRequest>(std::move(start_request));
}
OptimizationGuideTabOrganizationRequestFactory::
~OptimizationGuideTabOrganizationRequestFactory() = default;
std::unique_ptr<TabOrganizationRequest>
OptimizationGuideTabOrganizationRequestFactory::CreateRequest(
Profile* profile) {
return std::make_unique<TabOrganizationRequest>(
base::BindOnce(PerformTabOrganizationExecution, profile));
}