| // Copyright 2021 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chromeos/services/libassistant/audio/audio_input_impl.h" |
| |
| #include <cstdint> |
| #include <utility> |
| |
| #include "base/bind.h" |
| #include "base/command_line.h" |
| #include "base/containers/cxx20_erase.h" |
| #include "base/logging.h" |
| #include "base/memory/weak_ptr.h" |
| #include "base/metrics/histogram_functions.h" |
| #include "base/strings/string_util.h" |
| #include "base/timer/timer.h" |
| #include "chromeos/services/assistant/public/cpp/assistant_browser_delegate.h" |
| #include "chromeos/services/assistant/public/cpp/features.h" |
| #include "chromeos/services/libassistant/audio/audio_input_stream.h" |
| #include "libassistant/shared/public/platform_audio_buffer.h" |
| #include "media/audio/audio_device_description.h" |
| #include "media/base/audio_parameters.h" |
| #include "media/base/audio_sample_types.h" |
| #include "media/base/channel_layout.h" |
| #include "media/mojo/mojom/audio_stream_factory.mojom.h" |
| #include "services/audio/public/cpp/device_factory.h" |
| #include "third_party/abseil-cpp/absl/types/optional.h" |
| |
| namespace chromeos { |
| namespace libassistant { |
| |
| namespace { |
| |
| constexpr assistant_client::BufferFormat kFormatMono{ |
| 16000 /* sample_rate */, assistant_client::INTERLEAVED_S16, 1 /* channels */ |
| }; |
| |
| constexpr assistant_client::BufferFormat kFormatStereo{ |
| 44100 /* sample_rate */, assistant_client::INTERLEAVED_S16, 2 /* channels */ |
| }; |
| |
| assistant_client::BufferFormat g_current_format = kFormatMono; |
| |
| class DspHotwordStateManager : public AudioInputImpl::HotwordStateManager { |
| public: |
| explicit DspHotwordStateManager(AudioInputImpl* input) |
| : AudioInputImpl::HotwordStateManager(input) {} |
| |
| DspHotwordStateManager(const DspHotwordStateManager&) = delete; |
| DspHotwordStateManager& operator=(const DspHotwordStateManager&) = delete; |
| |
| // HotwordStateManager overrides: |
| void OnConversationTurnStarted() override { |
| if (second_phase_timer_.IsRunning()) { |
| DCHECK(stream_state_ == StreamState::HOTWORD); |
| second_phase_timer_.Stop(); |
| } else { |
| // Handles user click on mic button. |
| input_->RecreateAudioInputStream(false /* use_dsp */); |
| } |
| stream_state_ = StreamState::NORMAL; |
| } |
| |
| void OnConversationTurnFinished() override { |
| input_->RecreateAudioInputStream(true /* use_dsp */); |
| if (stream_state_ == StreamState::HOTWORD) { |
| // If |stream_state_| remains unchanged, that indicates the first stage |
| // DSP hotword detection was rejected by Libassistant. |
| RecordDspHotwordDetection(DspHotwordDetectionStatus::SOFTWARE_REJECTED); |
| } |
| stream_state_ = StreamState::HOTWORD; |
| } |
| |
| void OnCaptureDataArrived() override { |
| if (stream_state_ == StreamState::HOTWORD && |
| !second_phase_timer_.IsRunning()) { |
| RecordDspHotwordDetection(DspHotwordDetectionStatus::HARDWARE_ACCEPTED); |
| // 1s from now, if OnConversationTurnStarted is not called, we assume that |
| // libassistant has rejected the hotword supplied by DSP. Thus, we reset |
| // and reopen the device on hotword state. |
| second_phase_timer_.Start( |
| FROM_HERE, base::TimeDelta::FromSeconds(1), |
| base::BindRepeating( |
| &DspHotwordStateManager::OnConversationTurnFinished, |
| base::Unretained(this))); |
| } |
| } |
| |
| void RecreateAudioInputStream() override { |
| input_->RecreateAudioInputStream(stream_state_ == StreamState::HOTWORD); |
| } |
| |
| private: |
| enum class StreamState { |
| HOTWORD, |
| NORMAL, |
| }; |
| |
| // Defines possible detection states of Dsp hotword. These values are |
| // persisted to logs. Entries should not be renumbered and numeric values |
| // should never be reused. Only append to this enum is allowed if the possible |
| // source grows. |
| enum class DspHotwordDetectionStatus { |
| HARDWARE_ACCEPTED = 0, |
| SOFTWARE_REJECTED = 1, |
| kMaxValue = SOFTWARE_REJECTED |
| }; |
| |
| // Helper function to record UMA metrics for Dsp hotword detection. |
| void RecordDspHotwordDetection(DspHotwordDetectionStatus status) { |
| base::UmaHistogramEnumeration("Assistant.DspHotwordDetection", status); |
| } |
| |
| StreamState stream_state_ = StreamState::HOTWORD; |
| base::OneShotTimer second_phase_timer_; |
| }; |
| |
| class AudioInputBufferImpl : public assistant_client::AudioBuffer { |
| public: |
| AudioInputBufferImpl(std::vector<int16_t>&& data, uint32_t frame_count) |
| : data_(std::move(data)), frame_count_(frame_count) {} |
| AudioInputBufferImpl(const AudioInputBufferImpl&) = delete; |
| AudioInputBufferImpl& operator=(const AudioInputBufferImpl&) = delete; |
| AudioInputBufferImpl(AudioInputBufferImpl&&) = default; |
| AudioInputBufferImpl& operator=(AudioInputBufferImpl&&) = default; |
| ~AudioInputBufferImpl() override = default; |
| |
| // assistant_client::AudioBuffer overrides: |
| assistant_client::BufferFormat GetFormat() const override { |
| return g_current_format; |
| } |
| const void* GetData() const override { return data_.data(); } |
| void* GetWritableData() override { |
| NOTREACHED(); |
| return nullptr; |
| } |
| int GetFrameCount() const override { return frame_count_; } |
| |
| private: |
| std::vector<int16_t> data_; |
| int frame_count_; |
| }; |
| |
| AudioInputBufferImpl ToAudioInputBuffer(const media::AudioBus* audio_source) { |
| std::vector<int16_t> buffer(audio_source->channels() * |
| audio_source->frames()); |
| audio_source->ToInterleaved<media::SignedInt16SampleTypeTraits>( |
| audio_source->frames(), buffer.data()); |
| return AudioInputBufferImpl(std::move(buffer), audio_source->frames()); |
| } |
| |
| } // namespace |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // AudioCapturer |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| // Helper class that will receive the callbacks from the audio source, |
| // and forward the audio data to Libassistant. |
| // Note that all callback methods in this object run on the audio service |
| // thread, so this class should be treated carefully. |
| // All public methods can be called from other threads, and the |
| // |on_capture_callback| will be invoked on the given callback thread. |
| class AudioCapturer : public media::AudioCapturerSource::CaptureCallback { |
| public: |
| explicit AudioCapturer( |
| base::RepeatingCallback<void()> on_capture_callback, |
| scoped_refptr<base::SequencedTaskRunner> callback_task_runner) |
| : on_capture_callback_(on_capture_callback), |
| callback_task_runner_(callback_task_runner) {} |
| AudioCapturer(const AudioCapturer&) = delete; |
| AudioCapturer& operator=(const AudioCapturer&) = delete; |
| ~AudioCapturer() override = default; |
| |
| void AddObserver(assistant_client::AudioInput::Observer* observer) { |
| base::AutoLock lock(observers_lock_); |
| observers_.push_back(observer); |
| } |
| |
| void RemoveObserver(assistant_client::AudioInput::Observer* observer) { |
| base::AutoLock lock(observers_lock_); |
| base::Erase(observers_, observer); |
| } |
| |
| int num_observers() { |
| base::AutoLock lock(observers_lock_); |
| return observers_.size(); |
| } |
| |
| int captured_frames_count() { return captured_frames_count_; } |
| |
| private: |
| // media::AudioCapturerSource::CaptureCallback implementation: |
| // Runs on audio service thread. |
| void Capture(const media::AudioBus* audio_source, |
| base::TimeTicks audio_capture_time, |
| double volume, |
| bool key_pressed) override { |
| DCHECK_EQ(g_current_format.num_channels, audio_source->channels()); |
| |
| callback_task_runner_->PostTask(FROM_HERE, on_capture_callback_); |
| |
| UpdateCapturedFramesCount(audio_source->frames()); |
| |
| AudioInputBufferImpl input_buffer(ToAudioInputBuffer(audio_source)); |
| int64_t time = ToLibassistantTime(audio_capture_time); |
| |
| base::AutoLock lock(observers_lock_); |
| for (auto* observer : observers_) |
| observer->OnAudioBufferAvailable(input_buffer, time); |
| } |
| |
| // Runs on audio service thread. |
| void OnCaptureError(media::AudioCapturerSource::ErrorCode code, |
| const std::string& message) override { |
| LOG(ERROR) << "Capture error " << message |
| << ", code=" << static_cast<uint32_t>(code); |
| base::AutoLock lock(observers_lock_); |
| for (auto* observer : observers_) |
| observer->OnAudioError(assistant_client::AudioInput::Error::FATAL_ERROR); |
| } |
| |
| // Runs on audio service thread. |
| void OnCaptureMuted(bool is_muted) override {} |
| |
| int64_t ToLibassistantTime(base::TimeTicks audio_capture_time) const { |
| // Only provide accurate timestamp when eraser is enabled, otherwise it |
| // seems break normal libassistant voice recognition. |
| if (assistant::features::IsAudioEraserEnabled()) |
| return audio_capture_time.since_origin().InMicroseconds(); |
| return 0; |
| } |
| |
| void UpdateCapturedFramesCount(int num_arrived_frames) { |
| captured_frames_count_ += num_arrived_frames; |
| if (VLOG_IS_ON(1)) { |
| auto now = base::TimeTicks::Now(); |
| if ((now - last_frame_count_report_time_) > |
| base::TimeDelta::FromMinutes(2)) { |
| VLOG(1) << "Captured frames: " << captured_frames_count_; |
| last_frame_count_report_time_ = now; |
| } |
| } |
| } |
| |
| // This is the total number of frames captured during the life time of this |
| // object. We don't worry about overflow because this count is only used for |
| // logging purposes. If in the future this changes, we should re-evaluate. |
| int captured_frames_count_ = 0; |
| base::TimeTicks last_frame_count_report_time_; |
| |
| base::Lock observers_lock_; |
| std::vector<assistant_client::AudioInput::Observer*> observers_ |
| GUARDED_BY(observers_lock_); |
| |
| // |on_capture_callback| must always be called from the main thread. |
| base::RepeatingCallback<void()> on_capture_callback_; |
| scoped_refptr<base::SequencedTaskRunner> callback_task_runner_; |
| }; |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // AudioInputImpl |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| AudioInputImpl::HotwordStateManager::HotwordStateManager( |
| AudioInputImpl* audio_input) |
| : input_(audio_input) {} |
| |
| void AudioInputImpl::HotwordStateManager::RecreateAudioInputStream() { |
| input_->RecreateAudioInputStream(/*use_dsp=*/false); |
| } |
| |
| AudioInputImpl::AudioInputImpl(const absl::optional<std::string>& device_id) |
| : task_runner_(base::SequencedTaskRunnerHandle::Get()), |
| preferred_device_id_(device_id), |
| weak_factory_(this) { |
| DETACH_FROM_SEQUENCE(observer_sequence_checker_); |
| |
| audio_capturer_ = std::make_unique<AudioCapturer>( |
| base::BindRepeating(&AudioInputImpl::OnCaptureDataArrived, |
| weak_factory_.GetWeakPtr()), |
| /*callback_task_runner=*/base::SequencedTaskRunnerHandle::Get()); |
| |
| RecreateStateManager(); |
| if (assistant::features::IsStereoAudioInputEnabled()) |
| g_current_format = kFormatStereo; |
| else |
| g_current_format = kFormatMono; |
| } |
| |
| AudioInputImpl::~AudioInputImpl() { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| StopRecording(); |
| } |
| |
| void AudioInputImpl::RecreateStateManager() { |
| if (IsHotwordAvailable()) { |
| state_manager_ = std::make_unique<DspHotwordStateManager>(this); |
| } else { |
| state_manager_ = std::make_unique<HotwordStateManager>(this); |
| } |
| } |
| |
| void AudioInputImpl::OnCaptureDataArrived() { |
| state_manager_->OnCaptureDataArrived(); |
| } |
| |
| void AudioInputImpl::Initialize(mojom::PlatformDelegate* platform_delegate) { |
| platform_delegate_ = platform_delegate; |
| DCHECK(platform_delegate_); |
| UpdateRecordingState(); |
| } |
| |
| // Run on LibAssistant thread. |
| assistant_client::BufferFormat AudioInputImpl::GetFormat() const { |
| return g_current_format; |
| } |
| |
| // Run on LibAssistant thread. |
| void AudioInputImpl::AddObserver( |
| assistant_client::AudioInput::Observer* observer) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(observer_sequence_checker_); |
| VLOG(1) << "Add observer"; |
| |
| audio_capturer_->AddObserver(observer); |
| |
| if (audio_capturer_->num_observers() == 1) { |
| // Post to main thread runner to start audio recording. Assistant thread |
| // does not have thread context defined in //base and will fail sequence |
| // check in AudioCapturerSource::Start(). |
| task_runner_->PostTask(FROM_HERE, |
| base::BindOnce(&AudioInputImpl::UpdateRecordingState, |
| weak_factory_.GetWeakPtr())); |
| } |
| } |
| |
| // Run on LibAssistant thread. |
| void AudioInputImpl::RemoveObserver( |
| assistant_client::AudioInput::Observer* observer) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(observer_sequence_checker_); |
| VLOG(1) << "Remove observer"; |
| |
| audio_capturer_->RemoveObserver(observer); |
| |
| if (audio_capturer_->num_observers() == 0) { |
| task_runner_->PostTask(FROM_HERE, |
| base::BindOnce(&AudioInputImpl::UpdateRecordingState, |
| weak_factory_.GetWeakPtr())); |
| |
| // Reset the sequence checker since assistant may call from different thread |
| // after restart. |
| DETACH_FROM_SEQUENCE(observer_sequence_checker_); |
| } |
| } |
| |
| void AudioInputImpl::SetMicState(bool mic_open) { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| if (mic_open_ == mic_open) |
| return; |
| |
| mic_open_ = mic_open; |
| UpdateRecordingState(); |
| } |
| |
| void AudioInputImpl::OnConversationTurnStarted() { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| state_manager_->OnConversationTurnStarted(); |
| } |
| |
| void AudioInputImpl::OnConversationTurnFinished() { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| state_manager_->OnConversationTurnFinished(); |
| } |
| |
| void AudioInputImpl::OnHotwordEnabled(bool enable) { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| |
| if (hotword_enabled_ == enable) |
| return; |
| |
| hotword_enabled_ = enable; |
| UpdateRecordingState(); |
| } |
| |
| void AudioInputImpl::SetDeviceId(const absl::optional<std::string>& device_id) { |
| DVLOG(1) << "Set audio input preferred_device_id to " |
| << device_id.value_or("<null>"); |
| auto new_device_id = device_id; |
| |
| constexpr char kAssistantForceDefaultAudioInput[] = |
| "assistant-force-default-audio-input"; |
| auto* command_line = base::CommandLine::ForCurrentProcess(); |
| if (command_line->HasSwitch(kAssistantForceDefaultAudioInput)) { |
| // Sometimes there may not be a preferred audio device, |
| // e.g. if the device does not have built-in mic and using a bluetooth |
| // microphone, in this case we do not want to open the bluetooth device by |
| // default to drain the battery; also if running linux chromeos chrome |
| // build, there won't be cras and we won't have a device id set. Force using |
| // default audio input in these cases to mimic the common Assistant hotword |
| // behaviors. |
| DVLOG(1) << "Force audio input preferred_device_id to default."; |
| new_device_id = media::AudioDeviceDescription::kDefaultDeviceId; |
| } |
| |
| if (preferred_device_id_ == new_device_id) |
| return; |
| |
| preferred_device_id_ = new_device_id; |
| |
| UpdateRecordingState(); |
| if (HasOpenAudioStream()) |
| state_manager_->RecreateAudioInputStream(); |
| } |
| |
| void AudioInputImpl::SetHotwordDeviceId( |
| const absl::optional<std::string>& device_id) { |
| if (hotword_device_id_ == device_id) |
| return; |
| |
| hotword_device_id_ = device_id; |
| RecreateStateManager(); |
| if (HasOpenAudioStream()) |
| state_manager_->RecreateAudioInputStream(); |
| } |
| |
| void AudioInputImpl::OnLidStateChanged(mojom::LidState new_state) { |
| // Lid switch event still gets fired during system suspend, which enables |
| // us to stop DSP recording correctly when user closes lid after the device |
| // goes to sleep. |
| if (new_state != lid_state_) { |
| lid_state_ = new_state; |
| UpdateRecordingState(); |
| } |
| } |
| |
| void AudioInputImpl::RecreateAudioInputStream(bool use_dsp) { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| StopRecording(); |
| |
| open_audio_stream_ = std::make_unique<AudioInputStream>( |
| platform_delegate_, GetDeviceId(use_dsp), |
| ShouldEnableDeadStreamDetection(use_dsp), GetFormat(), |
| /*capture_callback=*/audio_capturer_.get()); |
| |
| VLOG(1) << open_audio_stream_->device_id() << " start recording"; |
| } |
| |
| bool AudioInputImpl::IsHotwordAvailable() const { |
| return assistant::features::IsDspHotwordEnabled() && |
| hotword_device_id_.has_value(); |
| } |
| |
| bool AudioInputImpl::IsRecordingForTesting() const { |
| return HasOpenAudioStream(); |
| } |
| |
| bool AudioInputImpl::IsUsingHotwordDeviceForTesting() const { |
| return IsRecordingForTesting() // IN-TEST |
| && GetOpenDeviceId() == hotword_device_id_ && IsHotwordAvailable(); |
| } |
| |
| bool AudioInputImpl::IsMicOpenForTesting() const { |
| return mic_open_; |
| } |
| |
| absl::optional<std::string> AudioInputImpl::GetOpenDeviceIdForTesting() const { |
| return GetOpenDeviceId(); |
| } |
| |
| absl::optional<bool> AudioInputImpl::IsUsingDeadStreamDetectionForTesting() |
| const { |
| if (!open_audio_stream_) |
| return absl::nullopt; |
| return open_audio_stream_->has_dead_stream_detection(); |
| } |
| |
| void AudioInputImpl::StartRecording() { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| DCHECK(!HasOpenAudioStream()); |
| RecreateAudioInputStream(IsHotwordAvailable()); |
| } |
| |
| void AudioInputImpl::StopRecording() { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| if (open_audio_stream_) { |
| VLOG(1) << open_audio_stream_->device_id() << " stop recording"; |
| VLOG(1) << open_audio_stream_->device_id() << " ending captured frames: " |
| << audio_capturer_->captured_frames_count(); |
| open_audio_stream_.reset(); |
| } |
| } |
| |
| void AudioInputImpl::UpdateRecordingState() { |
| DCHECK(task_runner_->RunsTasksInCurrentSequence()); |
| |
| bool has_observers = (audio_capturer_->num_observers() > 0); |
| bool is_lid_closed = (lid_state_ == mojom::LidState::kClosed); |
| bool should_enable_hotword = |
| hotword_enabled_ && preferred_device_id_.has_value(); |
| bool has_delegate = (platform_delegate_ != nullptr); |
| bool should_start = !is_lid_closed && (should_enable_hotword || mic_open_) && |
| has_observers && has_delegate; |
| |
| VLOG(1) << "UpdateRecordingState: " |
| << " is_lid_closed: " << is_lid_closed << "\n" |
| << " hotword_enabled: " << hotword_enabled_ << "\n" |
| << " preferred_device_id: '" |
| << preferred_device_id_.value_or("<unset>") << "'\n" |
| << " hotword_device_id: '" |
| << hotword_device_id_.value_or("<unset>") << "'\n" |
| << " mic_open: " << mic_open_ << "\n" |
| << " has_observers: " << has_observers << "\n" |
| << " has_delegate: " << has_delegate << "\n" |
| << " => should_start: " << should_start; |
| |
| if (!HasOpenAudioStream() && should_start) |
| StartRecording(); |
| else if (HasOpenAudioStream() && !should_start) |
| StopRecording(); |
| } |
| |
| std::string AudioInputImpl::GetDeviceId(bool use_dsp) const { |
| if (use_dsp && hotword_device_id_.has_value()) |
| return hotword_device_id_.value(); |
| else if (preferred_device_id_.has_value()) |
| return preferred_device_id_.value(); |
| else |
| return media::AudioDeviceDescription::kDefaultDeviceId; |
| } |
| |
| absl::optional<std::string> AudioInputImpl::GetOpenDeviceId() const { |
| if (!open_audio_stream_) |
| return absl::nullopt; |
| return open_audio_stream_->device_id(); |
| } |
| |
| bool AudioInputImpl::ShouldEnableDeadStreamDetection(bool use_dsp) const { |
| if (use_dsp && hotword_device_id_.has_value()) { |
| // The DSP device won't provide data until it detects a hotword, so |
| // we disable its the dead stream detection. |
| return false; |
| } |
| return true; |
| } |
| |
| bool AudioInputImpl::HasOpenAudioStream() const { |
| return open_audio_stream_ != nullptr; |
| } |
| |
| } // namespace libassistant |
| } // namespace chromeos |