Mojo: Validate response message type
Ensures that a response message is actually the type expected by the
original request.
(cherry picked from commit 9b5207569882e59cc81f11fb364753569211dc48)
Fixed: 1358134
Change-Id: I8f8f58168764477fbf7a6d2e8aeb040f07793d45
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3864274
Reviewed-by: Robert Sesek <rsesek@chromium.org>
Commit-Queue: Ken Rockot <rockot@google.com>
Cr-Original-Commit-Position: refs/heads/main@{#1041553}
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3866709
Commit-Queue: Srinivas Sista <srinivassista@chromium.org>
Bot-Commit: Rubber Stamper <rubber-stamper@appspot.gserviceaccount.com>
Owners-Override: Srinivas Sista <srinivassista@chromium.org>
Cr-Commit-Position: refs/branch-heads/5112@{#1543}
Cr-Branched-From: b13d3fe7b3c47a56354ef54b221008afa754412e-refs/heads/main@{#1012729}
diff --git a/mojo/public/cpp/bindings/interface_endpoint_client.h b/mojo/public/cpp/bindings/interface_endpoint_client.h
index df33a20..dcc6e2aa 100644
--- a/mojo/public/cpp/bindings/interface_endpoint_client.h
+++ b/mojo/public/cpp/bindings/interface_endpoint_client.h
@@ -221,20 +221,32 @@
void ForgetAsyncRequest(uint64_t request_id);
private:
- // Maps from the id of a response to the MessageReceiver that handles the
- // response.
- using AsyncResponderMap =
- std::map<uint64_t, std::unique_ptr<MessageReceiver>>;
+ struct PendingAsyncResponse {
+ public:
+ PendingAsyncResponse(uint32_t request_message_name,
+ std::unique_ptr<MessageReceiver> responder);
+ PendingAsyncResponse(PendingAsyncResponse&&);
+ PendingAsyncResponse(const PendingAsyncResponse&) = delete;
+ PendingAsyncResponse& operator=(PendingAsyncResponse&&);
+ PendingAsyncResponse& operator=(const PendingAsyncResponse&) = delete;
+ ~PendingAsyncResponse();
+
+ uint32_t request_message_name;
+ std::unique_ptr<MessageReceiver> responder;
+ };
+
+ using AsyncResponderMap = std::map<uint64_t, PendingAsyncResponse>;
struct SyncResponseInfo {
public:
- explicit SyncResponseInfo(bool* in_response_received);
+ SyncResponseInfo(uint32_t request_message_name, bool* in_response_received);
SyncResponseInfo(const SyncResponseInfo&) = delete;
SyncResponseInfo& operator=(const SyncResponseInfo&) = delete;
~SyncResponseInfo();
+ uint32_t request_message_name;
Message response;
// Points to a stack-allocated variable.
diff --git a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc
index b9db8f31..6e87db1 100644
--- a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc
+++ b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc
@@ -28,6 +28,7 @@
#include "mojo/public/cpp/bindings/sync_call_restrictions.h"
#include "mojo/public/cpp/bindings/sync_event_watcher.h"
#include "mojo/public/cpp/bindings/thread_safe_proxy.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
#include "third_party/perfetto/protos/perfetto/trace/track_event/chrome_mojo_event_info.pbzero.h"
namespace mojo {
@@ -314,9 +315,27 @@
// ----------------------------------------------------------------------------
+InterfaceEndpointClient::PendingAsyncResponse::PendingAsyncResponse(
+ uint32_t request_message_name,
+ std::unique_ptr<MessageReceiver> responder)
+ : request_message_name(request_message_name),
+ responder(std::move(responder)) {}
+
+InterfaceEndpointClient::PendingAsyncResponse::PendingAsyncResponse(
+ PendingAsyncResponse&&) = default;
+
+InterfaceEndpointClient::PendingAsyncResponse&
+InterfaceEndpointClient::PendingAsyncResponse::operator=(
+ PendingAsyncResponse&&) = default;
+
+InterfaceEndpointClient::PendingAsyncResponse::~PendingAsyncResponse() =
+ default;
+
InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
+ uint32_t request_message_name,
bool* in_response_received)
- : response_received(in_response_received) {}
+ : request_message_name(request_message_name),
+ response_received(in_response_received) {}
InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
@@ -604,6 +623,7 @@
// message before calling |SendMessage()| below.
#endif
+ const uint32_t message_name = message->name();
const bool is_sync = message->has_flag(Message::kFlagIsSync);
const bool exclusive_wait = message->has_flag(Message::kFlagNoInterrupt);
if (!controller_->SendMessage(message))
@@ -620,7 +640,8 @@
controller_->RegisterExternalSyncWaiter(request_id);
}
base::AutoLock lock(async_responders_lock_);
- async_responders_[request_id] = std::move(responder);
+ async_responders_.emplace(
+ request_id, PendingAsyncResponse{message_name, std::move(responder)});
return true;
}
@@ -628,7 +649,8 @@
bool response_received = false;
sync_responses_.insert(std::make_pair(
- request_id, std::make_unique<SyncResponseInfo>(&response_received)));
+ request_id,
+ std::make_unique<SyncResponseInfo>(message_name, &response_received)));
base::WeakPtr<InterfaceEndpointClient> weak_self =
weak_ptr_factory_.GetWeakPtr();
@@ -806,13 +828,13 @@
}
void InterfaceEndpointClient::ForgetAsyncRequest(uint64_t request_id) {
- std::unique_ptr<MessageReceiver> responder;
+ absl::optional<PendingAsyncResponse> response;
{
base::AutoLock lock(async_responders_lock_);
auto it = async_responders_.find(request_id);
if (it == async_responders_.end())
return;
- responder = std::move(it->second);
+ response = std::move(it->second);
async_responders_.erase(it);
}
}
@@ -893,6 +915,10 @@
return false;
if (it->second) {
+ if (message->name() != it->second->request_message_name) {
+ return false;
+ }
+
it->second->response = std::move(*message);
*it->second->response_received = true;
return true;
@@ -903,18 +929,22 @@
sync_responses_.erase(it);
}
- std::unique_ptr<MessageReceiver> responder;
+ absl::optional<PendingAsyncResponse> pending_response;
{
base::AutoLock lock(async_responders_lock_);
auto it = async_responders_.find(request_id);
if (it == async_responders_.end())
return false;
- responder = std::move(it->second);
+ pending_response = std::move(it->second);
async_responders_.erase(it);
}
+ if (message->name() != pending_response->request_message_name) {
+ return false;
+ }
+
internal::MessageDispatchContext dispatch_context(message);
- return responder->Accept(message);
+ return pending_response->responder->Accept(message);
} else {
if (mojo::internal::ControlMessageHandler::IsControlMessage(message))
return control_message_handler_.Accept(message);