| // Copyright 2025 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <memory> |
| #include <utility> |
| #include <vector> |
| |
| #include "base/files/file.h" |
| #include "base/functional/bind.h" |
| #include "base/logging.h" |
| #include "base/sequence_checker.h" |
| #include "base/task/task_traits.h" |
| #include "base/task/thread_pool.h" |
| #include "build/build_config.h" |
| #include "services/audio/ml_model_manager.h" |
| #include "third_party/tflite/src/tensorflow/lite/model_builder.h" |
| |
| namespace audio { |
| |
| // Holds the TFLite model and the buffer that backs it. |
| // The buffer must outlive the model. |
| struct ModelWithBuffer { |
| explicit ModelWithBuffer(size_t buffer_size) : buffer(buffer_size) {} |
| ~ModelWithBuffer() = default; |
| |
| std::vector<uint8_t> buffer; |
| std::unique_ptr<tflite::FlatBufferModel> model; |
| int num_active_clients = 0; |
| }; |
| |
| namespace { |
| |
| // Reads the model contents from the given base::File. |
| // This function is intended to run on a background thread. |
| // Returns a struct containing the model and its backing buffer. |
| std::unique_ptr<ModelWithBuffer> ReadModelContents(base::File model_file) { |
| if (!model_file.IsValid()) { |
| LOG(ERROR) << "Invalid model file."; |
| return nullptr; |
| } |
| int64_t length = model_file.GetLength(); |
| if (length <= 0) { |
| LOG(ERROR) << "Invalid model file length."; |
| return nullptr; |
| } |
| auto model_with_buffer = std::make_unique<ModelWithBuffer>(length); |
| if (!model_file.ReadAndCheck(0, model_with_buffer->buffer)) { |
| LOG(ERROR) << "Failed to read model file contents."; |
| return nullptr; |
| } |
| model_with_buffer->model = tflite::FlatBufferModel::BuildFromBuffer( |
| reinterpret_cast<const char*>(model_with_buffer->buffer.data()), |
| model_with_buffer->buffer.size()); |
| |
| if (!model_with_buffer->model) { |
| LOG(ERROR) << "Failed to build FlatBufferModel from buffer."; |
| return nullptr; |
| } |
| return model_with_buffer; |
| } |
| |
| class MlModelHandleImpl : public MlModelHandle { |
| public: |
| MlModelHandleImpl(tflite::FlatBufferModel* model, |
| base::OnceClosure on_destruction_closure) |
| : model_(model), |
| on_destruction_closure_(std::move(on_destruction_closure)) { |
| CHECK(model); |
| } |
| ~MlModelHandleImpl() override { |
| // Explicitly reset the model pointer as the pointed-to memory may be |
| // affected by the destruction closure. |
| model_ = nullptr; |
| std::move(on_destruction_closure_).Run(); |
| } |
| const tflite::FlatBufferModel* Get() override { return model_; } |
| |
| private: |
| raw_ptr<tflite::FlatBufferModel> model_; |
| base::OnceClosure on_destruction_closure_; |
| }; |
| |
| } // namespace |
| |
| MlModelManagerImpl::MlModelManagerImpl() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| } |
| |
| MlModelManagerImpl::~MlModelManagerImpl() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| CHECK(!used_serving_model_ && retired_models_.size() == 0) |
| << "MlModelManagerImpl has existing clients at destruction time"; |
| } |
| |
| void MlModelManagerImpl::BindReceiver( |
| mojo::PendingReceiver<mojom::MlModelManager> receiver) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| CHECK(!receiver_.has_value()); |
| receiver_.emplace(this, std::move(receiver)); |
| } |
| |
| void MlModelManagerImpl::OnResidualEchoEstimationModelRead( |
| std::unique_ptr<ModelWithBuffer> model_with_buffer) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!model_with_buffer) { |
| // Model load failed, ignore. |
| return; |
| } |
| if (used_serving_model_) { |
| retired_models_.emplace(used_serving_model_.get(), |
| std::move(used_serving_model_)); |
| } |
| unused_serving_model_ = std::move(model_with_buffer); |
| } |
| |
| void MlModelManagerImpl::StopServingResidualEchoEstimationModel() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| // Stop any ongoing model loading: This stop signal makes it obsolete. |
| CancelModelLoadingTasks(); |
| |
| if (used_serving_model_) { |
| retired_models_.emplace(used_serving_model_.get(), |
| std::move(used_serving_model_)); |
| } |
| unused_serving_model_.reset(); |
| used_serving_model_.reset(); |
| } |
| |
| void MlModelManagerImpl::SetResidualEchoEstimationModel( |
| base::File tflite_file) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| // Stop loading any older models: |
| // - They are soon replaced with this new file, and |
| // - we don't want races due to different file operation durations. |
| CancelModelLoadingTasks(); |
| |
| base::ThreadPool::PostTaskAndReplyWithResult( |
| FROM_HERE, {base::MayBlock(), base::TaskPriority::BEST_EFFORT}, |
| base::BindOnce(&ReadModelContents, std::move(tflite_file)), |
| base::BindOnce(&MlModelManagerImpl::OnResidualEchoEstimationModelRead, |
| weak_factory_.GetWeakPtr())); |
| } |
| |
| void MlModelManagerImpl::OnModelHandleDestruction(ModelWithBuffer* model) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| CHECK(model); |
| // Find the model, and update the client count. |
| |
| if (model == used_serving_model_.get()) { |
| // The model is actively serving. |
| CHECK_GT(used_serving_model_->num_active_clients, 0); |
| --used_serving_model_->num_active_clients; |
| if (used_serving_model_->num_active_clients == 0) { |
| unused_serving_model_ = std::move(used_serving_model_); |
| } |
| return; |
| } |
| // If we get here, the model is one of the retired models. |
| auto iter = retired_models_.find(model); |
| CHECK(iter != retired_models_.end()); |
| ModelWithBuffer& retired_model = *(*iter).second; |
| CHECK_GT(retired_model.num_active_clients, 0); |
| --(retired_model.num_active_clients); |
| if (retired_model.num_active_clients == 0) { |
| // All clients are gone, the model can be deleted. |
| retired_models_.erase(iter); |
| } |
| } |
| |
| std::unique_ptr<MlModelHandle> |
| MlModelManagerImpl::GetResidualEchoEstimationModel() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (!unused_serving_model_ && !used_serving_model_) { |
| return nullptr; |
| } |
| if (unused_serving_model_) { |
| used_serving_model_ = std::move(unused_serving_model_); |
| } |
| ++(used_serving_model_->num_active_clients); |
| return std::make_unique<MlModelHandleImpl>( |
| used_serving_model_->model.get(), |
| base::BindOnce(&MlModelManagerImpl::OnModelHandleDestruction, |
| // Safe because the MlModelManager API requires clients to |
| // destroy their model handles within the manager lifetime. |
| base::Unretained(this), used_serving_model_.get())); |
| } |
| |
| void MlModelManagerImpl::CancelModelLoadingTasks() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| weak_factory_.InvalidateWeakPtrs(); |
| } |
| |
| bool MlModelManagerImpl::HasPendingTasksForTesting() const { |
| return weak_factory_.HasWeakPtrs(); |
| } |
| |
| } // namespace audio |