[Cast MRP] STOP_SESSION message from client.

This patch is a continuation on patch 1419204 to implement STOP_SESSION
message.

Bug: 809249
Change-Id: I86b897a793e17b7b3495f8103904af3697f224ae
Reviewed-on: https://chromium-review.googlesource.com/c/1423425
Commit-Queue: John Williams <jrw@chromium.org>
Reviewed-by: mark a. foltz <mfoltz@chromium.org>
Reviewed-by: Takumi Fujimoto <takumif@chromium.org>
Cr-Commit-Position: refs/heads/master@{#634471}
diff --git a/chrome/browser/media/router/providers/cast/cast_activity_manager.cc b/chrome/browser/media/router/providers/cast/cast_activity_manager.cc
index 8569e0c..9e084fc 100644
--- a/chrome/browser/media/router/providers/cast/cast_activity_manager.cc
+++ b/chrome/browser/media/router/providers/cast/cast_activity_manager.cc
@@ -4,6 +4,9 @@
 
 #include "chrome/browser/media/router/providers/cast/cast_activity_manager.h"
 
+#include <memory>
+#include <vector>
+
 #include "base/bind.h"
 #include "chrome/browser/media/router/data_decoder_util.h"
 #include "chrome/common/media_router/discovery/media_sink_service_base.h"
@@ -202,20 +205,6 @@
   connection_binding_.Close();
 }
 
-CastActivityRecord::CastActivityRecord(
-    const MediaRoute& route,
-    const std::string& app_id,
-    MediaSinkServiceBase* media_sink_service,
-    cast_channel::CastMessageHandler* message_handler,
-    CastSessionTracker* session_tracker,
-    DataDecoder* data_decoder)
-    : route_(route),
-      app_id_(app_id),
-      media_sink_service_(media_sink_service),
-      message_handler_(message_handler),
-      session_tracker_(session_tracker),
-      data_decoder_(data_decoder) {}
-
 CastActivityRecord::~CastActivityRecord() {}
 
 mojom::RoutePresentationConnectionPtr CastActivityRecord::AddClient(
@@ -284,6 +273,21 @@
       cast_message.client_id, std::move(callback));
 }
 
+void CastActivityRecord::SendStopSessionMessageToReceiver(
+    const base::Optional<std::string>& client_id,
+    mojom::MediaRouteProvider::TerminateRouteCallback callback) {
+  const std::string& sink_id = route_.media_sink_id();
+  const MediaSinkInternal* sink = media_sink_service_->GetSinkById(sink_id);
+  DCHECK(sink);
+  DCHECK(session_id_);
+
+  message_handler_->StopSession(
+      sink->cast_data().cast_channel_id, *session_id_, client_id,
+      base::BindOnce(&CastActivityManager::HandleStopSessionResponse,
+                     activity_manager_->GetWeakPtr(), route_.media_route_id(),
+                     std::move(callback)));
+}
+
 void CastActivityRecord::SendMessageToClient(
     const std::string& client_id,
     blink::mojom::PresentationConnectionMessagePtr message) {
@@ -306,6 +310,22 @@
     client.second->TerminateConnection();
 }
 
+CastActivityRecord::CastActivityRecord(
+    const MediaRoute& route,
+    const std::string& app_id,
+    MediaSinkServiceBase* media_sink_service,
+    cast_channel::CastMessageHandler* message_handler,
+    CastSessionTracker* session_tracker,
+    DataDecoder* data_decoder,
+    CastActivityManager* owner)
+    : route_(route),
+      app_id_(app_id),
+      media_sink_service_(media_sink_service),
+      message_handler_(message_handler),
+      session_tracker_(session_tracker),
+      data_decoder_(data_decoder),
+      activity_manager_(owner) {}
+
 CastSession* CastActivityRecord::GetSession() {
   DCHECK(session_id_);
   CastSession* session = session_tracker_->GetSessionById(*session_id_);
@@ -412,8 +432,7 @@
         existing_route_id,
         base::BindOnce(
             &CastActivityManager::LaunchSessionAfterTerminatingExisting,
-            weak_ptr_factory_.GetWeakPtr(), existing_route_id,
-            std::move(params)));
+            GetWeakPtr(), existing_route_id, std::move(params)));
   }
 }
 
