blob: 10d7c9965c3d598765fbf7540b06dce7127a30e7 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <string>
#include "base/check.h"
#include "base/containers/contains.h"
#include "base/containers/flat_set.h"
#include "base/files/file_path.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/notreached.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h"
#include "third_party/protobuf/src/google/protobuf/compiler/code_generator.h"
#include "third_party/protobuf/src/google/protobuf/compiler/cpp/helpers.h"
#include "third_party/protobuf/src/google/protobuf/compiler/cpp/names.h"
#include "third_party/protobuf/src/google/protobuf/compiler/importer.h"
#include "third_party/protobuf/src/google/protobuf/compiler/plugin.h"
#include "third_party/protobuf/src/google/protobuf/descriptor.h"
#include "third_party/protobuf/src/google/protobuf/io/printer.h"
namespace {
using google::protobuf::Descriptor;
using google::protobuf::FieldDescriptor;
using google::protobuf::FileDescriptor;
using google::protobuf::OneofDescriptor;
using google::protobuf::compiler::GeneratorContext;
using google::protobuf::compiler::cpp::ClassName;
using google::protobuf::compiler::cpp::FieldName;
using google::protobuf::compiler::cpp::Namespace;
using google::protobuf::compiler::cpp::NamespaceOpener;
using google::protobuf::compiler::cpp::QualifiedClassName;
using google::protobuf::compiler::cpp::UnderscoresToCamelCase;
using google::protobuf::io::Printer;
using google::protobuf::io::ZeroCopyOutputStream;
struct ProtoExtrasGeneratorOptions {
bool generate_to_value_serialization = false;
bool generate_stream_operator = false;
bool generate_equality = false;
bool protobuf_full_support = false;
};
bool GetAllClassNames(const Descriptor& message,
base::flat_set<std::string>& class_names) {
class_names.insert(ClassName(&message));
for (int i = 0; i < message.nested_type_count(); i++) {
if (!GetAllClassNames(*message.nested_type(i), class_names)) {
return false;
}
}
return true;
}
void FieldToMapKeyFunction(const FieldDescriptor* field, Printer* printer) {
using enum FieldDescriptor::Type;
// From:
// - https://protobuf.dev/programming-guides/proto3/#maps
// - https://protobuf.dev/programming-guides/proto2/#maps
// > `key_type` can be any integral or string type (so, any scalar type except
// > for floating point types and bytes). Note that neither enum nor proto
// > messages are valid for `key_type`. The `value_type` can be any type
// except > another map.
auto conversion_function = [&]() -> std::string {
switch (field->type()) {
case TYPE_STRING:
return "static_cast<std::string>";
case TYPE_INT32:
case TYPE_INT64:
case TYPE_UINT64:
case TYPE_UINT32:
case TYPE_FIXED64:
case TYPE_FIXED32:
case TYPE_SFIXED64:
case TYPE_SFIXED32:
case TYPE_SINT64:
case TYPE_SINT32:
return "base::NumberToString";
case TYPE_BOOL:
case TYPE_BYTES:
case TYPE_ENUM:
case TYPE_DOUBLE:
case TYPE_FLOAT:
case TYPE_MESSAGE:
case TYPE_GROUP:
NOTREACHED() << "Invalid protobuf map key type.";
}
NOTREACHED();
};
printer->Print(conversion_function());
}
void FieldToValueFunction(const FieldDescriptor* field, Printer* printer) {
using enum FieldDescriptor::Type;
auto conversion_function = [&]() -> std::string {
switch (field->type()) {
case TYPE_DOUBLE:
case TYPE_FLOAT:
return "static_cast<double>";
case TYPE_INT32:
case TYPE_INT64:
case TYPE_UINT64:
case TYPE_UINT32:
case TYPE_FIXED64:
case TYPE_FIXED32:
case TYPE_SFIXED64:
case TYPE_SFIXED32:
case TYPE_SINT64:
case TYPE_SINT32:
return "::proto_extras::ToNumericTypeForValue";
case TYPE_BOOL:
return "static_cast<bool>";
case TYPE_STRING:
return "static_cast<std::string>";
case TYPE_BYTES:
return "base::Base64Encode";
case TYPE_ENUM:
return base::StrCat({QualifiedClassName(field->enum_type()), "_Name"});
case TYPE_MESSAGE:
case TYPE_GROUP:
// The Serialize function for the message is in the namespace of the
// nested message itself.
return base::StrCat({Namespace(field->message_type()), "::Serialize"});
}
NOTREACHED();
};
printer->Print(conversion_function());
}
void CreateToValueSerializationDefinitions(
const Descriptor& message,
Printer* printer,
const ProtoExtrasGeneratorOptions& options) {
printer->Emit(
{{"message_type", ClassName(&message)},
{"serialize_fields",
[&]() {
for (int j = 0; j < message.field_count(); j++) {
const FieldDescriptor* field = message.field(j);
std::string field_name = FieldName(message.field(j));
auto field_to_value = [&]() {
FieldToValueFunction(field, printer);
};
if (field->is_map()) {
const FieldDescriptor* map_value =
field->message_type()->map_value();
const FieldDescriptor* map_key = field->message_type()->map_key();
printer->Emit(
{{"field_name", field_name},
{"map_key_to_value",
[&] { FieldToMapKeyFunction(map_key, printer); }},
{"map_value_to_value",
[&] { FieldToValueFunction(map_value, printer); }}},
R"(
if (!message.$field_name$().empty()) {
base::DictValue map_dict;
for (const auto& [key, value] : message.$field_name$()) {
map_dict.Set($map_key_to_value$(key), $map_value_to_value$(value));
}
dict.Set("$field_name$", std::move(map_dict));
}
)");
} else if (field->is_repeated()) {
printer->Emit({{"field_name", field_name},
{"field_to_value", field_to_value}},
R"(
if (!message.$field_name$().empty()) {
base::ListValue list;
for (const auto& value : message.$field_name$()) {
list.Append($field_to_value$(value));
}
dict.Set("$field_name$", std::move(list));
}
)");
} else if (field->has_presence()) {
printer->Emit({{"field_name", field_name},
{"field_to_value", field_to_value}},
R"(
if (message.has_$field_name$()) {
dict.Set("$field_name$", $field_to_value$(message.$field_name$()));
}
)");
} else if (field->type() == FieldDescriptor::Type::TYPE_STRING ||
field->type() == FieldDescriptor::Type::TYPE_BYTES) {
printer->Emit({{"field_name", field_name},
{"field_to_value", field_to_value}},
R"(
if (!message.$field_name$().empty()) {
dict.Set("$field_name$", $field_to_value$(message.$field_name$()));
}
)");
} else {
printer->Emit({{"field_name", field_name},
{"field_to_value", field_to_value}},
R"(
dict.Set("$field_name$", $field_to_value$(message.$field_name$()));
)");
}
}
}}},
R"(
base::Value Serialize(const $message_type$& message) {
base::DictValue dict;
::proto_extras::SerializeUnknownFields(message, dict);
$serialize_fields$
return base::Value(std::move(dict));
}
void MaybeSerialize(const std::optional<$message_type$>& opt_message,
std::string_view name,
base::DictValue& output_dictionary) {
if (!opt_message.has_value()) {
return;
}
output_dictionary.Set(name, Serialize(*opt_message));
}
)");
}
void CreateOstreamDefinition(const Descriptor& message,
Printer* printer,
const ProtoExtrasGeneratorOptions& options) {
std::string message_type = ClassName(&message);
printer->Emit({{"message_type", message_type}},
R"(
std::ostream& operator<<(std::ostream& out, const $message_type$& message) {
// This relies on Serialize() from *.to_value.h.
return out << Serialize(message).DebugString();
}
)");
}
void CreateEqualityOperatorDefinition(
const Descriptor& message,
Printer* printer,
const ProtoExtrasGeneratorOptions& options) {
std::string message_type = ClassName(&message);
printer->Emit(
{{"message_type", message_type},
{"compare_fields",
[&]() {
// If protobuf_full_support is enabled, use MessageDifferencerEquals
// to compare the messages as the messages should be full Message
// types.
if (options.protobuf_full_support) {
printer->Print(
"if (!::proto_extras::MessageDifferencerEquals(lhs, rhs)) "
"return false;\n");
return;
}
printer->Print(
"if (lhs.unknown_fields() != rhs.unknown_fields()) return "
"false;\n");
// Compare oneof fields using a switch statement.
for (int i = 0; i < message.oneof_decl_count(); ++i) {
const OneofDescriptor* oneof = message.oneof_decl(i);
printer->Emit(
{{"oneof_name", oneof->name()},
{"message_type", message_type},
{"captital_oneof_name", base::ToUpperASCII(oneof->name())},
{"body",
[&]() {
for (int j = 0; j < oneof->field_count(); ++j) {
const FieldDescriptor* field = oneof->field(j);
std::string field_name = FieldName(field);
std::string case_name = UnderscoresToCamelCase(
field->lowercase_name(), /*cap_next_letter=*/true);
printer->Emit(
{
{"message_type", message_type},
{"case_name", case_name},
{"field_name", field_name},
},
R"(
case $message_type$::k$case_name$:
if (lhs.$field_name$() != rhs.$field_name$()) return false;
break;
)");
}
}}},
R"(
if (lhs.$oneof_name$_case() != rhs.$oneof_name$_case()) return false;
switch (lhs.$oneof_name$_case()) {
$body$
case $message_type$::$captital_oneof_name$_NOT_SET:
break;
}
)");
}
// Compare non-oneof fields.
for (int j = 0; j < message.field_count(); j++) {
const FieldDescriptor* field = message.field(j);
// Skip fields that are part of a oneof, as they are handled above.
if (field->containing_oneof()) {
continue;
}
std::string field_name = FieldName(field);
if (field->is_map()) {
printer->Emit({{"field_name", field_name}},
R"(
if (lhs.$field_name$().size() != rhs.$field_name$().size()) return false;
for (const auto& [key, value] : lhs.$field_name$()) {
auto it = rhs.$field_name$().find(key);
if (it == rhs.$field_name$().end()) return false;
if (value != it->second) return false;
}
)");
} else if (field->is_repeated()) {
printer->Emit({{"field_name", field_name}},
R"(
if (lhs.$field_name$().size() != rhs.$field_name$().size()) return false;
for (int i = 0; i < lhs.$field_name$().size(); ++i) {
if (lhs.$field_name$()[i] != rhs.$field_name$()[i]) return false;
}
)");
} else if (field->has_presence()) {
printer->Emit({{"field_name", field_name}},
R"(
if (lhs.has_$field_name$() != rhs.has_$field_name$()) return false;
if (lhs.has_$field_name$() && lhs.$field_name$() != rhs.$field_name$()) return false;
)");
} else {
printer->Emit({{"field_name", field_name}},
R"(
if (lhs.$field_name$() != rhs.$field_name$()) return false;
)");
}
}
}}},
R"(
bool operator==(const $message_type$& lhs, const $message_type$& rhs) {
if (&lhs == &rhs) return true;
$compare_fields$
return true;
}
bool operator!=(const $message_type$& lhs, const $message_type$& rhs) {
return !(lhs == rhs);
}
)");
}
// Returns if the descriptor is for a synthetic 'map entry' message type,
// which is internally created by the protobuf library to support map fields.
// Map fields are instead handled explicitly in the generation via the
// `is_map()` case.
bool IsSyntheticMapEntry(const Descriptor& message) {
return message.map_key() != nullptr;
}
class ProtoExtrasGenerator : public google::protobuf::compiler::CodeGenerator {
public:
ProtoExtrasGenerator() = default;
~ProtoExtrasGenerator() override = default;
bool Generate(const FileDescriptor* file,
const std::string& command_line_options,
GeneratorContext* context,
std::string* error) const override {
CHECK(file);
ProtoExtrasGeneratorOptions generator_options{
.generate_to_value_serialization = base::Contains(
command_line_options, "generate_to_value_serialization"),
.generate_stream_operator =
base::Contains(command_line_options, "generate_stream_operator"),
.generate_equality =
base::Contains(command_line_options, "generate_equality"),
.protobuf_full_support =
base::Contains(command_line_options, "protobuf_full_support"),
};
// The current design of this library assumes that only one of the
// serialization options is enabled.
CHECK(generator_options.generate_to_value_serialization ^
generator_options.generate_equality ^
generator_options.generate_stream_operator);
base::FilePath proto_file_path = base::FilePath::FromASCII(file->name());
base::FilePath::StringType file_suffix;
if (generator_options.generate_to_value_serialization) {
file_suffix = FILE_PATH_LITERAL(".to_value");
} else if (generator_options.generate_stream_operator) {
file_suffix = FILE_PATH_LITERAL(".ostream");
} else {
CHECK(generator_options.generate_equality);
file_suffix = FILE_PATH_LITERAL(".equal");
}
base::FilePath h_file_path =
proto_file_path.ReplaceExtension(file_suffix + FILE_PATH_LITERAL(".h"));
base::FilePath cc_file_path = proto_file_path.ReplaceExtension(
file_suffix + FILE_PATH_LITERAL(".cc"));
const std::unique_ptr<ZeroCopyOutputStream> h_stream(
context->Open(h_file_path.AsUTF8Unsafe()));
const std::unique_ptr<ZeroCopyOutputStream> cc_stream(
context->Open(cc_file_path.AsUTF8Unsafe()));
Printer h_printer(h_stream.get(), Printer::Options{'$', nullptr});
Printer cc_printer(cc_stream.get(), Printer::Options{'$', nullptr});
std::string include_guard =
base::ToUpperASCII(h_file_path.AsUTF8Unsafe()) + "_";
CHECK(base::ReplaceChars(include_guard, ".-/\\", "_", &include_guard));
auto forward_declarations = [&]() {
if (generator_options.generate_to_value_serialization) {
NamespaceOpener ns("base", &h_printer);
h_printer.Print("class DictValue;\nclass Value;\n");
}
NamespaceOpener ns(Namespace(file), &h_printer);
base::flat_set<std::string> forward_declarations;
for (int i = 0; i < file->message_type_count(); i++) {
GetAllClassNames(*file->message_type(i), forward_declarations);
}
for (const auto& forward_declaration : forward_declarations) {
h_printer.Print("class $class$;\n", "class", forward_declaration);
}
};
h_printer.Emit(
{
{"include_guard", include_guard},
{"proto_file_path", proto_file_path.AsUTF8Unsafe()},
{"includes",
[&] {
if (generator_options.generate_stream_operator) {
h_printer.Print("#include <iosfwd>\n\n");
}
if (generator_options.generate_to_value_serialization) {
h_printer.Print(
"#include <optional>\n#include <string_view>\n\n");
}
}},
{"forward_declarations", forward_declarations},
{"function_declarations",
[&] {
NamespaceOpener ns(Namespace(file), &h_printer);
for (int i = 0; i < file->message_type_count(); i++) {
PrintFunctionDeclarations(*file->message_type(i), &h_printer,
error, generator_options);
}
}},
},
R"(// Generated by the proto_extras plugin. DO NOT EDIT!
// source: $proto_file_path$
#ifndef $include_guard$
#define $include_guard$
$includes$
$forward_declarations$
$function_declarations$
#endif // $include_guard$
)");
// Determine the #includes for the implementation file.
base::flat_set<std::string> impl_system_includes;
base::flat_set<std::string> impl_user_includes = {
proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("pb.h"))
.AsUTF8Unsafe(),
};
if (generator_options.generate_stream_operator) {
impl_system_includes.insert("<ostream>");
impl_user_includes.insert(
proto_file_path.ReplaceExtension(FILE_PATH_LITERAL("to_value.h"))
.AsUTF8Unsafe());
impl_user_includes.insert("base/values.h");
}
if (generator_options.generate_to_value_serialization) {
impl_user_includes.insert("base/base64.h");
impl_user_includes.insert("base/strings/string_number_conversions.h");
impl_user_includes.insert("base/values.h");
impl_user_includes.insert("components/proto_extras/proto_extras_lib.h");
}
for (int i = 0; i < file->dependency_count(); i++) {
base::FilePath dependency_proto_file_path =
base::FilePath::FromASCII(file->dependency(i)->name());
bool needs_pb_h = generator_options.generate_to_value_serialization ||
generator_options.generate_equality;
if (needs_pb_h) {
impl_user_includes.insert(
dependency_proto_file_path
.ReplaceExtension(FILE_PATH_LITERAL("pb.h"))
.AsUTF8Unsafe());
}
if (generator_options.generate_to_value_serialization) {
impl_user_includes.insert(
dependency_proto_file_path
.ReplaceExtension(FILE_PATH_LITERAL("to_value.h"))
.AsUTF8Unsafe());
} else if (generator_options.generate_equality) {
impl_user_includes.insert(
dependency_proto_file_path
.ReplaceExtension(FILE_PATH_LITERAL("equal.h"))
.AsUTF8Unsafe());
}
}
if (generator_options.protobuf_full_support) {
impl_user_includes.insert(
"components/proto_extras/protobuf_full_support.h");
}
cc_printer.Emit(
{
{"proto_file_path", proto_file_path.AsUTF8Unsafe()},
{"header_file_path", h_file_path.AsUTF8Unsafe()},
{"system_includes",
[&] {
for (const auto& include : impl_system_includes) {
cc_printer.Print("#include $f$\n", "f", include);
}
}},
{"user_includes",
[&] {
for (const auto& include : impl_user_includes) {
cc_printer.Print("#include \"$f$\"\n", "f", include);
}
}},
{"function_definitions",
[&] {
NamespaceOpener ns(Namespace(file), &cc_printer);
for (int i = 0; i < file->message_type_count(); i++) {
PrintFunctionDefinitions(*file->message_type(i), &cc_printer,
error, generator_options);
}
}},
},
R"(// Generated by the proto_extras plugin. DO NOT EDIT!
// source: $proto_file_path$
#include "$header_file_path$"
$system_includes$
$user_includes$
$function_definitions$
)");
return true;
}
bool PrintFunctionDeclaration(
const Descriptor& message,
Printer* printer,
std::string* error,
const ProtoExtrasGeneratorOptions& options) const {
if (IsSyntheticMapEntry(message)) {
return true;
}
std::string message_type = ClassName(&message);
if (options.generate_to_value_serialization) {
printer->Print("base::Value Serialize(const $m$& message);", "m",
message_type);
printer->Print(
R"(
void MaybeSerialize(const std::optional<$m$>& opt_message,
std::string_view output_dictionary_field_name,
base::DictValue& output_dictionary);
)", "m", message_type);
}
if (options.generate_stream_operator) {
printer->Print(
"std::ostream& operator<<(std::ostream& out, const "
"$m$& message);\n",
"m", message_type);
}
if (options.generate_equality) {
printer->Print("bool operator==(const $m$& lhs, const $m$& rhs);\n", "m",
message_type);
}
return true;
}
bool PrintFunctionDefinition(
const Descriptor& message,
Printer* printer,
std::string* error,
const ProtoExtrasGeneratorOptions& options) const {
if (IsSyntheticMapEntry(message)) {
return true;
}
if (options.generate_to_value_serialization) {
CreateToValueSerializationDefinitions(message, printer, options);
}
if (options.generate_stream_operator) {
CreateOstreamDefinition(message, printer, options);
}
if (options.generate_equality) {
CreateEqualityOperatorDefinition(message, printer, options);
}
return true;
}
bool PrintFunctionDeclarations(
const Descriptor& message,
Printer* printer,
std::string* error,
const ProtoExtrasGeneratorOptions& options) const {
if (!PrintFunctionDeclaration(message, printer, error, options)) {
return false;
}
for (int i = 0; i < message.nested_type_count(); i++) {
if (!PrintFunctionDeclarations(*message.nested_type(i), printer, error,
options)) {
return false;
}
}
return true;
}
bool PrintFunctionDefinitions(
const Descriptor& message,
Printer* printer,
std::string* error,
const ProtoExtrasGeneratorOptions& options) const {
if (!PrintFunctionDefinition(message, printer, error, options)) {
return false;
}
for (int i = 0; i < message.nested_type_count(); i++) {
if (!PrintFunctionDefinitions(*message.nested_type(i), printer, error,
options)) {
return false;
}
}
return true;
}
};
} // namespace
int main(int argc, char** argv) {
ProtoExtrasGenerator generator;
return google::protobuf::compiler::PluginMain(argc, argv, &generator);
}