| // Copyright 2020 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <map> |
| #include <vector> |
| |
| #include "base/optional.h" |
| #include "base/run_loop.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/test/bind.h" |
| #include "base/test/scoped_feature_list.h" |
| #include "content/browser/direct_sockets/direct_sockets_service_impl.h" |
| #include "content/browser/renderer_host/frame_tree_node.h" |
| #include "content/browser/renderer_host/render_frame_host_impl.h" |
| #include "content/public/browser/browser_context.h" |
| #include "content/public/browser/storage_partition.h" |
| #include "content/public/browser/web_contents.h" |
| #include "content/public/common/content_features.h" |
| #include "content/public/test/browser_test.h" |
| #include "content/public/test/browser_test_utils.h" |
| #include "content/public/test/content_browser_test.h" |
| #include "content/public/test/content_browser_test_utils.h" |
| #include "content/shell/browser/shell.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "mojo/public/cpp/system/data_pipe.h" |
| #include "net/base/ip_address.h" |
| #include "net/base/ip_endpoint.h" |
| #include "net/base/net_errors.h" |
| #include "net/net_buildflags.h" |
| #include "net/test/embedded_test_server/embedded_test_server.h" |
| #include "net/traffic_annotation/network_traffic_annotation.h" |
| #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" |
| #include "services/network/public/mojom/host_resolver.mojom.h" |
| #include "services/network/public/mojom/network_context.mojom.h" |
| #include "services/network/public/mojom/tcp_socket.mojom.h" |
| #include "testing/gmock/include/gmock/gmock-matchers.h" |
| #include "url/gurl.h" |
| |
| // The tests in this file use the Network Service implementation of |
| // NetworkContext, to test sending and receiving of data over TCP sockets. |
| |
| using testing::StartsWith; |
| |
| namespace content { |
| |
| namespace { |
| |
| net::Error UnconditionallyPermitConnection( |
| const blink::mojom::DirectSocketOptions& options) { |
| DCHECK(options.remote_hostname.has_value()); |
| return net::OK; |
| } |
| |
| class ReadWriteWaiter { |
| public: |
| ReadWriteWaiter( |
| uint32_t required_receive_bytes, |
| uint32_t required_send_bytes, |
| mojo::Remote<network::mojom::TCPServerSocket>& tcp_server_socket) |
| : required_receive_bytes_(required_receive_bytes), |
| required_send_bytes_(required_send_bytes) { |
| tcp_server_socket->Accept( |
| /*observer=*/mojo::NullRemote(), |
| base::BindRepeating(&ReadWriteWaiter::OnAccept, |
| base::Unretained(this))); |
| } |
| |
| void Await() { run_loop_.Run(); } |
| |
| private: |
| void OnAccept( |
| int result, |
| const base::Optional<net::IPEndPoint>& remote_addr, |
| mojo::PendingRemote<network::mojom::TCPConnectedSocket> accepted_socket, |
| mojo::ScopedDataPipeConsumerHandle consumer_handle, |
| mojo::ScopedDataPipeProducerHandle producer_handle) { |
| DCHECK_EQ(result, net::OK); |
| DCHECK(!accepted_socket_); |
| accepted_socket_.Bind(std::move(accepted_socket)); |
| |
| if (required_receive_bytes_ > 0) { |
| receive_stream_ = std::move(consumer_handle); |
| read_watcher_ = std::make_unique<mojo::SimpleWatcher>( |
| FROM_HERE, mojo::SimpleWatcher::ArmingPolicy::MANUAL); |
| read_watcher_->Watch( |
| receive_stream_.get(), |
| MOJO_HANDLE_SIGNAL_READABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED, |
| MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED, |
| base::BindRepeating(&ReadWriteWaiter::OnReadReady, |
| base::Unretained(this))); |
| read_watcher_->ArmOrNotify(); |
| } |
| |
| if (required_send_bytes_ > 0) { |
| send_stream_ = std::move(producer_handle); |
| write_watcher_ = std::make_unique<mojo::SimpleWatcher>( |
| FROM_HERE, mojo::SimpleWatcher::ArmingPolicy::MANUAL); |
| write_watcher_->Watch( |
| send_stream_.get(), |
| MOJO_HANDLE_SIGNAL_WRITABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED, |
| MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED, |
| base::BindRepeating(&ReadWriteWaiter::OnWriteReady, |
| base::Unretained(this))); |
| write_watcher_->ArmOrNotify(); |
| } |
| } |
| |
| void OnReadReady(MojoResult result, const mojo::HandleSignalsState& state) { |
| ReadData(); |
| } |
| |
| void OnWriteReady(MojoResult result, const mojo::HandleSignalsState& state) { |
| WriteData(); |
| } |
| |
| void ReadData() { |
| while (true) { |
| DCHECK(receive_stream_.is_valid()); |
| DCHECK_LT(bytes_received_, required_receive_bytes_); |
| const void* buffer = nullptr; |
| uint32_t num_bytes = 0; |
| MojoResult mojo_result = receive_stream_->BeginReadData( |
| &buffer, &num_bytes, MOJO_READ_DATA_FLAG_NONE); |
| if (mojo_result == MOJO_RESULT_SHOULD_WAIT) { |
| read_watcher_->ArmOrNotify(); |
| return; |
| } |
| DCHECK_EQ(mojo_result, MOJO_RESULT_OK); |
| |
| // This is guaranteed by Mojo. |
| DCHECK_GT(num_bytes, 0u); |
| |
| const unsigned char* current = static_cast<const unsigned char*>(buffer); |
| const unsigned char* const end = current + num_bytes; |
| while (current < end) { |
| EXPECT_EQ(*current, bytes_received_ % 256); |
| ++current; |
| ++bytes_received_; |
| } |
| |
| mojo_result = receive_stream_->EndReadData(num_bytes); |
| DCHECK_EQ(mojo_result, MOJO_RESULT_OK); |
| |
| if (bytes_received_ == required_receive_bytes_) { |
| if (bytes_sent_ == required_send_bytes_) |
| run_loop_.Quit(); |
| return; |
| } |
| } |
| } |
| |
| void WriteData() { |
| while (true) { |
| DCHECK(send_stream_.is_valid()); |
| DCHECK_LT(bytes_sent_, required_send_bytes_); |
| void* buffer = nullptr; |
| uint32_t num_bytes = 0; |
| MojoResult mojo_result = send_stream_->BeginWriteData( |
| &buffer, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE); |
| if (mojo_result == MOJO_RESULT_SHOULD_WAIT) { |
| write_watcher_->ArmOrNotify(); |
| return; |
| } |
| DCHECK_EQ(mojo_result, MOJO_RESULT_OK); |
| |
| // This is guaranteed by Mojo. |
| DCHECK_GT(num_bytes, 0u); |
| |
| num_bytes = std::min(num_bytes, required_send_bytes_ - bytes_sent_); |
| |
| unsigned char* current = static_cast<unsigned char*>(buffer); |
| unsigned char* const end = current + num_bytes; |
| while (current != end) { |
| *current = bytes_sent_ % 256; |
| ++current; |
| ++bytes_sent_; |
| } |
| |
| mojo_result = send_stream_->EndWriteData(num_bytes); |
| DCHECK_EQ(mojo_result, MOJO_RESULT_OK); |
| |
| if (bytes_sent_ == required_send_bytes_) { |
| if (bytes_received_ == required_receive_bytes_) |
| run_loop_.Quit(); |
| return; |
| } |
| } |
| } |
| |
| const uint32_t required_receive_bytes_; |
| const uint32_t required_send_bytes_; |
| base::RunLoop run_loop_; |
| mojo::Remote<network::mojom::TCPConnectedSocket> accepted_socket_; |
| mojo::ScopedDataPipeConsumerHandle receive_stream_; |
| mojo::ScopedDataPipeProducerHandle send_stream_; |
| std::unique_ptr<mojo::SimpleWatcher> read_watcher_; |
| std::unique_ptr<mojo::SimpleWatcher> write_watcher_; |
| uint32_t bytes_received_ = 0; |
| uint32_t bytes_sent_ = 0; |
| }; |
| |
| } // anonymous namespace |
| |
| class DirectSocketsTcpBrowserTest : public ContentBrowserTest { |
| public: |
| DirectSocketsTcpBrowserTest() { |
| feature_list_.InitAndEnableFeature(features::kDirectSockets); |
| } |
| ~DirectSocketsTcpBrowserTest() override = default; |
| |
| GURL GetTestOpenPageURL() { |
| return embedded_test_server()->GetURL("/direct_sockets/open.html"); |
| } |
| |
| GURL GetTestPageURL() { |
| return embedded_test_server()->GetURL("/direct_sockets/tcp.html"); |
| } |
| |
| network::mojom::NetworkContext* GetNetworkContext() { |
| return browser_context()->GetDefaultStoragePartition()->GetNetworkContext(); |
| } |
| |
| std::string CreateMDNSHostName() { |
| DCHECK(!mdns_responder_.is_bound()); |
| GetNetworkContext()->CreateMdnsResponder( |
| mdns_responder_.BindNewPipeAndPassReceiver()); |
| |
| std::string name; |
| base::RunLoop run_loop; |
| mdns_responder_->CreateNameForAddress( |
| net::IPAddress::IPv4Localhost(), |
| base::BindLambdaForTesting( |
| [&name, &run_loop](const std::string& name_out, |
| bool announcement_scheduled) { |
| name = name_out; |
| run_loop.Quit(); |
| })); |
| run_loop.Run(); |
| return name; |
| } |
| |
| // Returns the port listening for TCP connections. |
| uint16_t StartTcpServer() { |
| net::IPEndPoint local_addr; |
| base::RunLoop run_loop; |
| GetNetworkContext()->CreateTCPServerSocket( |
| net::IPEndPoint(net::IPAddress::IPv4Localhost(), |
| /*port=*/0), |
| /*backlog=*/5, |
| net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), |
| tcp_server_socket_.BindNewPipeAndPassReceiver(), |
| base::BindLambdaForTesting( |
| [&local_addr, &run_loop]( |
| int32_t result, |
| const base::Optional<net::IPEndPoint>& local_addr_out) { |
| DCHECK_EQ(result, net::OK); |
| DCHECK(local_addr_out.has_value()); |
| local_addr = *local_addr_out; |
| run_loop.Quit(); |
| })); |
| run_loop.Run(); |
| return local_addr.port(); |
| } |
| |
| mojo::Remote<network::mojom::TCPServerSocket>& tcp_server_socket() { |
| return tcp_server_socket_; |
| } |
| |
| protected: |
| void SetUp() override { |
| DirectSocketsServiceImpl::SetEnterpriseManagedForTesting(false); |
| |
| embedded_test_server()->AddDefaultHandlers(GetTestDataFilePath()); |
| ASSERT_TRUE(embedded_test_server()->Start()); |
| |
| ContentBrowserTest::SetUp(); |
| } |
| |
| private: |
| BrowserContext* browser_context() { |
| return shell()->web_contents()->GetBrowserContext(); |
| } |
| |
| base::test::ScopedFeatureList feature_list_; |
| mojo::Remote<network::mojom::MdnsResponder> mdns_responder_; |
| mojo::Remote<network::mojom::TCPServerSocket> tcp_server_socket_; |
| }; |
| |
| IN_PROC_BROWSER_TEST_F(DirectSocketsTcpBrowserTest, OpenTcp_Success) { |
| EXPECT_TRUE(NavigateToURL(shell(), GetTestOpenPageURL())); |
| |
| DirectSocketsServiceImpl::SetPermissionCallbackForTesting( |
| base::BindRepeating(&UnconditionallyPermitConnection)); |
| |
| const uint16_t listening_port = StartTcpServer(); |
| const std::string script = base::StringPrintf( |
| "openTcp({remoteAddress: '127.0.0.1', remotePort: %d})", listening_port); |
| |
| EXPECT_THAT(EvalJs(shell(), script).ExtractString(), |
| StartsWith("openTcp succeeded")); |
| } |
| |
| IN_PROC_BROWSER_TEST_F(DirectSocketsTcpBrowserTest, OpenTcp_Success_Global) { |
| EXPECT_TRUE(NavigateToURL(shell(), GetTestOpenPageURL())); |
| |
| const uint16_t listening_port = StartTcpServer(); |
| const std::string script = base::StringPrintf( |
| "openTcp({remoteAddress: '127.0.0.1', remotePort: %d})", listening_port); |
| |
| EXPECT_THAT(EvalJs(shell(), script).ExtractString(), |
| StartsWith("openTcp succeeded")); |
| } |
| |
| IN_PROC_BROWSER_TEST_F(DirectSocketsTcpBrowserTest, OpenTcp_MDNS) { |
| EXPECT_TRUE(NavigateToURL(shell(), GetTestOpenPageURL())); |
| |
| const uint16_t listening_port = StartTcpServer(); |
| const std::string name = CreateMDNSHostName(); |
| EXPECT_TRUE(base::EndsWith(name, ".local")); |
| |
| const std::string script = |
| base::StringPrintf("openTcp({remoteAddress: '%s', remotePort: %d})", |
| name.c_str(), listening_port); |
| |
| #if BUILDFLAG(ENABLE_MDNS) |
| EXPECT_THAT(EvalJs(shell(), script).ExtractString(), |
| StartsWith("openTcp succeeded")); |
| #else |
| EXPECT_EQ("openTcp failed: NotAllowedError: Permission denied", |
| EvalJs(shell(), script)); |
| #endif // BUILDFLAG(ENABLE_MDNS) |
| } |
| |
| IN_PROC_BROWSER_TEST_F(DirectSocketsTcpBrowserTest, CloseTcp) { |
| EXPECT_TRUE(NavigateToURL(shell(), GetTestPageURL())); |
| |
| DirectSocketsServiceImpl::SetPermissionCallbackForTesting( |
| base::BindRepeating(&UnconditionallyPermitConnection)); |
| |
| const uint16_t listening_port = StartTcpServer(); |
| const std::string script = base::StringPrintf( |
| "closeTcp({remoteAddress: '127.0.0.1', remotePort: %d})", listening_port); |
| |
| EXPECT_EQ("closeTcp succeeded", EvalJs(shell(), script)); |
| } |
| |
| // Tests that we can close the writer, then the socket. |
| IN_PROC_BROWSER_TEST_F(DirectSocketsTcpBrowserTest, CloseTcpWriter) { |
| EXPECT_TRUE(NavigateToURL(shell(), GetTestPageURL())); |
| |
| DirectSocketsServiceImpl::SetPermissionCallbackForTesting( |
| base::BindRepeating(&UnconditionallyPermitConnection)); |
| |
| const uint16_t listening_port = StartTcpServer(); |
| const std::string script = base::StringPrintf( |
| "closeTcp({remoteAddress: '127.0.0.1', remotePort: %d}, " |
| "/*closeWriter=*/true)", |
| listening_port); |
| |
| EXPECT_EQ("closeTcp succeeded", EvalJs(shell(), script)); |
| } |
| |
| IN_PROC_BROWSER_TEST_F(DirectSocketsTcpBrowserTest, WriteTcp) { |
| const uint32_t kRequiredBytes = 10000; |
| EXPECT_TRUE(NavigateToURL(shell(), GetTestPageURL())); |
| |
| const uint16_t listening_port = StartTcpServer(); |
| ReadWriteWaiter waiter(/*required_receive_bytes=*/kRequiredBytes, |
| /*required_send_bytes=*/0, tcp_server_socket()); |
| |
| const std::string script = base::StringPrintf( |
| "writeTcp({remoteAddress: '127.0.0.1', remotePort: %d}, %u)", |
| listening_port, kRequiredBytes); |
| EXPECT_EQ("write succeeded", EvalJs(shell(), script)); |
| waiter.Await(); |
| } |
| |
| IN_PROC_BROWSER_TEST_F(DirectSocketsTcpBrowserTest, ReadTcp) { |
| const uint32_t kRequiredBytes = 150000; |
| EXPECT_TRUE(NavigateToURL(shell(), GetTestPageURL())); |
| |
| const uint16_t listening_port = StartTcpServer(); |
| ReadWriteWaiter waiter(/*required_receive_bytes=*/0, |
| /*required_send_bytes=*/kRequiredBytes, |
| tcp_server_socket()); |
| |
| const std::string script = base::StringPrintf( |
| "readTcp({remoteAddress: '127.0.0.1', remotePort: %d}, %u)", |
| listening_port, kRequiredBytes); |
| EXPECT_EQ("read succeeded", EvalJs(shell(), script)); |
| waiter.Await(); |
| } |
| |
| IN_PROC_BROWSER_TEST_F(DirectSocketsTcpBrowserTest, ReadWriteTcp) { |
| const uint32_t kRequiredBytes = 1000; |
| EXPECT_TRUE(NavigateToURL(shell(), GetTestPageURL())); |
| |
| const uint16_t listening_port = StartTcpServer(); |
| ReadWriteWaiter waiter(/*required_receive_bytes=*/kRequiredBytes, |
| /*required_send_bytes=*/kRequiredBytes, |
| tcp_server_socket()); |
| |
| const std::string script = base::StringPrintf( |
| "readWriteTcp({remoteAddress: '127.0.0.1', remotePort: %d}, %u)", |
| listening_port, kRequiredBytes); |
| EXPECT_EQ("readWrite succeeded", EvalJs(shell(), script)); |
| waiter.Await(); |
| } |
| |
| } // namespace content |