blob: 3dc9215ee728425c3a64118acaad53bf514680c4 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_model_download_progress_manager.h"
#include "base/functional/bind.h"
#include "chrome/browser/ai/ai_utils.h"
namespace on_device_ai {
AIModelDownloadProgressManager::AIModelDownloadProgressManager() = default;
AIModelDownloadProgressManager::~AIModelDownloadProgressManager() = default;
void AIModelDownloadProgressManager::AddObserver(
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
observer_remote,
base::flat_set<std::unique_ptr<Component>> components) {
reporters_.emplace(std::make_unique<Reporter>(
*this, std::move(observer_remote), std::move(components)));
}
void AIModelDownloadProgressManager::RemoveReporter(Reporter* reporter) {
CHECK(reporter);
reporters_.erase(reporter);
}
int AIModelDownloadProgressManager::GetNumberOfReporters() {
return reporters_.size();
}
AIModelDownloadProgressManager::Component::Component() = default;
AIModelDownloadProgressManager::Component::~Component() = default;
AIModelDownloadProgressManager::Component::Component(Component&&) = default;
void AIModelDownloadProgressManager::Component::SetDownloadedBytes(
int64_t downloaded_bytes) {
if (downloaded_bytes == downloaded_bytes_) {
return;
}
// `downloaded_bytes` should be monotonically increasing.
CHECK_GT(downloaded_bytes, downloaded_bytes_.value_or(-1));
CHECK_GE(downloaded_bytes, 0);
downloaded_bytes_ = downloaded_bytes;
MaybeRunEventCallback();
}
void AIModelDownloadProgressManager::Component::SetTotalBytes(
int64_t total_bytes) {
if (total_bytes == total_bytes_) {
return;
}
// `total_bytes_` should never change after it's been set.
CHECK(!total_bytes_.has_value());
CHECK_GE(total_bytes, 0);
total_bytes_ = total_bytes;
MaybeRunEventCallback();
}
void AIModelDownloadProgressManager::Component::SetEventCallback(
EventCallback event_callback) {
event_callback_ = std::move(event_callback);
MaybeRunEventCallback();
}
void AIModelDownloadProgressManager::Component::MaybeRunEventCallback() {
if (!determined_bytes() || !event_callback_) {
return;
}
event_callback_.Run(
*this, !last_determined_bytes_,
downloaded_bytes_.value() - last_downloaded_bytes_.value_or(0));
last_downloaded_bytes_ = downloaded_bytes_;
last_determined_bytes_ = true;
}
AIModelDownloadProgressManager::Reporter::Reporter(
AIModelDownloadProgressManager& manager,
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
observer_remote,
base::flat_set<std::unique_ptr<Component>> components)
: manager_(manager),
observer_remote_(std::move(observer_remote)),
components_(std::move(components)) {
observer_remote_.set_disconnect_handler(base::BindOnce(
&Reporter::OnRemoteDisconnect, weak_ptr_factory_.GetWeakPtr()));
// Don't watch any components that are already installed.
for (auto iter = components_.begin(); iter != components_.end();) {
if ((*iter)->is_complete()) {
iter = components_.erase(iter);
continue;
}
Component& component = *iter->get();
// Watch for progress updates.
component.SetEventCallback(base::BindRepeating(
&Reporter::OnEvent, weak_ptr_factory_.GetWeakPtr()));
++iter;
}
// If there are no component ids to observe, just send zero and one hundred
// percent.
if (components_.empty()) {
observer_remote_->OnDownloadProgressUpdate(
0, AIUtils::kNormalizedDownloadProgressMax);
observer_remote_->OnDownloadProgressUpdate(
AIUtils::kNormalizedDownloadProgressMax,
AIUtils::kNormalizedDownloadProgressMax);
}
}
AIModelDownloadProgressManager::Reporter::~Reporter() = default;
void AIModelDownloadProgressManager::Reporter::OnRemoteDisconnect() {
// Destroy `this` when the `ModelDownloadProgressObserver` is garbage
// collected in the renderer.
manager_->RemoveReporter(this);
}
bool AIModelDownloadProgressManager::Reporter::ReadyToReport() {
// If we have observed the downloaded bytes of all our components then we're
// ready to start reporting.
return determined_components_ == static_cast<int>(components_.size());
}
void AIModelDownloadProgressManager::Reporter::ProcessEvent(
const Component& component,
bool just_determined,
int64_t downloaded_bytes_delta) {
// Should only receive events for components that have their bytes determined.
CHECK(component.determined_bytes());
CHECK_GE(component.downloaded_bytes(), 0);
CHECK_GE(component.total_bytes(), 0);
components_downloaded_bytes_ += downloaded_bytes_delta;
// If we haven't just determined bytes, that means we've already seen this
// component and we don't need to do anything further.
if (!just_determined) {
return;
}
// We shouldn't already be ready to report if a component's bytes have just
// been determined.
CHECK(!ReadyToReport());
components_total_bytes_ += component.total_bytes();
determined_components_++;
if (!ReadyToReport()) {
return;
}
last_reported_progress_ = 0;
last_progress_time_ = base::TimeTicks::Now();
// We don't want to include already downloaded bytes in our progress
// calculation, so determine it for later calculations and remove it now
// from components_total_bytes_.
components_total_bytes_ -= components_downloaded_bytes_;
components_downloaded_bytes_ = 0;
CHECK_GE(components_total_bytes_, 0);
// Must always fire the zero progress event first.
observer_remote_->OnDownloadProgressUpdate(
0, AIUtils::kNormalizedDownloadProgressMax);
}
void AIModelDownloadProgressManager::Reporter::OnEvent(
Component& component,
bool just_determined,
int64_t downloaded_bytes_delta) {
ProcessEvent(component, just_determined, downloaded_bytes_delta);
// Wait for the total number of bytes to be downloaded to become determined.
if (!ReadyToReport()) {
return;
}
CHECK_GE(components_downloaded_bytes_, 0);
CHECK_LE(components_downloaded_bytes_, components_total_bytes_);
// Only report this event if we're at 100% or if more than 50ms has passed
// since the last time we reported a progress event.
if (components_downloaded_bytes_ != components_total_bytes_) {
base::TimeTicks current_time = base::TimeTicks::Now();
if (current_time - last_progress_time_ <= base::Milliseconds(50)) {
return;
}
last_progress_time_ = current_time;
}
int normalized_progress = AIUtils::NormalizeModelDownloadProgress(
components_downloaded_bytes_, components_total_bytes_);
// Don't report progress events we've already sent.
if (normalized_progress == last_reported_progress_) {
return;
}
CHECK_GT(normalized_progress, last_reported_progress_);
last_reported_progress_ = normalized_progress;
// Send the progress event to the observer.
observer_remote_->OnDownloadProgressUpdate(
normalized_progress, AIUtils::kNormalizedDownloadProgressMax);
}
} // namespace on_device_ai