blob: 80ad4ae2b8ceb07e00f6bf9c252e1a84e99c7b9c [file] [log] [blame]
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "device/fido/hid/fido_hid_device.h"
#include <limits>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/command_line.h"
#include "base/logging.h"
#include "base/threading/thread_task_runner_handle.h"
#include "crypto/random.h"
#include "device/fido/hid/fido_hid_message.h"
#include "mojo/public/cpp/bindings/interface_request.h"
namespace device {
namespace {
// U2F devices only provide a single report so specify a report ID of 0 here.
static constexpr uint8_t kReportId = 0x00;
} // namespace
FidoHidDevice::FidoHidDevice(device::mojom::HidDeviceInfoPtr device_info,
device::mojom::HidManager* hid_manager)
: FidoDevice(),
output_report_size_(device_info->max_output_report_size),
hid_manager_(hid_manager),
device_info_(std::move(device_info)),
weak_factory_(this) {
DCHECK_GE(std::numeric_limits<decltype(output_report_size_)>::max(),
device_info_->max_output_report_size);
// These limits on the report size are enforced in fido_hid_discovery.cc.
DCHECK_LT(kHidInitPacketHeaderSize, output_report_size_);
DCHECK_GE(kHidMaxPacketSize, output_report_size_);
}
FidoHidDevice::~FidoHidDevice() = default;
void FidoHidDevice::DeviceTransact(std::vector<uint8_t> command,
DeviceCallback callback) {
Transition(std::move(command), std::move(callback));
}
void FidoHidDevice::Cancel() {
// If device has not been connected or is already in error state, do nothing.
if (state_ != State::kBusy && state_ != State::kReady)
return;
// Delete any remaining pending requests on this Channel ID.
pending_transactions_ = {};
WriteMessage(
FidoHidMessage::Create(channel_id_, FidoHidDeviceCommand::kCancel,
output_report_size_, std::vector<uint8_t>()),
false /* response_expected */, base::DoNothing());
}
void FidoHidDevice::Transition(std::vector<uint8_t> command,
DeviceCallback callback) {
// This adapter is needed to support the calls to ArmTimeout(). However, it is
// still guaranteed that |callback| will only be invoked once.
auto repeating_callback =
base::AdaptCallbackForRepeating(std::move(callback));
switch (state_) {
case State::kInit:
state_ = State::kBusy;
ArmTimeout(repeating_callback);
Connect(base::BindOnce(&FidoHidDevice::OnConnect,
weak_factory_.GetWeakPtr(), std::move(command),
repeating_callback));
break;
case State::kConnected:
state_ = State::kBusy;
ArmTimeout(repeating_callback);
AllocateChannel(std::move(command), repeating_callback);
break;
case State::kReady: {
state_ = State::kBusy;
ArmTimeout(repeating_callback);
// Write message to the device.
const auto command_type = supported_protocol() == ProtocolVersion::kCtap
? FidoHidDeviceCommand::kCbor
: FidoHidDeviceCommand::kMsg;
WriteMessage(
FidoHidMessage::Create(channel_id_, command_type, output_report_size_,
std::move(command)),
true,
base::BindOnce(&FidoHidDevice::MessageReceived,
weak_factory_.GetWeakPtr(), repeating_callback));
break;
}
case State::kBusy:
pending_transactions_.emplace(std::move(command), repeating_callback);
break;
case State::kDeviceError:
default:
base::WeakPtr<FidoHidDevice> self = weak_factory_.GetWeakPtr();
repeating_callback.Run(base::nullopt);
// Executing callbacks may free |this|. Check |self| first.
while (self && !pending_transactions_.empty()) {
// Respond to any pending requests.
DeviceCallback pending_cb =
std::move(pending_transactions_.front().second);
pending_transactions_.pop();
std::move(pending_cb).Run(base::nullopt);
}
break;
}
}
void FidoHidDevice::Connect(ConnectCallback callback) {
DCHECK(hid_manager_);
hid_manager_->Connect(device_info_->guid, std::move(callback));
}
void FidoHidDevice::OnConnect(std::vector<uint8_t> command,
DeviceCallback callback,
device::mojom::HidConnectionPtr connection) {
if (state_ == State::kDeviceError)
return;
timeout_callback_.Cancel();
if (connection) {
connection_ = std::move(connection);
state_ = State::kConnected;
} else {
state_ = State::kDeviceError;
}
Transition(std::move(command), std::move(callback));
}
void FidoHidDevice::AllocateChannel(std::vector<uint8_t> command,
DeviceCallback callback) {
// Send random nonce to device to verify received message.
std::vector<uint8_t> nonce(8);
crypto::RandBytes(nonce.data(), nonce.size());
WriteMessage(FidoHidMessage::Create(channel_id_, FidoHidDeviceCommand::kInit,
output_report_size_, nonce),
true,
base::BindOnce(&FidoHidDevice::OnAllocateChannel,
weak_factory_.GetWeakPtr(), nonce,
std::move(command), std::move(callback)));
}
void FidoHidDevice::OnAllocateChannel(std::vector<uint8_t> nonce,
std::vector<uint8_t> command,
DeviceCallback callback,
base::Optional<FidoHidMessage> message) {
if (state_ == State::kDeviceError)
return;
timeout_callback_.Cancel();
if (!message || message->cmd() != FidoHidDeviceCommand::kInit) {
state_ = State::kDeviceError;
Transition(std::vector<uint8_t>(), std::move(callback));
return;
}
// Channel allocation response is defined as:
// 0: 8 byte nonce
// 8: 4 byte channel id
// 12: Protocol version id
// 13: Major device version
// 14: Minor device version
// 15: Build device version
// 16: Capabilities
std::vector<uint8_t> payload = message->GetMessagePayload();
if (payload.size() != 17) {
state_ = State::kDeviceError;
Transition(std::vector<uint8_t>(), std::move(callback));
return;
}
auto received_nonce = base::make_span(payload).first(8);
// Received a broadcast message for a different client. Disregard and continue
// reading.
if (!std::equal(nonce.begin(), nonce.end(), received_nonce.begin(),
received_nonce.end())) {
auto repeating_callback =
base::AdaptCallbackForRepeating(std::move(callback));
ArmTimeout(repeating_callback);
ReadMessage(base::BindOnce(&FidoHidDevice::OnAllocateChannel,
weak_factory_.GetWeakPtr(), nonce,
std::move(command), repeating_callback));
return;
}
size_t index = 8;
channel_id_ = payload[index++] << 24;
channel_id_ |= payload[index++] << 16;
channel_id_ |= payload[index++] << 8;
channel_id_ |= payload[index++];
capabilities_ = payload[16];
state_ = State::kReady;
Transition(std::move(command), std::move(callback));
}
void FidoHidDevice::WriteMessage(base::Optional<FidoHidMessage> message,
bool response_expected,
HidMessageCallback callback) {
if (!connection_ || !message || message->NumPackets() == 0) {
std::move(callback).Run(base::nullopt);
return;
}
auto packet = message->PopNextPacket();
DCHECK_LE(packet.size(), output_report_size_);
packet.resize(output_report_size_, 0);
connection_->Write(
kReportId, packet,
base::BindOnce(&FidoHidDevice::PacketWritten, weak_factory_.GetWeakPtr(),
std::move(message), response_expected,
std::move(callback)));
}
void FidoHidDevice::PacketWritten(base::Optional<FidoHidMessage> message,
bool response_expected,
HidMessageCallback callback,
bool success) {
if (success && message->NumPackets() > 0) {
WriteMessage(std::move(message), response_expected, std::move(callback));
} else if (success && response_expected) {
ReadMessage(std::move(callback));
} else {
std::move(callback).Run(base::nullopt);
}
}
void FidoHidDevice::ReadMessage(HidMessageCallback callback) {
if (!connection_) {
state_ = State::kDeviceError;
std::move(callback).Run(base::nullopt);
return;
}
connection_->Read(base::BindOnce(
&FidoHidDevice::OnRead, weak_factory_.GetWeakPtr(), std::move(callback)));
}
void FidoHidDevice::OnRead(HidMessageCallback callback,
bool success,
uint8_t report_id,
const base::Optional<std::vector<uint8_t>>& buf) {
if (!success) {
state_ = State::kDeviceError;
std::move(callback).Run(base::nullopt);
return;
}
DCHECK(buf);
auto read_message = FidoHidMessage::CreateFromSerializedData(*buf);
if (!read_message) {
std::move(callback).Run(base::nullopt);
return;
}
// Received a message from a different channel, so try again.
if (channel_id_ != read_message->channel_id()) {
connection_->Read(base::BindOnce(&FidoHidDevice::OnRead,
weak_factory_.GetWeakPtr(),
std::move(callback)));
return;
}
if (read_message->MessageComplete()) {
std::move(callback).Run(std::move(read_message));
return;
}
// Continue reading additional packets.
connection_->Read(base::BindOnce(
&FidoHidDevice::OnReadContinuation, weak_factory_.GetWeakPtr(),
std::move(read_message), std::move(callback)));
}
void FidoHidDevice::OnReadContinuation(
base::Optional<FidoHidMessage> message,
HidMessageCallback callback,
bool success,
uint8_t report_id,
const base::Optional<std::vector<uint8_t>>& buf) {
if (!success) {
state_ = State::kDeviceError;
std::move(callback).Run(base::nullopt);
return;
}
DCHECK(buf);
message->AddContinuationPacket(*buf);
if (message->MessageComplete()) {
std::move(callback).Run(std::move(message));
return;
}
connection_->Read(base::BindOnce(&FidoHidDevice::OnReadContinuation,
weak_factory_.GetWeakPtr(),
std::move(message), std::move(callback)));
}
void FidoHidDevice::MessageReceived(DeviceCallback callback,
base::Optional<FidoHidMessage> message) {
if (state_ == State::kDeviceError)
return;
timeout_callback_.Cancel();
if (!message) {
state_ = State::kDeviceError;
Transition(std::vector<uint8_t>(), std::move(callback));
return;
}
const auto cmd = message->cmd();
// If received HID packet has keep_alive as command type, re-read after delay.
if (supported_protocol() == ProtocolVersion::kCtap &&
cmd == FidoHidDeviceCommand::kKeepAlive) {
base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&FidoHidDevice::OnKeepAlive, weak_factory_.GetWeakPtr(),
std::move(callback)),
kHidKeepAliveDelay);
return;
}
if (cmd != FidoHidDeviceCommand::kMsg && cmd != FidoHidDeviceCommand::kCbor) {
ProcessHidError(cmd, message->GetMessagePayload());
Transition(std::vector<uint8_t>(), std::move(callback));
return;
}
auto response = message->GetMessagePayload();
state_ = State::kReady;
base::WeakPtr<FidoHidDevice> self = weak_factory_.GetWeakPtr();
std::move(callback).Run(
message ? base::make_optional(message->GetMessagePayload())
: base::nullopt);
// Executing |callback| may have freed |this|. Check |self| first.
if (self && !pending_transactions_.empty()) {
// If any transactions were queued, process the first one.
auto pending_cmd = std::move(pending_transactions_.front().first);
auto pending_cb = std::move(pending_transactions_.front().second);
pending_transactions_.pop();
Transition(std::move(pending_cmd), std::move(pending_cb));
}
}
void FidoHidDevice::TryWink(WinkCallback callback) {
// Only try to wink if device claims support.
if (!(capabilities_ & kWinkCapability) || state_ != State::kReady) {
std::move(callback).Run();
return;
}
WriteMessage(
FidoHidMessage::Create(channel_id_, FidoHidDeviceCommand::kWink,
output_report_size_, std::vector<uint8_t>()),
true,
base::BindOnce(&FidoHidDevice::OnWink, weak_factory_.GetWeakPtr(),
std::move(callback)));
}
void FidoHidDevice::OnKeepAlive(DeviceCallback callback) {
auto repeating_callback =
base::AdaptCallbackForRepeating(std::move(callback));
ArmTimeout(repeating_callback);
ReadMessage(base::BindOnce(&FidoHidDevice::MessageReceived,
weak_factory_.GetWeakPtr(),
std::move(repeating_callback)));
}
void FidoHidDevice::OnWink(WinkCallback callback,
base::Optional<FidoHidMessage> response) {
std::move(callback).Run();
}
void FidoHidDevice::ArmTimeout(DeviceCallback callback) {
DCHECK(timeout_callback_.IsCancelled());
timeout_callback_.Reset(base::BindOnce(&FidoHidDevice::OnTimeout,
weak_factory_.GetWeakPtr(),
std::move(callback)));
// Setup timeout task for 3 seconds.
base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE, timeout_callback_.callback(), kDeviceTimeout);
}
void FidoHidDevice::OnTimeout(DeviceCallback callback) {
state_ = State::kDeviceError;
Transition(std::vector<uint8_t>(), std::move(callback));
}
void FidoHidDevice::ProcessHidError(FidoHidDeviceCommand cmd,
base::span<const uint8_t> payload) {
if (cmd != FidoHidDeviceCommand::kError || payload.size() != 1) {
DLOG(ERROR) << "Unexpected HID device command received.";
state_ = State::kDeviceError;
return;
}
const auto error_constant = payload[0];
if (error_constant ==
base::strict_cast<uint8_t>(HidErrorConstant::kInvalidCommand) ||
error_constant ==
base::strict_cast<uint8_t>(HidErrorConstant::kInvalidParameter) ||
error_constant ==
base::strict_cast<uint8_t>(HidErrorConstant::kInvalidLength)) {
state_ = State::kMsgError;
} else {
state_ = State::kDeviceError;
}
}
std::string FidoHidDevice::GetId() const {
return GetIdForDevice(*device_info_);
}
FidoTransportProtocol FidoHidDevice::DeviceTransport() const {
return FidoTransportProtocol::kUsbHumanInterfaceDevice;
}
// static
std::string FidoHidDevice::GetIdForDevice(
const device::mojom::HidDeviceInfo& device_info) {
return "hid:" + device_info.guid;
}
base::WeakPtr<FidoDevice> FidoHidDevice::GetWeakPtr() {
return weak_factory_.GetWeakPtr();
}
} // namespace device