blob: 3b0305ea86d8baa42c6b2ae804ebb0dc3130554a [file] [log] [blame]
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "base/at_exit.h"
#include "base/macros.h"
#include "base/memory/scoped_ptr.h"
#include "mojo/public/cpp/bindings/callback.h"
#include "mojo/services/public/cpp/network/udp_socket_wrapper.h"
#include "mojo/services/public/interfaces/network/network_service.mojom.h"
#include "mojo/services/public/interfaces/network/udp_socket.mojom.h"
#include "mojo/shell/shell_test_helper.h"
#include "net/base/net_errors.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/gurl.h"
namespace mojo {
namespace service {
namespace {
NetAddressPtr GetLocalHostWithAnyPort() {
NetAddressPtr addr(NetAddress::New());
addr->family = NET_ADDRESS_FAMILY_IPV4;
addr->ipv4 = NetAddressIPv4::New();
addr->ipv4->port = 0;
addr->ipv4->addr.resize(4);
addr->ipv4->addr[0] = 127;
addr->ipv4->addr[1] = 0;
addr->ipv4->addr[2] = 0;
addr->ipv4->addr[3] = 1;
return addr.Pass();
}
Array<uint8_t> CreateTestMessage(uint8_t initial, size_t size) {
Array<uint8_t> array(size);
for (size_t i = 0; i < size; ++i)
array[i] = static_cast<uint8_t>((i + initial) % 256);
return array.Pass();
}
template <typename CallbackType>
class TestCallbackBase {
public:
TestCallbackBase() : state_(nullptr), run_loop_(nullptr), ran_(false) {}
~TestCallbackBase() {
state_->set_test_callback(nullptr);
}
CallbackType callback() const { return callback_; }
void WaitForResult() {
if (ran_)
return;
base::RunLoop run_loop;
run_loop_ = &run_loop;
run_loop.Run();
run_loop_ = nullptr;
}
protected:
struct StateBase : public CallbackType::Runnable {
StateBase() : test_callback_(nullptr) {}
virtual ~StateBase() {}
void set_test_callback(TestCallbackBase* test_callback) {
test_callback_ = test_callback;
}
protected:
void NotifyRun() const {
if (test_callback_) {
test_callback_->ran_ = true;
if (test_callback_->run_loop_)
test_callback_->run_loop_->Quit();
}
}
TestCallbackBase* test_callback_;
private:
DISALLOW_COPY_AND_ASSIGN(StateBase);
};
// Takes ownership of |state|, and guarantees that it lives at least as long
// as this object.
void Initialize(StateBase* state) {
state_ = state;
state_->set_test_callback(this);
callback_ = CallbackType(
static_cast<typename CallbackType::Runnable*>(state_));
}
private:
// The lifespan is managed by |callback_| (and its copies).
StateBase* state_;
CallbackType callback_;
base::RunLoop* run_loop_;
bool ran_;
DISALLOW_COPY_AND_ASSIGN(TestCallbackBase);
};
class TestCallback : public TestCallbackBase<Callback<void(NetworkErrorPtr)>> {
public:
TestCallback() {
Initialize(new State());
}
~TestCallback() {}
const NetworkErrorPtr& result() const { return result_; }
private:
struct State: public StateBase {
~State() override {}
void Run(NetworkErrorPtr result) const override {
if (test_callback_) {
TestCallback* callback = static_cast<TestCallback*>(test_callback_);
callback->result_ = result.Pass();
}
NotifyRun();
}
};
NetworkErrorPtr result_;
};
class TestCallbackWithAddress
: public TestCallbackBase<Callback<void(NetworkErrorPtr, NetAddressPtr)>> {
public:
TestCallbackWithAddress() {
Initialize(new State());
}
~TestCallbackWithAddress() {}
const NetworkErrorPtr& result() const { return result_; }
const NetAddressPtr& net_address() const { return net_address_; }
private:
struct State : public StateBase {
~State() override {}
void Run(NetworkErrorPtr result, NetAddressPtr net_address) const override {
if (test_callback_) {
TestCallbackWithAddress* callback =
static_cast<TestCallbackWithAddress*>(test_callback_);
callback->result_ = result.Pass();
callback->net_address_ = net_address.Pass();
}
NotifyRun();
}
};
NetworkErrorPtr result_;
NetAddressPtr net_address_;
};
class TestCallbackWithUint32
: public TestCallbackBase<Callback<void(uint32_t)>> {
public:
TestCallbackWithUint32() : result_(0) {
Initialize(new State());
}
~TestCallbackWithUint32() {}
uint32_t result() const { return result_; }
private:
struct State : public StateBase {
~State() override {}
void Run(uint32_t result) const override {
if (test_callback_) {
TestCallbackWithUint32* callback =
static_cast<TestCallbackWithUint32*>(test_callback_);
callback->result_ = result;
}
NotifyRun();
}
};
uint32_t result_;
};
class TestReceiveCallback
: public TestCallbackBase<
Callback<void(NetworkErrorPtr, NetAddressPtr, Array<uint8_t>)>> {
public:
TestReceiveCallback() {
Initialize(new State());
}
~TestReceiveCallback() {}
const NetworkErrorPtr& result() const { return result_; }
const NetAddressPtr& src_addr() const { return src_addr_; }
const Array<uint8_t>& data() const { return data_; }
private:
struct State : public StateBase {
~State() override {}
void Run(NetworkErrorPtr result,
NetAddressPtr src_addr,
Array<uint8_t> data) const override {
if (test_callback_) {
TestReceiveCallback* callback =
static_cast<TestReceiveCallback*>(test_callback_);
callback->result_ = result.Pass();
callback->src_addr_ = src_addr.Pass();
callback->data_ = data.Pass();
}
NotifyRun();
}
};
NetworkErrorPtr result_;
NetAddressPtr src_addr_;
Array<uint8_t> data_;
};
class UDPSocketTest : public testing::Test {
public:
UDPSocketTest() {}
virtual ~UDPSocketTest() {}
virtual void SetUp() override {
test_helper_.Init();
test_helper_.application_manager()->ConnectToService(
GURL("mojo:network_service"), &network_service_);
network_service_->CreateUDPSocket(GetProxy(&udp_socket_));
udp_socket_.set_client(&udp_socket_client_);
}
protected:
struct ReceiveResult {
NetworkErrorPtr result;
NetAddressPtr addr;
Array<uint8_t> data;
};
class UDPSocketClientImpl : public UDPSocketClient {
public:
UDPSocketClientImpl() : run_loop_(nullptr), expected_receive_count_(0) {}
~UDPSocketClientImpl() override {
while (!results_.empty()) {
delete results_.front();
results_.pop();
}
}
void OnReceived(NetworkErrorPtr result,
NetAddressPtr src_addr,
Array<uint8_t> data) override {
ReceiveResult* entry = new ReceiveResult();
entry->result = result.Pass();
entry->addr = src_addr.Pass();
entry->data = data.Pass();
results_.push(entry);
if (results_.size() == expected_receive_count_ && run_loop_) {
expected_receive_count_ = 0;
run_loop_->Quit();
}
}
base::RunLoop* run_loop_;
std::queue<ReceiveResult*> results_;
size_t expected_receive_count_;
DISALLOW_COPY_AND_ASSIGN(UDPSocketClientImpl);
};
std::queue<ReceiveResult*>* GetReceiveResults() {
return &udp_socket_client_.results_;
}
void WaitForReceiveResults(size_t count) {
if (GetReceiveResults()->size() == count)
return;
udp_socket_client_.expected_receive_count_ = count;
base::RunLoop run_loop;
udp_socket_client_.run_loop_ = &run_loop;
run_loop.Run();
udp_socket_client_.run_loop_ = nullptr;
}
base::ShadowingAtExitManager at_exit_;
shell::ShellTestHelper test_helper_;
NetworkServicePtr network_service_;
UDPSocketPtr udp_socket_;
UDPSocketClientImpl udp_socket_client_;
DISALLOW_COPY_AND_ASSIGN(UDPSocketTest);
};
} // namespace
TEST_F(UDPSocketTest, Settings) {
TestCallback callback1;
udp_socket_->AllowAddressReuse(callback1.callback());
callback1.WaitForResult();
EXPECT_EQ(net::OK, callback1.result()->code);
// Should fail because the socket hasn't been bound.
TestCallback callback2;
udp_socket_->SetSendBufferSize(1024, callback2.callback());
callback2.WaitForResult();
EXPECT_NE(net::OK, callback2.result()->code);
// Should fail because the socket hasn't been bound.
TestCallback callback3;
udp_socket_->SetReceiveBufferSize(2048, callback3.callback());
callback3.WaitForResult();
EXPECT_NE(net::OK, callback3.result()->code);
TestCallbackWithAddress callback4;
udp_socket_->Bind(GetLocalHostWithAnyPort(), callback4.callback());
callback4.WaitForResult();
EXPECT_EQ(net::OK, callback4.result()->code);
EXPECT_NE(0u, callback4.net_address()->ipv4->port);
// Should fail because the socket has been bound.
TestCallback callback5;
udp_socket_->AllowAddressReuse(callback5.callback());
callback5.WaitForResult();
EXPECT_NE(net::OK, callback5.result()->code);
TestCallback callback6;
udp_socket_->SetSendBufferSize(1024, callback6.callback());
callback6.WaitForResult();
EXPECT_EQ(net::OK, callback6.result()->code);
TestCallback callback7;
udp_socket_->SetReceiveBufferSize(2048, callback7.callback());
callback7.WaitForResult();
EXPECT_EQ(net::OK, callback7.result()->code);
TestCallbackWithUint32 callback8;
udp_socket_->NegotiateMaxPendingSendRequests(0, callback8.callback());
callback8.WaitForResult();
EXPECT_GT(callback8.result(), 0u);
TestCallbackWithUint32 callback9;
udp_socket_->NegotiateMaxPendingSendRequests(16, callback9.callback());
callback9.WaitForResult();
EXPECT_GT(callback9.result(), 0u);
}
TEST_F(UDPSocketTest, TestReadWrite) {
TestCallbackWithAddress callback1;
udp_socket_->Bind(GetLocalHostWithAnyPort(), callback1.callback());
callback1.WaitForResult();
ASSERT_EQ(net::OK, callback1.result()->code);
ASSERT_NE(0u, callback1.net_address()->ipv4->port);
NetAddressPtr server_addr = callback1.net_address().Clone();
UDPSocketPtr client_socket;
network_service_->CreateUDPSocket(GetProxy(&client_socket));
TestCallbackWithAddress callback2;
client_socket->Bind(GetLocalHostWithAnyPort(), callback2.callback());
callback2.WaitForResult();
ASSERT_EQ(net::OK, callback2.result()->code);
ASSERT_NE(0u, callback2.net_address()->ipv4->port);
NetAddressPtr client_addr = callback2.net_address().Clone();
const size_t kDatagramCount = 6;
const size_t kDatagramSize = 255;
udp_socket_->ReceiveMore(kDatagramCount);
for (size_t i = 0; i < kDatagramCount; ++i) {
TestCallback callback;
client_socket->SendTo(
server_addr.Clone(),
CreateTestMessage(static_cast<uint8_t>(i), kDatagramSize),
callback.callback());
callback.WaitForResult();
EXPECT_EQ(255, callback.result()->code);
}
WaitForReceiveResults(kDatagramCount);
for (size_t i = 0; i < kDatagramCount; ++i) {
scoped_ptr<ReceiveResult> result(GetReceiveResults()->front());
GetReceiveResults()->pop();
EXPECT_EQ(static_cast<int>(kDatagramSize), result->result->code);
EXPECT_TRUE(result->addr.Equals(client_addr));
EXPECT_TRUE(result->data.Equals(
CreateTestMessage(static_cast<uint8_t>(i), kDatagramSize)));
}
}
TEST_F(UDPSocketTest, TestUDPSocketWrapper) {
UDPSocketWrapper udp_socket(udp_socket_.Pass(), 4, 4);
TestCallbackWithAddress callback1;
udp_socket.Bind(GetLocalHostWithAnyPort(), callback1.callback());
callback1.WaitForResult();
ASSERT_EQ(net::OK, callback1.result()->code);
ASSERT_NE(0u, callback1.net_address()->ipv4->port);
NetAddressPtr server_addr = callback1.net_address().Clone();
UDPSocketPtr raw_client_socket;
network_service_->CreateUDPSocket(GetProxy(&raw_client_socket));
UDPSocketWrapper client_socket(raw_client_socket.Pass(), 4, 4);
TestCallbackWithAddress callback2;
client_socket.Bind(GetLocalHostWithAnyPort(), callback2.callback());
callback2.WaitForResult();
ASSERT_EQ(net::OK, callback2.result()->code);
ASSERT_NE(0u, callback2.net_address()->ipv4->port);
NetAddressPtr client_addr = callback2.net_address().Clone();
const size_t kDatagramCount = 16;
const size_t kDatagramSize = 255;
for (size_t i = 1; i < kDatagramCount; ++i) {
scoped_ptr<TestCallback[]> send_callbacks(new TestCallback[i]);
scoped_ptr<TestReceiveCallback[]> receive_callbacks(
new TestReceiveCallback[i]);
for (size_t j = 0; j < i; ++j) {
client_socket.SendTo(
server_addr.Clone(),
CreateTestMessage(static_cast<uint8_t>(j), kDatagramSize),
send_callbacks[j].callback());
udp_socket.ReceiveFrom(receive_callbacks[j].callback());
}
receive_callbacks[i - 1].WaitForResult();
for (size_t j = 0; j < i; ++j) {
EXPECT_EQ(static_cast<int>(kDatagramSize),
receive_callbacks[j].result()->code);
EXPECT_TRUE(receive_callbacks[j].src_addr().Equals(client_addr));
EXPECT_TRUE(receive_callbacks[j].data().Equals(
CreateTestMessage(static_cast<uint8_t>(j), kDatagramSize)));
}
}
}
} // namespace service
} // namespace mojo