blob: 622e98c65e4622ec30e46a98cdd1c54ed7d4deb1 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/quic/quic_socket_data_provider.h"
#include <algorithm>
#include <map>
#include <memory>
#include <optional>
#include <set>
#include <sstream>
#include <string>
#include "base/functional/callback.h"
#include "base/run_loop.h"
#include "base/strings/string_number_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "net/base/hex_utils.h"
#include "net/socket/socket_test_util.h"
#include "net/third_party/quiche/src/quiche/quic/core/crypto/null_encrypter.h"
#include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_server_config.h"
#include "net/third_party/quiche/src/quiche/quic/core/http/quic_spdy_session.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
#include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
#include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
#include "net/third_party/quiche/src/quiche/quic/tools/quic_simple_server_session.h"
namespace quic {
namespace {
class HttpStreamPrinter : public Http3DebugVisitor {
public:
explicit HttpStreamPrinter(std::ostream* output) : output_(output) {}
void OnControlStreamCreated(QuicStreamId stream_id) override {
if (output_) {
*output_ << "OnControlStreamCreated: " << stream_id << "\n";
}
}
void OnQpackEncoderStreamCreated(QuicStreamId stream_id) override {
if (output_) {
*output_ << "OnQpackEncoderStreamCreated: " << stream_id << "\n";
}
}
void OnQpackDecoderStreamCreated(QuicStreamId stream_id) override {
if (output_) {
*output_ << "OnQpackDecoderStreamCreated: " << stream_id << "\n";
}
}
void OnPeerControlStreamCreated(QuicStreamId stream_id) override {
if (output_) {
*output_ << "OnPeerControlStreamCreated: " << stream_id << "\n";
}
}
void OnPeerQpackEncoderStreamCreated(QuicStreamId stream_id) override {
if (output_) {
*output_ << "OnPeerQpackEncoderStreamCreated: " << stream_id << "\n";
}
}
void OnPeerQpackDecoderStreamCreated(QuicStreamId stream_id) override {
if (output_) {
*output_ << "OnPeerQpackDecoderStreamCreated: " << stream_id << "\n";
}
}
void OnSettingsFrameReceived(const SettingsFrame& frame) override {
if (output_) {
*output_ << "OnSettingsFrameReceived: " << frame << "\n";
}
}
void OnGoAwayFrameReceived(const GoAwayFrame& frame) override {
if (output_) {
*output_ << "OnGoAwayFrameReceived" << "\n";
}
}
void OnPriorityUpdateFrameReceived(
const PriorityUpdateFrame& frame) override {
if (output_) {
*output_ << "OnPriorityUpdateFrameReceived: " << frame << "\n";
}
}
void OnDataFrameReceived(QuicStreamId stream_id,
QuicByteCount payload_length) override {
if (output_) {
*output_ << "OnDataFrameReceived: " << stream_id << ", "
<< "received: " << payload_length << " bytes" << "\n";
}
}
void OnHeadersFrameReceived(
QuicStreamId stream_id,
QuicByteCount compressed_headers_length) override {
if (output_) {
*output_ << "OnHeadersFrameReceived: " << stream_id << ", "
<< "received: " << compressed_headers_length << " bytes" << "\n";
}
}
void OnHeadersDecoded(QuicStreamId stream_id,
const QuicHeaderList& headers) override {
if (output_) {
*output_ << "OnHeadersDecoded: " << stream_id << ", "
<< "received: " << headers.DebugString() << "\n";
}
}
void OnUnknownFrameReceived(QuicStreamId stream_id,
uint64_t frame_type,
QuicByteCount payload_length) override {
if (output_) {
*output_ << "OnUnknownFrameReceived: " << stream_id << ", "
<< "frame_type: " << frame_type << ","
<< "received: " << payload_length << "bytes " << "\n ";
}
}
void OnSettingsFrameSent(const SettingsFrame& frame) override {
if (output_) {
*output_ << "OnSettingsFrameSent: " << frame << "\n";
}
}
void OnSettingsFrameResumed(const SettingsFrame& frame) override {
if (output_) {
*output_ << "OnSettingsFrameResumed: " << frame << "\n";
}
}
void OnGoAwayFrameSent(QuicStreamId stream_id) override {
if (output_) {
*output_ << "OnGoAwayFrameSent: " << stream_id << "\n";
}
}
void OnPriorityUpdateFrameSent(const PriorityUpdateFrame& frame) override {
if (output_) {
*output_ << "OnPriorityUpdateFrameSent: " << frame << "\n";
}
}
void OnDataFrameSent(QuicStreamId stream_id,
QuicByteCount payload_length) override {
if (output_) {
*output_ << "OnDataFrameSent: " << stream_id << ", "
<< "sent: " << payload_length << " bytes" << "\n";
}
}
void OnHeadersFrameSent(
QuicStreamId stream_id,
const quiche::HttpHeaderBlock& header_block) override {
if (output_) {
*output_ << "OnHeadersFrameSent: " << stream_id << ", "
<< "sent: " << header_block.DebugString() << "\n";
}
}
private:
mutable raw_ptr<std::ostream> output_ = nullptr;
};
} // namespace
} // namespace quic
namespace net::test {
QuicSimpleServerSessionForTest::~QuicSimpleServerSessionForTest() = default;
// Always return true if CryptoStream is created, so that the pending stream can
// handle initial setting frames for tests.
bool QuicSimpleServerSessionForTest::IsEncryptionEstablished() const {
return GetCryptoStream() != nullptr;
}
QuicSocketDataProvider::Expectation::Expectation(
std::string name,
Type type,
int rv,
std::unique_ptr<quic::QuicEncryptedPacket> packet)
: name_(std::move(name)),
type_(type),
rv_(rv),
packet_(std::move(packet)) {}
QuicSocketDataProvider::Expectation::Expectation(
QuicSocketDataProvider::Expectation&&) = default;
QuicSocketDataProvider::Expectation::~Expectation() = default;
QuicSocketDataProvider::Expectation& QuicSocketDataProvider::Expectation::After(
std::string name) {
after_.insert(std::move(name));
return *this;
}
std::string QuicSocketDataProvider::Expectation::TypeToString(
QuicSocketDataProvider::Expectation::Type type) {
switch (type) {
case Expectation::Type::READ:
return "READ";
case Expectation::Type::WRITE:
return "WRITE";
case Expectation::Type::PAUSE:
return "PAUSE";
}
NOTREACHED();
}
void QuicSocketDataProvider::Expectation::Consume() {
CHECK(!consumed_);
VLOG(1) << "Consuming " << TypeToString(type_) << " expectation " << name_;
consumed_ = true;
}
QuicSocketDataProvider::QuicSocketDataProvider(quic::ParsedQuicVersion version)
: printer_(version),
crypto_config_(quic::QuicCryptoServerConfig::TESTING,
quic::QuicRandom::GetInstance(),
quic::test::crypto_test_utils::ProofSourceForTesting(),
quic::KeyExchangeSource::Default()),
compressed_certs_cache_(
quic::QuicCompressedCertsCache::kQuicCompressedCertsCacheSize) {
session_for_actual_ = GenSimpleServerSession();
session_for_expected_ = GenSimpleServerSession();
}
QuicSocketDataProvider::~QuicSocketDataProvider() = default;
std::unique_ptr<quic::QuicSimpleServerSession>
QuicSocketDataProvider::GenSimpleServerSession() {
quic::test::MockQuicConnection* connection =
new quic::test::MockQuicConnection( // IN-TEST
&helper_, &alarm_factory_, quic::Perspective::IS_SERVER);
connection->AdvanceTime(quic::QuicTime::Delta::FromSeconds(1));
connection->SetEncrypter(
quic::ENCRYPTION_FORWARD_SECURE,
std::make_unique<quic::NullEncrypter>(connection->perspective()));
std::unique_ptr<QuicSimpleServerSessionForTest> session =
std::make_unique<QuicSimpleServerSessionForTest>(
config_, quic::CurrentSupportedVersions(), connection, &owner_,
&stream_helper_, &crypto_config_, &compressed_certs_cache_,
&memory_cache_backend_);
session->Initialize();
return session;
}
std::string QuicSocketDataProvider::PrintWithQuicSession(
quic::QuicSimpleServerSession* session,
std::string data) {
std::ostringstream output;
quic::HttpStreamPrinter printer(&output);
session->set_debug_visitor(&printer);
std::string res = printer_.PrintWithQuicSession(data, output, session);
session->set_debug_visitor(nullptr);
return res;
}
QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddRead(
std::string name,
std::unique_ptr<quic::QuicEncryptedPacket> packet) {
expectations_.push_back(Expectation(std::move(name), Expectation::Type::READ,
OK, std::move(packet)));
return expectations_.back();
}
QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddRead(
std::string name,
std::unique_ptr<quic::QuicReceivedPacket> packet) {
uint8_t tos_byte = static_cast<uint8_t>(packet->ecn_codepoint());
return AddRead(std::move(name),
static_cast<std::unique_ptr<quic::QuicEncryptedPacket>>(
std::move(packet)))
.TosByte(tos_byte);
}
QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddReadError(
std::string name,
int rv) {
CHECK_NE(rv, OK);
CHECK_NE(rv, ERR_IO_PENDING);
expectations_.push_back(
Expectation(std::move(name), Expectation::Type::READ, rv, nullptr));
return expectations_.back();
}
QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddWrite(
std::string name,
std::unique_ptr<quic::QuicEncryptedPacket> packet,
int rv) {
expectations_.push_back(Expectation(std::move(name), Expectation::Type::WRITE,
rv, std::move(packet)));
return expectations_.back();
}
QuicSocketDataProvider::Expectation& QuicSocketDataProvider::AddWriteError(
std::string name,
int rv) {
CHECK_NE(rv, OK);
CHECK_NE(rv, ERR_IO_PENDING);
expectations_.push_back(
Expectation(std::move(name), Expectation::Type::WRITE, rv, nullptr));
return expectations_.back();
}
QuicSocketDataProvider::PausePoint QuicSocketDataProvider::AddPause(
std::string name) {
expectations_.push_back(
Expectation(std::move(name), Expectation::Type::PAUSE, OK, nullptr));
return expectations_.size() - 1;
}
bool QuicSocketDataProvider::AllDataConsumed() const {
return std::all_of(
expectations_.begin(), expectations_.end(),
[](const Expectation& expectation) { return expectation.consumed(); });
}
void QuicSocketDataProvider::RunUntilPause(
QuicSocketDataProvider::PausePoint pause_point) {
if (!paused_at_.has_value()) {
run_until_run_loop_ = std::make_unique<base::RunLoop>();
run_until_run_loop_->Run();
run_until_run_loop_.reset();
}
CHECK(paused_at_.has_value() && *paused_at_ == pause_point)
<< "Did not pause at '" << expectations_[pause_point].name() << "'.";
}
void QuicSocketDataProvider::Resume() {
CHECK(paused_at_.has_value());
VLOG(1) << "Resuming from pause point " << expectations_[*paused_at_].name();
expectations_[*paused_at_].Consume();
paused_at_ = std::nullopt;
ExpectationConsumed();
}
void QuicSocketDataProvider::RunUntilAllConsumed() {
if (!AllDataConsumed()) {
run_until_run_loop_ = std::make_unique<base::RunLoop>();
run_until_run_loop_->Run();
run_until_run_loop_.reset();
}
// If that run timed out, then there will still be un-consumed data.
if (!AllDataConsumed()) {
std::vector<size_t> unconsumed;
for (size_t i = 0; i < expectations_.size(); i++) {
if (!expectations_[i].consumed()) {
unconsumed.push_back(i);
}
}
FAIL() << "All expectations were not consumed; remaining: "
<< ExpectationList(unconsumed);
}
}
MockRead QuicSocketDataProvider::OnRead() {
CHECK(!read_pending_);
read_pending_ = true;
std::optional<MockRead> next_read = ConsumeNextRead();
if (!next_read.has_value()) {
return MockRead(ASYNC, ERR_IO_PENDING);
}
read_pending_ = false;
return *next_read;
}
MockWriteResult QuicSocketDataProvider::OnWrite(const std::string& data) {
CHECK(!write_pending_.has_value());
write_pending_ = data;
std::optional<MockWriteResult> next_write = ConsumeNextWrite();
if (!next_write.has_value()) {
// If Write() was called when no corresponding expectation exists, that's
// an error unless execution is currently paused, in which case it's just
// pending. This rarely occurs because the only other type of expectation
// that might be blocking a WRITE is a READ, and QUIC implementations
// typically eagerly consume READs.
if (paused_at_.has_value()) {
return MockWriteResult(ASYNC, ERR_IO_PENDING);
} else {
ADD_FAILURE() << "Write call when none is expected:\n"
<< printer_.PrintWrite(data);
return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
}
}
write_pending_ = std::nullopt;
return *next_write;
}
bool QuicSocketDataProvider::AllReadDataConsumed() const {
return AllDataConsumed();
}
bool QuicSocketDataProvider::AllWriteDataConsumed() const {
return AllDataConsumed();
}
void QuicSocketDataProvider::CancelPendingRead() {
read_pending_ = false;
}
void QuicSocketDataProvider::Reset() {
// Note that `Reset` is a parent-class method with a confusing name. It is
// used to initialize the socket data provider before it is used.
// Map names to index, and incidentally check for duplicate names.
std::map<std::string, size_t> names;
for (size_t i = 0; i < expectations_.size(); i++) {
Expectation& expectation = expectations_[i];
auto [_, inserted] = names.insert({expectation.name(), i});
CHECK(inserted) << "Another expectation named " << expectation.name()
<< " exists.";
}
// Calculate `dependencies_` mapping indices in `expectations_` to indices
// of the expectations they depend on.
dependencies_.clear();
for (size_t i = 0; i < expectations_.size(); i++) {
Expectation& expectation = expectations_[i];
if (expectation.after().empty()) {
// If no other dependencies are given, make the expectation depend on
// the previous expectation.
if (i > 0) {
dependencies_[i].insert(i - 1);
}
} else {
for (auto& after : expectation.after()) {
const auto dep = names.find(after);
CHECK(dep != names.end()) << "No expectation named " << after;
dependencies_[i].insert(dep->second);
}
}
}
pending_maybe_consume_expectations_ = false;
read_pending_ = false;
write_pending_ = std::nullopt;
MaybeConsumeExpectations();
}
std::optional<size_t> QuicSocketDataProvider::FindReadyExpectations(
Expectation::Type type) {
std::vector<size_t> matches;
for (size_t i = 0; i < expectations_.size(); i++) {
const Expectation& expectation = expectations_[i];
if (expectation.consumed() || expectation.type() != type) {
continue;
}
bool found_unconsumed = false;
for (auto dep : dependencies_[i]) {
if (!expectations_[dep].consumed_) {
found_unconsumed = true;
break;
}
}
if (!found_unconsumed) {
matches.push_back(i);
}
}
if (matches.size() > 1) {
std::string exp_type = Expectation::TypeToString(type);
std::string names = ExpectationList(matches);
CHECK(matches.size() <= 1)
<< "Multiple expectations of type " << exp_type
<< " are ready: " << names << ". Use .After() to disambiguate.";
}
return matches.empty() ? std::nullopt : std::make_optional(matches[0]);
}
std::optional<MockRead> QuicSocketDataProvider::ConsumeNextRead() {
CHECK(read_pending_);
std::optional<size_t> ready = FindReadyExpectations(Expectation::Type::READ);
if (!ready.has_value()) {
return std::nullopt;
}
// If there's exactly one matching expectation, return it.
Expectation& ready_expectation = expectations_[*ready];
MockRead read(ready_expectation.mode(), ready_expectation.rv());
if (ready_expectation.packet()) {
read.data = ready_expectation.packet()->AsStringPiece();
}
read.tos = ready_expectation.tos_byte();
ready_expectation.Consume();
ExpectationConsumed();
return read;
}
std::optional<MockWriteResult> QuicSocketDataProvider::ConsumeNextWrite() {
CHECK(write_pending_.has_value());
std::optional<size_t> ready = FindReadyExpectations(Expectation::Type::WRITE);
if (!ready.has_value()) {
return std::nullopt;
}
// If there's exactly one matching expectation, check if it matches the
// write and return it.
Expectation& ready_expectation = expectations_[*ready];
if (ready_expectation.packet()) {
if (!VerifyWriteData(ready_expectation)) {
return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
}
}
MockWriteResult write(ready_expectation.mode(),
ready_expectation.packet()
? ready_expectation.packet()->length()
: ready_expectation.rv());
ready_expectation.Consume();
ExpectationConsumed();
return write;
}
void QuicSocketDataProvider::MaybeConsumeExpectations() {
pending_maybe_consume_expectations_ = false;
if (read_pending_) {
std::optional<MockRead> next_read = ConsumeNextRead();
if (next_read.has_value()) {
read_pending_ = false;
if (socket()) {
socket()->OnReadComplete(*next_read);
}
}
}
if (write_pending_.has_value()) {
std::optional<MockWriteResult> next_write = ConsumeNextWrite();
if (next_write.has_value()) {
write_pending_ = std::nullopt;
if (socket()) {
socket()->OnWriteComplete(next_write->result);
}
}
}
if (!paused_at_) {
std::optional<size_t> ready =
FindReadyExpectations(Expectation::Type::PAUSE);
if (ready.has_value()) {
VLOG(1) << "Pausing at " << expectations_[*ready].name();
paused_at_ = *ready;
if (run_until_run_loop_) {
run_until_run_loop_->Quit();
}
}
}
if (run_until_run_loop_ && AllDataConsumed()) {
run_until_run_loop_->Quit();
}
}
void QuicSocketDataProvider::ExpectationConsumed() {
if (pending_maybe_consume_expectations_) {
return;
}
pending_maybe_consume_expectations_ = true;
// Call `MaybeConsumeExpectations` in a task. That method may trigger
// consumption of other expectations, and that consumption must happen
// _after_ the current call to `Read` or `Write` has finished.
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&QuicSocketDataProvider::MaybeConsumeExpectations,
weak_factory_.GetWeakPtr()));
}
bool QuicSocketDataProvider::VerifyWriteData(
QuicSocketDataProvider::Expectation& expectation) {
std::string expected_data(expectation.packet()->data(),
expectation.packet()->length());
std::string& actual_data = *write_pending_;
actual_log_.append(
PrintWithQuicSession(session_for_actual_.get(), actual_data));
expected_log_.append(
PrintWithQuicSession(session_for_expected_.get(), expected_data));
bool write_matches = actual_data == expected_data;
EXPECT_TRUE(write_matches)
<< "Expectation '" << expectation.name()
<< "' not met. Actual formatted write data:\n"
<< actual_log_ << "But expectation '" << expectation.name()
<< "' expected formatted write data:\n"
<< expected_log_ << "Actual raw write data:\n"
<< HexDump(actual_data) << "Expected raw write data:\n"
<< HexDump(expected_data);
return write_matches;
}
std::string QuicSocketDataProvider::ExpectationList(
const std::vector<size_t>& indices) {
std::ostringstream names;
bool first = true;
for (auto i : indices) {
names << (first ? "" : ", ") << expectations_[i].name();
first = false;
}
return names.str();
}
} // namespace net::test