blob: d9371b4767b56b744c9d33f902aa44e8f10ea4b7 [file] [log] [blame]
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "device/test/usb_test_gadget.h"
#include <string>
#include <vector>
#include "base/command_line.h"
#include "base/compiler_specific.h"
#include "base/files/file.h"
#include "base/files/file_path.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/path_service.h"
#include "base/process/process_handle.h"
#include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
#include "device/usb/usb_device.h"
#include "device/usb/usb_device_handle.h"
#include "device/usb/usb_service.h"
#include "net/proxy/proxy_service.h"
#include "net/url_request/url_fetcher.h"
#include "net/url_request/url_fetcher_delegate.h"
#include "net/url_request/url_request_context.h"
#include "net/url_request/url_request_context_builder.h"
#include "net/url_request/url_request_context_getter.h"
#include "url/gurl.h"
using ::base::PlatformThread;
using ::base::TimeDelta;
namespace device {
namespace {
static const char kCommandLineSwitch[] = "enable-gadget-tests";
static const int kClaimRetries = 100; // 5 seconds
static const int kDisconnectRetries = 100; // 5 seconds
static const int kRetryPeriod = 50; // 0.05 seconds
static const int kReconnectRetries = 100; // 5 seconds
static const int kUpdateRetries = 100; // 5 seconds
// Wait for the given time delta while still running the main loop. This is
// necessary so that device add/remove events are processed by the UsbService.
void SleepWithRunLoop(base::TimeDelta delta) {
base::RunLoop run_loop;
base::MessageLoop::current()->PostDelayedTask(FROM_HERE,
run_loop.QuitClosure(), delta);
run_loop.Run();
}
struct UsbTestGadgetConfiguration {
UsbTestGadget::Type type;
const char* http_resource;
uint16 product_id;
};
static const struct UsbTestGadgetConfiguration kConfigurations[] = {
{UsbTestGadget::DEFAULT, "/unconfigure", 0x58F0},
{UsbTestGadget::KEYBOARD, "/keyboard/configure", 0x58F1},
{UsbTestGadget::MOUSE, "/mouse/configure", 0x58F2},
{UsbTestGadget::HID_ECHO, "/hid_echo/configure", 0x58F3},
{UsbTestGadget::ECHO, "/echo/configure", 0x58F4},
};
class UsbTestGadgetImpl : public UsbTestGadget {
public:
~UsbTestGadgetImpl() override;
bool Unclaim() override;
bool Disconnect() override;
bool Reconnect() override;
bool SetType(Type type) override;
UsbDevice* GetDevice() const override;
const std::string& GetSerialNumber() const override;
protected:
UsbTestGadgetImpl();
private:
scoped_ptr<net::URLFetcher> CreateURLFetcher(
const GURL& url,
net::URLFetcher::RequestType request_type,
net::URLFetcherDelegate* delegate);
int SimplePOSTRequest(const GURL& url, const std::string& form_data);
bool FindUnclaimed();
bool GetVersion(std::string* version);
bool Update();
bool FindClaimed();
bool ReadLocalVersion(std::string* version);
bool ReadLocalPackage(std::string* package);
bool ReadFile(const base::FilePath& file_path, std::string* content);
class Delegate : public net::URLFetcherDelegate {
public:
Delegate() {}
~Delegate() override {}
void WaitForCompletion() {
run_loop_.Run();
}
void OnURLFetchComplete(const net::URLFetcher* source) override {
run_loop_.Quit();
}
private:
base::RunLoop run_loop_;
DISALLOW_COPY_AND_ASSIGN(Delegate);
};
scoped_refptr<UsbDevice> device_;
std::string device_address_;
scoped_ptr<net::URLRequestContext> request_context_;
std::string session_id_;
UsbService* usb_service_;
friend class UsbTestGadget;
DISALLOW_COPY_AND_ASSIGN(UsbTestGadgetImpl);
};
} // namespace
bool UsbTestGadget::IsTestEnabled() {
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
return command_line->HasSwitch(kCommandLineSwitch);
}
scoped_ptr<UsbTestGadget> UsbTestGadget::Claim() {
scoped_ptr<UsbTestGadgetImpl> gadget(new UsbTestGadgetImpl);
int retries = kClaimRetries;
while (!gadget->FindUnclaimed()) {
if (--retries == 0) {
LOG(ERROR) << "Failed to find an unclaimed device.";
return scoped_ptr<UsbTestGadget>();
}
SleepWithRunLoop(TimeDelta::FromMilliseconds(kRetryPeriod));
}
VLOG(1) << "It took " << (kClaimRetries - retries)
<< " retries to find an unclaimed device.";
return gadget.Pass();
}
UsbTestGadgetImpl::UsbTestGadgetImpl() {
net::URLRequestContextBuilder context_builder;
context_builder.set_proxy_service(net::ProxyService::CreateDirect());
request_context_.reset(context_builder.Build());
base::ProcessId process_id = base::GetCurrentProcId();
session_id_ = base::StringPrintf(
"%s:%p", base::HexEncode(&process_id, sizeof(process_id)).c_str(), this);
usb_service_ = UsbService::GetInstance(NULL);
}
UsbTestGadgetImpl::~UsbTestGadgetImpl() {
if (!device_address_.empty()) {
Unclaim();
}
}
UsbDevice* UsbTestGadgetImpl::GetDevice() const {
return device_.get();
}
const std::string& UsbTestGadgetImpl::GetSerialNumber() const {
return device_address_;
}
scoped_ptr<net::URLFetcher> UsbTestGadgetImpl::CreateURLFetcher(
const GURL& url, net::URLFetcher::RequestType request_type,
net::URLFetcherDelegate* delegate) {
scoped_ptr<net::URLFetcher> url_fetcher(
net::URLFetcher::Create(url, request_type, delegate));
url_fetcher->SetRequestContext(new net::TrivialURLRequestContextGetter(
request_context_.get(), base::MessageLoop::current()->task_runner()));
return url_fetcher;
}
int UsbTestGadgetImpl::SimplePOSTRequest(const GURL& url,
const std::string& form_data) {
Delegate delegate;
scoped_ptr<net::URLFetcher> url_fetcher =
CreateURLFetcher(url, net::URLFetcher::POST, &delegate);
url_fetcher->SetUploadData("application/x-www-form-urlencoded", form_data);
url_fetcher->Start();
delegate.WaitForCompletion();
return url_fetcher->GetResponseCode();
}
bool UsbTestGadgetImpl::FindUnclaimed() {
std::vector<scoped_refptr<UsbDevice> > devices;
usb_service_->GetDevices(&devices);
for (std::vector<scoped_refptr<UsbDevice> >::const_iterator iter =
devices.begin(); iter != devices.end(); ++iter) {
const scoped_refptr<UsbDevice> &device = *iter;
if (device->vendor_id() == 0x18D1 && device->product_id() == 0x58F0) {
base::string16 serial_utf16;
if (!device->GetSerialNumber(&serial_utf16)) {
continue;
}
const std::string serial = base::UTF16ToUTF8(serial_utf16);
const GURL url("http://" + serial + "/claim");
const std::string form_data = base::StringPrintf(
"session_id=%s",
net::EscapeUrlEncodedData(session_id_, true).c_str());
const int response_code = SimplePOSTRequest(url, form_data);
if (response_code == 200) {
device_address_ = serial;
device_ = device;
break;
}
// The device is probably claimed by another process.
if (response_code != 403) {
LOG(WARNING) << "Unexpected HTTP " << response_code << " from /claim.";
}
}
}
std::string local_version;
std::string version;
if (!ReadLocalVersion(&local_version) ||
!GetVersion(&version)) {
return false;
}
if (version == local_version) {
return true;
}
return Update();
}
bool UsbTestGadgetImpl::GetVersion(std::string* version) {
Delegate delegate;
const GURL url("http://" + device_address_ + "/version");
scoped_ptr<net::URLFetcher> url_fetcher =
CreateURLFetcher(url, net::URLFetcher::GET, &delegate);
url_fetcher->Start();
delegate.WaitForCompletion();
const int response_code = url_fetcher->GetResponseCode();
if (response_code != 200) {
VLOG(2) << "Unexpected HTTP " << response_code << " from /version.";
return false;
}
STLClearObject(version);
if (!url_fetcher->GetResponseAsString(version)) {
VLOG(2) << "Failed to read body from /version.";
return false;
}
return true;
}
bool UsbTestGadgetImpl::Update() {
std::string version;
if (!ReadLocalVersion(&version)) {
return false;
}
LOG(INFO) << "Updating " << device_address_ << " to " << version << "...";
Delegate delegate;
const GURL url("http://" + device_address_ + "/update");
scoped_ptr<net::URLFetcher> url_fetcher =
CreateURLFetcher(url, net::URLFetcher::POST, &delegate);
const std::string mime_header =
base::StringPrintf(
"--foo\r\n"
"Content-Disposition: form-data; name=\"file\"; "
"filename=\"usb_gadget-%s.zip\"\r\n"
"Content-Type: application/octet-stream\r\n"
"\r\n", version.c_str());
const std::string mime_footer("\r\n--foo--\r\n");
std::string package;
if (!ReadLocalPackage(&package)) {
return false;
}
url_fetcher->SetUploadData("multipart/form-data; boundary=foo",
mime_header + package + mime_footer);
url_fetcher->Start();
delegate.WaitForCompletion();
const int response_code = url_fetcher->GetResponseCode();
if (response_code != 200) {
LOG(ERROR) << "Unexpected HTTP " << response_code << " from /update.";
return false;
}
int retries = kUpdateRetries;
std::string new_version;
while (!GetVersion(&new_version) || new_version != version) {
if (--retries == 0) {
LOG(ERROR) << "Device not responding with new version.";
return false;
}
SleepWithRunLoop(TimeDelta::FromMilliseconds(kRetryPeriod));
}
VLOG(1) << "It took " << (kUpdateRetries - retries)
<< " retries to see the new version.";
// Release the old reference to the device and try to open a new one.
device_ = NULL;
retries = kReconnectRetries;
while (!FindClaimed()) {
if (--retries == 0) {
LOG(ERROR) << "Failed to find updated device.";
return false;
}
SleepWithRunLoop(TimeDelta::FromMilliseconds(kRetryPeriod));
}
VLOG(1) << "It took " << (kReconnectRetries - retries)
<< " retries to find the updated device.";
return true;
}
bool UsbTestGadgetImpl::FindClaimed() {
CHECK(!device_.get());
std::string expected_serial = GetSerialNumber();
std::vector<scoped_refptr<UsbDevice> > devices;
usb_service_->GetDevices(&devices);
for (std::vector<scoped_refptr<UsbDevice> >::iterator iter =
devices.begin(); iter != devices.end(); ++iter) {
scoped_refptr<UsbDevice> &device = *iter;
if (device->vendor_id() == 0x18D1) {
const uint16 product_id = device->product_id();
bool found = false;
for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
if (product_id == kConfigurations[i].product_id) {
found = true;
break;
}
}
if (!found) {
continue;
}
base::string16 serial_utf16;
if (!device->GetSerialNumber(&serial_utf16)) {
continue;
}
std::string serial = base::UTF16ToUTF8(serial_utf16);
if (serial != expected_serial) {
continue;
}
device_ = device;
return true;
}
}
return false;
}
bool UsbTestGadgetImpl::ReadLocalVersion(std::string* version) {
base::FilePath file_path;
CHECK(PathService::Get(base::DIR_EXE, &file_path));
file_path = file_path.AppendASCII("usb_gadget.zip.md5");
return ReadFile(file_path, version);
}
bool UsbTestGadgetImpl::ReadLocalPackage(std::string* package) {
base::FilePath file_path;
CHECK(PathService::Get(base::DIR_EXE, &file_path));
file_path = file_path.AppendASCII("usb_gadget.zip");
return ReadFile(file_path, package);
}
bool UsbTestGadgetImpl::ReadFile(const base::FilePath& file_path,
std::string* content) {
base::File file(file_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
if (!file.IsValid()) {
LOG(ERROR) << "Cannot open " << file_path.MaybeAsASCII() << ": "
<< base::File::ErrorToString(file.error_details());
return false;
}
STLClearObject(content);
int rv;
do {
char buf[4096];
rv = file.ReadAtCurrentPos(buf, sizeof buf);
if (rv == -1) {
LOG(ERROR) << "Cannot read " << file_path.MaybeAsASCII() << ": "
<< base::File::ErrorToString(file.error_details());
return false;
}
content->append(buf, rv);
} while (rv > 0);
return true;
}
bool UsbTestGadgetImpl::Unclaim() {
VLOG(1) << "Releasing the device at " << device_address_ << ".";
const GURL url("http://" + device_address_ + "/unclaim");
const int response_code = SimplePOSTRequest(url, "");
if (response_code != 200) {
LOG(ERROR) << "Unexpected HTTP " << response_code << " from /unclaim.";
return false;
}
device_address_.clear();
return true;
}
bool UsbTestGadgetImpl::SetType(Type type) {
const struct UsbTestGadgetConfiguration* config = NULL;
for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
if (kConfigurations[i].type == type) {
config = &kConfigurations[i];
}
}
CHECK(config);
const GURL url("http://" + device_address_ + config->http_resource);
const int response_code = SimplePOSTRequest(url, "");
if (response_code != 200) {
LOG(ERROR) << "Unexpected HTTP " << response_code
<< " from " << config->http_resource << ".";
return false;
}
// Release the old reference to the device and try to open a new one.
int retries = kReconnectRetries;
while (true) {
device_ = NULL;
if (FindClaimed() && device_->product_id() == config->product_id) {
break;
}
if (--retries == 0) {
LOG(ERROR) << "Failed to find updated device.";
return false;
}
SleepWithRunLoop(TimeDelta::FromMilliseconds(kRetryPeriod));
}
VLOG(1) << "It took " << (kReconnectRetries - retries)
<< " retries to find the updated device.";
return true;
}
bool UsbTestGadgetImpl::Disconnect() {
const GURL url("http://" + device_address_ + "/disconnect");
const int response_code = SimplePOSTRequest(url, "");
if (response_code != 200) {
LOG(ERROR) << "Unexpected HTTP " << response_code << " from /disconnect.";
return false;
}
// Release the old reference to the device and wait until it can't be found.
int retries = kDisconnectRetries;
while (true) {
device_ = NULL;
if (!FindClaimed()) {
break;
}
if (--retries == 0) {
LOG(ERROR) << "Device did not disconnect.";
return false;
}
SleepWithRunLoop(TimeDelta::FromMilliseconds(kRetryPeriod));
}
VLOG(1) << "It took " << (kDisconnectRetries - retries)
<< " retries for the device to disconnect.";
return true;
}
bool UsbTestGadgetImpl::Reconnect() {
const GURL url("http://" + device_address_ + "/reconnect");
const int response_code = SimplePOSTRequest(url, "");
if (response_code != 200) {
LOG(ERROR) << "Unexpected HTTP " << response_code << " from /reconnect.";
return false;
}
int retries = kDisconnectRetries;
while (true) {
if (FindClaimed()) {
break;
}
if (--retries == 0) {
LOG(ERROR) << "Device did not reconnect.";
return false;
}
SleepWithRunLoop(TimeDelta::FromMilliseconds(kRetryPeriod));
}
VLOG(1) << "It took " << (kDisconnectRetries - retries)
<< " retries for the device to reconnect.";
return true;
}
} // namespace device