blob: e319b3af57730ac10be73d092a7fe02f69889663 [file] [log] [blame]
// Copyright (c) 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/device/hid/hid_connection.h"
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "base/bind.h"
#include "base/callback.h"
#include "base/memory/ref_counted_memory.h"
#include "base/run_loop.h"
#include "base/scoped_observer.h"
#include "base/strings/utf_string_conversions.h"
#include "base/test/scoped_task_environment.h"
#include "base/test/test_io_thread.h"
#include "device/test/usb_test_gadget.h"
#include "device/usb/usb_device.h"
#include "services/device/hid/hid_service.h"
#include "services/device/public/mojom/hid.mojom.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace device {
namespace {
// Helper class that can be used to block until a HID device with a particular
// serial number is available. Example usage:
//
// DeviceCatcher device_catcher("ABC123");
// std::string device_guid = device_catcher.WaitForDevice();
// /* Call HidService::Connect(device_guid) to open the device. */
//
class DeviceCatcher : HidService::Observer {
public:
DeviceCatcher(HidService* hid_service, const base::string16& serial_number)
: serial_number_(base::UTF16ToUTF8(serial_number)), observer_(this) {
hid_service->GetDevices(
base::BindOnce(&DeviceCatcher::OnEnumerationComplete,
base::Unretained(this), hid_service));
}
const std::string& WaitForDevice() {
run_loop_.Run();
observer_.RemoveAll();
return device_guid_;
}
private:
void OnEnumerationComplete(HidService* hid_service,
std::vector<mojom::HidDeviceInfoPtr> devices) {
for (auto& device_info : devices) {
if (device_info->serial_number == serial_number_) {
device_guid_ = device_info->guid;
run_loop_.Quit();
break;
}
}
observer_.Add(hid_service);
}
void OnDeviceAdded(mojom::HidDeviceInfoPtr device_info) override {
if (device_info->serial_number == serial_number_) {
device_guid_ = device_info->guid;
run_loop_.Quit();
}
}
std::string serial_number_;
ScopedObserver<HidService, HidService::Observer> observer_;
base::RunLoop run_loop_;
std::string device_guid_;
};
class TestConnectCallback {
public:
TestConnectCallback() {}
~TestConnectCallback() {}
void SetConnection(scoped_refptr<HidConnection> connection) {
connection_ = connection;
run_loop_.Quit();
}
scoped_refptr<HidConnection> WaitForConnection() {
run_loop_.Run();
return connection_;
}
HidService::ConnectCallback GetCallback() {
return base::Bind(&TestConnectCallback::SetConnection,
base::Unretained(this));
}
private:
base::RunLoop run_loop_;
scoped_refptr<HidConnection> connection_;
};
class TestIoCallback {
public:
TestIoCallback() {}
~TestIoCallback() {}
void SetReadResult(bool success,
scoped_refptr<base::RefCountedBytes> buffer,
size_t size) {
result_ = success;
buffer_ = buffer;
size_ = size;
run_loop_.Quit();
}
void SetWriteResult(bool success) {
result_ = success;
run_loop_.Quit();
}
bool WaitForResult() {
run_loop_.Run();
return result_;
}
HidConnection::ReadCallback GetReadCallback() {
return base::BindOnce(&TestIoCallback::SetReadResult,
base::Unretained(this));
}
HidConnection::WriteCallback GetWriteCallback() {
return base::BindOnce(&TestIoCallback::SetWriteResult,
base::Unretained(this));
}
scoped_refptr<base::RefCountedBytes> buffer() const { return buffer_; }
size_t size() const { return size_; }
private:
base::RunLoop run_loop_;
bool result_;
size_t size_;
scoped_refptr<base::RefCountedBytes> buffer_;
};
} // namespace
class HidConnectionTest : public testing::Test {
public:
HidConnectionTest()
: scoped_task_environment_(
base::test::ScopedTaskEnvironment::MainThreadType::UI),
io_thread_(base::TestIOThread::kAutoStart) {}
protected:
void SetUp() override {
if (!UsbTestGadget::IsTestEnabled())
return;
service_ = HidService::Create();
ASSERT_TRUE(service_);
test_gadget_ = UsbTestGadget::Claim(io_thread_.task_runner());
ASSERT_TRUE(test_gadget_);
ASSERT_TRUE(test_gadget_->SetType(UsbTestGadget::HID_ECHO));
DeviceCatcher device_catcher(service_.get(),
test_gadget_->GetDevice()->serial_number());
device_guid_ = device_catcher.WaitForDevice();
ASSERT_FALSE(device_guid_.empty());
}
base::test::ScopedTaskEnvironment scoped_task_environment_;
base::TestIOThread io_thread_;
std::unique_ptr<HidService> service_;
std::unique_ptr<UsbTestGadget> test_gadget_;
std::string device_guid_;
};
TEST_F(HidConnectionTest, ReadWrite) {
if (!UsbTestGadget::IsTestEnabled())
return;
TestConnectCallback connect_callback;
service_->Connect(device_guid_, connect_callback.GetCallback());
scoped_refptr<HidConnection> conn = connect_callback.WaitForConnection();
ASSERT_TRUE(conn.get());
const char kBufferSize = 9;
for (char i = 0; i < 8; ++i) {
auto buffer = base::MakeRefCounted<base::RefCountedBytes>(kBufferSize);
buffer->data()[0] = 0;
for (unsigned char j = 1; j < kBufferSize; ++j) {
buffer->data()[j] = i + j - 1;
}
TestIoCallback write_callback;
conn->Write(buffer, write_callback.GetWriteCallback());
ASSERT_TRUE(write_callback.WaitForResult());
TestIoCallback read_callback;
conn->Read(read_callback.GetReadCallback());
ASSERT_TRUE(read_callback.WaitForResult());
ASSERT_EQ(9UL, read_callback.size());
ASSERT_EQ(0, read_callback.buffer()->data()[0]);
for (unsigned char j = 1; j < kBufferSize; ++j) {
ASSERT_EQ(i + j - 1, read_callback.buffer()->data()[j]);
}
}
conn->Close();
}
} // namespace device