| // Copyright 2025 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include <set> |
| #include <string> |
| |
| #include "base/check.h" |
| #include "base/containers/contains.h" |
| #include "base/files/file_path.h" |
| #include "base/logging.h" |
| #include "base/memory/raw_ptr.h" |
| #include "base/notreached.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/string_util.h" |
| #include "base/strings/utf_string_conversions.h" |
| #include "third_party/protobuf/src/google/protobuf/compiler/code_generator.h" |
| #include "third_party/protobuf/src/google/protobuf/compiler/cpp/helpers.h" |
| #include "third_party/protobuf/src/google/protobuf/compiler/cpp/names.h" |
| #include "third_party/protobuf/src/google/protobuf/compiler/importer.h" |
| #include "third_party/protobuf/src/google/protobuf/compiler/plugin.h" |
| #include "third_party/protobuf/src/google/protobuf/descriptor.h" |
| #include "third_party/protobuf/src/google/protobuf/io/printer.h" |
| |
| namespace { |
| |
| using google::protobuf::Descriptor; |
| using google::protobuf::FieldDescriptor; |
| using google::protobuf::FileDescriptor; |
| using google::protobuf::compiler::GeneratorContext; |
| using google::protobuf::compiler::cpp::ClassName; |
| using google::protobuf::compiler::cpp::FieldName; |
| using google::protobuf::compiler::cpp::Namespace; |
| using google::protobuf::compiler::cpp::NamespaceOpener; |
| using google::protobuf::compiler::cpp::QualifiedClassName; |
| using google::protobuf::io::Printer; |
| using google::protobuf::io::ZeroCopyOutputStream; |
| |
| // Returns if the descriptor is for a synthetic 'map entry' message type, |
| // which is internally created by the protobuf library to support map fields. |
| // Map fields are instead handled explicitly in the generation via the |
| // `is_map()` case. |
| bool IsSyntheticMapEntry(const Descriptor& message) { |
| return message.map_key() != nullptr; |
| } |
| |
| class ProtoGmockGenerator : public google::protobuf::compiler::CodeGenerator { |
| public: |
| ProtoGmockGenerator() = default; |
| ~ProtoGmockGenerator() override = default; |
| |
| bool Generate(const FileDescriptor* file, |
| const std::string& options, // Options from build system |
| GeneratorContext* context, |
| std::string* error) const override { |
| CHECK(file); |
| |
| base::FilePath proto_file_path = base::FilePath::FromASCII(file->name()); |
| |
| base::FilePath h_file_path = |
| proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("test.h")); |
| const std::unique_ptr<ZeroCopyOutputStream> h_stream( |
| context->Open(h_file_path.AsUTF8Unsafe())); |
| Printer h_printer(h_stream.get(), Printer::Options{'$', nullptr}); |
| |
| std::string include_guard = |
| base::ToUpperASCII(h_file_path.AsUTF8Unsafe()) + "_"; |
| CHECK(base::ReplaceChars(include_guard, ".-/\\", "_", &include_guard)); |
| |
| std::set<std::string> includes = { |
| "testing/gmock/include/gmock/gmock.h", |
| proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("pb.h")) |
| .AsUTF8Unsafe(), |
| "components/proto_extras/proto_matchers.h"}; |
| for (int i = 0; i < file->dependency_count(); i++) { |
| base::FilePath dependency_proto_file_path = |
| base::FilePath::FromASCII(file->dependency(i)->name()); |
| includes.insert(dependency_proto_file_path |
| .ReplaceExtension(FILE_PATH_LITERAL("test.h")) |
| .AsUTF8Unsafe()); |
| } |
| |
| h_printer.Emit( |
| {{"include_guard", include_guard}, |
| {"proto_file_path", proto_file_path.AsUTF8Unsafe()}, |
| {"includes", |
| [&] { |
| for (const auto& include : includes) { |
| h_printer.Print("#include \"$f$\"\n", "f", include); |
| } |
| }}, |
| {"matchers", |
| [&] { |
| NamespaceOpener ns(Namespace(file), &h_printer); |
| for (int i = 0; i < file->message_type_count(); i++) { |
| PrintMatchersRecursive(*file->message_type(i), &h_printer); |
| } |
| }}, |
| {"print_to_declarations", |
| [&] { |
| NamespaceOpener ns(Namespace(file), &h_printer); |
| for (int i = 0; i < file->message_type_count(); i++) { |
| PrintToDeclarationRecursive(*file->message_type(i), &h_printer); |
| } |
| }}}, |
| R"(// Generated by the proto_test_extras plugin. DO NOT EDIT! |
| // source: $proto_file_path$ |
| |
| #ifndef $include_guard$ |
| #define $include_guard$ |
| |
| #include <iosfwd> |
| |
| $includes$ |
| |
| $matchers$ |
| |
| $print_to_declarations$ |
| |
| #endif // $include_guard$ |
| )"); |
| |
| base::FilePath cc_file_path = |
| proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("test.cc")); |
| const std::unique_ptr<ZeroCopyOutputStream> cc_stream( |
| context->Open(cc_file_path.AsUTF8Unsafe())); |
| Printer cc_printer(cc_stream.get(), Printer::Options{'$', nullptr}); |
| cc_printer.Emit( |
| {{"proto_file_path", proto_file_path.AsUTF8Unsafe()}, |
| {"header", |
| proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("test.h")) |
| .AsUTF8Unsafe()}, |
| {"to_value_header", |
| proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("to_value.h")) |
| .AsUTF8Unsafe()}, |
| {"print_to_definitions", |
| [&] { |
| NamespaceOpener ns(Namespace(file), &cc_printer); |
| for (int i = 0; i < file->message_type_count(); i++) { |
| PrintToDefinitionRecursive(*file->message_type(i), &cc_printer); |
| } |
| }}}, |
| R"(// Generated by the proto_test_extras plugin. DO NOT EDIT! |
| // source: $proto_file_path$ |
| |
| #include "$header$" |
| |
| #include <ostream> |
| |
| #include "base/values.h" |
| #include "$to_value_header$" |
| |
| $print_to_definitions$ |
| |
| )"); |
| |
| return true; |
| } |
| |
| std::string UniqueMatcherName(const Descriptor& message) const { |
| return base::StrCat({"Equals", ClassName(&message)}); |
| } |
| |
| std::string QualifiedMatcher(const Descriptor& message) const { |
| return base::StrCat({Namespace(&message), "::", "Equals", |
| ClassName(&message), "<", QualifiedClassName(&message), |
| ">"}); |
| } |
| |
| void PrintFieldMatcher(const FieldDescriptor& field, |
| const std::string& message_class_name, |
| Printer* printer) const { |
| std::string field_accessor = FieldName(&field); |
| |
| if (field.is_map()) { |
| std::string maybe_nested_matcher; |
| if (field.message_type()->map_value()->type() == |
| FieldDescriptor::Type::TYPE_MESSAGE || |
| field.message_type()->map_value()->type() == |
| FieldDescriptor::Type::TYPE_GROUP) { |
| maybe_nested_matcher = |
| ", &" + QualifiedMatcher( |
| *field.message_type()->map_value()->message_type()); |
| } |
| printer->Emit({{"message_class_name", message_class_name}, |
| {"field_accessor", field_accessor}, |
| {"maybe_nested_matcher", maybe_nested_matcher}}, |
| R"(::proto_extras::HasMapField( |
| "$field_accessor$", |
| &$message_class_name$::$field_accessor$, |
| expected$maybe_nested_matcher$))"); |
| return; |
| } |
| if (field.is_repeated()) { |
| std::string maybe_nested_matcher; |
| if (field.type() == FieldDescriptor::Type::TYPE_MESSAGE || |
| field.type() == FieldDescriptor::Type::TYPE_GROUP) { |
| maybe_nested_matcher = ", &" + QualifiedMatcher(*field.message_type()); |
| } |
| std::string resolve_field_function; |
| if (field.type() == FieldDescriptor::Type::TYPE_MESSAGE || |
| field.type() == FieldDescriptor::Type::TYPE_GROUP) { |
| resolve_field_function = "::proto_extras::ResolveRepeatedPtrField"; |
| } else { |
| resolve_field_function = "::proto_extras::ResolveRepeatedField"; |
| } |
| printer->Emit( |
| { |
| {"message_class_name", message_class_name}, |
| {"field_accessor", field_accessor}, |
| {"resolve_field_function", resolve_field_function}, |
| {"maybe_nested_matcher", maybe_nested_matcher}, |
| }, |
| R"(::proto_extras::HasRepeatedField( |
| "$field_accessor$", |
| $resolve_field_function$(&$message_class_name$::$field_accessor$), |
| expected$maybe_nested_matcher$))"); |
| } else if (field.has_presence()) { |
| std::string matcher; |
| if (field.type() == FieldDescriptor::Type::TYPE_MESSAGE || |
| field.type() == FieldDescriptor::Type::TYPE_GROUP) { |
| matcher = QualifiedMatcher(*field.message_type()); |
| } else { |
| matcher = "testing::Eq"; |
| } |
| printer->Emit( |
| { |
| {"field_accessor", field_accessor}, |
| {"message_class_name", message_class_name}, |
| {"matcher", matcher}, |
| }, |
| R"(::proto_extras::HasOptionalField( |
| "$field_accessor$", |
| &$message_class_name$::has_$field_accessor$, |
| &$message_class_name$::$field_accessor$, |
| expected.has_$field_accessor$(), |
| $matcher$(expected.$field_accessor$())))"); |
| } else { |
| printer->Emit( |
| {{"message_class_name", message_class_name}, |
| {"field_accessor", field_accessor}, |
| {"value_matcher", |
| [&] { |
| if (field.type() == FieldDescriptor::Type::TYPE_MESSAGE || |
| field.type() == FieldDescriptor::Type::TYPE_GROUP) { |
| printer->Emit({{"nested_matcher", |
| QualifiedMatcher(*field.message_type())}, |
| {"field_accessor", field_accessor}}, |
| "$nested_matcher$(expected.$field_accessor$())"); |
| } else { |
| printer->Print("testing::Eq(expected.$field_accessor$())", |
| "field_accessor", field_accessor); |
| } |
| }}}, |
| R"(testing::Property( |
| "$field_accessor$", |
| &$message_class_name$::$field_accessor$, |
| $value_matcher$))"); |
| } |
| } |
| |
| void PrintMatcher(const Descriptor& message, Printer* printer) const { |
| std::string message_class_name = ClassName(&message); |
| std::string matcher_name = UniqueMatcherName(message); |
| |
| printer->Emit( |
| {{"message_type", message_class_name}, {"matcher_name", matcher_name}}, |
| R"( |
| MATCHER_P($matcher_name$, expected, "") { |
| return testing::ExplainMatchResult( |
| testing::AllOf()"); |
| { |
| auto unindent = printer->WithIndent(10); |
| if (message.field_count() == 0) { |
| printer->Print("testing::IsTrue()"); |
| } else { |
| for (int i = 0; i < message.field_count(); ++i) { |
| const FieldDescriptor& field = *message.field(i); |
| PrintFieldMatcher(field, message_class_name, printer); |
| if (i < message.field_count() - 1) { |
| printer->Print(",\n"); |
| } else { |
| printer->Print("\n"); |
| } |
| } |
| } |
| } |
| |
| printer->Print(R"( |
| ), arg, result_listener); |
| } |
| )"); |
| } |
| |
| void PrintMatchersRecursive(const Descriptor& message, |
| Printer* printer) const { |
| if (IsSyntheticMapEntry(message)) { |
| return; |
| } |
| // Nested types need to be printed before the parent message so that the |
| // parent can reference them. |
| for (int i = 0; i < message.nested_type_count(); i++) { |
| PrintMatchersRecursive(*message.nested_type(i), printer); |
| } |
| PrintMatcher(message, printer); |
| } |
| |
| void PrintToDeclaration(const Descriptor& message, Printer* printer) const { |
| std::string message_class_name = ClassName(&message); |
| printer->Emit({{"message_type", message_class_name}}, |
| R"( |
| void PrintTo(const $message_type$& msg, std::ostream* os);)"); |
| } |
| |
| void PrintToDefinition(const Descriptor& message, Printer* printer) const { |
| std::string message_class_name = ClassName(&message); |
| printer->Emit({{"message_type", message_class_name}}, |
| R"( |
| void PrintTo(const $message_type$& msg, std::ostream* os) { |
| *os << Serialize(msg).DebugString(); |
| })"); |
| } |
| |
| void PrintToDeclarationRecursive(const Descriptor& message, |
| Printer* printer) const { |
| if (IsSyntheticMapEntry(message)) { |
| return; |
| } |
| PrintToDeclaration(message, printer); |
| for (int i = 0; i < message.nested_type_count(); i++) { |
| PrintToDeclarationRecursive(*message.nested_type(i), printer); |
| } |
| } |
| |
| void PrintToDefinitionRecursive(const Descriptor& message, |
| Printer* printer) const { |
| if (IsSyntheticMapEntry(message)) { |
| return; |
| } |
| PrintToDefinition(message, printer); |
| for (int i = 0; i < message.nested_type_count(); i++) { |
| PrintToDefinitionRecursive(*message.nested_type(i), printer); |
| } |
| } |
| }; |
| } // namespace |
| |
| int main(int argc, char** argv) { |
| ProtoGmockGenerator generator; |
| return google::protobuf::compiler::PluginMain(argc, argv, &generator); |
| } |