blob: 70c107155ebae486970a51accaa79110ececda1f [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_model_download_progress_manager.h"
#include "base/functional/bind.h"
#include "base/types/pass_key.h"
#include "chrome/browser/ai/ai_utils.h"
#include "components/update_client/crx_update_item.h"
#include "components/update_client/update_client.h"
namespace on_device_ai {
namespace {
bool IsDownloadEvent(const component_updater::CrxUpdateItem& item) {
switch (item.state) {
case update_client::ComponentState::kDownloading:
case update_client::ComponentState::kUpdating:
case update_client::ComponentState::kUpToDate:
case update_client::ComponentState::kDownloadingDiff:
case update_client::ComponentState::kUpdatingDiff:
return item.downloaded_bytes >= 0 && item.total_bytes >= 0;
case update_client::ComponentState::kNew:
case update_client::ComponentState::kChecking:
case update_client::ComponentState::kCanUpdate:
case update_client::ComponentState::kUpdated:
case update_client::ComponentState::kUpdateError:
case update_client::ComponentState::kRun:
case update_client::ComponentState::kLastStatus:
return false;
}
}
} // namespace
AIModelDownloadProgressManager::AIModelDownloadProgressManager() = default;
AIModelDownloadProgressManager::~AIModelDownloadProgressManager() = default;
void AIModelDownloadProgressManager::AddObserver(
component_updater::ComponentUpdateService* component_update_service,
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
observer_remote,
base::flat_set<std::string> component_ids) {
reporters_.emplace(std::make_unique<Reporter>(*this, component_update_service,
std::move(observer_remote),
std::move(component_ids)));
}
void AIModelDownloadProgressManager::RemoveReporter(Reporter* reporter) {
CHECK(reporter);
reporters_.erase(reporter);
}
int AIModelDownloadProgressManager::GetNumberOfReporters() {
return reporters_.size();
}
AIModelDownloadProgressManager::Reporter::Reporter(
AIModelDownloadProgressManager& manager,
component_updater::ComponentUpdateService* component_update_service,
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
observer_remote,
base::flat_set<std::string> component_ids)
: manager_(manager),
observer_remote_(std::move(observer_remote)),
component_ids_(std::move(component_ids)) {
CHECK(component_update_service);
observer_remote_.set_disconnect_handler(base::BindOnce(
&AIModelDownloadProgressManager::Reporter::OnRemoteDisconnect,
weak_ptr_factory_.GetWeakPtr()));
// Don't watch any components that are already installed.
for (const auto& component_id : component_update_service->GetComponentIDs()) {
component_ids_.erase(component_id);
}
// If there are no component ids to observe, just send zero and one hundred
// percent.
if (component_ids_.empty()) {
observer_remote_->OnDownloadProgressUpdate(
0, AIUtils::kNormalizedDownloadProgressMax);
observer_remote_->OnDownloadProgressUpdate(
AIUtils::kNormalizedDownloadProgressMax,
AIUtils::kNormalizedDownloadProgressMax);
return;
}
// Watch for progress updates.
component_updater_observation_.Observe(component_update_service);
}
AIModelDownloadProgressManager::Reporter::~Reporter() = default;
void AIModelDownloadProgressManager::Reporter::OnRemoteDisconnect() {
// Destroy `this` when the `ModelDownloadProgressObserver` is garbage
// collected in the renderer.
manager_->RemoveReporter(this);
}
int64_t AIModelDownloadProgressManager::Reporter::GetDownloadedBytes() {
int64_t bytes_so_far = 0;
for (const auto& [id, downloaded_bytes] : observed_downloaded_bytes_) {
bytes_so_far += downloaded_bytes;
}
return bytes_so_far;
}
void AIModelDownloadProgressManager::Reporter::ProcessEvent(
const component_updater::CrxUpdateItem& item) {
CHECK_GE(item.downloaded_bytes, 0);
CHECK_GE(item.total_bytes, 0);
auto iter = observed_downloaded_bytes_.find(item.id);
// If we've seen this component before, then just update the downloaded bytes
// for it.
if (iter != observed_downloaded_bytes_.end()) {
iter->second = item.downloaded_bytes;
return;
}
// We shouldn't already be ready to report if a component is not in the
// `component_downloaded_bytes_` map.
CHECK(!ready_to_report_);
auto result =
observed_downloaded_bytes_.insert({item.id, item.downloaded_bytes});
CHECK(result.second);
components_total_bytes_ += item.total_bytes;
// If we have observed the downloaded bytes of all our components then we're
// ready to start reporting.
ready_to_report_ = observed_downloaded_bytes_.size() == component_ids_.size();
if (!ready_to_report_) {
return;
}
last_reported_progress_ = 0;
last_progress_time_ = base::TimeTicks::Now();
// We don't want to include already downloaded bytes in our progress
// calculation, so determine it for later calculations and remove it now
// from components_total_bytes_.
already_downloaded_bytes_ = GetDownloadedBytes();
components_total_bytes_ -= already_downloaded_bytes_;
// Must always fire the zero progress event first.
observer_remote_->OnDownloadProgressUpdate(
0, AIUtils::kNormalizedDownloadProgressMax);
}
void AIModelDownloadProgressManager::Reporter::OnEvent(
const component_updater::CrxUpdateItem& item) {
if (!IsDownloadEvent(item) || !component_ids_.contains(item.id)) {
return;
}
ProcessEvent(item);
// Wait for the total number of bytes to be downloaded to become determined.
if (!ready_to_report_) {
return;
}
// Calculate the total number of bytes downloaded so far. Don't include bytes
// that were already downloaded before we determined the total bytes.
int64_t bytes_so_far = GetDownloadedBytes() - already_downloaded_bytes_;
CHECK_GE(bytes_so_far, 0);
CHECK_LE(bytes_so_far, components_total_bytes_);
// Only report this event if we're at 100% or if more than 50ms has passed
// since the last time we reported a progress event.
if (bytes_so_far != components_total_bytes_) {
base::TimeTicks current_time = base::TimeTicks::Now();
if (current_time - last_progress_time_ <= base::Milliseconds(50)) {
return;
}
last_progress_time_ = current_time;
}
// Determine the normalized progress.
//
// If `components_total_bytes_` is zero, we should have downloaded zero bytes
// out of zero meaning we're at 100%. So set it to
// `kNormalizedDownloadProgressMax` to avoid dividing by zero in
// `NormalizeModelDownloadProgress`.
int normalized_progress = components_total_bytes_ == 0
? AIUtils::kNormalizedDownloadProgressMax
: AIUtils::NormalizeModelDownloadProgress(
bytes_so_far, components_total_bytes_);
// Don't report progress events we've already sent.
if (normalized_progress <= last_reported_progress_) {
CHECK(normalized_progress == last_reported_progress_);
return;
}
last_reported_progress_ = normalized_progress;
// Send the progress event to the observer.
observer_remote_->OnDownloadProgressUpdate(
normalized_progress, AIUtils::kNormalizedDownloadProgressMax);
}
} // namespace on_device_ai