| // Copyright 2013 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <sys/socket.h> |
| |
| #include "base/bind.h" |
| #include "base/files/file_path.h" |
| #include "base/path_service.h" |
| #include "base/posix/eintr_wrapper.h" |
| #include "base/synchronization/waitable_event.h" |
| #include "base/threading/thread.h" |
| #include "base/threading/thread_restrictions.h" |
| #include "ipc/unix_domain_socket_util.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace { |
| |
| class SocketAcceptor : public base::MessageLoopForIO::Watcher { |
| public: |
| SocketAcceptor(int fd, base::MessageLoopProxy* target_thread) |
| : server_fd_(-1), |
| target_thread_(target_thread), |
| started_watching_event_(false, false), |
| accepted_event_(false, false) { |
| target_thread->PostTask(FROM_HERE, |
| base::Bind(&SocketAcceptor::StartWatching, base::Unretained(this), fd)); |
| } |
| |
| virtual ~SocketAcceptor() { |
| Close(); |
| } |
| |
| int server_fd() const { return server_fd_; } |
| |
| void WaitUntilReady() { |
| started_watching_event_.Wait(); |
| } |
| |
| void WaitForAccept() { |
| accepted_event_.Wait(); |
| } |
| |
| void Close() { |
| if (watcher_.get()) { |
| target_thread_->PostTask(FROM_HERE, |
| base::Bind(&SocketAcceptor::StopWatching, base::Unretained(this), |
| watcher_.release())); |
| } |
| } |
| |
| private: |
| void StartWatching(int fd) { |
| watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher); |
| base::MessageLoopForIO::current()->WatchFileDescriptor( |
| fd, true, base::MessageLoopForIO::WATCH_READ, watcher_.get(), this); |
| started_watching_event_.Signal(); |
| } |
| void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher* watcher) { |
| watcher->StopWatchingFileDescriptor(); |
| delete watcher; |
| } |
| virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE { |
| ASSERT_EQ(-1, server_fd_); |
| IPC::ServerAcceptConnection(fd, &server_fd_); |
| watcher_->StopWatchingFileDescriptor(); |
| accepted_event_.Signal(); |
| } |
| virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {} |
| |
| int server_fd_; |
| base::MessageLoopProxy* target_thread_; |
| scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> watcher_; |
| base::WaitableEvent started_watching_event_; |
| base::WaitableEvent accepted_event_; |
| |
| DISALLOW_COPY_AND_ASSIGN(SocketAcceptor); |
| }; |
| |
| const base::FilePath GetChannelDir() { |
| #if defined(OS_ANDROID) |
| base::FilePath tmp_dir; |
| PathService::Get(base::DIR_CACHE, &tmp_dir); |
| return tmp_dir; |
| #else |
| return base::FilePath("/var/tmp"); |
| #endif |
| } |
| |
| class TestUnixSocketConnection { |
| public: |
| TestUnixSocketConnection() |
| : worker_("WorkerThread"), |
| server_listen_fd_(-1), |
| server_fd_(-1), |
| client_fd_(-1) { |
| socket_name_ = GetChannelDir().Append("TestSocket"); |
| base::Thread::Options options; |
| options.message_loop_type = base::MessageLoop::TYPE_IO; |
| worker_.StartWithOptions(options); |
| } |
| |
| bool CreateServerSocket() { |
| IPC::CreateServerUnixDomainSocket(socket_name_, &server_listen_fd_); |
| if (server_listen_fd_ < 0) |
| return false; |
| struct stat socket_stat; |
| stat(socket_name_.value().c_str(), &socket_stat); |
| EXPECT_TRUE(S_ISSOCK(socket_stat.st_mode)); |
| acceptor_.reset(new SocketAcceptor(server_listen_fd_, |
| worker_.message_loop_proxy().get())); |
| acceptor_->WaitUntilReady(); |
| return true; |
| } |
| |
| bool CreateClientSocket() { |
| DCHECK(server_listen_fd_ >= 0); |
| IPC::CreateClientUnixDomainSocket(socket_name_, &client_fd_); |
| if (client_fd_ < 0) |
| return false; |
| acceptor_->WaitForAccept(); |
| server_fd_ = acceptor_->server_fd(); |
| return server_fd_ >= 0; |
| } |
| |
| virtual ~TestUnixSocketConnection() { |
| if (client_fd_ >= 0) |
| close(client_fd_); |
| if (server_fd_ >= 0) |
| close(server_fd_); |
| if (server_listen_fd_ >= 0) { |
| close(server_listen_fd_); |
| unlink(socket_name_.value().c_str()); |
| } |
| } |
| |
| int client_fd() const { return client_fd_; } |
| int server_fd() const { return server_fd_; } |
| |
| private: |
| base::Thread worker_; |
| base::FilePath socket_name_; |
| int server_listen_fd_; |
| int server_fd_; |
| int client_fd_; |
| scoped_ptr<SocketAcceptor> acceptor_; |
| }; |
| |
| // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that |
| // IPC::CreateClientUnixDomainSocket can successfully connect to. |
| TEST(UnixDomainSocketUtil, Connect) { |
| TestUnixSocketConnection connection; |
| ASSERT_TRUE(connection.CreateServerSocket()); |
| ASSERT_TRUE(connection.CreateClientSocket()); |
| } |
| |
| // Ensure that messages can be sent across the resulting socket. |
| TEST(UnixDomainSocketUtil, SendReceive) { |
| TestUnixSocketConnection connection; |
| ASSERT_TRUE(connection.CreateServerSocket()); |
| ASSERT_TRUE(connection.CreateClientSocket()); |
| |
| const char buffer[] = "Hello, server!"; |
| size_t buf_len = sizeof(buffer); |
| size_t sent_bytes = |
| HANDLE_EINTR(send(connection.client_fd(), buffer, buf_len, 0)); |
| ASSERT_EQ(buf_len, sent_bytes); |
| char recv_buf[sizeof(buffer)]; |
| size_t received_bytes = |
| HANDLE_EINTR(recv(connection.server_fd(), recv_buf, buf_len, 0)); |
| ASSERT_EQ(buf_len, received_bytes); |
| ASSERT_EQ(0, memcmp(recv_buf, buffer, buf_len)); |
| } |
| |
| } // namespace |