blob: 7f91723893cb36d688ed2953693df945dfe1626f [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <set>
#include <string>
#include "base/check.h"
#include "base/containers/contains.h"
#include "base/files/file_path.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/notreached.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h"
#include "third_party/protobuf/src/google/protobuf/compiler/code_generator.h"
#include "third_party/protobuf/src/google/protobuf/compiler/cpp/helpers.h"
#include "third_party/protobuf/src/google/protobuf/compiler/cpp/names.h"
#include "third_party/protobuf/src/google/protobuf/compiler/importer.h"
#include "third_party/protobuf/src/google/protobuf/compiler/plugin.h"
#include "third_party/protobuf/src/google/protobuf/descriptor.h"
#include "third_party/protobuf/src/google/protobuf/io/printer.h"
namespace {
using google::protobuf::Descriptor;
using google::protobuf::FieldDescriptor;
using google::protobuf::FileDescriptor;
using google::protobuf::compiler::GeneratorContext;
using google::protobuf::compiler::cpp::ClassName;
using google::protobuf::compiler::cpp::FieldName;
using google::protobuf::compiler::cpp::Namespace;
using google::protobuf::compiler::cpp::NamespaceOpener;
using google::protobuf::compiler::cpp::QualifiedClassName;
using google::protobuf::io::Printer;
using google::protobuf::io::ZeroCopyOutputStream;
// Returns if the descriptor is for a synthetic 'map entry' message type,
// which is internally created by the protobuf library to support map fields.
// Map fields are instead handled explicitly in the generation via the
// `is_map()` case.
bool IsSyntheticMapEntry(const Descriptor& message) {
return message.map_key() != nullptr;
}
class ProtoGmockGenerator : public google::protobuf::compiler::CodeGenerator {
public:
ProtoGmockGenerator() = default;
~ProtoGmockGenerator() override = default;
bool Generate(const FileDescriptor* file,
const std::string& options, // Options from build system
GeneratorContext* context,
std::string* error) const override {
CHECK(file);
base::FilePath proto_file_path = base::FilePath::FromASCII(file->name());
base::FilePath h_file_path =
proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("test.h"));
const std::unique_ptr<ZeroCopyOutputStream> h_stream(
context->Open(h_file_path.AsUTF8Unsafe()));
Printer h_printer(h_stream.get(), Printer::Options{'$', nullptr});
std::string include_guard =
base::ToUpperASCII(h_file_path.AsUTF8Unsafe()) + "_";
CHECK(base::ReplaceChars(include_guard, ".-/\\", "_", &include_guard));
std::set<std::string> includes = {
"testing/gmock/include/gmock/gmock.h",
proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("pb.h"))
.AsUTF8Unsafe(),
"components/proto_extras/proto_matchers.h"};
for (int i = 0; i < file->dependency_count(); i++) {
base::FilePath dependency_proto_file_path =
base::FilePath::FromASCII(file->dependency(i)->name());
includes.insert(dependency_proto_file_path
.ReplaceExtension(FILE_PATH_LITERAL("test.h"))
.AsUTF8Unsafe());
}
h_printer.Emit(
{{"include_guard", include_guard},
{"proto_file_path", proto_file_path.AsUTF8Unsafe()},
{"includes",
[&] {
for (const auto& include : includes) {
h_printer.Print("#include \"$f$\"\n", "f", include);
}
}},
{"matchers",
[&] {
NamespaceOpener ns(Namespace(file), &h_printer);
for (int i = 0; i < file->message_type_count(); i++) {
PrintMatchersRecursive(*file->message_type(i), &h_printer);
}
}},
{"print_to_declarations",
[&] {
NamespaceOpener ns(Namespace(file), &h_printer);
for (int i = 0; i < file->message_type_count(); i++) {
PrintToDeclarationRecursive(*file->message_type(i), &h_printer);
}
}}},
R"(// Generated by the proto_test_extras plugin. DO NOT EDIT!
// source: $proto_file_path$
#ifndef $include_guard$
#define $include_guard$
#include <iosfwd>
$includes$
$matchers$
$print_to_declarations$
#endif // $include_guard$
)");
base::FilePath cc_file_path =
proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("test.cc"));
const std::unique_ptr<ZeroCopyOutputStream> cc_stream(
context->Open(cc_file_path.AsUTF8Unsafe()));
Printer cc_printer(cc_stream.get(), Printer::Options{'$', nullptr});
cc_printer.Emit(
{{"proto_file_path", proto_file_path.AsUTF8Unsafe()},
{"header",
proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("test.h"))
.AsUTF8Unsafe()},
{"to_value_header",
proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("to_value.h"))
.AsUTF8Unsafe()},
{"print_to_definitions",
[&] {
NamespaceOpener ns(Namespace(file), &cc_printer);
for (int i = 0; i < file->message_type_count(); i++) {
PrintToDefinitionRecursive(*file->message_type(i), &cc_printer);
}
}}},
R"(// Generated by the proto_test_extras plugin. DO NOT EDIT!
// source: $proto_file_path$
#include "$header$"
#include <ostream>
#include "base/values.h"
#include "$to_value_header$"
$print_to_definitions$
)");
return true;
}
std::string UniqueMatcherName(const Descriptor& message) const {
return base::StrCat({"Equals", ClassName(&message)});
}
std::string QualifiedMatcher(const Descriptor& message) const {
return base::StrCat({Namespace(&message), "::", "Equals",
ClassName(&message), "<", QualifiedClassName(&message),
">"});
}
void PrintFieldMatcher(const FieldDescriptor& field,
const std::string& message_class_name,
Printer* printer) const {
std::string field_accessor = FieldName(&field);
if (field.is_map()) {
std::string maybe_nested_matcher;
if (field.message_type()->map_value()->type() ==
FieldDescriptor::Type::TYPE_MESSAGE ||
field.message_type()->map_value()->type() ==
FieldDescriptor::Type::TYPE_GROUP) {
maybe_nested_matcher =
", &" + QualifiedMatcher(
*field.message_type()->map_value()->message_type());
}
printer->Emit({{"message_class_name", message_class_name},
{"field_accessor", field_accessor},
{"maybe_nested_matcher", maybe_nested_matcher}},
R"(::proto_extras::HasMapField(
"$field_accessor$",
&$message_class_name$::$field_accessor$,
expected$maybe_nested_matcher$))");
return;
}
if (field.is_repeated()) {
std::string maybe_nested_matcher;
if (field.type() == FieldDescriptor::Type::TYPE_MESSAGE ||
field.type() == FieldDescriptor::Type::TYPE_GROUP) {
maybe_nested_matcher = ", &" + QualifiedMatcher(*field.message_type());
}
std::string resolve_field_function;
if (field.type() == FieldDescriptor::Type::TYPE_MESSAGE ||
field.type() == FieldDescriptor::Type::TYPE_GROUP) {
resolve_field_function = "::proto_extras::ResolveRepeatedPtrField";
} else {
resolve_field_function = "::proto_extras::ResolveRepeatedField";
}
printer->Emit(
{
{"message_class_name", message_class_name},
{"field_accessor", field_accessor},
{"resolve_field_function", resolve_field_function},
{"maybe_nested_matcher", maybe_nested_matcher},
},
R"(::proto_extras::HasRepeatedField(
"$field_accessor$",
$resolve_field_function$(&$message_class_name$::$field_accessor$),
expected$maybe_nested_matcher$))");
} else if (field.has_presence()) {
std::string matcher;
if (field.type() == FieldDescriptor::Type::TYPE_MESSAGE ||
field.type() == FieldDescriptor::Type::TYPE_GROUP) {
matcher = QualifiedMatcher(*field.message_type());
} else {
matcher = "testing::Eq";
}
printer->Emit(
{
{"field_accessor", field_accessor},
{"message_class_name", message_class_name},
{"matcher", matcher},
},
R"(::proto_extras::HasOptionalField(
"$field_accessor$",
&$message_class_name$::has_$field_accessor$,
&$message_class_name$::$field_accessor$,
expected.has_$field_accessor$(),
$matcher$(expected.$field_accessor$())))");
} else {
printer->Emit(
{{"message_class_name", message_class_name},
{"field_accessor", field_accessor},
{"value_matcher",
[&] {
if (field.type() == FieldDescriptor::Type::TYPE_MESSAGE ||
field.type() == FieldDescriptor::Type::TYPE_GROUP) {
printer->Emit({{"nested_matcher",
QualifiedMatcher(*field.message_type())},
{"field_accessor", field_accessor}},
"$nested_matcher$(expected.$field_accessor$())");
} else {
printer->Print("testing::Eq(expected.$field_accessor$())",
"field_accessor", field_accessor);
}
}}},
R"(testing::Property(
"$field_accessor$",
&$message_class_name$::$field_accessor$,
$value_matcher$))");
}
}
void PrintMatcher(const Descriptor& message, Printer* printer) const {
std::string message_class_name = ClassName(&message);
std::string matcher_name = UniqueMatcherName(message);
printer->Emit(
{{"message_type", message_class_name}, {"matcher_name", matcher_name}},
R"(
MATCHER_P($matcher_name$, expected, "") {
return testing::ExplainMatchResult(
testing::AllOf()");
{
auto unindent = printer->WithIndent(10);
if (message.field_count() == 0) {
printer->Print("testing::IsTrue()");
} else {
for (int i = 0; i < message.field_count(); ++i) {
const FieldDescriptor& field = *message.field(i);
PrintFieldMatcher(field, message_class_name, printer);
if (i < message.field_count() - 1) {
printer->Print(",\n");
} else {
printer->Print("\n");
}
}
}
}
printer->Print(R"(
), arg, result_listener);
}
)");
}
void PrintMatchersRecursive(const Descriptor& message,
Printer* printer) const {
if (IsSyntheticMapEntry(message)) {
return;
}
// Nested types need to be printed before the parent message so that the
// parent can reference them.
for (int i = 0; i < message.nested_type_count(); i++) {
PrintMatchersRecursive(*message.nested_type(i), printer);
}
PrintMatcher(message, printer);
}
void PrintToDeclaration(const Descriptor& message, Printer* printer) const {
std::string message_class_name = ClassName(&message);
printer->Emit({{"message_type", message_class_name}},
R"(
void PrintTo(const $message_type$& msg, std::ostream* os);)");
}
void PrintToDefinition(const Descriptor& message, Printer* printer) const {
std::string message_class_name = ClassName(&message);
printer->Emit({{"message_type", message_class_name}},
R"(
void PrintTo(const $message_type$& msg, std::ostream* os) {
*os << Serialize(msg).DebugString();
})");
}
void PrintToDeclarationRecursive(const Descriptor& message,
Printer* printer) const {
if (IsSyntheticMapEntry(message)) {
return;
}
PrintToDeclaration(message, printer);
for (int i = 0; i < message.nested_type_count(); i++) {
PrintToDeclarationRecursive(*message.nested_type(i), printer);
}
}
void PrintToDefinitionRecursive(const Descriptor& message,
Printer* printer) const {
if (IsSyntheticMapEntry(message)) {
return;
}
PrintToDefinition(message, printer);
for (int i = 0; i < message.nested_type_count(); i++) {
PrintToDefinitionRecursive(*message.nested_type(i), printer);
}
}
};
} // namespace
int main(int argc, char** argv) {
ProtoGmockGenerator generator;
return google::protobuf::compiler::PluginMain(argc, argv, &generator);
}