@@ -431,19 +450,17 @@
            << ", sink ID = " << sink.sink().id() << ", app ID = " << app_id
            << ", origin = " << params.origin << ", tab ID = " << params.tab_id;
 
-  auto activity = std::make_unique<CastActivityRecord>(
+  std::unique_ptr<CastActivityRecord> activity(new CastActivityRecord(
       route, app_id, media_sink_service_, message_handler_, session_tracker_,
-      data_decoder_.get());
+      data_decoder_.get(), this));
   auto* activity_ptr = activity.get();
   activities_.emplace(route_id, std::move(activity));
   NotifyAllOnRoutesUpdated();
-
   base::TimeDelta launch_timeout = cast_source.launch_timeout();
   message_handler_->LaunchSession(
       sink.cast_data().cast_channel_id, app_id, launch_timeout,
       base::BindOnce(&CastActivityManager::HandleLaunchSessionResponse,
-                     weak_ptr_factory_.GetWeakPtr(), route_id, sink,
-                     cast_source));
+                     GetWeakPtr(), route_id, sink, cast_source));
 
   mojom::RoutePresentationConnectionPtr presentation_connection;
   const std::string& client_id = cast_source.client_id();
@@ -521,22 +538,16 @@
   }
 
   const MediaSinkInternal* sink = media_sink_service_->GetSinkByRoute(route);
-  if (!sink) {
-    RemoveActivity(activity_it);
-    std::move(callback).Run(base::nullopt, RouteRequestResult::OK);
-    return;
-  }
+  CHECK(sink);
 
   for (auto& client : activity->connected_clients()) {
     client.second->SendMessageToClient(
         CreateReceiverActionStopMessage(client.first, *sink, hash_token_));
   }
 
-  message_handler_->StopSession(
-      sink->cast_data().cast_channel_id, *session_id,
-      base::BindOnce(&CastActivityManager::HandleStopSessionResponse,
-                     weak_ptr_factory_.GetWeakPtr(), route_id,
-                     std::move(callback)));
+  activity->SendStopSessionMessageToReceiver(
+      base::nullopt,  // TODO(jrw): Get the real client ID.
+      std::move(callback));
 }
 
 CastActivityManager::ActivityMap::iterator
@@ -674,9 +685,9 @@
   MediaRoute route(route_id, source, sink_id, /* description */ std::string(),
                    /* is_local */ false, /* for_display */ true);
 
-  auto record = std::make_unique<CastActivityRecord>(
+  std::unique_ptr<CastActivityRecord> record(new CastActivityRecord(
       route, app_id, media_sink_service_, message_handler_, session_tracker_,
-      data_decoder_.get());
+      data_decoder_.get(), this));
   record->SetOrUpdateSession(session, sink, hash_token_);
   activities_.emplace(route_id, std::move(record));
 }
@@ -744,13 +755,12 @@
     DVLOG(2) << "Sending new_session message for route " << route_id
              << ", client_id: " << client_id;
     activity_it->second->SendMessageToClient(
-        client_id, CreateNewSessionMessage(*session, cast_source.client_id(),
-                                           sink, hash_token_));
+        client_id,
+        CreateNewSessionMessage(*session, client_id, sink, hash_token_));
 
-    // TODO(imcheng): Query media status.
+    // TODO(jrw): Query media status.
     message_handler_->EnsureConnection(sink.cast_data().cast_channel_id,
-                                       cast_source.client_id(),
-                                       session->transport_id());
+                                       client_id, session->transport_id());
   }
 
   activity_it->second->SetOrUpdateSession(*session, sink, hash_token_);
