blob: 50650fc55b68a960df9b0ef5d6fdcd3a0b226de2 [file] [log] [blame]
// Copyright 2018 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "shill/http_request.h"
#include <netinet/in.h>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include <base/functional/bind.h>
#include <base/memory/weak_ptr.h>
#include <base/strings/string_number_conversions.h>
#include <base/strings/stringprintf.h>
#include <base/time/time.h>
#include <base/types/expected.h>
#include <brillo/http/mock_connection.h>
#include <brillo/http/mock_transport.h>
#include <brillo/mime_utils.h>
#include <brillo/streams/mock_stream.h>
#include <curl/curl.h>
#include <gtest/gtest.h>
#include <net-base/dns_client.h>
#include <net-base/http_url.h>
#include <net-base/ip_address.h>
#include "shill/mock_control.h"
#include "shill/mock_manager.h"
#include "shill/test_event_dispatcher.h"
using ::testing::_;
using ::testing::AtLeast;
using ::testing::DoAll;
using ::testing::Eq;
using ::testing::Invoke;
using ::testing::Mock;
using ::testing::NiceMock;
using ::testing::Return;
using ::testing::ReturnRef;
using ::testing::SetArgPointee;
using ::testing::StrEq;
using ::testing::StrictMock;
using ::testing::Test;
using ::testing::Unused;
using ::testing::WithArg;
namespace shill {
namespace {
constexpr std::string_view kTextSiteName = "www.chromium.org";
constexpr std::string_view kTextURL =
"http://www.chromium.org/path/to/resource";
constexpr std::string_view kIPv4AddressURL = "http://10.1.1.1";
constexpr std::string_view kIPv6AddressURL = "http://[2001:db8::1]/example";
constexpr std::string_view kInterfaceName = "int0";
constexpr std::string_view kLoggingTag = "int0 IPv4 attempt=1";
constexpr net_base::IPAddress kIPv4DNS0(net_base::IPv4Address(8, 8, 8, 8));
constexpr net_base::IPAddress kIPv4DNS1(net_base::IPv4Address(8, 8, 4, 4));
// clang-format off
constexpr net_base::IPAddress kIPv6DNS0(net_base::IPv6Address(0x20, 0x01,
0x48, 0x60,
0x48, 0x60,
0x00, 0x00,
0x00, 0x00,
0x00, 0x00,
0x00, 0x00,
0x88, 0x88));
constexpr net_base::IPAddress kIPv6DNS1(net_base::IPv6Address(0x20, 0x01,
0x48, 0x60,
0x48, 0x60,
0x00, 0x00,
0x00, 0x00,
0x00, 0x00,
0x00, 0x00,
0x88, 0x44));
// clang-format on
} // namespace
MATCHER_P(ByteStringMatches, byte_string, "") {
return byte_string.Equals(arg);
}
MATCHER_P(CallbackEq, callback, "") {
return arg.Equals(callback);
}
// The fake client doesn't need to do anything. WeakPtrFactory is for querying
// whether the object is still valid in the test.
class FakeDNSClient : public net_base::DNSClient {
public:
FakeDNSClient(std::string_view hostname,
std::optional<net_base::IPAddress> dns)
: hostname_(hostname), dns_(dns) {}
base::WeakPtr<FakeDNSClient> AsWeakPtr() {
return weak_factory_.GetWeakPtr();
}
std::string_view hostname() { return hostname_; }
std::optional<net_base::IPAddress> dns() { return dns_; }
private:
std::string hostname_;
std::optional<net_base::IPAddress> dns_;
base::WeakPtrFactory<FakeDNSClient> weak_factory_{this};
};
class FakeDNSClientFactory : public net_base::DNSClientFactory {
public:
using Callbacks = std::vector<net_base::DNSClient::Callback>;
FakeDNSClientFactory() {
ON_CALL(*this, Resolve)
.WillByDefault([this](
net_base::IPFamily family, std::string_view hostname,
net_base::DNSClient::CallbackWithDuration callback,
const net_base::DNSClient::Options& options,
net_base::AresInterface* ares_interface) {
callbacks_.emplace_back(std::move(callback));
auto ret =
std::make_unique<FakeDNSClient>(hostname, options.name_server);
clients_.push_back(ret->AsWeakPtr());
return ret;
});
}
MOCK_METHOD(std::unique_ptr<net_base::DNSClient>,
Resolve,
(net_base::IPFamily family,
std::string_view hostname,
net_base::DNSClient::CallbackWithDuration callback,
const net_base::DNSClient::Options& options,
net_base::AresInterface* ares_interface),
(override));
void TriggerCallback(const net_base::DNSClient::Result& result) {
std::move(callbacks_.back()).Run(base::Milliseconds(100), result);
callbacks_.pop_back();
}
std::vector<base::WeakPtr<FakeDNSClient>> GetWeakPtrsToExistingClients()
const {
return clients_;
}
private:
std::vector<net_base::DNSClient::CallbackWithDuration> callbacks_;
std::vector<base::WeakPtr<FakeDNSClient>> clients_;
};
class HttpRequestTest : public Test {
public:
HttpRequestTest()
: transport_(std::make_shared<brillo::http::MockTransport>()),
brillo_connection_(
std::make_shared<brillo::http::MockConnection>(transport_)),
manager_(&control_, &dispatcher_, nullptr) {}
protected:
class CallbackTarget {
public:
MOCK_METHOD(void, RequestCallback, (HttpRequest::Result result));
base::OnceCallback<void(HttpRequest::Result result)> callback() {
return base::BindOnce(&CallbackTarget::RequestCallback,
base::Unretained(this));
}
};
HttpRequest* request() { return request_.get(); }
std::shared_ptr<brillo::http::MockTransport> transport() {
return transport_;
}
// Expectations
void ExpectStopped() {
EXPECT_FALSE(request_->is_running());
if (dns_client_factory_) {
for (const auto& client :
dns_client_factory_->GetWeakPtrsToExistingClients()) {
EXPECT_TRUE(client.WasInvalidated());
}
}
}
void VerifyDNSRequests(std::string_view hostname,
const std::vector<net_base::IPAddress>& dns_list) {
ASSERT_NE(nullptr, dns_client_factory_);
auto dns_clients = dns_client_factory_->GetWeakPtrsToExistingClients();
auto dns_list_copy = dns_list;
EXPECT_EQ(dns_list_copy.size(), dns_clients.size());
for (const auto& client : dns_clients) {
ASSERT_NE(nullptr, client);
EXPECT_EQ(hostname, client->hostname());
ASSERT_TRUE(client->dns().has_value());
auto it =
std::find(dns_list_copy.begin(), dns_list_copy.end(), *client->dns());
EXPECT_NE(it, dns_list_copy.end());
dns_list_copy.erase(it);
}
EXPECT_TRUE(dns_list_copy.empty());
}
void ExpectRequestErrorCallback(HttpRequest::Error error) {
EXPECT_CALL(target_, RequestCallback(Eq(base::unexpected(error))));
}
void InvokeResultVerifySuccess(HttpRequest::Result result) {
ASSERT_TRUE(result.has_value());
std::unique_ptr<brillo::http::Response> response = std::move(*result);
EXPECT_CALL(*brillo_connection_, GetResponseStatusCode())
.WillOnce(Return(brillo::http::status_code::Partial));
EXPECT_EQ(brillo::http::status_code::Partial, response->GetStatusCode());
EXPECT_CALL(*brillo_connection_, GetResponseStatusText())
.WillOnce(Return("Partial completion"));
EXPECT_EQ("Partial completion", response->GetStatusText());
EXPECT_CALL(*brillo_connection_,
GetResponseHeader(brillo::http::response_header::kContentType))
.WillOnce(Return(brillo::mime::text::kHtml));
EXPECT_EQ(brillo::mime::text::kHtml, response->GetContentType());
EXPECT_EQ(expected_response_, response->ExtractDataAsString());
}
void ExpectRequestSuccessCallback(const std::string& resp_data) {
expected_response_ = resp_data;
EXPECT_CALL(target_, RequestCallback)
.WillOnce(Invoke(this, &HttpRequestTest::InvokeResultVerifySuccess));
}
void CreateRequest(std::string_view interface_name,
net_base::IPFamily ip_family,
const std::vector<net_base::IPAddress>& dns_list) {
dns_client_factory_ = new FakeDNSClientFactory();
EXPECT_CALL(*transport_, SetInterface(Eq(interface_name)));
EXPECT_CALL(*transport_, UseCustomCertificate).Times(0);
request_.reset(new HttpRequest(&dispatcher_, interface_name, ip_family,
dns_list, false, transport_,
base::WrapUnique(dns_client_factory_)));
}
void GetDNSResultFailure(net_base::IPAddress dns,
net_base::DNSClient::Error error) {
request_->GetDNSResult(dns, base::Milliseconds(100),
base::unexpected(error));
}
void GetDNSResultSuccess(net_base::IPAddress dns,
const std::vector<net_base::IPAddress>& addresses) {
request_->GetDNSResult(dns, base::Milliseconds(100), addresses);
}
void StartRequest(std::string_view url_string) {
auto url = net_base::HttpUrl::CreateFromString(url_string);
ASSERT_TRUE(url.has_value());
request_->Start(kLoggingTag, *url, {}, target_.callback());
}
void ExpectCreateConnection(std::string_view url) {
EXPECT_CALL(*transport_,
CreateConnection(Eq(url), brillo::http::request_type::kGet, _,
"", "", _))
.WillOnce(Return(brillo_connection_));
}
void FinishRequestAsyncSuccess(
brillo::http::SuccessCallback success_callback) {
auto read_data = [this](void* buffer, Unused, size_t* read,
Unused) -> bool {
memcpy(buffer, resp_data_.data(), resp_data_.size());
*read = resp_data_.size();
return true;
};
auto mock_stream = std::make_unique<brillo::MockStream>();
EXPECT_CALL(*mock_stream, ReadBlocking(_, _, _, _))
.WillOnce(Invoke(read_data))
.WillOnce(DoAll(SetArgPointee<2>(0), Return(true)));
EXPECT_CALL(*brillo_connection_, MockExtractDataStream(_))
.WillOnce(Return(mock_stream.release()));
auto resp = std::make_unique<brillo::http::Response>(brillo_connection_);
// request_id_ has not yet been set. Pass -1 to match default value.
std::move(success_callback).Run(-1, std::move(resp));
}
void FinishRequestAsyncFail(brillo::http::ErrorCallback error_callback) {
brillo::ErrorPtr error;
brillo::Error::AddTo(&error, FROM_HERE, "curl_easy_error",
base::NumberToString(CURLE_COULDNT_CONNECT), "");
// request_id_ has not yet been set. Pass -1 to match default value.
std::move(error_callback).Run(-1, error.get());
}
void ExpectFinishRequestAsyncSuccess(const std::string& resp_data) {
resp_data_ = resp_data;
EXPECT_CALL(*brillo_connection_, FinishRequestAsync(_, _))
.WillOnce(WithArg<0>([this](auto callback) {
FinishRequestAsyncSuccess(std::move(callback));
return 0;
}));
}
void ExpectFinishRequestAsyncFail() {
EXPECT_CALL(*brillo_connection_, FinishRequestAsync(_, _))
.WillOnce(WithArg<1>([this](auto callback) {
FinishRequestAsyncFail(std::move(callback));
return 0;
}));
}
protected:
FakeDNSClientFactory* dns_client_factory_;
std::shared_ptr<brillo::http::MockTransport> transport_;
std::shared_ptr<brillo::http::MockConnection> brillo_connection_;
EventDispatcherForTest dispatcher_;
MockControl control_;
MockManager manager_;
std::unique_ptr<HttpRequest> request_;
StrictMock<CallbackTarget> target_;
std::string expected_response_;
std::string resp_data_;
};
TEST_F(HttpRequestTest, Constructor) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv4,
{kIPv4DNS0, kIPv4DNS1});
ExpectStopped();
}
TEST_F(HttpRequestTest, UseCustomCertificate) {
auto transport = std::make_shared<brillo::http::MockTransport>();
EXPECT_CALL(*transport,
UseCustomCertificate(brillo::http::Transport::Certificate::kNss));
std::vector<net_base::IPAddress> dns_list = {kIPv4DNS0, kIPv4DNS1};
auto request = std::make_unique<HttpRequest>(
&dispatcher_, "wlan0", net_base::IPFamily::kIPv4, dns_list,
/*allow_non_google_https=*/true, transport,
std::make_unique<FakeDNSClientFactory>());
}
TEST_F(HttpRequestTest, IPv4NumericRequestSuccess) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv4,
{kIPv4DNS0, kIPv4DNS1});
const std::string resp = "Sample response.";
ExpectRequestSuccessCallback(resp);
EXPECT_CALL(*transport(), ResolveHostToIp).Times(0);
ExpectCreateConnection(kIPv4AddressURL);
ExpectFinishRequestAsyncSuccess(resp);
StartRequest(kIPv4AddressURL);
ExpectStopped();
}
TEST_F(HttpRequestTest, IPv4NumericRequestFail) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv4,
{kIPv4DNS0, kIPv4DNS1});
ExpectRequestErrorCallback(HttpRequest::Error::kConnectionFailure);
ExpectCreateConnection(kIPv4AddressURL);
ExpectFinishRequestAsyncFail();
StartRequest(kIPv4AddressURL);
ExpectStopped();
}
// TODO(cros-networking@): Re-enable this test when HttpUrl supports parsing
// IPv6 address hosts.
TEST_F(HttpRequestTest, DISABLED_IPv6NumericRequestSuccess) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv6,
{kIPv6DNS0, kIPv6DNS1});
const std::string resp = "Sample response.";
ExpectRequestSuccessCallback(resp);
EXPECT_CALL(*transport(), ResolveHostToIp).Times(0);
ExpectCreateConnection(kIPv6AddressURL);
ExpectFinishRequestAsyncSuccess(resp);
StartRequest(kIPv6AddressURL);
ExpectStopped();
}
// TODO(cros-networking@): Re-enable this test when HttpUrl supports parsing
// IPv6 address hosts.
TEST_F(HttpRequestTest, DISABLED_IPv6NumericRequestFail) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv6,
{kIPv6DNS0, kIPv6DNS1});
ExpectRequestErrorCallback(HttpRequest::Error::kConnectionFailure);
ExpectCreateConnection(kIPv6AddressURL);
ExpectFinishRequestAsyncFail();
StartRequest(kIPv6AddressURL);
ExpectStopped();
}
TEST_F(HttpRequestTest, IPv4TextRequestSuccess) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv4,
{kIPv4DNS0, kIPv4DNS1});
StartRequest(kTextURL);
VerifyDNSRequests(kTextSiteName, {kIPv4DNS0, kIPv4DNS1});
const std::string resp = "Sample response.";
ExpectRequestSuccessCallback(resp);
auto url = *net_base::HttpUrl::CreateFromString(kTextURL);
EXPECT_CALL(*transport(), ResolveHostToIp(url.host(), url.port(),
"10.1.1.1,10.1.1.2,10.1.1.3"));
ExpectCreateConnection(kTextURL);
ExpectFinishRequestAsyncSuccess(resp);
GetDNSResultSuccess(
kIPv4DNS0, {
net_base::IPAddress(net_base::IPv4Address(10, 1, 1, 1)),
net_base::IPAddress(net_base::IPv4Address(10, 1, 1, 2)),
net_base::IPAddress(net_base::IPv4Address(10, 1, 1, 3)),
});
ExpectStopped();
}
TEST_F(HttpRequestTest, IPv6TextRequestSuccess) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv6,
{kIPv6DNS0, kIPv6DNS1});
StartRequest(kTextURL);
VerifyDNSRequests(kTextSiteName, {kIPv6DNS0, kIPv6DNS1});
const std::string resp = "Sample response.";
ExpectRequestSuccessCallback(resp);
auto url = *net_base::HttpUrl::CreateFromString(kTextURL);
EXPECT_CALL(*transport(),
ResolveHostToIp(url.host(), url.port(), "2001:db8::1"));
ExpectCreateConnection(kTextURL);
ExpectFinishRequestAsyncSuccess(resp);
const auto addr = *net_base::IPAddress::CreateFromString("2001:db8::1");
GetDNSResultSuccess(kIPv6DNS0, {addr});
ExpectStopped();
}
// multiple addresses
TEST_F(HttpRequestTest, IPv4FailDNSFailure) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv4,
{kIPv4DNS0, kIPv4DNS1});
StartRequest(kTextURL);
VerifyDNSRequests(kTextSiteName, {kIPv4DNS0, kIPv4DNS1});
ExpectRequestErrorCallback(HttpRequest::Error::kDNSFailure);
GetDNSResultFailure(kIPv4DNS0, net_base::DNSClient::Error::kNoData);
GetDNSResultFailure(kIPv4DNS1, net_base::DNSClient::Error::kNoData);
ExpectStopped();
}
TEST_F(HttpRequestTest, IPv4FailDNSTimeout) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv4,
{kIPv4DNS0, kIPv4DNS1});
StartRequest(kTextURL);
VerifyDNSRequests(kTextSiteName, {kIPv4DNS0, kIPv4DNS1});
ExpectRequestErrorCallback(HttpRequest::Error::kDNSTimeout);
GetDNSResultFailure(kIPv4DNS0, net_base::DNSClient::Error::kTimedOut);
GetDNSResultFailure(kIPv4DNS1, net_base::DNSClient::Error::kTimedOut);
ExpectStopped();
}
TEST_F(HttpRequestTest, IPv6FailDNSFailure) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv6,
{kIPv6DNS0, kIPv6DNS1});
StartRequest(kTextURL);
VerifyDNSRequests(kTextSiteName, {kIPv6DNS0, kIPv6DNS1});
ExpectRequestErrorCallback(HttpRequest::Error::kDNSFailure);
GetDNSResultFailure(kIPv6DNS0, net_base::DNSClient::Error::kNoData);
GetDNSResultFailure(kIPv6DNS1, net_base::DNSClient::Error::kNoData);
ExpectStopped();
}
TEST_F(HttpRequestTest, IPv6FailDNSTimeout) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv6,
{kIPv6DNS0, kIPv6DNS1});
StartRequest(kTextURL);
VerifyDNSRequests(kTextSiteName, {kIPv6DNS0, kIPv6DNS1});
ExpectRequestErrorCallback(HttpRequest::Error::kDNSTimeout);
GetDNSResultFailure(kIPv6DNS0, net_base::DNSClient::Error::kTimedOut);
GetDNSResultFailure(kIPv6DNS1, net_base::DNSClient::Error::kTimedOut);
ExpectStopped();
}
TEST_F(HttpRequestTest, IPv4TextRequestSuccessAfterSomeDNSError) {
CreateRequest(kInterfaceName, net_base::IPFamily::kIPv4,
{kIPv6DNS0, kIPv6DNS1});
StartRequest(kTextURL);
VerifyDNSRequests(kTextSiteName, {kIPv6DNS0, kIPv6DNS1});
EXPECT_CALL(*transport(), ResolveHostToIp).Times(0);
GetDNSResultFailure(kIPv4DNS0, net_base::DNSClient::Error::kTimedOut);
EXPECT_TRUE(request_->is_running());
Mock::VerifyAndClearExpectations(transport().get());
const std::string resp = "Sample response.";
ExpectRequestSuccessCallback(resp);
auto url = *net_base::HttpUrl::CreateFromString(kTextURL);
EXPECT_CALL(*transport(),
ResolveHostToIp(url.host(), url.port(), "10.1.1.1"));
ExpectCreateConnection(kTextURL);
ExpectFinishRequestAsyncSuccess(resp);
GetDNSResultSuccess(
kIPv4DNS1, {net_base::IPAddress(net_base::IPv4Address(10, 1, 1, 1))});
ExpectStopped();
}
} // namespace shill