Mojo: Validate response message type

Ensures that a response message is actually the type expected by the
original request.

(cherry picked from commit 9b5207569882e59cc81f11fb364753569211dc48)

Fixed: 1358134
Change-Id: I8f8f58168764477fbf7a6d2e8aeb040f07793d45
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3864274
Reviewed-by: Robert Sesek <rsesek@chromium.org>
Commit-Queue: Ken Rockot <rockot@google.com>
Cr-Original-Commit-Position: refs/heads/main@{#1041553}
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3866709
Commit-Queue: Srinivas Sista <srinivassista@chromium.org>
Bot-Commit: Rubber Stamper <rubber-stamper@appspot.gserviceaccount.com>
Owners-Override: Srinivas Sista <srinivassista@chromium.org>
Cr-Commit-Position: refs/branch-heads/5112@{#1543}
Cr-Branched-From: b13d3fe7b3c47a56354ef54b221008afa754412e-refs/heads/main@{#1012729}
diff --git a/mojo/public/cpp/bindings/interface_endpoint_client.h b/mojo/public/cpp/bindings/interface_endpoint_client.h
index df33a20..dcc6e2aa 100644
--- a/mojo/public/cpp/bindings/interface_endpoint_client.h
+++ b/mojo/public/cpp/bindings/interface_endpoint_client.h
@@ -221,20 +221,32 @@
   void ForgetAsyncRequest(uint64_t request_id);
 
  private:
-  // Maps from the id of a response to the MessageReceiver that handles the
-  // response.
-  using AsyncResponderMap =
-      std::map<uint64_t, std::unique_ptr<MessageReceiver>>;
+  struct PendingAsyncResponse {
+   public:
+    PendingAsyncResponse(uint32_t request_message_name,
+                         std::unique_ptr<MessageReceiver> responder);
+    PendingAsyncResponse(PendingAsyncResponse&&);
+    PendingAsyncResponse(const PendingAsyncResponse&) = delete;
+    PendingAsyncResponse& operator=(PendingAsyncResponse&&);
+    PendingAsyncResponse& operator=(const PendingAsyncResponse&) = delete;
+    ~PendingAsyncResponse();
+
+    uint32_t request_message_name;
+    std::unique_ptr<MessageReceiver> responder;
+  };
+
+  using AsyncResponderMap = std::map<uint64_t, PendingAsyncResponse>;
 
   struct SyncResponseInfo {
    public:
-    explicit SyncResponseInfo(bool* in_response_received);
+    SyncResponseInfo(uint32_t request_message_name, bool* in_response_received);
 
     SyncResponseInfo(const SyncResponseInfo&) = delete;
     SyncResponseInfo& operator=(const SyncResponseInfo&) = delete;
 
     ~SyncResponseInfo();
 
+    uint32_t request_message_name;
     Message response;
 
     // Points to a stack-allocated variable.
diff --git a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc
index b9db8f31..6e87db1 100644
--- a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc
+++ b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc
@@ -28,6 +28,7 @@
 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
 #include "mojo/public/cpp/bindings/sync_event_watcher.h"
 #include "mojo/public/cpp/bindings/thread_safe_proxy.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
 #include "third_party/perfetto/protos/perfetto/trace/track_event/chrome_mojo_event_info.pbzero.h"
 
 namespace mojo {
@@ -314,9 +315,27 @@
 
 // ----------------------------------------------------------------------------
 
+InterfaceEndpointClient::PendingAsyncResponse::PendingAsyncResponse(
+    uint32_t request_message_name,
+    std::unique_ptr<MessageReceiver> responder)
+    : request_message_name(request_message_name),
+      responder(std::move(responder)) {}
+
+InterfaceEndpointClient::PendingAsyncResponse::PendingAsyncResponse(
+    PendingAsyncResponse&&) = default;
+
+InterfaceEndpointClient::PendingAsyncResponse&
+InterfaceEndpointClient::PendingAsyncResponse::operator=(
+    PendingAsyncResponse&&) = default;
+
+InterfaceEndpointClient::PendingAsyncResponse::~PendingAsyncResponse() =
+    default;
+
 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
+    uint32_t request_message_name,
     bool* in_response_received)
-    : response_received(in_response_received) {}
+    : request_message_name(request_message_name),
+      response_received(in_response_received) {}
 
 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
 
@@ -604,6 +623,7 @@
   // message before calling |SendMessage()| below.
 #endif
 
+  const uint32_t message_name = message->name();
   const bool is_sync = message->has_flag(Message::kFlagIsSync);
   const bool exclusive_wait = message->has_flag(Message::kFlagNoInterrupt);
   if (!controller_->SendMessage(message))
@@ -620,7 +640,8 @@
       controller_->RegisterExternalSyncWaiter(request_id);
     }
     base::AutoLock lock(async_responders_lock_);
-    async_responders_[request_id] = std::move(responder);
+    async_responders_.emplace(
+        request_id, PendingAsyncResponse{message_name, std::move(responder)});
     return true;
   }
 
@@ -628,7 +649,8 @@
 
   bool response_received = false;
   sync_responses_.insert(std::make_pair(
-      request_id, std::make_unique<SyncResponseInfo>(&response_received)));
+      request_id,
+      std::make_unique<SyncResponseInfo>(message_name, &response_received)));
 
   base::WeakPtr<InterfaceEndpointClient> weak_self =
       weak_ptr_factory_.GetWeakPtr();
@@ -806,13 +828,13 @@
 }
 
 void InterfaceEndpointClient::ForgetAsyncRequest(uint64_t request_id) {
-  std::unique_ptr<MessageReceiver> responder;
+  absl::optional<PendingAsyncResponse> response;
   {
     base::AutoLock lock(async_responders_lock_);
     auto it = async_responders_.find(request_id);
     if (it == async_responders_.end())
       return;
-    responder = std::move(it->second);
+    response = std::move(it->second);
     async_responders_.erase(it);
   }
 }
@@ -893,6 +915,10 @@
         return false;
 
       if (it->second) {
+        if (message->name() != it->second->request_message_name) {
+          return false;
+        }
+
         it->second->response = std::move(*message);
         *it->second->response_received = true;
         return true;
@@ -903,18 +929,22 @@
       sync_responses_.erase(it);
     }
 
-    std::unique_ptr<MessageReceiver> responder;
+    absl::optional<PendingAsyncResponse> pending_response;
     {
       base::AutoLock lock(async_responders_lock_);
       auto it = async_responders_.find(request_id);
       if (it == async_responders_.end())
         return false;
-      responder = std::move(it->second);
+      pending_response = std::move(it->second);
       async_responders_.erase(it);
     }
 
+    if (message->name() != pending_response->request_message_name) {
+      return false;
+    }
+
     internal::MessageDispatchContext dispatch_context(message);
-    return responder->Accept(message);
+    return pending_response->responder->Accept(message);
   } else {
     if (mojo::internal::ControlMessageHandler::IsControlMessage(message))
       return control_message_handler_.Accept(message);