@@ -762,6 +772,8 @@
     mojom::MediaRouteProvider::TerminateRouteCallback callback,
     cast_channel::Result result) {
   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+  VLOG(2) << __func__ << ": " << route_id;
+
   auto activity_it = activities_.find(route_id);
   if (activity_it == activities_.end()) {
     // The activity could've been removed via RECEIVER_STATUS message.
@@ -789,6 +801,10 @@
   media_router_->OnIssue(info);
 }
 
+base::WeakPtr<CastActivityManager> CastActivityManager::GetWeakPtr() {
+  return weak_ptr_factory_.GetWeakPtr();
+}
+
 CastActivityManager::DoLaunchSessionParams::DoLaunchSessionParams(
     const MediaRoute& route,
     const CastMediaSource& cast_source,
diff --git a/chrome/browser/media/router/providers/cast/cast_activity_manager.h b/chrome/browser/media/router/providers/cast/cast_activity_manager.h
index 08fe0ccc..0844a01 100644
--- a/chrome/browser/media/router/providers/cast/cast_activity_manager.h
+++ b/chrome/browser/media/router/providers/cast/cast_activity_manager.h
@@ -135,12 +135,6 @@
 // Instances of this class are associated with a specific session and app.
 class CastActivityRecord {
  public:
-  CastActivityRecord(const MediaRoute& route,
-                     const std::string& app_id,
-                     MediaSinkServiceBase* media_sink_service,
-                     cast_channel::CastMessageHandler* message_handler,
-                     CastSessionTracker* session_tracker,
-                     DataDecoder* data_decoder);
   ~CastActivityRecord();
 
   const MediaRoute& route() const { return route_; }
@@ -168,6 +162,10 @@
   void SendSetVolumeRequestToReceiver(const CastInternalMessage& cast_message,
                                       cast_channel::ResultCallback callback);
 
+  void SendStopSessionMessageToReceiver(
+      const base::Optional<std::string>& client_id,
+      mojom::MediaRouteProvider::TerminateRouteCallback callback);
+
   // Adds a new client |client_id| to this session and returns the handles of
   // the two pipes to be held by Blink It is invalid to call this method if the
   // client already exists.
@@ -198,6 +196,16 @@
 
  private:
   friend class CastSessionClient;
+  friend class CastActivityManager;
+
+  // Creates a new record owned by |owner|.
+  CastActivityRecord(const MediaRoute& route,
+                     const std::string& app_id,
+                     MediaSinkServiceBase* media_sink_service,
+                     cast_channel::CastMessageHandler* message_handler,
+                     CastSessionTracker* session_tracker,
+                     DataDecoder* data_decoder,
+                     CastActivityManager* owner);
 
   CastSession* GetSession();
   int GetCastChannelId();
@@ -217,6 +225,7 @@
   cast_channel::CastMessageHandler* const message_handler_;
   CastSessionTracker* const session_tracker_;
   DataDecoder* const data_decoder_;
+  CastActivityManager* const activity_manager_;
 
   DISALLOW_COPY_AND_ASSIGN(CastActivityRecord);
 };
@@ -280,6 +289,8 @@
                             base::Optional<int> request_id) override;
 
  private:
+  friend class CastActivityRecord;
+
   using ActivityMap =
       base::flat_map<MediaRoute::Id, std::unique_ptr<CastActivityRecord>>;
 
@@ -350,6 +361,8 @@
   void SendFailedToCastIssue(const MediaSink::Id& sink_id,
                              const MediaRoute::Id& route_id);
 
+  base::WeakPtr<CastActivityManager> GetWeakPtr();
+
   // These methods return |activities_.end()| when nothing is found.
   ActivityMap::iterator FindActivityByChannelId(int channel_id);
   ActivityMap::iterator FindActivityBySink(const MediaSinkInternal& sink);
diff --git a/chrome/browser/media/router/providers/cast/cast_activity_manager_unittest.cc b/chrome/browser/media/router/providers/cast/cast_activity_manager_unittest.cc
index d5ca539..dddd0cd 100644
--- a/chrome/browser/media/router/providers/cast/cast_activity_manager_unittest.cc
+++ b/chrome/browser/media/router/providers/cast/cast_activity_manager_unittest.cc
@@ -31,6 +31,7 @@
 using testing::IsEmpty;
 using testing::Not;
 using testing::Return;
+using testing::WithArg;
 
 namespace media_router {
 
@@ -179,10 +180,9 @@
     // A launch session request is sent to the sink.
     EXPECT_CALL(message_handler_,
                 LaunchSession(kChannelId, "ABCDEFGH", kDefaultLaunchTimeout, _))
-        .WillOnce(
-            [this](auto chanel_id, auto app_id, auto timeout, auto callback) {
-              launch_session_callback_ = std::move(callback);
-            });
+        .WillOnce(WithArg<3>([this](auto callback) {
+          launch_session_callback_ = std::move(callback);
+        }));
 
     auto source = CastMediaSource::FromMediaSourceId(kSource1);
     ASSERT_TRUE(source);
@@ -245,10 +245,11 @@
   void TerminateSession(cast_channel::Result result) {
     cast_channel::ResultCallback stop_session_callback;
 
-    EXPECT_CALL(message_handler_, StopSession(kChannelId, "theSessionId", _))
-        .WillOnce([&](auto channel_id, auto session_id, auto callback) {
+    EXPECT_CALL(message_handler_, StopSession(kChannelId, "theSessionId",
+                                              base::Optional<std::string>(), _))
+        .WillOnce(WithArg<3>([&](auto callback) {
           stop_session_callback = std::move(callback);
-        });
+        }));
     manager_->TerminateSession(
         route_->media_route_id(),
         base::BindOnce(
@@ -267,7 +268,7 @@
   // not called.
   void TerminateNoSession() {
     // Stop session message not sent because session has not launched yet.
-    EXPECT_CALL(message_handler_, StopSession(_, _, _)).Times(0);
+    EXPECT_CALL(message_handler_, StopSession(_, _, _, _)).Times(0);
     manager_->TerminateSession(
         route_->media_route_id(),
         base::BindOnce(&CastActivityManagerTest::ExpectTerminateResultSuccess,
@@ -299,13 +300,12 @@
   void ExpectSingleRouteUpdate(MediaRoute* route_ptr = nullptr) {
     EXPECT_CALL(mock_router_, OnRoutesUpdated(MediaRouteProviderId::CAST,
                                               Not(IsEmpty()), _, _))
-        .WillOnce([=](auto provider_id, auto routes, auto media_source,
-                      auto joinable_route_ids) {
+        .WillOnce(WithArg<1>([=](auto routes) {
           EXPECT_EQ(1u, routes.size());
           if (route_ptr) {
             *route_ptr = routes[0];
           }
-        });
+        }));
   }
 
   // Expect a call to OnRoutesUpdated() with no routes.
@@ -369,10 +369,10 @@
 
   // Existing session will be terminated.
   cast_channel::ResultCallback stop_session_callback;
-  EXPECT_CALL(message_handler_, StopSession(kChannelId, "theSessionId", _))
-      .WillOnce([&](auto channel_id, auto session_id, auto callback) {
-        stop_session_callback = std::move(callback);
-      });
+  EXPECT_CALL(message_handler_, StopSession(kChannelId, "theSessionId",
+                                            base::Optional<std::string>(), _))
+      .WillOnce(WithArg<3>(
+          [&](auto callback) { stop_session_callback = std::move(callback); }));
 
   // Launch a new session on the same sink.
   auto source = CastMediaSource::FromMediaSourceId(kSource2);
@@ -642,8 +642,7 @@
   EXPECT_CALL(message_handler_,
               SendSetVolumeRequest(kChannelId, IsJson(expected_message),
                                    "theClientId", _))
-      .WillOnce([&](int channel_id, const base::Value& message,
-                    const std::string& client_id, auto callback) {
+      .WillOnce(WithArg<3>([&](auto callback) {
         // Check message created by CastSessionClient::SendResultResponse().
         EXPECT_CALL(*client_connection_, OnMessage(IsCastMessage(R"({
                     "clientId": "theClientId",
@@ -654,7 +653,7 @@
                   })")));
         std::move(callback).Run(cast_channel::Result::kOk);
         return cast_channel::Result::kOk;
-      });
+      }));
   client_connection_->SendMessageToMediaRouter(
       blink::mojom::PresentationConnectionMessage::NewMessage(R"({
         "type": "v2_message",
diff --git a/components/cast_channel/cast_message_handler.cc b/components/cast_channel/cast_message_handler.cc
index 920c3af..8baa632 100644
--- a/components/cast_channel/cast_message_handler.cc
+++ b/components/cast_channel/cast_message_handler.cc
@@ -6,6 +6,7 @@
 
 #include <tuple>
 #include <utility>
+#include <vector>
 
 #include "base/bind.h"
 #include "base/rand_util.h"
@@ -191,9 +192,11 @@
   }
 }
 
-void CastMessageHandler::StopSession(int channel_id,
-                                     const std::string& session_id,
-                                     ResultCallback callback) {
+void CastMessageHandler::StopSession(
+    int channel_id,
+    const std::string& session_id,
+    const base::Optional<std::string>& client_id,
+    ResultCallback callback) {
   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
   CastSocket* socket = socket_service_->GetSocket(channel_id);
   if (!socket) {
@@ -207,8 +210,8 @@
            << ", request_id: " << request_id;
   if (requests->AddStopRequest(std::make_unique<StopSessionRequest>(
           request_id, std::move(callback), clock_))) {
-    SendCastMessage(socket,
-                    CreateStopRequest(sender_id_, request_id, session_id));
+    SendCastMessage(socket, CreateStopRequest(client_id.value_or(sender_id_),
+                                              request_id, session_id));
   }
 }
 
diff --git a/components/cast_channel/cast_message_handler.h b/components/cast_channel/cast_message_handler.h
index e6049ad..9023302 100644
--- a/components/cast_channel/cast_message_handler.h
+++ b/components/cast_channel/cast_message_handler.h
@@ -11,6 +11,7 @@
 #include "base/gtest_prod_util.h"
 #include "base/macros.h"
 #include "base/memory/weak_ptr.h"
+#include "base/optional.h"
 #include "base/sequence_checker.h"
 #include "base/time/tick_clock.h"
 #include "base/timer/timer.h"
@@ -176,6 +177,7 @@
   // request.
   virtual void StopSession(int channel_id,
                            const std::string& session_id,
+                           const base::Optional<std::string>& client_id,
                            ResultCallback callback);
 
   // Sends |message| to the device given by |channel_id|. The caller may use
diff --git a/components/cast_channel/cast_message_handler_unittest.cc b/components/cast_channel/cast_message_handler_unittest.cc
index cab1fdb..908f75e 100644
--- a/components/cast_channel/cast_message_handler_unittest.cc
+++ b/components/cast_channel/cast_message_handler_unittest.cc
@@ -144,7 +144,7 @@
                   R"({"sessionId": "theSessionId", "type": "SET_VOLUME"})"),
               "theSourceId", set_volume_callback_.Get()));
     }
-    handler_.StopSession(channel_id_, "theSessionId",
+    handler_.StopSession(channel_id_, "theSessionId", "theSourceId",
                          stop_session_callback_.Get());
   }
 
diff --git a/components/cast_channel/cast_test_util.h b/components/cast_channel/cast_test_util.h
index 335b78d59..d037bc14 100644
--- a/components/cast_channel/cast_test_util.h
+++ b/components/cast_channel/cast_test_util.h
@@ -7,6 +7,7 @@
 
 #include <string>
 #include <utility>
+#include <vector>
 
 #include "base/bind.h"
 #include "base/macros.h"
@@ -164,9 +165,10 @@
                     const std::string&,
                     base::TimeDelta,
                     LaunchSessionCallback callback));
-  MOCK_METHOD3(StopSession,
+  MOCK_METHOD4(StopSession,
                void(int channel_id,
                     const std::string& session_id,
+                    const base::Optional<std::string>& client_id,
                     ResultCallback callback));
   MOCK_METHOD2(SendAppMessage,
                Result(int channel_id, const CastMessage& message));