blob: c04295565b7b2da311060ecbdbf6f38995dbe8e6 [file] [log] [blame] [edit]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <algorithm>
#include <cctype>
#include <fstream>
#include <iostream>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "third_party/protobuf/src/google/protobuf/compiler/importer.h"
#include "third_party/protobuf/src/google/protobuf/descriptor.h"
namespace {
using ::google::protobuf::Descriptor;
using ::google::protobuf::FileDescriptor;
using ::google::protobuf::MethodDescriptor;
using ::google::protobuf::ServiceDescriptor;
using ::google::protobuf::compiler::DiskSourceTree;
using ::google::protobuf::compiler::Importer;
using ::google::protobuf::compiler::MultiFileErrorCollector;
// Command line flag for the path to proto file to parse.
constexpr const char kProtoFlag[] = "--proto";
// Command line flag to for additional includes to be added to the generated
// file.
constexpr const char kIncludesFlag[] = "--include";
// Command line flag for target output directory for the generated file.
constexpr const char kOutDirFlag[] = "--out-dir";
// Command line flag for source root directory for the generated file.
constexpr const char kSourceDirFlag[] = "--source-dir";
// Command line flag for additional import directories for proto lookup.
constexpr const char kImportDir[] = "--import-dir";
// Suffix added to the generated file.
constexpr const char kCastCorePbExtension[] = ".castcore.pb.h";
// Extension of gRPC protobuf file.
constexpr const char kGrpcPbExtension[] = ".grpc.pb.h";
static const bool is_debug = false;
std::string RemoveProtoExtension(const std::string& file_path) {
static const std::string kProtoExtension = ".proto";
if (file_path.length() < kProtoExtension.length()) {
std::cerr << "[ERROR] File path is too short: " << file_path << std::endl;
exit(1);
}
int ext_pos = file_path.length() - kProtoExtension.length();
if (file_path.substr(ext_pos) != kProtoExtension) {
std::cerr << "[ERROR] Not a proto file: " << file_path << std::endl;
exit(1);
}
return file_path.substr(0, ext_pos);
}
std::string ReplaceString(const std::string& input, const std::string& old_token,
const std::string& new_token) {
std::string output(input);
std::string::size_type pos;
while ((pos = output.find(old_token)) != std::string::npos) {
output.replace(pos, old_token.length(), new_token);
}
return output;
}
class SimpleErrorCollector : public MultiFileErrorCollector {
public:
void AddError(const std::string& filename, int line, int column,
const std::string& message) override {
std::cerr << "[ERROR] " << filename << ", Line " << line << ", Column "
<< column << ": " << message << std::endl;
}
};
class Includes {
public:
explicit Includes(std::vector<std::string> includes)
: includes_(includes.begin(), includes.end()) {}
void AddInclude(std::string include_file_path) {
includes_.insert(include_file_path);
}
void PrintIncludePrologues(std::ostream& os) const {
for (const auto& include : includes_) {
os << "#include \"" << include << "\"" << std::endl;
}
os << std::endl;
}
private:
std::set<std::string> includes_;
};
class SourceProto {
public:
SourceProto(std::string proto_file_path, Includes* includes)
: proto_file_path_(std::move(proto_file_path)),
output_file_path_(RemoveProtoExtension(proto_file_path_)),
cpp_headerguard_(CreateHeaderGuard(output_file_path_)),
includes_(includes) {}
bool Initialize(std::string source_dir,
std::vector<std::string> import_dirs) {
includes_->AddInclude(output_file_path_ + kGrpcPbExtension);
source_tree_.MapPath("", source_dir);
for (const auto& import_dir : import_dirs) {
source_tree_.MapPath("", source_dir + import_dir);
}
importer_ = std::make_unique<Importer>(&source_tree_, &error_collector_);
file_descriptor_ = importer_->Import(proto_file_path_);
if (!file_descriptor_) {
return false;
}
if (file_descriptor_->service_count() == 0) {
std::cout << "[WARNING] At least one service should be specified in the "
"proto file: "
<< proto_file_path_ << std::endl;
}
const std::string& package = file_descriptor_->package();
std::string::size_type start_pos = 0;
std::string::size_type next_pos;
while ((next_pos = package.find(".", start_pos)) != std::string::npos) {
namespaces_.emplace_back(package.substr(start_pos, next_pos - start_pos));
start_pos = next_pos + 1;
}
namespaces_.emplace_back(package.substr(start_pos));
return true;
}
bool Generate(std::string out_dir) const {
std::ostringstream header;
PrintHeaderGruardPrologues(header);
includes_->PrintIncludePrologues(header);
PrintNamespacePrologues(header);
for (int i = 0; i < file_descriptor_->service_count(); ++i) {
if (is_debug) {
std::cout << "[INFO] Generating Cast Core gRPC definitions for "
<< file_descriptor_->service(i)->name() << std::endl;
}
PrintCastCoreHandlerDefinition(header, file_descriptor_->service(i));
PrintCastCoreStubDefinition(header, file_descriptor_->service(i));
}
PrintNamespaceEpilogues(header);
PrintHeaderGruardEpilogues(header);
auto target_file = out_dir + output_file_path_ + kCastCorePbExtension;
if (is_debug) {
std::cout << "[INFO] Writing generated files: " << target_file
<< std::endl;
}
std::ofstream cpp_headerfile(target_file);
if (!cpp_headerfile) {
std::cerr << "[ERROR] Failed to open target header file: " << target_file
<< std::endl;
return false;
}
cpp_headerfile << header.str();
return true;
}
private:
void PrintHeaderGruardPrologues(std::ostream& os) const {
os << "#ifndef " << cpp_headerguard_ << std::endl
<< "#define " << cpp_headerguard_ << std::endl
<< std::endl;
}
void PrintHeaderGruardEpilogues(std::ostream& os) const {
os << "#endif // " << cpp_headerguard_ << std::endl;
}
void PrintNamespacePrologues(std::ostream& os) const {
for (const auto& ns : namespaces_) {
os << "namespace " << ns << " {" << std::endl;
}
os << std::endl;
}
void PrintNamespaceEpilogues(std::ostream& os) const {
for (const auto& ns : namespaces_) {
os << "} // namespace " << ns << std::endl;
}
os << std::endl;
}
void PrintCastCoreHandlerDefinition(
std::ostream& header, const ServiceDescriptor* service_descriptor) const {
const std::string& service_name = service_descriptor->name();
std::ostringstream method_names;
std::ostringstream class_methods;
for (int i = 0; i < service_descriptor->method_count(); ++i) {
const MethodDescriptor* method = service_descriptor->method(i);
if (method->client_streaming()) {
std::cout << "[WARNING] Client streaming APIs are not supported yet - "
"skipping "
<< method->name() << std::endl;
continue;
}
const Descriptor* input = method->input_type();
const Descriptor* output = method->output_type();
const std::string method_name_var = "k" + std::string(service_name) +
"_" + std::string(method->name()) +
"_MethodName";
method_names << "constexpr char " << method_name_var << "[] = \""
<< method->name() << "\";" << std::endl;
class_methods << " using " << method->name() << " = ";
if (method->server_streaming()) {
class_methods << "::cast::utils::GrpcServerStreamingHandler<";
} else {
class_methods << "::cast::utils::GrpcUnaryHandler<";
}
const auto request_type = FullTypeName(input->full_name());
const auto response_type = FullTypeName(output->full_name());
class_methods << service_name << ", " << request_type << ", "
<< response_type << ", " << method_name_var << ">;"
<< std::endl;
}
header << "// " << service_name << " gRPC handler." << std::endl
<< method_names.str() << std::endl
<< "class " << service_name << "Handler {" << std::endl
<< " public: " << std::endl
<< class_methods.str() << "};" << std::endl
<< std::endl;
}
void PrintCastCoreStubDefinition(
std::ostream& header, const ServiceDescriptor* service_descriptor) const {
const std::string& service_name = service_descriptor->name();
header << "// " << service_name << " gRPC stub." << std::endl
<< "class " << service_name << "Stub : "
<< " public ::cast::utils::GrpcStub<" << service_name << "> {"
<< std::endl
<< " public:" << std::endl
<< " using GrpcStub::GrpcStub;" << std::endl
<< " using GrpcStub::operator=;" << std::endl
<< " using GrpcStub::AsyncInterface;" << std::endl
<< " using GrpcStub::CreateCall;" << std::endl
<< " using GrpcStub::SyncInterface;" << std::endl
<< std::endl;
for (int i = 0; i < service_descriptor->method_count(); ++i) {
const MethodDescriptor* method = service_descriptor->method(i);
if (method->client_streaming()) {
std::cout << "[WARNING] Client streaming APIs are not supported yet - "
"skipping "
<< method->name() << std::endl;
continue;
}
const Descriptor* input = method->input_type();
const Descriptor* output = method->output_type();
const std::string method_name_var =
"k" + std::string(method->name()) + "Method";
header << " using " << method->name() << " = ";
if (method->server_streaming()) {
header << "::cast::utils::GrpcServerStreamingCall<";
} else {
header << "::cast::utils::GrpcUnaryCall<";
}
const auto request_type = FullTypeName(input->full_name());
const auto response_type = FullTypeName(output->full_name());
header << service_name << "Stub, " << request_type << ", "
<< response_type << ", &AsyncInterface::" << method->name();
if (!method->client_streaming() && !method->server_streaming()) {
// Add the sync interface defs
header << ", &SyncInterface::" << method->name();
}
header << ">;" << std::endl;
}
header << "};" << std::endl << std::endl;
}
static std::string FullTypeName(const std::string& proto_type_name) {
return "::" + ReplaceString(proto_type_name, ".", "::");
}
static std::string CreateHeaderGuard(const std::string& proto_file_path) {
std::ostringstream cpp_headerguard;
std::string result = proto_file_path;
std::transform(result.begin(), result.end(), result.begin(),
[](unsigned char c) -> unsigned char {
if (c == '.' || c == '/' || c == '\\') {
return '_';
}
return std::toupper(c);
});
result.append("_CASTCORE_PB_H_");
return result;
}
const std::string proto_file_path_;
const std::string output_file_path_;
const std::string cpp_headerguard_;
Includes* const includes_;
SimpleErrorCollector error_collector_;
DiskSourceTree source_tree_;
std::unique_ptr<Importer> importer_;
const FileDescriptor* file_descriptor_;
std::vector<std::string> namespaces_;
};
} // namespace
int main(int argc, char** argv) {
std::string flag;
std::string proto_file_path;
std::string out_dir;
std::string source_dir;
std::vector<std::string> include_file_paths;
std::vector<std::string> import_dirs;
for (int i = 1; i < argc; ++i) {
if (flag.empty()) {
flag = argv[i];
continue;
}
if (flag == kProtoFlag) {
proto_file_path = argv[i];
} else if (flag == kIncludesFlag) {
include_file_paths.push_back(argv[i]);
} else if (flag == kImportDir) {
import_dirs.push_back(argv[i]);
} else if (flag == kOutDirFlag) {
out_dir = argv[i];
if (out_dir.empty()) {
std::cerr << "[ERROR] Output directory must be specified" << std::endl;
exit(1);
}
if (out_dir[out_dir.length() - 1] != '/') {
out_dir.append("/");
}
} else if (flag == kSourceDirFlag) {
source_dir = argv[i];
} else {
std::cerr << "[ERROR] Unexpected flag: " << flag << std::endl;
exit(1);
}
flag.clear();
}
if (proto_file_path.empty()) {
std::cerr << "[ERROR] Proto file cannot be empty" << std::endl;
return -1;
}
if (is_debug) {
std::cout << "[INFO] Generating Cast Core gRPC definitions: proto="
<< proto_file_path << ", out_dir=" << out_dir
<< ", source_dir=" << source_dir << std::endl;
}
Includes includes(std::move(include_file_paths));
SourceProto source_proto(std::move(proto_file_path), &includes);
if (!source_proto.Initialize(std::move(source_dir), std::move(import_dirs))) {
return -2;
}
if (!source_proto.Generate(std::move(out_dir))) {
return -3;
}
return 0;
}