blob: b89a1dcb9e3211b61971ae1dc80e917d4bb91261 [file] [log] [blame]
// Copyright (c) 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <locale>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "base/at_exit.h"
#include "base/bind.h"
#include "base/callback.h"
#include "base/command_line.h"
#include "base/files/file_path.h"
#include "base/lazy_instance.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/message_loop/message_loop.h"
#include "base/run_loop.h"
#include "base/single_thread_task_runner.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_split.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/synchronization/waitable_event.h"
#include "base/threading/thread.h"
#include "base/threading/thread_local.h"
#include "base/threading/thread_task_runner_handle.h"
#include "build/build_config.h"
#include "chrome/test/chromedriver/logging.h"
#include "chrome/test/chromedriver/net/port_server.h"
#include "chrome/test/chromedriver/server/http_handler.h"
#include "chrome/test/chromedriver/version.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/log/net_log_source.h"
#include "net/server/http_server.h"
#include "net/server/http_server_request_info.h"
#include "net/server/http_server_response_info.h"
#include "net/socket/tcp_server_socket.h"
namespace {
const int kBufferSize = 100 * 1024 * 1024; // 100 MB
typedef base::Callback<
void(const net::HttpServerRequestInfo&, const HttpResponseSenderFunc&)>
HttpRequestHandlerFunc;
int ListenOnIPv4(net::ServerSocket* socket, uint16_t port, bool allow_remote) {
std::string binding_ip = net::IPAddress::IPv4Localhost().ToString();
if (allow_remote)
binding_ip = net::IPAddress::IPv4AllZeros().ToString();
return socket->ListenWithAddressAndPort(binding_ip, port, 1);
}
int ListenOnIPv6(net::ServerSocket* socket, uint16_t port, bool allow_remote) {
std::string binding_ip = net::IPAddress::IPv6Localhost().ToString();
if (allow_remote)
binding_ip = net::IPAddress::IPv6AllZeros().ToString();
return socket->ListenWithAddressAndPort(binding_ip, port, 1);
}
class HttpServer : public net::HttpServer::Delegate {
public:
explicit HttpServer(const HttpRequestHandlerFunc& handle_request_func)
: handle_request_func_(handle_request_func),
weak_factory_(this) {}
~HttpServer() override {}
bool Start(uint16_t port, bool allow_remote) {
std::unique_ptr<net::ServerSocket> server_socket(
new net::TCPServerSocket(NULL, net::NetLogSource()));
if (ListenOnIPv4(server_socket.get(), port, allow_remote) != net::OK) {
// This will work on an IPv6-only host, but we will be IPv4-only on
// dual-stack hosts.
// TODO(samuong): change this to listen on both IPv4 and IPv6.
VLOG(0) << "listen on IPv4 failed, trying IPv6";
if (ListenOnIPv6(server_socket.get(), port, allow_remote) != net::OK) {
VLOG(1) << "listen on both IPv4 and IPv6 failed, giving up";
return false;
}
}
server_.reset(new net::HttpServer(std::move(server_socket), this));
net::IPEndPoint address;
return server_->GetLocalAddress(&address) == net::OK;
}
// Overridden from net::HttpServer::Delegate:
void OnConnect(int connection_id) override {
server_->SetSendBufferSize(connection_id, kBufferSize);
server_->SetReceiveBufferSize(connection_id, kBufferSize);
}
void OnHttpRequest(int connection_id,
const net::HttpServerRequestInfo& info) override {
handle_request_func_.Run(
info,
base::Bind(&HttpServer::OnResponse,
weak_factory_.GetWeakPtr(),
connection_id));
}
void OnWebSocketRequest(int connection_id,
const net::HttpServerRequestInfo& info) override {}
void OnWebSocketMessage(int connection_id, const std::string& data) override {
}
void OnClose(int connection_id) override {}
private:
void OnResponse(int connection_id,
std::unique_ptr<net::HttpServerResponseInfo> response) {
// Don't support keep-alive, since there's no way to detect if the
// client is HTTP/1.0. In such cases, the client may hang waiting for
// the connection to close (e.g., python 2.7 urllib).
response->AddHeader("Connection", "close");
server_->SendResponse(connection_id, *response);
// Don't need to call server_->Close(), since SendResponse() will handle
// this for us.
}
HttpRequestHandlerFunc handle_request_func_;
std::unique_ptr<net::HttpServer> server_;
base::WeakPtrFactory<HttpServer> weak_factory_; // Should be last.
};
void SendResponseOnCmdThread(
const scoped_refptr<base::SingleThreadTaskRunner>& io_task_runner,
const HttpResponseSenderFunc& send_response_on_io_func,
std::unique_ptr<net::HttpServerResponseInfo> response) {
io_task_runner->PostTask(
FROM_HERE, base::Bind(send_response_on_io_func, base::Passed(&response)));
}
void HandleRequestOnCmdThread(
HttpHandler* handler,
const std::vector<std::string>& whitelisted_ips,
const net::HttpServerRequestInfo& request,
const HttpResponseSenderFunc& send_response_func) {
if (!whitelisted_ips.empty()) {
std::string peer_address = request.peer.ToStringWithoutPort();
if (peer_address != net::IPAddress::IPv4Localhost().ToString() &&
std::find(whitelisted_ips.begin(), whitelisted_ips.end(),
peer_address) == whitelisted_ips.end()) {
LOG(WARNING) << "unauthorized access from " << request.peer.ToString();
std::unique_ptr<net::HttpServerResponseInfo> response(
new net::HttpServerResponseInfo(net::HTTP_UNAUTHORIZED));
response->SetBody("Unauthorized access", "text/plain");
send_response_func.Run(std::move(response));
return;
}
}
handler->Handle(request, send_response_func);
}
void HandleRequestOnIOThread(
const scoped_refptr<base::SingleThreadTaskRunner>& cmd_task_runner,
const HttpRequestHandlerFunc& handle_request_on_cmd_func,
const net::HttpServerRequestInfo& request,
const HttpResponseSenderFunc& send_response_func) {
cmd_task_runner->PostTask(
FROM_HERE, base::Bind(handle_request_on_cmd_func, request,
base::Bind(&SendResponseOnCmdThread,
base::ThreadTaskRunnerHandle::Get(),
send_response_func)));
}
base::LazyInstance<base::ThreadLocalPointer<HttpServer> >
lazy_tls_server = LAZY_INSTANCE_INITIALIZER;
void StopServerOnIOThread() {
// Note, |server| may be NULL.
HttpServer* server = lazy_tls_server.Pointer()->Get();
lazy_tls_server.Pointer()->Set(NULL);
delete server;
}
void StartServerOnIOThread(uint16_t port,
bool allow_remote,
const HttpRequestHandlerFunc& handle_request_func) {
std::unique_ptr<HttpServer> temp_server(new HttpServer(handle_request_func));
if (!temp_server->Start(port, allow_remote)) {
printf("Port not available. Exiting...\n");
exit(1);
}
lazy_tls_server.Pointer()->Set(temp_server.release());
}
void RunServer(uint16_t port,
bool allow_remote,
const std::vector<std::string>& whitelisted_ips,
const std::string& url_base,
int adb_port,
std::unique_ptr<PortServer> port_server) {
base::Thread io_thread("ChromeDriver IO");
CHECK(io_thread.StartWithOptions(
base::Thread::Options(base::MessageLoop::TYPE_IO, 0)));
base::MessageLoop cmd_loop;
base::RunLoop cmd_run_loop;
HttpHandler handler(cmd_run_loop.QuitClosure(), io_thread.task_runner(),
url_base, adb_port, std::move(port_server));
HttpRequestHandlerFunc handle_request_func =
base::Bind(&HandleRequestOnCmdThread, &handler, whitelisted_ips);
io_thread.task_runner()->PostTask(
FROM_HERE,
base::Bind(&StartServerOnIOThread, port, allow_remote,
base::Bind(&HandleRequestOnIOThread, cmd_loop.task_runner(),
handle_request_func)));
// Run the command loop. This loop is quit after the response for a shutdown
// request is posted to the IO loop. After the command loop quits, a task
// is posted to the IO loop to stop the server. Lastly, the IO thread is
// destroyed, which waits until all pending tasks have been completed.
// This assumes the response is sent synchronously as part of the IO task.
cmd_run_loop.Run();
io_thread.task_runner()->PostTask(FROM_HERE,
base::Bind(&StopServerOnIOThread));
}
} // namespace
int main(int argc, char *argv[]) {
base::CommandLine::Init(argc, argv);
base::AtExitManager at_exit;
base::CommandLine* cmd_line = base::CommandLine::ForCurrentProcess();
#if defined(OS_LINUX)
// Select the locale from the environment by passing an empty string instead
// of the default "C" locale. This is particularly needed for the keycode
// conversion code to work.
setlocale(LC_ALL, "");
#endif
// Parse command line flags.
uint16_t port = 9515;
int adb_port = 5037;
bool allow_remote = false;
std::vector<std::string> whitelisted_ips;
std::string url_base;
std::unique_ptr<PortServer> port_server;
if (cmd_line->HasSwitch("h") || cmd_line->HasSwitch("help")) {
std::string options;
const char* const kOptionAndDescriptions[] = {
"port=PORT", "port to listen on",
"adb-port=PORT", "adb server port",
"log-path=FILE", "write server log to file instead of stderr, "
"increases log level to INFO",
"verbose", "log verbosely",
"version", "print the version number and exit",
"silent", "log nothing",
"url-base", "base URL path prefix for commands, e.g. wd/url",
"port-server", "address of server to contact for reserving a port",
"whitelisted-ips", "comma-separated whitelist of remote IPv4 addresses "
"which are allowed to connect to ChromeDriver",
};
for (size_t i = 0; i < arraysize(kOptionAndDescriptions) - 1; i += 2) {
options += base::StringPrintf(
" --%-30s%s\n",
kOptionAndDescriptions[i], kOptionAndDescriptions[i + 1]);
}
printf("Usage: %s [OPTIONS]\n\nOptions\n%s", argv[0], options.c_str());
return 0;
}
if (cmd_line->HasSwitch("v") || cmd_line->HasSwitch("version")) {
printf("ChromeDriver %s\n", kChromeDriverVersion);
return 0;
}
if (cmd_line->HasSwitch("port")) {
int cmd_line_port;
if (!base::StringToInt(cmd_line->GetSwitchValueASCII("port"),
&cmd_line_port) ||
cmd_line_port < 0 || cmd_line_port > 65535) {
printf("Invalid port. Exiting...\n");
return 1;
}
port = static_cast<uint16_t>(cmd_line_port);
}
if (cmd_line->HasSwitch("adb-port")) {
if (!base::StringToInt(cmd_line->GetSwitchValueASCII("adb-port"),
&adb_port)) {
printf("Invalid adb-port. Exiting...\n");
return 1;
}
}
if (cmd_line->HasSwitch("port-server")) {
#if defined(OS_LINUX)
std::string address = cmd_line->GetSwitchValueASCII("port-server");
if (address.empty() || address[0] != '@') {
printf("Invalid port-server. Exiting...\n");
return 1;
}
std::string path;
// First character of path is \0 to use Linux's abstract namespace.
path.push_back(0);
path += address.substr(1);
port_server.reset(new PortServer(path));
#else
printf("Warning: port-server not implemented for this platform.\n");
#endif
}
if (cmd_line->HasSwitch("url-base"))
url_base = cmd_line->GetSwitchValueASCII("url-base");
if (url_base.empty() || url_base.front() != '/')
url_base = "/" + url_base;
if (url_base.back() != '/')
url_base = url_base + "/";
if (cmd_line->HasSwitch("whitelisted-ips")) {
allow_remote = true;
std::string whitelist = cmd_line->GetSwitchValueASCII("whitelisted-ips");
whitelisted_ips = base::SplitString(
whitelist, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL);
}
if (!cmd_line->HasSwitch("silent")) {
printf("Starting ChromeDriver %s on port %u\n", kChromeDriverVersion, port);
if (!allow_remote) {
printf("Only local connections are allowed.\n");
} else if (!whitelisted_ips.empty()) {
printf("Remote connections are allowed by a whitelist (%s).\n",
cmd_line->GetSwitchValueASCII("whitelisted-ips").c_str());
} else {
printf("All remote connections are allowed. Use a whitelist instead!\n");
}
fflush(stdout);
}
if (!InitLogging()) {
printf("Unable to initialize logging. Exiting...\n");
return 1;
}
RunServer(port, allow_remote, whitelisted_ips, url_base, adb_port,
std::move(port_server));
return 0;
}