blob: b4f29ea13bd6fdcfc1a4c92e15ced04f96e3982a [file]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/contextual_tasks/internal/contextual_tasks_service_impl.h"
#include <optional>
#include <utility>
#include "base/containers/flat_set.h"
#include "base/feature_list.h"
#include "base/functional/bind.h"
#include "base/metrics/histogram_functions.h"
#include "base/notreached.h"
#include "base/strings/string_util.h"
#include "base/task/single_thread_task_runner.h"
#include "base/uuid.h"
#include "components/contextual_search/contextual_search_service.h"
#include "components/contextual_tasks/internal/composite_context_decorator.h"
#include "components/contextual_tasks/internal/conversions.h"
#include "components/contextual_tasks/public/account_utils.h"
#include "components/contextual_tasks/public/contextual_task.h"
#include "components/contextual_tasks/public/contextual_task_context.h"
#include "components/contextual_tasks/public/contextual_tasks_service.h"
#include "components/contextual_tasks/public/features.h"
#include "components/contextual_tasks/public/utils.h"
#include "components/omnibox/browser/aim_eligibility_service.h"
#include "components/omnibox/common/logger.h"
#include "components/prefs/pref_service.h"
#include "components/sessions/core/session_id.h"
#include "components/signin/public/identity_manager/identity_manager.h"
#include "components/sync/base/data_type.h"
#include "components/sync/base/report_unrecoverable_error.h"
#include "components/sync/model/client_tag_based_data_type_processor.h"
#include "components/sync/protocol/gemini_thread_specifics.pb.h"
#include "net/base/url_util.h"
#include "url/gurl.h"
namespace contextual_tasks {
namespace {
struct MergeUrlResourcesResult {
std::vector<UrlResource> final_resources;
std::vector<UrlResource> added_or_updated_resources;
std::vector<base::Uuid> removed_resource_ids;
bool has_changes = false;
};
// Helper to find the index of a matching resource in existing_resources.
// Returns -1 if no match is found.
int FindMatchingResourceIndex(
const UrlResource& incoming_res,
const std::vector<UrlResource>& existing_resources,
const std::vector<bool>& existing_matched) {
for (size_t i = 0; i < existing_resources.size(); ++i) {
if (existing_matched[i]) {
continue;
}
const auto& existing_res = existing_resources[i];
// 1. Match by url_id.
if (incoming_res.url_id.is_valid() &&
existing_res.url_id == incoming_res.url_id) {
return static_cast<int>(i);
}
// 2. Match by context_id.
if (incoming_res.context_id.has_value() &&
existing_res.context_id == incoming_res.context_id) {
return static_cast<int>(i);
}
// 3. Match by url.
if (incoming_res.url.is_valid() && existing_res.url == incoming_res.url) {
return static_cast<int>(i);
}
}
return -1;
}
// Merges incoming resources with existing ones.
// Returns a result containing the final list of resources, as well as lists of
// added/updated and removed resources to be used for sync notifications.
MergeUrlResourcesResult MergeUrlResources(
const std::vector<UrlResource>& existing_resources,
std::vector<UrlResource> incoming_resources) {
MergeUrlResourcesResult result;
std::vector<bool> existing_matched(existing_resources.size(), false);
result.has_changes = false;
for (auto& incoming_res : incoming_resources) {
int matched_index = FindMatchingResourceIndex(
incoming_res, existing_resources, existing_matched);
if (matched_index != -1) {
existing_matched[matched_index] = true;
const auto& existing_res = existing_resources[matched_index];
// Copy over fields if missing.
if (!incoming_res.url_id.is_valid()) {
incoming_res.url_id = existing_res.url_id;
}
if (!incoming_res.url.is_valid()) {
incoming_res.url = existing_res.url;
}
if (!incoming_res.tab_id.has_value()) {
incoming_res.tab_id = existing_res.tab_id;
}
if (!incoming_res.title.has_value()) {
incoming_res.title = existing_res.title;
}
if (!incoming_res.context_id.has_value()) {
incoming_res.context_id = existing_res.context_id;
}
// Check if anything changed.
if (incoming_res.url_id != existing_res.url_id ||
incoming_res.url != existing_res.url ||
incoming_res.tab_id != existing_res.tab_id ||
incoming_res.title != existing_res.title ||
incoming_res.context_id != existing_res.context_id) {
result.has_changes = true;
result.added_or_updated_resources.push_back(incoming_res);
}
} else {
// New resource.
result.has_changes = true;
if (!incoming_res.url_id.is_valid()) {
incoming_res.url_id = base::Uuid::GenerateRandomV4();
}
result.added_or_updated_resources.push_back(incoming_res);
}
}
// Check for removed resources.
for (size_t i = 0; i < existing_resources.size(); ++i) {
if (!existing_matched[i]) {
result.has_changes = true;
result.removed_resource_ids.push_back(existing_resources[i].url_id);
}
}
// If no content changes, adds or removes were detected, check if the order
// has changed.
if (!result.has_changes &&
existing_resources.size() == incoming_resources.size()) {
for (size_t i = 0; i < existing_resources.size(); ++i) {
if (existing_resources[i].url_id != incoming_resources[i].url_id) {
result.has_changes = true;
break;
}
}
}
result.final_resources = std::move(incoming_resources);
return result;
}
void RecordNumberOfActiveTasks(int count) {
base::UmaHistogramCounts100("ContextualTasks.ActiveTasksCount", count);
}
ContextualTask CreateTaskForThread(const Thread& thread, bool is_ephemeral) {
ContextualTask task(base::Uuid::GenerateRandomV4(), is_ephemeral);
task.AddThread(thread);
task.SetTitle(thread.title);
return task;
}
} // namespace
ContextualTasksServiceImpl::ContextualTasksServiceImpl(
version_info::Channel channel,
syncer::RepeatingDataTypeStoreFactory data_type_store_factory,
std::unique_ptr<CompositeContextDecorator> composite_context_decorator,
AimEligibilityService* aim_eligibility_service,
signin::IdentityManager* identity_manager,
PrefService* pref_service,
bool supports_ephemeral_only,
base::RepeatingCallback<size_t()> get_active_task_count_callback,
base::RepeatingCallback<bool()> is_gemini_threads_enabled)
: composite_context_decorator_(std::move(composite_context_decorator)),
get_active_task_count_callback_(
std::move(get_active_task_count_callback)),
is_gemini_threads_enabled_callback_(is_gemini_threads_enabled),
aim_eligibility_service_(aim_eligibility_service),
identity_manager_(identity_manager),
pref_service_(pref_service),
supports_ephemeral_only_(supports_ephemeral_only) {
const base::RepeatingClosure& dump_stack =
base::BindRepeating(&syncer::ReportUnrecoverableError, channel);
auto ai_thread_processor =
std::make_unique<syncer::ClientTagBasedDataTypeProcessor>(
syncer::AI_THREAD, dump_stack);
ai_thread_sync_bridge_ = std::make_unique<AiThreadSyncBridge>(
std::move(ai_thread_processor), data_type_store_factory);
ai_thread_observation_.Observe(ai_thread_sync_bridge_.get());
auto gemini_thread_processor =
std::make_unique<syncer::ClientTagBasedDataTypeProcessor>(
syncer::GEMINI_THREAD, dump_stack);
gemini_thread_sync_bridge_ = std::make_unique<GeminiThreadSyncBridge>(
std::move(gemini_thread_processor), data_type_store_factory);
gemini_thread_observation_.Observe(gemini_thread_sync_bridge_.get());
// Wait for both AiThreadSyncBridge and GeminiThreadSyncBridge to finish
// loading their data store.
on_data_loaded_barrier_ = base::BarrierClosure(
2, base::BindOnce(&ContextualTasksServiceImpl::OnDataStoresLoaded,
weak_ptr_factory_.GetWeakPtr()));
}
ContextualTasksServiceImpl::~ContextualTasksServiceImpl() {
for (auto& observer : observers_) {
observer.OnWillBeDestroyed();
}
}
FeatureEligibility ContextualTasksServiceImpl::GetFeatureEligibility() {
return {base::FeatureList::IsEnabled(contextual_tasks::kContextualTasks),
aim_eligibility_service_->IsAimEligible(),
aim_eligibility_service_->IsCobrowseEligible(),
contextual_search::ContextualSearchService::IsContextSharingEnabled(
pref_service_)};
}
bool ContextualTasksServiceImpl::IsInitialized() {
return is_initialized_;
}
ContextualTask ContextualTasksServiceImpl::CreateTask() {
// Note that we are adding 1 to the number of active tasks because the
// histogram is recorded before the task is created, but it's not
// immediately reflected in the active tasks count.
RecordNumberOfActiveTasks(get_active_task_count_callback_.Run() + 1);
base::Uuid task_id = base::Uuid::GenerateRandomV4();
ContextualTask task(task_id, supports_ephemeral_only_);
return AddTaskAndNotify(std::move(task));
}
ContextualTask ContextualTasksServiceImpl::CreateTaskFromUrl(const GURL& url) {
// Note that we are adding 1 to the number of active tasks because the
// histogram is recorded before the task is created, but it's not
// immediately reflected in the active tasks count.
RecordNumberOfActiveTasks(get_active_task_count_callback_.Run() + 1);
base::Uuid task_id = base::Uuid::GenerateRandomV4();
bool is_ephemeral = supports_ephemeral_only_ ||
!IsUrlForPrimaryAccount(identity_manager_, url);
ContextualTask task(task_id, is_ephemeral);
return AddTaskAndNotify(std::move(task));
}
void ContextualTasksServiceImpl::GetTaskById(
const base::Uuid& task_id,
base::OnceCallback<void(std::optional<ContextualTask>)> callback) const {
auto it = tasks_.find(task_id);
std::optional<ContextualTask> result;
if (it != tasks_.end()) {
result = it->second;
}
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::move(result)));
}
void ContextualTasksServiceImpl::GetTasks(
base::OnceCallback<void(std::vector<ContextualTask>)> callback) const {
std::vector<ContextualTask> tasks;
for (const auto& pair : tasks_) {
ContextualTask task = pair.second;
if (task.IsEphemeral() || supports_ephemeral_only_) {
continue;
}
tasks.push_back(task);
}
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::move(tasks)));
}
void ContextualTasksServiceImpl::DeleteTask(const base::Uuid& task_id) {
RemoveTaskInternal(task_id, TriggerSource::kLocal);
}
void ContextualTasksServiceImpl::UpdateThreadForTask(
const base::Uuid& task_id,
ThreadType thread_type,
const std::string& server_id,
std::optional<std::string> conversation_turn_id,
std::optional<std::string> title) {
auto [it, is_new_task] = FindOrCreateTask(task_id, thread_type, server_id);
// If a thread already exists and its server ID does not match the new server
// ID, it indicates a mismatch or an attempt to update a different thread, so
// we return.
std::optional<Thread> thread = it->second.GetThread();
// If the task doesn't exist doesn't have the right thread, return early.
if (thread.has_value() &&
(thread->server_id != server_id || thread->type != thread_type)) {
return;
}
// Determine the new title and conversation turn ID. If provided, use them;
// otherwise, retain the existing values if a thread already exists.
const std::string& new_title =
title.value_or(thread.has_value() ? thread->title : "");
std::optional<std::string> new_conversation_turn_id = std::nullopt;
if (conversation_turn_id.has_value()) {
new_conversation_turn_id = conversation_turn_id;
} else if (thread.has_value()) {
new_conversation_turn_id = thread->conversation_turn_id;
}
// Add or update the thread information within the task.
it->second.AddThread(Thread(thread_type, server_id, new_title,
base::Time::Now().InMillisecondsSinceUnixEpoch(),
new_conversation_turn_id));
if (is_new_task) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAdded,
weak_ptr_factory_.GetWeakPtr(), it->second,
TriggerSource::kLocal));
} else {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskUpdated,
weak_ptr_factory_.GetWeakPtr(), it->second,
TriggerSource::kLocal));
}
}
void ContextualTasksServiceImpl::RemoveThreadFromTask(
const base::Uuid& task_id,
ThreadType type,
const std::string& server_id) {
auto it = tasks_.find(task_id);
if (it != tasks_.end()) {
DCHECK(it->second.GetThread().has_value());
switch (type) {
case ThreadType::kAiMode:
ai_thread_sync_bridge_->DeleteThread(it->second.GetThread().value());
break;
case ThreadType::kGemini:
gemini_thread_sync_bridge_->DeleteThread(
it->second.GetThread().value());
break;
case ThreadType::kUnknown:
default:
NOTREACHED();
}
it->second.RemoveThread(type, server_id);
// If the task no longer has any thread, remove it.
if (!it->second.GetThread()) {
DeleteTask(task_id);
}
}
}
std::optional<ContextualTask> ContextualTasksServiceImpl::GetTaskFromServerId(
ThreadType thread_type,
const std::string& server_id) {
for (const auto& pair : tasks_) {
std::optional<Thread> thread = pair.second.GetThread();
if (thread.has_value() && thread->type == thread_type &&
thread->server_id == server_id) {
return pair.second;
}
}
return std::nullopt;
}
void ContextualTasksServiceImpl::AttachUrlToTask(const base::Uuid& task_id,
const GURL& url) {
auto it = tasks_.find(task_id);
if (it != tasks_.end()) {
UrlResource url_resource(base::Uuid::GenerateRandomV4(), url);
if (it->second.AddUrlResource(url_resource)) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskUpdated,
weak_ptr_factory_.GetWeakPtr(), it->second,
TriggerSource::kLocal));
}
}
}
void ContextualTasksServiceImpl::DetachUrlFromTask(const base::Uuid& task_id,
const GURL& url) {
auto it = tasks_.find(task_id);
if (it != tasks_.end()) {
std::optional<base::Uuid> url_id = it->second.RemoveUrl(url);
if (url_id) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskUpdated,
weak_ptr_factory_.GetWeakPtr(), it->second,
TriggerSource::kLocal));
}
}
}
void ContextualTasksServiceImpl::SetUrlResourcesFromServer(
const base::Uuid& task_id,
std::vector<UrlResource> url_resources) {
auto it = tasks_.find(task_id);
if (it == tasks_.end()) {
return;
}
// Merge incoming resources with existing ones and calculate the diff.
ContextualTask& task = it->second;
MergeUrlResourcesResult result =
MergeUrlResources(task.GetUrlResources(), std::move(url_resources));
if (!result.has_changes) {
return;
}
// Update the local in-memory task state.
task.SetUrlResourcesFromServer(std::move(result.final_resources));
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskUpdated,
weak_ptr_factory_.GetWeakPtr(), task,
TriggerSource::kLocal));
}
void ContextualTasksServiceImpl::AssociateTabWithTask(const base::Uuid& task_id,
SessionID tab_id) {
auto it = tasks_.find(task_id);
if (it == tasks_.end()) {
return;
}
std::optional<ContextualTask> current_task = GetContextualTaskForTab(tab_id);
if (current_task && current_task->GetTaskId() != task_id) {
DisassociateTabFromTask(current_task->GetTaskId(), tab_id);
}
tab_to_task_[tab_id] = task_id;
it->second.AddTabId(tab_id);
base::UmaHistogramCounts100("ContextualTasks.TabAffiliationCount",
GetTabsAssociatedWithTask(task_id).size());
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAssociatedToTab,
weak_ptr_factory_.GetWeakPtr(), task_id, tab_id));
}
void ContextualTasksServiceImpl::DisassociateTabFromTask(
const base::Uuid& task_id,
SessionID tab_id) {
tab_to_task_.erase(tab_id);
auto it = tasks_.find(task_id);
if (it != tasks_.end()) {
it->second.RemoveTabId(tab_id);
if (base::FeatureList::IsEnabled(
kContextualTasksRemoveTasksWithoutThreadsOrTabAssociations)) {
// If the task doesn't have a thread and tabs associated with it,
// it can be safely removed here.
if (!it->second.GetThread() && it->second.GetTabIds().empty()) {
RemoveTaskInternal(task_id, TriggerSource::kLocal);
}
}
}
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(
&ContextualTasksServiceImpl::NotifyTaskDisassociatedFromTab,
weak_ptr_factory_.GetWeakPtr(), task_id, tab_id));
}
void ContextualTasksServiceImpl::DisassociateAllTabsFromTask(
const base::Uuid& task_id) {
for (const SessionID& tab_id : GetTabsAssociatedWithTask(task_id)) {
DisassociateTabFromTask(task_id, tab_id);
}
}
std::optional<ContextualTask>
ContextualTasksServiceImpl::GetContextualTaskForTab(SessionID tab_id) const {
auto it = tab_to_task_.find(tab_id);
if (it != tab_to_task_.end()) {
auto task_it = tasks_.find(it->second);
if (task_it != tasks_.end()) {
return task_it->second;
}
}
return std::nullopt;
}
std::vector<SessionID> ContextualTasksServiceImpl::GetTabsAssociatedWithTask(
const base::Uuid& task_id) const {
std::vector<SessionID> associated_tabs;
for (const auto& pair : tab_to_task_) {
if (pair.second == task_id) {
associated_tabs.push_back(pair.first);
}
}
return associated_tabs;
}
void ContextualTasksServiceImpl::GetContextForTask(
const base::Uuid& task_id,
const std::set<ContextualTaskContextSource>& sources,
std::unique_ptr<ContextDecorationParams> params,
base::OnceCallback<void(std::unique_ptr<ContextualTaskContext>)>
context_callback) {
auto it = tasks_.find(task_id);
if (it == tasks_.end()) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(context_callback),
std::unique_ptr<ContextualTaskContext>()));
return;
}
composite_context_decorator_->DecorateContext(
std::make_unique<ContextualTaskContext>(it->second), sources,
std::move(params), std::move(context_callback));
}
void ContextualTasksServiceImpl::GetThreadUrlFromTaskId(
const base::Uuid& task_id,
const std::string& locale,
omnibox::ChromeAimEntryPoint entry_point,
base::OnceCallback<void(GURL)> callback) {
GetTaskById(
task_id,
base::BindOnce(
[](const base::Uuid& task_id, const std::string& locale,
omnibox::ChromeAimEntryPoint entry_point,
base::OnceCallback<void(GURL)> callback,
std::optional<ContextualTask> task) {
OMNIBOX_LOG("nav_trace")
<< "ContextualTasks navigation trace: "
"GetThreadUrlFromTaskId callback called";
// AIM is the default if no task or thread is found.
GURL url = GetDefaultAimUrl(locale, entry_point);
if (!task) {
OMNIBOX_LOG("nav_trace")
<< "ContextualTasks navigation trace: "
"GetThreadUrlFromTaskId returning early, no task. "
"Returning default: "
<< url;
std::move(callback).Run(url);
return;
}
std::optional<Thread> thread = task->GetThread();
if (!thread) {
OMNIBOX_LOG("nav_trace")
<< "ContextualTasks navigation trace: "
"GetThreadUrlFromTaskId returning early, no thread or "
"Gemini thread. Returning default: "
<< url;
std::move(callback).Run(url);
return;
}
if (thread->type == ThreadType::kAiMode) {
// Attach the thread ID and the most recent turn ID to the
// URL. A query parameter needs to be present, but its
// value is not used for continued threads.
url = net::AppendQueryParameter(url, "q", thread->title);
// There are rare cases where the mstk may not be present before
// this url is requested. If that is the case, do not attempt to
// include it in the url.
if (thread->conversation_turn_id) {
url = net::AppendQueryParameter(
url, "mstk", thread->conversation_turn_id.value());
}
url = net::AppendQueryParameter(url, "mtid", thread->server_id);
} else if (thread->type == ThreadType::kGemini) {
url = GURL(GetContextualTasksGeminiBaseUrl());
GURL::Replacements replacements;
std::string path = url.GetPath();
if (path.back() != '/') {
path += '/';
}
std::string server_id = thread->server_id;
if (base::StartsWith(server_id, "c_")) {
server_id.erase(0, 2);
}
path += server_id;
replacements.SetPathStr(path);
url = url.ReplaceComponents(replacements);
}
OMNIBOX_LOG("nav_trace") << "ContextualTasks navigation trace: "
"GetThreadUrlFromTaskId returning URL: "
<< url;
std::move(callback).Run(url);
},
task_id, locale, entry_point, std::move(callback)));
}
void ContextualTasksServiceImpl::AddObserver(
ContextualTasksService::Observer* observer) {
observers_.AddObserver(observer);
}
void ContextualTasksServiceImpl::RemoveObserver(
ContextualTasksService::Observer* observer) {
observers_.RemoveObserver(observer);
}
base::WeakPtr<syncer::DataTypeControllerDelegate>
ContextualTasksServiceImpl::GetAiThreadControllerDelegate() {
return ai_thread_sync_bridge_->change_processor()->GetControllerDelegate();
}
base::WeakPtr<syncer::DataTypeControllerDelegate>
ContextualTasksServiceImpl::GetGeminiThreadControllerDelegate() {
return gemini_thread_sync_bridge_->change_processor()
->GetControllerDelegate();
}
bool ContextualTasksServiceImpl::IsGeminiThreadsEligible() {
return is_gemini_threads_enabled_callback_ &&
is_gemini_threads_enabled_callback_.Run();
}
void ContextualTasksServiceImpl::SetAiThreadSyncBridgeForTesting(
std::unique_ptr<AiThreadSyncBridge> bridge) {
// When provided a new service for testing, ensure observation of the old
// service is removed to avoid UAF when this service is destroyed.
ai_thread_observation_.Reset();
ai_thread_sync_bridge_ = std::move(bridge);
}
void ContextualTasksServiceImpl::SetGeminiThreadSyncBridgeForTesting(
std::unique_ptr<GeminiThreadSyncBridge> bridge) {
gemini_thread_observation_.Reset();
gemini_thread_sync_bridge_ = std::move(bridge);
}
void ContextualTasksServiceImpl::OnThreadDataStoreLoaded() {
on_data_loaded_barrier_.Run();
}
void ContextualTasksServiceImpl::OnThreadAddedOrUpdatedRemotely(
const std::vector<proto::AiThreadEntity>& threads) {
std::map<std::string, const proto::AiThreadEntity&> thread_map;
for (const auto& thread : threads) {
thread_map.emplace(thread.specifics().server_id(), thread);
}
for (auto& task_entry : tasks_) {
ContextualTask& task = task_entry.second;
if (!task.GetThread()) {
continue;
}
auto it = thread_map.find(task.GetThread()->server_id);
if (it == thread_map.end() ||
ToThreadType(it->second.specifics().type()) != task.GetThread()->type) {
continue;
}
// Check if the thread has changed for the task.
const proto::AiThreadEntity& new_thread_entity = it->second;
const std::optional<Thread>& old_thread = task.GetThread();
if (old_thread->conversation_turn_id !=
new_thread_entity.specifics().conversation_turn_id() ||
old_thread->title != new_thread_entity.specifics().title()) {
task.AddThread(Thread(
ThreadType::kAiMode, new_thread_entity.specifics().server_id(),
new_thread_entity.specifics().title(),
new_thread_entity.specifics().last_turn_time_unix_epoch_millis(),
new_thread_entity.specifics().conversation_turn_id()));
NotifyTaskUpdated(task, TriggerSource::kRemote);
}
// Remove the thread from the map. Any remaining threads will have tasks
// created for them at the end of this function.
thread_map.erase(it->first);
}
// Create tasks for any of the threads that were added or updated and didn't
// have an associated task.
for (const auto& [thread_id, thread_entity] : thread_map) {
Thread thread(ToThreadType(thread_entity.specifics().type()),
thread_entity.specifics().server_id(),
thread_entity.specifics().title(),
thread_entity.specifics().last_turn_time_unix_epoch_millis(),
thread_entity.specifics().conversation_turn_id());
ContextualTask new_task =
CreateTaskForThread(thread, supports_ephemeral_only_);
const auto it =
tasks_.emplace(new_task.GetTaskId(), std::move(new_task)).first;
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAdded,
weak_ptr_factory_.GetWeakPtr(), it->second,
TriggerSource::kRemote));
}
}
void ContextualTasksServiceImpl::OnThreadRemovedRemotely(
const std::vector<base::Uuid>& thread_ids) {
OnThreadRemovedRemotelyInternal(ThreadType::kAiMode, thread_ids);
}
void ContextualTasksServiceImpl::OnGeminiThreadDataStoreLoaded() {
on_data_loaded_barrier_.Run();
}
void ContextualTasksServiceImpl::OnGeminiThreadAddedOrUpdatedRemotely(
const std::vector<sync_pb::GeminiThreadSpecifics>& thread_specifics) {
std::map<std::string, const sync_pb::GeminiThreadSpecifics&> thread_map;
for (const auto& specifics : thread_specifics) {
thread_map.emplace(specifics.conversation_id(), specifics);
}
// Update existing tasks
for (auto& task_entry : tasks_) {
ContextualTask& task = task_entry.second;
if (!task.GetThread() || task.GetThread()->type != ThreadType::kGemini) {
continue;
}
auto it = thread_map.find(task.GetThread()->server_id);
if (it == thread_map.end()) {
continue;
}
// Check if the thread has changed for the task.
const sync_pb::GeminiThreadSpecifics& new_thread_entity = it->second;
const std::optional<Thread>& old_thread = task.GetThread();
if (old_thread->title != new_thread_entity.title()) {
task.AddThread(
Thread(ThreadType::kGemini, new_thread_entity.conversation_id(),
new_thread_entity.title(),
new_thread_entity.last_turn_time_unix_epoch_millis()));
NotifyTaskUpdated(task, TriggerSource::kRemote);
}
thread_map.erase(it->first);
}
// Create new task for specifics which didn't have an existing task.
for (const auto& [thread_id, specifics] : thread_map) {
Thread thread(ThreadType::kGemini, specifics.conversation_id(),
specifics.title(),
specifics.last_turn_time_unix_epoch_millis());
ContextualTask new_task =
CreateTaskForThread(thread, supports_ephemeral_only_);
const auto it =
tasks_.emplace(new_task.GetTaskId(), std::move(new_task)).first;
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAdded,
weak_ptr_factory_.GetWeakPtr(), it->second,
TriggerSource::kRemote));
}
}
void ContextualTasksServiceImpl::OnGeminiThreadRemovedRemotely(
const std::vector<base::Uuid>& thread_ids) {
OnThreadRemovedRemotelyInternal(ThreadType::kGemini, thread_ids);
}
std::pair<std::map<base::Uuid, ContextualTask>::iterator, bool>
ContextualTasksServiceImpl::FindOrCreateTask(const base::Uuid& task_id,
ThreadType thread_type,
const std::string& server_id) {
auto it = tasks_.find(task_id);
if (it != tasks_.end()) {
return {it, /*is_new_task=*/false};
}
// Task not found, but we have a task ID. Create the task on the fly unless
// we already have a task for this server ID.
std::optional<ContextualTask> existing_task =
GetTaskFromServerId(thread_type, server_id);
if (existing_task.has_value()) {
// TODO(nyquist): This is a temporary solution to avoid creating
// duplicate tasks. We should remove this once we have a better solution
// for handling out-of-sync tasks.
it = tasks_.find(existing_task->GetTaskId());
return {it, /*is_new_task=*/false};
}
it =
tasks_.emplace(task_id, ContextualTask(task_id, supports_ephemeral_only_))
.first;
return {it, /*is_new_task=*/true};
}
void ContextualTasksServiceImpl::RemoveTaskInternal(const base::Uuid& task_id,
TriggerSource source) {
auto task_it = tasks_.find(task_id);
if (task_it == tasks_.end()) {
return;
}
const auto& task = task_it->second;
for (const auto& tab_id : task.GetTabIds()) {
tab_to_task_.erase(tab_id);
}
tasks_.erase(task_it);
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskRemoved,
weak_ptr_factory_.GetWeakPtr(), task_id, source));
}
void ContextualTasksServiceImpl::OnThreadRemovedRemotelyInternal(
ThreadType thread_type_filter,
const std::vector<base::Uuid>& thread_ids) {
std::set<std::string> removed_thread_server_ids;
for (const auto& id : thread_ids) {
removed_thread_server_ids.insert(id.AsLowercaseString());
}
std::vector<base::Uuid> tasks_to_delete;
for (const auto& task_entry : tasks_) {
const ContextualTask& task = task_entry.second;
if (task.GetThread() && task.GetThread()->type == thread_type_filter) {
if (removed_thread_server_ids.count(task.GetThread()->server_id)) {
tasks_to_delete.push_back(task.GetTaskId());
}
}
}
for (const auto& task_id : tasks_to_delete) {
RemoveTaskInternal(task_id, TriggerSource::kRemote);
}
}
size_t ContextualTasksServiceImpl::GetTabIdMapSizeForTesting() const {
return tab_to_task_.size();
}
void ContextualTasksServiceImpl::NotifyTaskAdded(const ContextualTask& task,
TriggerSource source) {
for (auto& observer : observers_) {
observer.OnTaskAdded(task, source);
}
}
void ContextualTasksServiceImpl::NotifyTaskUpdated(const ContextualTask& task,
TriggerSource source) {
for (auto& observer : observers_) {
observer.OnTaskUpdated(task, source);
}
}
void ContextualTasksServiceImpl::NotifyTaskRemoved(const base::Uuid& task_id,
TriggerSource source) {
for (auto& observer : observers_) {
observer.OnTaskRemoved(task_id, source);
}
}
void ContextualTasksServiceImpl::NotifyTaskAssociatedToTab(
const base::Uuid& task_id,
SessionID tab_id) {
observers_.Notify(&ContextualTasksService::Observer::OnTaskAssociatedToTab,
task_id, tab_id);
}
void ContextualTasksServiceImpl::NotifyTaskDisassociatedFromTab(
const base::Uuid& task_id,
SessionID tab_id) {
observers_.Notify(
&ContextualTasksService::Observer::OnTaskDisassociatedFromTab, task_id,
tab_id);
}
ContextualTask ContextualTasksServiceImpl::AddTaskAndNotify(
ContextualTask task) {
auto it = tasks_.emplace(task.GetTaskId(), task).first;
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAdded,
weak_ptr_factory_.GetWeakPtr(), it->second,
TriggerSource::kLocal));
return it->second;
}
void ContextualTasksServiceImpl::OnDataStoresLoaded() {
is_initialized_ = true;
std::vector<ContextualTask> tasks = BuildTasks();
for (const auto& task : tasks) {
tasks_.emplace(task.GetTaskId(), task);
}
for (auto& observer : observers_) {
observer.OnContextualTasksServiceInitialized();
}
}
std::vector<ContextualTask> ContextualTasksServiceImpl::BuildTasks() const {
// First attempt to add threads to tasks that were persisted. Any threads that
// do not have a task will have one created.
base::flat_set<std::string> used_thread_ids;
std::vector<Thread> threads = ai_thread_sync_bridge_->GetThreads();
std::vector<ContextualTask> tasks;
for (const auto& thread : threads) {
if (used_thread_ids.contains(thread.server_id)) {
continue;
}
// While these tasks are not persisted, they're also not considered
// ephemeral since they are built using a user's threads.
// TODO(485520978): Use a UUIDv5 based on the thread ID here so the UUID
// is deterministic between restarts.
tasks.push_back(CreateTaskForThread(thread, supports_ephemeral_only_));
}
for (const auto& thread : gemini_thread_sync_bridge_->GetThreads()) {
if (used_thread_ids.contains(thread.server_id)) {
continue;
}
tasks.push_back(CreateTaskForThread(thread, supports_ephemeral_only_));
}
return tasks;
}
} // namespace contextual_tasks