blob: 9a50018e59a87b1b3724c01f3645acd6a6cb4a2f [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/speech/on_device_speech_recognition_engine_impl.h"
#include "base/containers/span.h"
#include "base/containers/to_vector.h"
#include "base/functional/bind.h"
#include "base/strings/utf_string_conversions.h"
#include "components/optimization_guide/core/model_execution/model_broker_client.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/browser_thread.h"
#include "content/public/browser/content_browser_client.h"
#include "content/public/browser/render_frame_host.h"
#include "content/public/browser/speech_recognition_session_config.h"
#include "content/public/common/content_client.h"
#include "media/base/audio_parameters.h"
#include "media/base/audio_timestamp_helper.h"
#include "media/mojo/mojom/audio_data.mojom.h"
#include "media/mojo/mojom/speech_recognizer.mojom.h"
namespace content {
OnDeviceSpeechRecognitionEngine::Core::Core(
on_device_model::mojom::AsrStreamResponder* responder)
: asr_stream_responder(responder) {}
OnDeviceSpeechRecognitionEngine::Core::~Core() = default;
void OnDeviceSpeechRecognitionEngine::Core::Deleter::operator()(
Core* core) const {
// This class is created on the UI thread and thus must be deleted on the UI
// thread.
if (!GetUIThreadTaskRunner({})->DeleteSoon(FROM_HERE, core)) {
delete core;
}
}
OnDeviceSpeechRecognitionEngine::OnDeviceSpeechRecognitionEngine(
const SpeechRecognitionSessionConfig& config)
: ui_task_runner_(GetUIThreadTaskRunner({})),
io_task_runner_(base::SequencedTaskRunner::GetCurrentDefault()),
config_(config),
core_(new Core(this)) {
ui_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&OnDeviceSpeechRecognitionEngine::CreateModelClientOnUI,
weak_factory_.GetWeakPtr(),
config_.initial_context.render_process_id,
config_.initial_context.render_frame_id));
}
OnDeviceSpeechRecognitionEngine::~OnDeviceSpeechRecognitionEngine() = default;
void OnDeviceSpeechRecognitionEngine::StartRecognition() {}
void OnDeviceSpeechRecognitionEngine::EndRecognition() {
DCHECK_CALLED_ON_VALID_SEQUENCE(main_sequence_checker_);
core_.reset();
asr_stream_.reset();
}
void OnDeviceSpeechRecognitionEngine::SetAudioParameters(
media::AudioParameters audio_parameters) {
SpeechRecognitionEngine::SetAudioParameters(audio_parameters);
ui_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&OnDeviceSpeechRecognitionEngine::CreateSessionOnUI,
weak_factory_.GetWeakPtr()));
}
void OnDeviceSpeechRecognitionEngine::TakeAudioChunk(const AudioChunk& data) {
DCHECK_CALLED_ON_VALID_SEQUENCE(main_sequence_checker_);
const size_t num_samples = data.NumSamples() * audio_parameters_.channels();
std::vector<int16_t> pcm_data_vector(num_samples);
auto source_bytes = base::as_bytes(base::span(data.AsString()));
auto dest_bytes = base::as_writable_bytes(base::span(pcm_data_vector));
CHECK_EQ(source_bytes.size(), dest_bytes.size());
dest_bytes.copy_from(source_bytes);
accumulated_audio_data_.insert(accumulated_audio_data_.end(),
pcm_data_vector.begin(),
pcm_data_vector.end());
// The duration of audio to accumulate before sending to the on-device
// service.
constexpr base::TimeDelta kAudioBufferDuration = base::Seconds(2);
if (media::AudioTimestampHelper::FramesToTime(
accumulated_audio_data_.size(), audio_parameters_.sample_rate()) >=
kAudioBufferDuration) {
if (asr_stream_.is_bound()) {
asr_stream_->AddAudioChunk(ConvertAccumulatedAudioData());
}
}
}
void OnDeviceSpeechRecognitionEngine::UpdateRecognitionContext(
const media::SpeechRecognitionRecognitionContext& recognition_context) {
// TODO(crbug.com/446260680): Implement biasing support with NanoV3.
}
void OnDeviceSpeechRecognitionEngine::OnResponse(
std::vector<on_device_model::mojom::SpeechRecognitionResultPtr> result) {
std::vector<media::mojom::WebSpeechRecognitionResultPtr> recognition_results;
for (const auto& r : result) {
auto web_speech_result = media::mojom::WebSpeechRecognitionResult::New();
web_speech_result->is_provisional = !r->is_final;
constexpr float kSpeechRecognitionConfidence = 1.0f;
web_speech_result->hypotheses.emplace_back(
media::mojom::SpeechRecognitionHypothesis::New(
base::UTF8ToUTF16(r->transcript), kSpeechRecognitionConfidence));
recognition_results.push_back(std::move(web_speech_result));
}
delegate_->OnSpeechRecognitionEngineResults(std::move(recognition_results));
}
void OnDeviceSpeechRecognitionEngine::AudioChunksEnded() {
DCHECK_CALLED_ON_VALID_SEQUENCE(main_sequence_checker_);
core_.reset();
asr_stream_.reset();
}
int OnDeviceSpeechRecognitionEngine::GetDesiredAudioChunkDurationMs() const {
constexpr int kAudioPacketIntervalMs = 100;
return kAudioPacketIntervalMs;
}
void OnDeviceSpeechRecognitionEngine::CreateModelClientOnUI(
int render_process_id,
int render_frame_id) {
RenderFrameHost* rfh =
RenderFrameHost::FromID(config_.initial_context.render_process_id,
config_.initial_context.render_frame_id);
if (!rfh) {
return;
}
core_->model_broker_client =
GetContentClient()->browser()->CreateModelBrokerClient(
rfh->GetBrowserContext());
if (core_->model_broker_client) {
core_->model_broker_client
->GetSubscriber(
optimization_guide::mojom::ModelBasedCapabilityKey::kPromptApi)
.WaitForClient(base::BindOnce(
&OnDeviceSpeechRecognitionEngine::OnModelClientAvailable,
weak_factory_.GetWeakPtr()));
}
}
void OnDeviceSpeechRecognitionEngine::OnModelClientAvailable(
base::WeakPtr<optimization_guide::ModelClient> client) {
core_->model_client = client;
CreateSessionOnUI();
}
void OnDeviceSpeechRecognitionEngine::CreateSessionOnUI() {
if (!core_ || !core_->model_client || !audio_parameters_.IsValid() ||
session_created_) {
return;
}
session_created_ = true;
auto params = on_device_model::mojom::SessionParams::New();
params->capabilities.Put(on_device_model::CapabilityFlags::kAudioInput);
core_->model_client->solution().CreateSession(
core_->session.BindNewPipeAndPassReceiver(), std::move(params));
auto asr_options = on_device_model::mojom::AsrStreamOptions::New();
asr_options->sample_rate_hz = audio_parameters_.sample_rate();
mojo::PendingRemote<on_device_model::mojom::AsrStreamInput> remote;
core_->session->AsrStream(
std::move(asr_options), remote.InitWithNewPipeAndPassReceiver(),
core_->asr_stream_responder.BindNewPipeAndPassRemote());
io_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&OnDeviceSpeechRecognitionEngine::OnAsrStreamCreated,
weak_factory_.GetWeakPtr(), std::move(remote)));
}
void OnDeviceSpeechRecognitionEngine::OnAsrStreamCreated(
mojo::PendingRemote<on_device_model::mojom::AsrStreamInput> remote) {
DCHECK_CALLED_ON_VALID_SEQUENCE(main_sequence_checker_);
asr_stream_.Bind(std::move(remote));
}
void OnDeviceSpeechRecognitionEngine::OnRecognizerDisconnected() {
DCHECK_CALLED_ON_VALID_SEQUENCE(main_sequence_checker_);
EndRecognition();
}
on_device_model::mojom::AudioDataPtr
OnDeviceSpeechRecognitionEngine::ConvertAccumulatedAudioData() {
CHECK_EQ(audio_parameters_.channels(), 1);
auto signed_buffer = on_device_model::mojom::AudioData::New();
signed_buffer->channel_count = audio_parameters_.channels();
signed_buffer->sample_rate = audio_parameters_.sample_rate();
signed_buffer->frame_count = accumulated_audio_data_.size();
// Normalization factor for converting int16_t audio samples to float.
constexpr float kInt16ToFloatNormalizer = 32768.0f;
signed_buffer->data = base::ToVector(accumulated_audio_data_, [](int16_t s) {
return s / kInt16ToFloatNormalizer;
});
accumulated_audio_data_.clear();
return signed_buffer;
}
} // namespace content