| // Copyright 2018 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/cast_channel/cast_message_handler.h" |
| |
| #include <tuple> |
| #include <utility> |
| #include <vector> |
| |
| #include "base/bind.h" |
| #include "base/rand_util.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/time/default_tick_clock.h" |
| #include "components/cast_channel/cast_socket_service.h" |
| #include "services/data_decoder/public/cpp/safe_json_parser.h" |
| #include "services/service_manager/public/cpp/connector.h" |
| |
| namespace cast_channel { |
| |
| namespace { |
| |
| // The max launch timeout amount for session launch requests. |
| constexpr base::TimeDelta kLaunchMaxTimeout = base::TimeDelta::FromMinutes(2); |
| |
| void ReportParseError(const std::string& error) { |
| DVLOG(2) << "Error parsing JSON message: " << error; |
| } |
| |
| } // namespace |
| |
| GetAppAvailabilityRequest::GetAppAvailabilityRequest( |
| int request_id, |
| GetAppAvailabilityCallback callback, |
| const base::TickClock* clock, |
| const std::string& app_id) |
| : PendingRequest(request_id, std::move(callback), clock), app_id(app_id) {} |
| |
| GetAppAvailabilityRequest::~GetAppAvailabilityRequest() = default; |
| |
| VirtualConnection::VirtualConnection(int channel_id, |
| const std::string& source_id, |
| const std::string& destination_id) |
| : channel_id(channel_id), |
| source_id(source_id), |
| destination_id(destination_id) {} |
| VirtualConnection::~VirtualConnection() = default; |
| |
| bool VirtualConnection::operator<(const VirtualConnection& other) const { |
| return std::tie(channel_id, source_id, destination_id) < |
| std::tie(other.channel_id, other.source_id, other.destination_id); |
| } |
| |
| InternalMessage::InternalMessage(CastMessageType type, base::Value message) |
| : type(type), message(std::move(message)) {} |
| InternalMessage::~InternalMessage() = default; |
| |
| CastMessageHandler::CastMessageHandler( |
| CastSocketService* socket_service, |
| std::unique_ptr<service_manager::Connector> connector, |
| const base::Token& data_decoder_batch_id, |
| const std::string& user_agent, |
| const std::string& browser_version, |
| const std::string& locale) |
| : sender_id_(base::StringPrintf("sender-%d", base::RandInt(0, 1000000))), |
| connector_(std::move(connector)), |
| data_decoder_batch_id_(data_decoder_batch_id), |
| user_agent_(user_agent), |
| browser_version_(browser_version), |
| locale_(locale), |
| socket_service_(socket_service), |
| clock_(base::DefaultTickClock::GetInstance()), |
| weak_ptr_factory_(this) { |
| DETACH_FROM_SEQUENCE(sequence_checker_); |
| socket_service_->task_runner()->PostTask( |
| FROM_HERE, base::BindOnce(&CastSocketService::AddObserver, |
| base::Unretained(socket_service_), |
| base::Unretained(this))); |
| } |
| |
| CastMessageHandler::~CastMessageHandler() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| socket_service_->RemoveObserver(this); |
| } |
| |
| void CastMessageHandler::EnsureConnection(int channel_id, |
| const std::string& source_id, |
| const std::string& destination_id) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| CastSocket* socket = socket_service_->GetSocket(channel_id); |
| if (!socket) { |
| DVLOG(2) << "Socket not found: " << channel_id; |
| return; |
| } |
| |
| DoEnsureConnection(socket, source_id, destination_id); |
| } |
| |
| CastMessageHandler::PendingRequests* |
| CastMessageHandler::GetOrCreatePendingRequests(int channel_id) { |
| CastMessageHandler::PendingRequests* requests = nullptr; |
| auto pending_it = pending_requests_.find(channel_id); |
| if (pending_it != pending_requests_.end()) { |
| return pending_it->second.get(); |
| } |
| |
| auto new_requests = std::make_unique<CastMessageHandler::PendingRequests>(); |
| requests = new_requests.get(); |
| pending_requests_.emplace(channel_id, std::move(new_requests)); |
| return requests; |
| } |
| |
| void CastMessageHandler::RequestAppAvailability( |
| CastSocket* socket, |
| const std::string& app_id, |
| GetAppAvailabilityCallback callback) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| int channel_id = socket->id(); |
| auto* requests = GetOrCreatePendingRequests(channel_id); |
| int request_id = NextRequestId(); |
| |
| DVLOG(2) << __func__ << ", channel_id: " << channel_id |
| << ", app_id: " << app_id << ", request_id: " << request_id; |
| if (requests->AddAppAvailabilityRequest( |
| std::make_unique<GetAppAvailabilityRequest>( |
| request_id, std::move(callback), clock_, app_id))) { |
| SendCastMessage(socket, CreateGetAppAvailabilityRequest( |
| sender_id_, request_id, app_id)); |
| } |
| } |
| |
| void CastMessageHandler::RequestReceiverStatus(int channel_id) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| CastSocket* socket = socket_service_->GetSocket(channel_id); |
| if (!socket) { |
| DVLOG(2) << __func__ << ": socket not found: " << channel_id; |
| return; |
| } |
| |
| int request_id = NextRequestId(); |
| SendCastMessage(socket, CreateReceiverStatusRequest(sender_id_, request_id)); |
| } |
| |
| void CastMessageHandler::SendBroadcastMessage( |
| int channel_id, |
| const std::vector<std::string>& app_ids, |
| const BroadcastRequest& request) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| CastSocket* socket = socket_service_->GetSocket(channel_id); |
| if (!socket) { |
| DVLOG(2) << __func__ << ": socket not found: " << channel_id; |
| return; |
| } |
| |
| int request_id = NextRequestId(); |
| DVLOG(2) << __func__ << ", channel_id: " << channel_id |
| << ", request_id: " << request_id; |
| |
| // Note: Even though the message is formatted like a request, we don't care |
| // about the response, as broadcasts are fire-and-forget. |
| CastMessage message = |
| CreateBroadcastRequest(sender_id_, request_id, app_ids, request); |
| SendCastMessage(socket, message); |
| } |
| |
| void CastMessageHandler::LaunchSession(int channel_id, |
| const std::string& app_id, |
| base::TimeDelta launch_timeout, |
| LaunchSessionCallback callback) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| CastSocket* socket = socket_service_->GetSocket(channel_id); |
| if (!socket) { |
| DVLOG(2) << __func__ << ": socket not found: " << channel_id; |
| std::move(callback).Run(LaunchSessionResponse()); |
| return; |
| } |
| |
| auto* requests = GetOrCreatePendingRequests(channel_id); |
| int request_id = NextRequestId(); |
| // Cap the max launch timeout to avoid long-living pending requests. |
| launch_timeout = std::min(launch_timeout, kLaunchMaxTimeout); |
| DVLOG(2) << __func__ << ", channel_id: " << channel_id |
| << ", request_id: " << request_id; |
| if (requests->AddLaunchRequest(std::make_unique<LaunchSessionRequest>( |
| request_id, std::move(callback), clock_), |
| launch_timeout)) { |
| SendCastMessage( |
| socket, CreateLaunchRequest(sender_id_, request_id, app_id, locale_)); |
| } |
| } |
| |
| void CastMessageHandler::StopSession( |
| int channel_id, |
| const std::string& session_id, |
| const base::Optional<std::string>& client_id, |
| ResultCallback callback) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| CastSocket* socket = socket_service_->GetSocket(channel_id); |
| if (!socket) { |
| DVLOG(2) << __func__ << ": socket not found: " << channel_id; |
| return; |
| } |
| |
| auto* requests = GetOrCreatePendingRequests(channel_id); |
| int request_id = NextRequestId(); |
| DVLOG(2) << __func__ << ", channel_id: " << channel_id |
| << ", request_id: " << request_id; |
| if (requests->AddStopRequest(std::make_unique<StopSessionRequest>( |
| request_id, std::move(callback), clock_))) { |
| SendCastMessage(socket, CreateStopRequest(client_id.value_or(sender_id_), |
| request_id, session_id)); |
| } |
| } |
| |
| Result CastMessageHandler::SendAppMessage(int channel_id, |
| const CastMessage& message) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DCHECK(!IsCastInternalNamespace(message.namespace_())) |
| << ": unexpected app message namespace: " << message.namespace_(); |
| |
| CastSocket* socket = socket_service_->GetSocket(channel_id); |
| if (!socket) { |
| DVLOG(2) << __func__ << ": socket not found: " << channel_id; |
| return Result::kFailed; |
| } |
| |
| SendCastMessage(socket, message); |
| return Result::kOk; |
| } |
| |
| base::Optional<int> CastMessageHandler::SendMediaRequest( |
| int channel_id, |
| const base::Value& body, |
| const std::string& source_id, |
| const std::string& destination_id) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| CastSocket* socket = socket_service_->GetSocket(channel_id); |
| if (!socket) { |
| DVLOG(2) << __func__ << ": socket not found: " << channel_id; |
| return base::nullopt; |
| } |
| |
| int request_id = NextRequestId(); |
| SendCastMessage( |
| socket, CreateMediaRequest(body, request_id, source_id, destination_id)); |
| return request_id; |
| } |
| |
| Result CastMessageHandler::SendSetVolumeRequest(int channel_id, |
| const base::Value& body, |
| const std::string& source_id, |
| ResultCallback callback) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| CastSocket* socket = socket_service_->GetSocket(channel_id); |
| if (!socket) { |
| DVLOG(2) << __func__ << ": socket not found: " << channel_id; |
| return Result::kFailed; |
| } |
| |
| auto* requests = GetOrCreatePendingRequests(channel_id); |
| int request_id = NextRequestId(); |
| |
| requests->AddVolumeRequest(std::make_unique<SetVolumeRequest>( |
| request_id, std::move(callback), clock_)); |
| SendCastMessage(socket, CreateSetVolumeRequest(body, request_id, source_id)); |
| return Result::kOk; |
| } |
| |
| void CastMessageHandler::AddObserver(Observer* observer) { |
| observers_.AddObserver(observer); |
| } |
| |
| void CastMessageHandler::RemoveObserver(Observer* observer) { |
| observers_.RemoveObserver(observer); |
| } |
| |
| void CastMessageHandler::OnError(const CastSocket& socket, |
| ChannelError error_state) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| int channel_id = socket.id(); |
| |
| base::EraseIf(virtual_connections_, |
| [&channel_id](const VirtualConnection& connection) { |
| return connection.channel_id == channel_id; |
| }); |
| |
| pending_requests_.erase(channel_id); |
| } |
| |
| void CastMessageHandler::OnMessage(const CastSocket& socket, |
| const CastMessage& message) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DVLOG(2) << __func__ << ", channel_id: " << socket.id() |
| << ", message: " << CastMessageToString(message); |
| if (IsCastInternalNamespace(message.namespace_())) { |
| if (message.payload_type() == |
| cast_channel::CastMessage_PayloadType_STRING) { |
| data_decoder::SafeJsonParser::ParseBatch( |
| connector_.get(), message.payload_utf8(), |
| base::BindRepeating(&CastMessageHandler::HandleCastInternalMessage, |
| weak_ptr_factory_.GetWeakPtr(), socket.id(), |
| message.source_id(), message.destination_id()), |
| base::BindRepeating(&ReportParseError), data_decoder_batch_id_); |
| } else { |
| DLOG(ERROR) << "Dropping internal message with binary payload: " |
| << message.namespace_(); |
| } |
| } else { |
| DVLOG(2) << "Got app message from cast channel with namespace: " |
| << message.namespace_(); |
| for (auto& observer : observers_) |
| observer.OnAppMessage(socket.id(), message); |
| } |
| } |
| |
| void CastMessageHandler::OnReadyStateChanged(const CastSocket& socket) { |
| if (socket.ready_state() == ReadyState::CLOSED) |
| pending_requests_.erase(socket.id()); |
| } |
| |
| void CastMessageHandler::HandleCastInternalMessage( |
| int channel_id, |
| const std::string& source_id, |
| const std::string& destination_id, |
| std::unique_ptr<base::Value> payload) { |
| if (!payload->is_dict()) { |
| ReportParseError("Parsed message not a dictionary"); |
| return; |
| } |
| |
| // Check if the socket still exists as it might have been removed during |
| // message parsing. |
| if (!socket_service_->GetSocket(channel_id)) { |
| DVLOG(2) << __func__ << ": socket not found: " << channel_id; |
| return; |
| } |
| |
| base::Optional<int> request_id = GetRequestIdFromResponse(*payload); |
| if (request_id) { |
| auto requests_it = pending_requests_.find(channel_id); |
| if (requests_it != pending_requests_.end()) |
| requests_it->second->HandlePendingRequest(*request_id, *payload); |
| } |
| |
| CastMessageType type = ParseMessageTypeFromPayload(*payload); |
| if (type == CastMessageType::kOther) { |
| DVLOG(2) << "Unknown message type: " << *payload; |
| return; |
| } |
| |
| if (type == CastMessageType::kCloseConnection) { |
| // Source / destination is flipped. |
| virtual_connections_.erase( |
| VirtualConnection(channel_id, destination_id, source_id)); |
| return; |
| } |
| |
| InternalMessage internal_message(type, std::move(*payload)); |
| for (auto& observer : observers_) |
| observer.OnInternalMessage(channel_id, internal_message); |
| } |
| |
| void CastMessageHandler::SendCastMessage(CastSocket* socket, |
| const CastMessage& message) { |
| // A virtual connection must be opened to the receiver before other messages |
| // can be sent. |
| DoEnsureConnection(socket, message.source_id(), message.destination_id()); |
| socket->transport()->SendMessage( |
| message, base::BindOnce(&CastMessageHandler::OnMessageSent, |
| weak_ptr_factory_.GetWeakPtr())); |
| } |
| |
| void CastMessageHandler::DoEnsureConnection(CastSocket* socket, |
| const std::string& source_id, |
| const std::string& destination_id) { |
| VirtualConnection connection(socket->id(), source_id, destination_id); |
| if (virtual_connections_.find(connection) != virtual_connections_.end()) |
| return; |
| |
| DVLOG(1) << "Creating VC for channel: " << connection.channel_id |
| << ", source: " << connection.source_id |
| << ", dest: " << connection.destination_id; |
| CastMessage virtual_connection_request = CreateVirtualConnectionRequest( |
| connection.source_id, connection.destination_id, |
| connection.destination_id == kPlatformReceiverId |
| ? VirtualConnectionType::kStrong |
| : VirtualConnectionType::kInvisible, |
| user_agent_, browser_version_); |
| socket->transport()->SendMessage( |
| virtual_connection_request, |
| base::BindOnce(&CastMessageHandler::OnMessageSent, |
| weak_ptr_factory_.GetWeakPtr())); |
| |
| // We assume the virtual connection request will succeed; otherwise this |
| // will eventually self-correct. |
| virtual_connections_.insert(connection); |
| } |
| |
| void CastMessageHandler::OnMessageSent(int result) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DVLOG_IF(2, result < 0) << "SendMessage failed with code: " << result; |
| } |
| |
| CastMessageHandler::PendingRequests::PendingRequests() {} |
| CastMessageHandler::PendingRequests::~PendingRequests() { |
| for (auto& request : pending_app_availability_requests_) { |
| std::move(request->callback) |
| .Run(request->app_id, GetAppAvailabilityResult::kUnknown); |
| } |
| |
| if (pending_launch_session_request_) { |
| LaunchSessionResponse response; |
| response.result = LaunchSessionResponse::kError; |
| std::move(pending_launch_session_request_->callback) |
| .Run(std::move(response)); |
| } |
| |
| if (pending_stop_session_request_) |
| std::move(pending_stop_session_request_->callback).Run(Result::kFailed); |
| |
| for (auto& request : pending_volume_requests_by_id_) |
| std::move(request.second->callback).Run(Result::kFailed); |
| } |
| |
| bool CastMessageHandler::PendingRequests::AddAppAvailabilityRequest( |
| std::unique_ptr<GetAppAvailabilityRequest> request) { |
| const std::string& app_id = request->app_id; |
| int request_id = request->request_id; |
| request->timeout_timer.Start( |
| FROM_HERE, kRequestTimeout, |
| base::BindOnce( |
| &CastMessageHandler::PendingRequests::AppAvailabilityTimedOut, |
| base::Unretained(this), request_id)); |
| |
| // Look for a request with the given app ID. |
| bool found = std::find_if(pending_app_availability_requests_.begin(), |
| pending_app_availability_requests_.end(), |
| [&app_id](const auto& old_request) { |
| return old_request->app_id == app_id; |
| }) != pending_app_availability_requests_.end(); |
| pending_app_availability_requests_.emplace_back(std::move(request)); |
| return !found; |
| } |
| |
| bool CastMessageHandler::PendingRequests::AddLaunchRequest( |
| std::unique_ptr<LaunchSessionRequest> request, |
| base::TimeDelta timeout) { |
| if (pending_launch_session_request_) |
| return false; |
| |
| int request_id = request->request_id; |
| request->timeout_timer.Start( |
| FROM_HERE, timeout, |
| base::BindOnce( |
| &CastMessageHandler::PendingRequests::LaunchSessionTimedOut, |
| base::Unretained(this), request_id)); |
| pending_launch_session_request_ = std::move(request); |
| return true; |
| } |
| |
| bool CastMessageHandler::PendingRequests::AddStopRequest( |
| std::unique_ptr<StopSessionRequest> request) { |
| if (pending_stop_session_request_) |
| return false; |
| |
| int request_id = request->request_id; |
| request->timeout_timer.Start( |
| FROM_HERE, kRequestTimeout, |
| base::BindOnce(&CastMessageHandler::PendingRequests::StopSessionTimedOut, |
| base::Unretained(this), request_id)); |
| pending_stop_session_request_ = std::move(request); |
| return true; |
| } |
| |
| void CastMessageHandler::PendingRequests::AddVolumeRequest( |
| std::unique_ptr<SetVolumeRequest> request) { |
| int request_id = request->request_id; |
| request->timeout_timer.Start( |
| FROM_HERE, kRequestTimeout, |
| base::BindOnce(&CastMessageHandler::PendingRequests::SetVolumeTimedOut, |
| base::Unretained(this), request_id)); |
| pending_volume_requests_by_id_.emplace(request_id, std::move(request)); |
| } |
| |
| void CastMessageHandler::PendingRequests::HandlePendingRequest( |
| int request_id, |
| const base::Value& response) { |
| // Look up an app availability request by its |request_id|. |
| auto app_availability_it = |
| std::find_if(pending_app_availability_requests_.begin(), |
| pending_app_availability_requests_.end(), |
| [request_id](const auto& request_ptr) { |
| return request_ptr->request_id == request_id; |
| }); |
| // If we found a request, process and remove all requests with the same |
| // |app_id|, which will of course include the one we just found. |
| if (app_availability_it != pending_app_availability_requests_.end()) { |
| std::string app_id = (*app_availability_it)->app_id; |
| GetAppAvailabilityResult result = |
| GetAppAvailabilityResultFromResponse(response, app_id); |
| base::EraseIf(pending_app_availability_requests_, |
| [&app_id, result](const auto& request_ptr) { |
| if (request_ptr->app_id == app_id) { |
| std::move(request_ptr->callback).Run(app_id, result); |
| return true; |
| } |
| return false; |
| }); |
| return; |
| } |
| |
| if (pending_launch_session_request_ && |
| pending_launch_session_request_->request_id == request_id) { |
| std::move(pending_launch_session_request_->callback) |
| .Run(GetLaunchSessionResponse(response)); |
| pending_launch_session_request_.reset(); |
| return; |
| } |
| |
| if (pending_stop_session_request_ && |
| pending_stop_session_request_->request_id == request_id) { |
| std::move(pending_stop_session_request_->callback).Run(Result::kOk); |
| pending_stop_session_request_.reset(); |
| return; |
| } |
| |
| auto volume_it = pending_volume_requests_by_id_.find(request_id); |
| if (volume_it != pending_volume_requests_by_id_.end()) { |
| std::move(volume_it->second->callback).Run(Result::kOk); |
| pending_volume_requests_by_id_.erase(volume_it); |
| return; |
| } |
| } |
| |
| void CastMessageHandler::PendingRequests::AppAvailabilityTimedOut( |
| int request_id) { |
| DVLOG(1) << __func__ << ", request_id: " << request_id; |
| |
| auto it = std::find_if(pending_app_availability_requests_.begin(), |
| pending_app_availability_requests_.end(), |
| [&request_id](const auto& request) { |
| return request->request_id == request_id; |
| }); |
| |
| CHECK(it != pending_app_availability_requests_.end()); |
| std::move((*it)->callback) |
| .Run((*it)->app_id, GetAppAvailabilityResult::kUnknown); |
| pending_app_availability_requests_.erase(it); |
| } |
| |
| void CastMessageHandler::PendingRequests::LaunchSessionTimedOut( |
| int request_id) { |
| DVLOG(1) << __func__ << ", request_id: " << request_id; |
| CHECK(pending_launch_session_request_); |
| CHECK(pending_launch_session_request_->request_id == request_id); |
| |
| LaunchSessionResponse response; |
| response.result = LaunchSessionResponse::kTimedOut; |
| std::move(pending_launch_session_request_->callback).Run(std::move(response)); |
| pending_launch_session_request_.reset(); |
| } |
| |
| void CastMessageHandler::PendingRequests::StopSessionTimedOut(int request_id) { |
| DVLOG(1) << __func__ << ", request_id: " << request_id; |
| CHECK(pending_stop_session_request_); |
| CHECK(pending_stop_session_request_->request_id == request_id); |
| |
| std::move(pending_stop_session_request_->callback).Run(Result::kFailed); |
| pending_stop_session_request_.reset(); |
| } |
| |
| void CastMessageHandler::PendingRequests::SetVolumeTimedOut(int request_id) { |
| DVLOG(1) << __func__ << ", request_id: " << request_id; |
| auto it = pending_volume_requests_by_id_.find(request_id); |
| DCHECK(it != pending_volume_requests_by_id_.end()); |
| std::move(it->second->callback).Run(Result::kFailed); |
| pending_volume_requests_by_id_.erase(it); |
| } |
| |
| } // namespace cast_channel |