| // Copyright 2025 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/contextual_tasks/internal/contextual_tasks_service_impl.h" |
| |
| #include <optional> |
| #include <utility> |
| |
| #include "base/containers/flat_set.h" |
| #include "base/feature_list.h" |
| #include "base/functional/bind.h" |
| #include "base/metrics/histogram_functions.h" |
| #include "base/notreached.h" |
| #include "base/strings/string_util.h" |
| #include "base/task/single_thread_task_runner.h" |
| #include "base/uuid.h" |
| #include "components/contextual_search/contextual_search_service.h" |
| #include "components/contextual_tasks/internal/composite_context_decorator.h" |
| #include "components/contextual_tasks/internal/conversions.h" |
| #include "components/contextual_tasks/public/account_utils.h" |
| #include "components/contextual_tasks/public/contextual_task.h" |
| #include "components/contextual_tasks/public/contextual_task_context.h" |
| #include "components/contextual_tasks/public/contextual_tasks_service.h" |
| #include "components/contextual_tasks/public/features.h" |
| #include "components/contextual_tasks/public/utils.h" |
| #include "components/omnibox/browser/aim_eligibility_service.h" |
| #include "components/omnibox/common/logger.h" |
| #include "components/prefs/pref_service.h" |
| #include "components/sessions/core/session_id.h" |
| #include "components/signin/public/identity_manager/identity_manager.h" |
| #include "components/sync/base/data_type.h" |
| #include "components/sync/base/report_unrecoverable_error.h" |
| #include "components/sync/model/client_tag_based_data_type_processor.h" |
| #include "components/sync/protocol/gemini_thread_specifics.pb.h" |
| #include "net/base/url_util.h" |
| #include "url/gurl.h" |
| |
| namespace contextual_tasks { |
| |
| namespace { |
| |
| struct MergeUrlResourcesResult { |
| std::vector<UrlResource> final_resources; |
| std::vector<UrlResource> added_or_updated_resources; |
| std::vector<base::Uuid> removed_resource_ids; |
| bool has_changes = false; |
| }; |
| |
| // Helper to find the index of a matching resource in existing_resources. |
| // Returns -1 if no match is found. |
| int FindMatchingResourceIndex( |
| const UrlResource& incoming_res, |
| const std::vector<UrlResource>& existing_resources, |
| const std::vector<bool>& existing_matched) { |
| for (size_t i = 0; i < existing_resources.size(); ++i) { |
| if (existing_matched[i]) { |
| continue; |
| } |
| |
| const auto& existing_res = existing_resources[i]; |
| // 1. Match by url_id. |
| if (incoming_res.url_id.is_valid() && |
| existing_res.url_id == incoming_res.url_id) { |
| return static_cast<int>(i); |
| } |
| |
| // 2. Match by context_id. |
| if (incoming_res.context_id.has_value() && |
| existing_res.context_id == incoming_res.context_id) { |
| return static_cast<int>(i); |
| } |
| |
| // 3. Match by url. |
| if (incoming_res.url.is_valid() && existing_res.url == incoming_res.url) { |
| return static_cast<int>(i); |
| } |
| } |
| |
| return -1; |
| } |
| |
| // Merges incoming resources with existing ones. |
| // Returns a result containing the final list of resources, as well as lists of |
| // added/updated and removed resources to be used for sync notifications. |
| MergeUrlResourcesResult MergeUrlResources( |
| const std::vector<UrlResource>& existing_resources, |
| std::vector<UrlResource> incoming_resources) { |
| MergeUrlResourcesResult result; |
| std::vector<bool> existing_matched(existing_resources.size(), false); |
| result.has_changes = false; |
| |
| for (auto& incoming_res : incoming_resources) { |
| int matched_index = FindMatchingResourceIndex( |
| incoming_res, existing_resources, existing_matched); |
| |
| if (matched_index != -1) { |
| existing_matched[matched_index] = true; |
| const auto& existing_res = existing_resources[matched_index]; |
| |
| // Copy over fields if missing. |
| if (!incoming_res.url_id.is_valid()) { |
| incoming_res.url_id = existing_res.url_id; |
| } |
| if (!incoming_res.url.is_valid()) { |
| incoming_res.url = existing_res.url; |
| } |
| if (!incoming_res.tab_id.has_value()) { |
| incoming_res.tab_id = existing_res.tab_id; |
| } |
| if (!incoming_res.title.has_value()) { |
| incoming_res.title = existing_res.title; |
| } |
| if (!incoming_res.context_id.has_value()) { |
| incoming_res.context_id = existing_res.context_id; |
| } |
| |
| // Check if anything changed. |
| if (incoming_res.url_id != existing_res.url_id || |
| incoming_res.url != existing_res.url || |
| incoming_res.tab_id != existing_res.tab_id || |
| incoming_res.title != existing_res.title || |
| incoming_res.context_id != existing_res.context_id) { |
| result.has_changes = true; |
| result.added_or_updated_resources.push_back(incoming_res); |
| } |
| } else { |
| // New resource. |
| result.has_changes = true; |
| if (!incoming_res.url_id.is_valid()) { |
| incoming_res.url_id = base::Uuid::GenerateRandomV4(); |
| } |
| result.added_or_updated_resources.push_back(incoming_res); |
| } |
| } |
| |
| // Check for removed resources. |
| for (size_t i = 0; i < existing_resources.size(); ++i) { |
| if (!existing_matched[i]) { |
| result.has_changes = true; |
| result.removed_resource_ids.push_back(existing_resources[i].url_id); |
| } |
| } |
| |
| // If no content changes, adds or removes were detected, check if the order |
| // has changed. |
| if (!result.has_changes && |
| existing_resources.size() == incoming_resources.size()) { |
| for (size_t i = 0; i < existing_resources.size(); ++i) { |
| if (existing_resources[i].url_id != incoming_resources[i].url_id) { |
| result.has_changes = true; |
| break; |
| } |
| } |
| } |
| |
| result.final_resources = std::move(incoming_resources); |
| return result; |
| } |
| |
| void RecordNumberOfActiveTasks(int count) { |
| base::UmaHistogramCounts100("ContextualTasks.ActiveTasksCount", count); |
| } |
| |
| ContextualTask CreateTaskForThread(const Thread& thread, bool is_ephemeral) { |
| ContextualTask task(base::Uuid::GenerateRandomV4(), is_ephemeral); |
| task.AddThread(thread); |
| task.SetTitle(thread.title); |
| return task; |
| } |
| |
| } // namespace |
| |
| ContextualTasksServiceImpl::ContextualTasksServiceImpl( |
| version_info::Channel channel, |
| syncer::RepeatingDataTypeStoreFactory data_type_store_factory, |
| std::unique_ptr<CompositeContextDecorator> composite_context_decorator, |
| AimEligibilityService* aim_eligibility_service, |
| signin::IdentityManager* identity_manager, |
| PrefService* pref_service, |
| bool supports_ephemeral_only, |
| base::RepeatingCallback<size_t()> get_active_task_count_callback, |
| base::RepeatingCallback<bool()> is_gemini_threads_enabled) |
| : composite_context_decorator_(std::move(composite_context_decorator)), |
| get_active_task_count_callback_( |
| std::move(get_active_task_count_callback)), |
| is_gemini_threads_enabled_callback_(is_gemini_threads_enabled), |
| aim_eligibility_service_(aim_eligibility_service), |
| identity_manager_(identity_manager), |
| pref_service_(pref_service), |
| supports_ephemeral_only_(supports_ephemeral_only) { |
| const base::RepeatingClosure& dump_stack = |
| base::BindRepeating(&syncer::ReportUnrecoverableError, channel); |
| auto ai_thread_processor = |
| std::make_unique<syncer::ClientTagBasedDataTypeProcessor>( |
| syncer::AI_THREAD, dump_stack); |
| ai_thread_sync_bridge_ = std::make_unique<AiThreadSyncBridge>( |
| std::move(ai_thread_processor), data_type_store_factory); |
| ai_thread_observation_.Observe(ai_thread_sync_bridge_.get()); |
| |
| auto gemini_thread_processor = |
| std::make_unique<syncer::ClientTagBasedDataTypeProcessor>( |
| syncer::GEMINI_THREAD, dump_stack); |
| gemini_thread_sync_bridge_ = std::make_unique<GeminiThreadSyncBridge>( |
| std::move(gemini_thread_processor), data_type_store_factory); |
| gemini_thread_observation_.Observe(gemini_thread_sync_bridge_.get()); |
| |
| // Wait for both AiThreadSyncBridge and GeminiThreadSyncBridge to finish |
| // loading their data store. |
| on_data_loaded_barrier_ = base::BarrierClosure( |
| 2, base::BindOnce(&ContextualTasksServiceImpl::OnDataStoresLoaded, |
| weak_ptr_factory_.GetWeakPtr())); |
| } |
| |
| ContextualTasksServiceImpl::~ContextualTasksServiceImpl() { |
| for (auto& observer : observers_) { |
| observer.OnWillBeDestroyed(); |
| } |
| } |
| |
| FeatureEligibility ContextualTasksServiceImpl::GetFeatureEligibility() { |
| return {base::FeatureList::IsEnabled(contextual_tasks::kContextualTasks), |
| aim_eligibility_service_->IsAimEligible(), |
| aim_eligibility_service_->IsCobrowseEligible(), |
| contextual_search::ContextualSearchService::IsContextSharingEnabled( |
| pref_service_)}; |
| } |
| |
| bool ContextualTasksServiceImpl::IsInitialized() { |
| return is_initialized_; |
| } |
| |
| ContextualTask ContextualTasksServiceImpl::CreateTask() { |
| // Note that we are adding 1 to the number of active tasks because the |
| // histogram is recorded before the task is created, but it's not |
| // immediately reflected in the active tasks count. |
| RecordNumberOfActiveTasks(get_active_task_count_callback_.Run() + 1); |
| |
| base::Uuid task_id = base::Uuid::GenerateRandomV4(); |
| ContextualTask task(task_id, supports_ephemeral_only_); |
| return AddTaskAndNotify(std::move(task)); |
| } |
| |
| ContextualTask ContextualTasksServiceImpl::CreateTaskFromUrl(const GURL& url) { |
| // Note that we are adding 1 to the number of active tasks because the |
| // histogram is recorded before the task is created, but it's not |
| // immediately reflected in the active tasks count. |
| RecordNumberOfActiveTasks(get_active_task_count_callback_.Run() + 1); |
| |
| base::Uuid task_id = base::Uuid::GenerateRandomV4(); |
| bool is_ephemeral = supports_ephemeral_only_ || |
| !IsUrlForPrimaryAccount(identity_manager_, url); |
| ContextualTask task(task_id, is_ephemeral); |
| return AddTaskAndNotify(std::move(task)); |
| } |
| |
| void ContextualTasksServiceImpl::GetTaskById( |
| const base::Uuid& task_id, |
| base::OnceCallback<void(std::optional<ContextualTask>)> callback) const { |
| auto it = tasks_.find(task_id); |
| std::optional<ContextualTask> result; |
| if (it != tasks_.end()) { |
| result = it->second; |
| } |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(std::move(callback), std::move(result))); |
| } |
| |
| void ContextualTasksServiceImpl::GetTasks( |
| base::OnceCallback<void(std::vector<ContextualTask>)> callback) const { |
| std::vector<ContextualTask> tasks; |
| for (const auto& pair : tasks_) { |
| ContextualTask task = pair.second; |
| if (task.IsEphemeral() || supports_ephemeral_only_) { |
| continue; |
| } |
| tasks.push_back(task); |
| } |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(std::move(callback), std::move(tasks))); |
| } |
| |
| void ContextualTasksServiceImpl::DeleteTask(const base::Uuid& task_id) { |
| RemoveTaskInternal(task_id, TriggerSource::kLocal); |
| } |
| |
| void ContextualTasksServiceImpl::UpdateThreadForTask( |
| const base::Uuid& task_id, |
| ThreadType thread_type, |
| const std::string& server_id, |
| std::optional<std::string> conversation_turn_id, |
| std::optional<std::string> title) { |
| auto [it, is_new_task] = FindOrCreateTask(task_id, thread_type, server_id); |
| |
| // If a thread already exists and its server ID does not match the new server |
| // ID, it indicates a mismatch or an attempt to update a different thread, so |
| // we return. |
| std::optional<Thread> thread = it->second.GetThread(); |
| // If the task doesn't exist doesn't have the right thread, return early. |
| if (thread.has_value() && |
| (thread->server_id != server_id || thread->type != thread_type)) { |
| return; |
| } |
| |
| // Determine the new title and conversation turn ID. If provided, use them; |
| // otherwise, retain the existing values if a thread already exists. |
| const std::string& new_title = |
| title.value_or(thread.has_value() ? thread->title : ""); |
| std::optional<std::string> new_conversation_turn_id = std::nullopt; |
| if (conversation_turn_id.has_value()) { |
| new_conversation_turn_id = conversation_turn_id; |
| } else if (thread.has_value()) { |
| new_conversation_turn_id = thread->conversation_turn_id; |
| } |
| |
| // Add or update the thread information within the task. |
| it->second.AddThread(Thread(thread_type, server_id, new_title, |
| base::Time::Now().InMillisecondsSinceUnixEpoch(), |
| new_conversation_turn_id)); |
| |
| if (is_new_task) { |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAdded, |
| weak_ptr_factory_.GetWeakPtr(), it->second, |
| TriggerSource::kLocal)); |
| } else { |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskUpdated, |
| weak_ptr_factory_.GetWeakPtr(), it->second, |
| TriggerSource::kLocal)); |
| } |
| } |
| |
| void ContextualTasksServiceImpl::RemoveThreadFromTask( |
| const base::Uuid& task_id, |
| ThreadType type, |
| const std::string& server_id) { |
| auto it = tasks_.find(task_id); |
| if (it != tasks_.end()) { |
| DCHECK(it->second.GetThread().has_value()); |
| switch (type) { |
| case ThreadType::kAiMode: |
| ai_thread_sync_bridge_->DeleteThread(it->second.GetThread().value()); |
| break; |
| case ThreadType::kGemini: |
| gemini_thread_sync_bridge_->DeleteThread( |
| it->second.GetThread().value()); |
| break; |
| case ThreadType::kUnknown: |
| default: |
| NOTREACHED(); |
| } |
| it->second.RemoveThread(type, server_id); |
| // If the task no longer has any thread, remove it. |
| if (!it->second.GetThread()) { |
| DeleteTask(task_id); |
| } |
| } |
| } |
| |
| std::optional<ContextualTask> ContextualTasksServiceImpl::GetTaskFromServerId( |
| ThreadType thread_type, |
| const std::string& server_id) { |
| for (const auto& pair : tasks_) { |
| std::optional<Thread> thread = pair.second.GetThread(); |
| if (thread.has_value() && thread->type == thread_type && |
| thread->server_id == server_id) { |
| return pair.second; |
| } |
| } |
| return std::nullopt; |
| } |
| |
| void ContextualTasksServiceImpl::AttachUrlToTask(const base::Uuid& task_id, |
| const GURL& url) { |
| auto it = tasks_.find(task_id); |
| if (it != tasks_.end()) { |
| UrlResource url_resource(base::Uuid::GenerateRandomV4(), url); |
| if (it->second.AddUrlResource(url_resource)) { |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskUpdated, |
| weak_ptr_factory_.GetWeakPtr(), it->second, |
| TriggerSource::kLocal)); |
| } |
| } |
| } |
| |
| void ContextualTasksServiceImpl::DetachUrlFromTask(const base::Uuid& task_id, |
| const GURL& url) { |
| auto it = tasks_.find(task_id); |
| if (it != tasks_.end()) { |
| std::optional<base::Uuid> url_id = it->second.RemoveUrl(url); |
| if (url_id) { |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskUpdated, |
| weak_ptr_factory_.GetWeakPtr(), it->second, |
| TriggerSource::kLocal)); |
| } |
| } |
| } |
| |
| void ContextualTasksServiceImpl::SetUrlResourcesFromServer( |
| const base::Uuid& task_id, |
| std::vector<UrlResource> url_resources) { |
| auto it = tasks_.find(task_id); |
| if (it == tasks_.end()) { |
| return; |
| } |
| |
| // Merge incoming resources with existing ones and calculate the diff. |
| ContextualTask& task = it->second; |
| MergeUrlResourcesResult result = |
| MergeUrlResources(task.GetUrlResources(), std::move(url_resources)); |
| |
| if (!result.has_changes) { |
| return; |
| } |
| |
| // Update the local in-memory task state. |
| task.SetUrlResourcesFromServer(std::move(result.final_resources)); |
| |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskUpdated, |
| weak_ptr_factory_.GetWeakPtr(), task, |
| TriggerSource::kLocal)); |
| } |
| |
| void ContextualTasksServiceImpl::AssociateTabWithTask(const base::Uuid& task_id, |
| SessionID tab_id) { |
| auto it = tasks_.find(task_id); |
| if (it == tasks_.end()) { |
| return; |
| } |
| |
| std::optional<ContextualTask> current_task = GetContextualTaskForTab(tab_id); |
| if (current_task && current_task->GetTaskId() != task_id) { |
| DisassociateTabFromTask(current_task->GetTaskId(), tab_id); |
| } |
| |
| tab_to_task_[tab_id] = task_id; |
| it->second.AddTabId(tab_id); |
| |
| base::UmaHistogramCounts100("ContextualTasks.TabAffiliationCount", |
| GetTabsAssociatedWithTask(task_id).size()); |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAssociatedToTab, |
| weak_ptr_factory_.GetWeakPtr(), task_id, tab_id)); |
| } |
| |
| void ContextualTasksServiceImpl::DisassociateTabFromTask( |
| const base::Uuid& task_id, |
| SessionID tab_id) { |
| tab_to_task_.erase(tab_id); |
| auto it = tasks_.find(task_id); |
| if (it != tasks_.end()) { |
| it->second.RemoveTabId(tab_id); |
| |
| if (base::FeatureList::IsEnabled( |
| kContextualTasksRemoveTasksWithoutThreadsOrTabAssociations)) { |
| // If the task doesn't have a thread and tabs associated with it, |
| // it can be safely removed here. |
| if (!it->second.GetThread() && it->second.GetTabIds().empty()) { |
| RemoveTaskInternal(task_id, TriggerSource::kLocal); |
| } |
| } |
| } |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce( |
| &ContextualTasksServiceImpl::NotifyTaskDisassociatedFromTab, |
| weak_ptr_factory_.GetWeakPtr(), task_id, tab_id)); |
| } |
| |
| void ContextualTasksServiceImpl::DisassociateAllTabsFromTask( |
| const base::Uuid& task_id) { |
| for (const SessionID& tab_id : GetTabsAssociatedWithTask(task_id)) { |
| DisassociateTabFromTask(task_id, tab_id); |
| } |
| } |
| |
| std::optional<ContextualTask> |
| ContextualTasksServiceImpl::GetContextualTaskForTab(SessionID tab_id) const { |
| auto it = tab_to_task_.find(tab_id); |
| if (it != tab_to_task_.end()) { |
| auto task_it = tasks_.find(it->second); |
| if (task_it != tasks_.end()) { |
| return task_it->second; |
| } |
| } |
| return std::nullopt; |
| } |
| |
| std::vector<SessionID> ContextualTasksServiceImpl::GetTabsAssociatedWithTask( |
| const base::Uuid& task_id) const { |
| std::vector<SessionID> associated_tabs; |
| for (const auto& pair : tab_to_task_) { |
| if (pair.second == task_id) { |
| associated_tabs.push_back(pair.first); |
| } |
| } |
| return associated_tabs; |
| } |
| |
| void ContextualTasksServiceImpl::GetContextForTask( |
| const base::Uuid& task_id, |
| const std::set<ContextualTaskContextSource>& sources, |
| std::unique_ptr<ContextDecorationParams> params, |
| base::OnceCallback<void(std::unique_ptr<ContextualTaskContext>)> |
| context_callback) { |
| auto it = tasks_.find(task_id); |
| if (it == tasks_.end()) { |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(std::move(context_callback), |
| std::unique_ptr<ContextualTaskContext>())); |
| return; |
| } |
| |
| composite_context_decorator_->DecorateContext( |
| std::make_unique<ContextualTaskContext>(it->second), sources, |
| std::move(params), std::move(context_callback)); |
| } |
| |
| void ContextualTasksServiceImpl::GetThreadUrlFromTaskId( |
| const base::Uuid& task_id, |
| const std::string& locale, |
| omnibox::ChromeAimEntryPoint entry_point, |
| base::OnceCallback<void(GURL)> callback) { |
| GetTaskById( |
| task_id, |
| base::BindOnce( |
| [](const base::Uuid& task_id, const std::string& locale, |
| omnibox::ChromeAimEntryPoint entry_point, |
| base::OnceCallback<void(GURL)> callback, |
| std::optional<ContextualTask> task) { |
| OMNIBOX_LOG("nav_trace") |
| << "ContextualTasks navigation trace: " |
| "GetThreadUrlFromTaskId callback called"; |
| |
| // AIM is the default if no task or thread is found. |
| GURL url = GetDefaultAimUrl(locale, entry_point); |
| |
| if (!task) { |
| OMNIBOX_LOG("nav_trace") |
| << "ContextualTasks navigation trace: " |
| "GetThreadUrlFromTaskId returning early, no task. " |
| "Returning default: " |
| << url; |
| std::move(callback).Run(url); |
| return; |
| } |
| |
| std::optional<Thread> thread = task->GetThread(); |
| if (!thread) { |
| OMNIBOX_LOG("nav_trace") |
| << "ContextualTasks navigation trace: " |
| "GetThreadUrlFromTaskId returning early, no thread or " |
| "Gemini thread. Returning default: " |
| << url; |
| std::move(callback).Run(url); |
| return; |
| } |
| |
| if (thread->type == ThreadType::kAiMode) { |
| // Attach the thread ID and the most recent turn ID to the |
| // URL. A query parameter needs to be present, but its |
| // value is not used for continued threads. |
| url = net::AppendQueryParameter(url, "q", thread->title); |
| |
| // There are rare cases where the mstk may not be present before |
| // this url is requested. If that is the case, do not attempt to |
| // include it in the url. |
| if (thread->conversation_turn_id) { |
| url = net::AppendQueryParameter( |
| url, "mstk", thread->conversation_turn_id.value()); |
| } |
| url = net::AppendQueryParameter(url, "mtid", thread->server_id); |
| |
| } else if (thread->type == ThreadType::kGemini) { |
| url = GURL(GetContextualTasksGeminiBaseUrl()); |
| GURL::Replacements replacements; |
| std::string path = url.GetPath(); |
| if (path.back() != '/') { |
| path += '/'; |
| } |
| std::string server_id = thread->server_id; |
| if (base::StartsWith(server_id, "c_")) { |
| server_id.erase(0, 2); |
| } |
| path += server_id; |
| replacements.SetPathStr(path); |
| |
| url = url.ReplaceComponents(replacements); |
| } |
| |
| OMNIBOX_LOG("nav_trace") << "ContextualTasks navigation trace: " |
| "GetThreadUrlFromTaskId returning URL: " |
| << url; |
| std::move(callback).Run(url); |
| }, |
| task_id, locale, entry_point, std::move(callback))); |
| } |
| |
| void ContextualTasksServiceImpl::AddObserver( |
| ContextualTasksService::Observer* observer) { |
| observers_.AddObserver(observer); |
| } |
| |
| void ContextualTasksServiceImpl::RemoveObserver( |
| ContextualTasksService::Observer* observer) { |
| observers_.RemoveObserver(observer); |
| } |
| |
| base::WeakPtr<syncer::DataTypeControllerDelegate> |
| ContextualTasksServiceImpl::GetAiThreadControllerDelegate() { |
| return ai_thread_sync_bridge_->change_processor()->GetControllerDelegate(); |
| } |
| |
| base::WeakPtr<syncer::DataTypeControllerDelegate> |
| ContextualTasksServiceImpl::GetGeminiThreadControllerDelegate() { |
| return gemini_thread_sync_bridge_->change_processor() |
| ->GetControllerDelegate(); |
| } |
| |
| bool ContextualTasksServiceImpl::IsGeminiThreadsEligible() { |
| return is_gemini_threads_enabled_callback_ && |
| is_gemini_threads_enabled_callback_.Run(); |
| } |
| |
| void ContextualTasksServiceImpl::SetAiThreadSyncBridgeForTesting( |
| std::unique_ptr<AiThreadSyncBridge> bridge) { |
| // When provided a new service for testing, ensure observation of the old |
| // service is removed to avoid UAF when this service is destroyed. |
| ai_thread_observation_.Reset(); |
| ai_thread_sync_bridge_ = std::move(bridge); |
| } |
| |
| void ContextualTasksServiceImpl::SetGeminiThreadSyncBridgeForTesting( |
| std::unique_ptr<GeminiThreadSyncBridge> bridge) { |
| gemini_thread_observation_.Reset(); |
| gemini_thread_sync_bridge_ = std::move(bridge); |
| } |
| |
| void ContextualTasksServiceImpl::OnThreadDataStoreLoaded() { |
| on_data_loaded_barrier_.Run(); |
| } |
| |
| void ContextualTasksServiceImpl::OnThreadAddedOrUpdatedRemotely( |
| const std::vector<proto::AiThreadEntity>& threads) { |
| std::map<std::string, const proto::AiThreadEntity&> thread_map; |
| for (const auto& thread : threads) { |
| thread_map.emplace(thread.specifics().server_id(), thread); |
| } |
| |
| for (auto& task_entry : tasks_) { |
| ContextualTask& task = task_entry.second; |
| if (!task.GetThread()) { |
| continue; |
| } |
| |
| auto it = thread_map.find(task.GetThread()->server_id); |
| if (it == thread_map.end() || |
| ToThreadType(it->second.specifics().type()) != task.GetThread()->type) { |
| continue; |
| } |
| |
| // Check if the thread has changed for the task. |
| const proto::AiThreadEntity& new_thread_entity = it->second; |
| const std::optional<Thread>& old_thread = task.GetThread(); |
| if (old_thread->conversation_turn_id != |
| new_thread_entity.specifics().conversation_turn_id() || |
| old_thread->title != new_thread_entity.specifics().title()) { |
| task.AddThread(Thread( |
| ThreadType::kAiMode, new_thread_entity.specifics().server_id(), |
| new_thread_entity.specifics().title(), |
| new_thread_entity.specifics().last_turn_time_unix_epoch_millis(), |
| new_thread_entity.specifics().conversation_turn_id())); |
| NotifyTaskUpdated(task, TriggerSource::kRemote); |
| } |
| |
| // Remove the thread from the map. Any remaining threads will have tasks |
| // created for them at the end of this function. |
| thread_map.erase(it->first); |
| } |
| |
| // Create tasks for any of the threads that were added or updated and didn't |
| // have an associated task. |
| for (const auto& [thread_id, thread_entity] : thread_map) { |
| Thread thread(ToThreadType(thread_entity.specifics().type()), |
| thread_entity.specifics().server_id(), |
| thread_entity.specifics().title(), |
| thread_entity.specifics().last_turn_time_unix_epoch_millis(), |
| thread_entity.specifics().conversation_turn_id()); |
| ContextualTask new_task = |
| CreateTaskForThread(thread, supports_ephemeral_only_); |
| const auto it = |
| tasks_.emplace(new_task.GetTaskId(), std::move(new_task)).first; |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAdded, |
| weak_ptr_factory_.GetWeakPtr(), it->second, |
| TriggerSource::kRemote)); |
| } |
| } |
| |
| void ContextualTasksServiceImpl::OnThreadRemovedRemotely( |
| const std::vector<base::Uuid>& thread_ids) { |
| OnThreadRemovedRemotelyInternal(ThreadType::kAiMode, thread_ids); |
| } |
| |
| void ContextualTasksServiceImpl::OnGeminiThreadDataStoreLoaded() { |
| on_data_loaded_barrier_.Run(); |
| } |
| |
| void ContextualTasksServiceImpl::OnGeminiThreadAddedOrUpdatedRemotely( |
| const std::vector<sync_pb::GeminiThreadSpecifics>& thread_specifics) { |
| std::map<std::string, const sync_pb::GeminiThreadSpecifics&> thread_map; |
| for (const auto& specifics : thread_specifics) { |
| thread_map.emplace(specifics.conversation_id(), specifics); |
| } |
| |
| // Update existing tasks |
| for (auto& task_entry : tasks_) { |
| ContextualTask& task = task_entry.second; |
| if (!task.GetThread() || task.GetThread()->type != ThreadType::kGemini) { |
| continue; |
| } |
| |
| auto it = thread_map.find(task.GetThread()->server_id); |
| if (it == thread_map.end()) { |
| continue; |
| } |
| |
| // Check if the thread has changed for the task. |
| const sync_pb::GeminiThreadSpecifics& new_thread_entity = it->second; |
| const std::optional<Thread>& old_thread = task.GetThread(); |
| if (old_thread->title != new_thread_entity.title()) { |
| task.AddThread( |
| Thread(ThreadType::kGemini, new_thread_entity.conversation_id(), |
| new_thread_entity.title(), |
| new_thread_entity.last_turn_time_unix_epoch_millis())); |
| NotifyTaskUpdated(task, TriggerSource::kRemote); |
| } |
| |
| thread_map.erase(it->first); |
| } |
| |
| // Create new task for specifics which didn't have an existing task. |
| for (const auto& [thread_id, specifics] : thread_map) { |
| Thread thread(ThreadType::kGemini, specifics.conversation_id(), |
| specifics.title(), |
| specifics.last_turn_time_unix_epoch_millis()); |
| ContextualTask new_task = |
| CreateTaskForThread(thread, supports_ephemeral_only_); |
| const auto it = |
| tasks_.emplace(new_task.GetTaskId(), std::move(new_task)).first; |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAdded, |
| weak_ptr_factory_.GetWeakPtr(), it->second, |
| TriggerSource::kRemote)); |
| } |
| } |
| |
| void ContextualTasksServiceImpl::OnGeminiThreadRemovedRemotely( |
| const std::vector<base::Uuid>& thread_ids) { |
| OnThreadRemovedRemotelyInternal(ThreadType::kGemini, thread_ids); |
| } |
| |
| std::pair<std::map<base::Uuid, ContextualTask>::iterator, bool> |
| ContextualTasksServiceImpl::FindOrCreateTask(const base::Uuid& task_id, |
| ThreadType thread_type, |
| const std::string& server_id) { |
| auto it = tasks_.find(task_id); |
| if (it != tasks_.end()) { |
| return {it, /*is_new_task=*/false}; |
| } |
| |
| // Task not found, but we have a task ID. Create the task on the fly unless |
| // we already have a task for this server ID. |
| std::optional<ContextualTask> existing_task = |
| GetTaskFromServerId(thread_type, server_id); |
| if (existing_task.has_value()) { |
| // TODO(nyquist): This is a temporary solution to avoid creating |
| // duplicate tasks. We should remove this once we have a better solution |
| // for handling out-of-sync tasks. |
| it = tasks_.find(existing_task->GetTaskId()); |
| return {it, /*is_new_task=*/false}; |
| } |
| |
| it = |
| tasks_.emplace(task_id, ContextualTask(task_id, supports_ephemeral_only_)) |
| .first; |
| return {it, /*is_new_task=*/true}; |
| } |
| |
| void ContextualTasksServiceImpl::RemoveTaskInternal(const base::Uuid& task_id, |
| TriggerSource source) { |
| auto task_it = tasks_.find(task_id); |
| if (task_it == tasks_.end()) { |
| return; |
| } |
| |
| const auto& task = task_it->second; |
| for (const auto& tab_id : task.GetTabIds()) { |
| tab_to_task_.erase(tab_id); |
| } |
| |
| tasks_.erase(task_it); |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, |
| base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskRemoved, |
| weak_ptr_factory_.GetWeakPtr(), task_id, source)); |
| } |
| |
| void ContextualTasksServiceImpl::OnThreadRemovedRemotelyInternal( |
| ThreadType thread_type_filter, |
| const std::vector<base::Uuid>& thread_ids) { |
| std::set<std::string> removed_thread_server_ids; |
| for (const auto& id : thread_ids) { |
| removed_thread_server_ids.insert(id.AsLowercaseString()); |
| } |
| |
| std::vector<base::Uuid> tasks_to_delete; |
| for (const auto& task_entry : tasks_) { |
| const ContextualTask& task = task_entry.second; |
| if (task.GetThread() && task.GetThread()->type == thread_type_filter) { |
| if (removed_thread_server_ids.count(task.GetThread()->server_id)) { |
| tasks_to_delete.push_back(task.GetTaskId()); |
| } |
| } |
| } |
| |
| for (const auto& task_id : tasks_to_delete) { |
| RemoveTaskInternal(task_id, TriggerSource::kRemote); |
| } |
| } |
| |
| size_t ContextualTasksServiceImpl::GetTabIdMapSizeForTesting() const { |
| return tab_to_task_.size(); |
| } |
| |
| void ContextualTasksServiceImpl::NotifyTaskAdded(const ContextualTask& task, |
| TriggerSource source) { |
| for (auto& observer : observers_) { |
| observer.OnTaskAdded(task, source); |
| } |
| } |
| |
| void ContextualTasksServiceImpl::NotifyTaskUpdated(const ContextualTask& task, |
| TriggerSource source) { |
| for (auto& observer : observers_) { |
| observer.OnTaskUpdated(task, source); |
| } |
| } |
| |
| void ContextualTasksServiceImpl::NotifyTaskRemoved(const base::Uuid& task_id, |
| TriggerSource source) { |
| for (auto& observer : observers_) { |
| observer.OnTaskRemoved(task_id, source); |
| } |
| } |
| |
| void ContextualTasksServiceImpl::NotifyTaskAssociatedToTab( |
| const base::Uuid& task_id, |
| SessionID tab_id) { |
| observers_.Notify(&ContextualTasksService::Observer::OnTaskAssociatedToTab, |
| task_id, tab_id); |
| } |
| |
| void ContextualTasksServiceImpl::NotifyTaskDisassociatedFromTab( |
| const base::Uuid& task_id, |
| SessionID tab_id) { |
| observers_.Notify( |
| &ContextualTasksService::Observer::OnTaskDisassociatedFromTab, task_id, |
| tab_id); |
| } |
| |
| ContextualTask ContextualTasksServiceImpl::AddTaskAndNotify( |
| ContextualTask task) { |
| auto it = tasks_.emplace(task.GetTaskId(), task).first; |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(&ContextualTasksServiceImpl::NotifyTaskAdded, |
| weak_ptr_factory_.GetWeakPtr(), it->second, |
| TriggerSource::kLocal)); |
| return it->second; |
| } |
| |
| void ContextualTasksServiceImpl::OnDataStoresLoaded() { |
| is_initialized_ = true; |
| std::vector<ContextualTask> tasks = BuildTasks(); |
| for (const auto& task : tasks) { |
| tasks_.emplace(task.GetTaskId(), task); |
| } |
| for (auto& observer : observers_) { |
| observer.OnContextualTasksServiceInitialized(); |
| } |
| } |
| |
| std::vector<ContextualTask> ContextualTasksServiceImpl::BuildTasks() const { |
| // First attempt to add threads to tasks that were persisted. Any threads that |
| // do not have a task will have one created. |
| base::flat_set<std::string> used_thread_ids; |
| |
| std::vector<Thread> threads = ai_thread_sync_bridge_->GetThreads(); |
| std::vector<ContextualTask> tasks; |
| for (const auto& thread : threads) { |
| if (used_thread_ids.contains(thread.server_id)) { |
| continue; |
| } |
| // While these tasks are not persisted, they're also not considered |
| // ephemeral since they are built using a user's threads. |
| // TODO(485520978): Use a UUIDv5 based on the thread ID here so the UUID |
| // is deterministic between restarts. |
| tasks.push_back(CreateTaskForThread(thread, supports_ephemeral_only_)); |
| } |
| for (const auto& thread : gemini_thread_sync_bridge_->GetThreads()) { |
| if (used_thread_ids.contains(thread.server_id)) { |
| continue; |
| } |
| tasks.push_back(CreateTaskForThread(thread, supports_ephemeral_only_)); |
| } |
| |
| return tasks; |
| } |
| |
| } // namespace contextual_tasks |