| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/c/c_api.h" |
| |
| #include <algorithm> |
| #include <memory> |
| #include <vector> |
| |
| #include "tensorflow/core/common_runtime/shape_refiner.h" |
| #include "tensorflow/core/framework/log_memory.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/partial_tensor_shape.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/graph/node_builder.h" |
| #include "tensorflow/core/lib/core/coding.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/stringpiece.h" |
| #include "tensorflow/core/lib/gtl/array_slice.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/platform/thread_annotations.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/public/session.h" |
| |
| // The implementation below is at the top level instead of the |
| // brain namespace because we are defining 'extern "C"' functions. |
| using tensorflow::error::Code; |
| using tensorflow::errors::InvalidArgument; |
| using tensorflow::gtl::ArraySlice; |
| using tensorflow::AllocationDescription; |
| using tensorflow::DataType; |
| using tensorflow::Env; |
| using tensorflow::Graph; |
| using tensorflow::GraphDef; |
| using tensorflow::mutex; |
| using tensorflow::mutex_lock; |
| using tensorflow::NameRangeMap; |
| using tensorflow::NameRangesForNode; |
| using tensorflow::NewSession; |
| using tensorflow::Node; |
| using tensorflow::NodeDef; |
| using tensorflow::NodeBuilder; |
| using tensorflow::OpDef; |
| using tensorflow::OpRegistry; |
| using tensorflow::PartialTensorShape; |
| using tensorflow::Reset; |
| using tensorflow::RunMetadata; |
| using tensorflow::RunOptions; |
| using tensorflow::Session; |
| using tensorflow::SessionOptions; |
| using tensorflow::Status; |
| using tensorflow::Tensor; |
| using tensorflow::TensorBuffer; |
| using tensorflow::TensorShape; |
| using tensorflow::TensorShapeProto; |
| |
| extern "C" { |
| |
| // -------------------------------------------------------------------------- |
| struct TF_Status { |
| Status status; |
| }; |
| |
| TF_Status* TF_NewStatus() { return new TF_Status; } |
| |
| void TF_DeleteStatus(TF_Status* s) { delete s; } |
| |
| void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { |
| s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg)); |
| } |
| |
| TF_Code TF_GetCode(const TF_Status* s) { |
| return static_cast<TF_Code>(s->status.code()); |
| } |
| |
| const char* TF_Message(const TF_Status* s) { |
| return s->status.error_message().c_str(); |
| } |
| |
| // -------------------------------------------------------------------------- |
| |
| namespace { |
| class TF_ManagedBuffer : public TensorBuffer { |
| public: |
| void* data_; |
| size_t len_; |
| void (*deallocator_)(void* data, size_t len, void* arg); |
| void* deallocator_arg_; |
| |
| ~TF_ManagedBuffer() override { |
| (*deallocator_)(data_, len_, deallocator_arg_); |
| } |
| |
| void* data() const override { return data_; } |
| size_t size() const override { return len_; } |
| TensorBuffer* root_buffer() override { return this; } |
| void FillAllocationDescription(AllocationDescription* proto) const override { |
| tensorflow::int64 rb = size(); |
| proto->set_requested_bytes(rb); |
| proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); |
| } |
| }; |
| |
| void* allocate_tensor(const char* operation, size_t len) { |
| void* data = |
| tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len); |
| if (tensorflow::LogMemory::IsEnabled()) { |
| tensorflow::LogMemory::RecordRawAllocation( |
| operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, |
| len, data, tensorflow::cpu_allocator()); |
| } |
| return data; |
| } |
| |
| void deallocate_buffer(void* data, size_t len, void* arg) { |
| if (tensorflow::LogMemory::IsEnabled()) { |
| tensorflow::LogMemory::RecordRawDeallocation( |
| "TensorFlow C Api", |
| tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data, |
| tensorflow::cpu_allocator(), false); |
| } |
| tensorflow::cpu_allocator()->DeallocateRaw(data); |
| } |
| |
| Status MessageToBuffer(const tensorflow::protobuf::Message& in, |
| TF_Buffer* out) { |
| if (out->data != nullptr) { |
| return InvalidArgument("Passing non-empty TF_Buffer is invalid."); |
| } |
| const auto proto_size = in.ByteSize(); |
| void* buf = malloc(proto_size); |
| in.SerializeToArray(buf, proto_size); |
| out->data = buf; |
| out->length = proto_size; |
| out->data_deallocator = [](void* data, size_t length) { free(data); }; |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| struct TF_Tensor { |
| TF_DataType dtype; |
| TensorShape shape; |
| TensorBuffer* buffer; |
| }; |
| |
| TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims, |
| int num_dims, size_t len) { |
| void* data = allocate_tensor("TF_AllocateTensor", len); |
| return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer, |
| nullptr); |
| } |
| |
| TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, |
| void* data, size_t len, |
| void (*deallocator)(void* data, size_t len, void* arg), |
| void* deallocator_arg) { |
| std::vector<tensorflow::int64> dimvec(num_dims); |
| for (int i = 0; i < num_dims; ++i) { |
| dimvec[i] = static_cast<tensorflow::int64>(dims[i]); |
| } |
| |
| TF_ManagedBuffer* buf = new TF_ManagedBuffer; |
| buf->len_ = len; |
| if (reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) { |
| // Copy the data into a buffer that satisfies Eigen's alignment |
| // requirements. |
| buf->data_ = allocate_tensor("TF_NewTensor", len); |
| std::memcpy(buf->data_, data, len); |
| buf->deallocator_ = deallocate_buffer; |
| buf->deallocator_arg_ = nullptr; |
| // Free the original buffer. |
| deallocator(data, len, deallocator_arg); |
| } else { |
| buf->data_ = data; |
| buf->deallocator_ = deallocator; |
| buf->deallocator_arg_ = deallocator_arg; |
| } |
| return new TF_Tensor{dtype, TensorShape(dimvec), buf}; |
| } |
| |
| void TF_DeleteTensor(TF_Tensor* t) { |
| t->buffer->Unref(); |
| delete t; |
| } |
| |
| TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; } |
| int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); } |
| int64_t TF_Dim(const TF_Tensor* t, int dim_index) { |
| return static_cast<int64_t>(t->shape.dim_size(dim_index)); |
| } |
| size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } |
| void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } |
| |
| // -------------------------------------------------------------------------- |
| struct TF_SessionOptions { |
| SessionOptions options; |
| }; |
| TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } |
| void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } |
| |
| void TF_SetTarget(TF_SessionOptions* options, const char* target) { |
| options->options.target = target; |
| } |
| |
| void TF_SetConfig(TF_SessionOptions* options, const void* proto, |
| size_t proto_len, TF_Status* status) { |
| if (!options->options.config.ParseFromArray(proto, proto_len)) { |
| status->status = InvalidArgument("Unparseable ConfigProto"); |
| } |
| } |
| // -------------------------------------------------------------------------- |
| TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; } |
| |
| TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { |
| void* copy = malloc(proto_len); |
| memcpy(copy, proto, proto_len); |
| |
| TF_Buffer* buf = new TF_Buffer; |
| buf->data = copy; |
| buf->length = proto_len; |
| buf->data_deallocator = [](void* data, size_t length) { free(data); }; |
| return buf; |
| } |
| |
| void TF_DeleteBuffer(TF_Buffer* buffer) { |
| if (buffer->data_deallocator != nullptr) { |
| (*buffer->data_deallocator)(const_cast<void*>(buffer->data), |
| buffer->length); |
| } |
| delete buffer; |
| } |
| |
| TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } |
| |
| // -------------------------------------------------------------------------- |
| struct TF_Session { |
| Session* session; |
| }; |
| |
| TF_Session* TF_NewSession(const TF_SessionOptions* opt, TF_Status* status) { |
| Session* session; |
| status->status = NewSession(opt->options, &session); |
| if (status->status.ok()) { |
| return new TF_Session({session}); |
| } else { |
| DCHECK_EQ(nullptr, session); |
| return NULL; |
| } |
| } |
| |
| void TF_CloseSession(TF_Session* s, TF_Status* status) { |
| status->status = s->session->Close(); |
| } |
| |
| void TF_DeleteSession(TF_Session* s, TF_Status* status) { |
| status->status = Status::OK(); |
| delete s->session; |
| delete s; |
| } |
| |
| void TF_ExtendGraph(TF_Session* s, const void* proto, size_t proto_len, |
| TF_Status* status) { |
| GraphDef g; |
| if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { |
| status->status = InvalidArgument("Invalid GraphDef"); |
| return; |
| } |
| status->status = s->session->Extend(g); |
| } |
| |
| static void DeleteArray(void* data, size_t size, void* arg) { |
| DCHECK_EQ(data, arg); |
| delete[] reinterpret_cast<char*>(arg); |
| } |
| |
| } // end extern "C" |
| |
| namespace tensorflow { |
| namespace { |
| |
| // Reset helper for converting character arrays to string vectors. |
| void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, |
| int ncontainers, TF_Status* status) { |
| std::vector<tensorflow::string> container_names(ncontainers); |
| for (int i = 0; i < ncontainers; ++i) { |
| container_names[i] = containers[i]; |
| } |
| |
| status->status = Reset(opt->options, container_names); |
| } |
| |
| } // namespace |
| } // namespace tensorflow |
| |
| extern "C" { |
| |
| void TF_Reset(const TF_SessionOptions* opt, const char** containers, |
| int ncontainers, TF_Status* status) { |
| tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status); |
| } |
| |
| } // end extern "C" |
| |
| namespace tensorflow { |
| |
| // Non-static for testing. |
| bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status) { |
| const tensorflow::int64 num_elements = src->shape.num_elements(); |
| const char* input = reinterpret_cast<const char*>(TF_TensorData(src)); |
| const size_t src_size = TF_TensorByteSize(src); |
| if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) < |
| num_elements) { |
| status->status = InvalidArgument( |
| "Malformed TF_STRING tensor; too short to hold number of elements"); |
| return false; |
| } |
| const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; |
| const char* limit = input + src_size; |
| |
| *dst = Tensor(static_cast<DataType>(src->dtype), src->shape); |
| auto dstarray = dst->flat<tensorflow::string>(); |
| for (tensorflow::int64 i = 0; i < num_elements; ++i) { |
| tensorflow::uint64 offset = |
| reinterpret_cast<const tensorflow::uint64*>(input)[i]; |
| tensorflow::uint64 len; |
| const char* p; |
| if (static_cast<ptrdiff_t>(offset) >= (limit - data_start) || |
| !(p = tensorflow::core::GetVarint64Ptr(data_start + offset, limit, |
| &len)) || |
| (static_cast<ptrdiff_t>(len) > (limit - p))) { |
| status->status = InvalidArgument("Malformed TF_STRING tensor; element ", |
| i, " out of range"); |
| return false; |
| } |
| dstarray(i).assign(p, len); |
| } |
| return true; |
| } |
| |
| // Non-static for testing. |
| TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) { |
| // Compute bytes needed for encoding. |
| size_t size = 0; |
| const auto& srcarray = src.flat<tensorflow::string>(); |
| for (int i = 0; i < srcarray.size(); ++i) { |
| const tensorflow::string& s = srcarray(i); |
| // uint64 starting_offset, varint64 length, string contents |
| size += sizeof(tensorflow::uint64) + |
| tensorflow::core::VarintLength(s.size()) + s.size(); |
| } |
| |
| // Encode all strings. |
| char* base = new char[size]; |
| char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size(); |
| char* dst = data_start; // Where next string is encoded. |
| tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base); |
| for (int i = 0; i < srcarray.size(); ++i) { |
| const tensorflow::string& s = srcarray(i); |
| *offsets = (dst - data_start); |
| offsets++; |
| dst = tensorflow::core::EncodeVarint64(dst, s.size()); |
| memcpy(dst, s.data(), s.size()); |
| dst += s.size(); |
| } |
| CHECK_EQ(dst, base + size); |
| |
| auto dims = src.shape().dim_sizes(); |
| std::vector<tensorflow::int64> dimvec(dims.size()); |
| for (size_t i = 0; i < dims.size(); ++i) { |
| dimvec[i] = dims[i]; |
| } |
| static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
| "64-bit int types should match in size"); |
| return TF_NewTensor(TF_STRING, |
| reinterpret_cast<const int64_t*>(dimvec.data()), |
| dimvec.size(), base, size, DeleteArray, base); |
| } |
| |
| class TensorCApi { |
| public: |
| static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } |
| static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, |
| TensorBuffer* buf) { |
| return Tensor(static_cast<DataType>(type), shape, buf); |
| } |
| }; |
| |
| // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to |
| // result in a zero-sized tensor. |
| static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { |
| static char empty; |
| tensorflow::int64 nelems = 1; |
| std::vector<tensorflow::int64> dims; |
| for (int i = 0; i < shape.dims(); ++i) { |
| dims.push_back(shape.dim_size(i)); |
| nelems *= shape.dim_size(i); |
| } |
| CHECK_EQ(nelems, 0); |
| static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
| "64-bit int types should match in size"); |
| return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()), |
| shape.dims(), reinterpret_cast<void*>(&empty), 0, |
| [](void*, size_t, void*) {}, nullptr); |
| } |
| |
| // Helpers for loading a TensorFlow plugin (a .so file). |
| Status LoadLibrary(const char* library_filename, void** result, |
| const void** buf, size_t* len); |
| |
| } // namespace tensorflow |
| |
| static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, |
| TF_Status* status) { |
| status->status = Status::OK(); |
| for (int i = 0; i < noutputs; ++i) { |
| c_outputs[i] = NULL; |
| } |
| } |
| |
| static bool TF_Run_Inputs( |
| TF_Tensor* const* c_inputs, |
| std::vector<std::pair<tensorflow::string, Tensor>>* input_pairs, |
| TF_Status* status) { |
| const int ninputs = input_pairs->size(); |
| bool ok = true; |
| for (int i = 0; i < ninputs; ++i) { |
| TF_Tensor* src = c_inputs[i]; |
| if (ok) { |
| if (c_inputs[i]->dtype != TF_STRING) { |
| (*input_pairs)[i].second = tensorflow::TensorCApi::MakeTensor( |
| src->dtype, src->shape, src->buffer); |
| } else { |
| // TF_STRING tensors require copying since Tensor class expects |
| // a sequence of string objects. |
| ok = tensorflow::TF_Tensor_DecodeStrings(src, &(*input_pairs)[i].second, |
| status); |
| // Must keep looping through all c_inputs even if there is an error |
| // so that TF_DeleteTensor() is called unconditionally on all c_inputs. |
| } |
| } |
| TF_DeleteTensor(src); |
| } |
| return ok; |
| } |
| |
| static void TF_Run_Helper( |
| Session* session, const char* handle, const TF_Buffer* run_options, |
| // Input tensors |
| const std::vector<std::pair<tensorflow::string, Tensor>>& input_pairs, |
| // Output tensors |
| const std::vector<tensorflow::string>& output_tensor_names, |
| TF_Tensor** c_outputs, |
| // Target nodes |
| const std::vector<tensorflow::string>& target_oper_names, |
| TF_Buffer* run_metadata, TF_Status* status) { |
| const int noutputs = output_tensor_names.size(); |
| std::vector<Tensor> outputs(noutputs); |
| Status result; |
| |
| if (handle == nullptr) { |
| RunOptions run_options_proto; |
| if (run_options != nullptr && |
| !run_options_proto.ParseFromArray(run_options->data, |
| run_options->length)) { |
| status->status = InvalidArgument("Unparseable RunOptions proto"); |
| return; |
| } |
| if (run_metadata != nullptr && run_metadata->data != nullptr) { |
| status->status = |
| InvalidArgument("Passing non-empty run_metadata is invalid."); |
| return; |
| } |
| |
| RunMetadata run_metadata_proto; |
| result = session->Run(run_options_proto, input_pairs, output_tensor_names, |
| target_oper_names, &outputs, &run_metadata_proto); |
| |
| // Serialize back to upstream client, who now owns the new buffer |
| if (run_metadata != nullptr) { |
| status->status = MessageToBuffer(run_metadata_proto, run_metadata); |
| if (!status->status.ok()) return; |
| } |
| } else { |
| // NOTE(zongheng): PRun does not support RunOptions yet. |
| result = session->PRun(handle, input_pairs, output_tensor_names, &outputs); |
| } |
| if (!result.ok()) { |
| status->status = result; |
| return; |
| } |
| |
| // Store results in c_outputs[] |
| for (int i = 0; i < noutputs; ++i) { |
| const Tensor& src = outputs[i]; |
| if (!src.IsInitialized() || src.NumElements() == 0) { |
| c_outputs[i] = tensorflow::EmptyTensor( |
| static_cast<TF_DataType>(src.dtype()), src.shape()); |
| continue; |
| } |
| if (src.dtype() != tensorflow::DT_STRING) { |
| // Share the underlying buffer. |
| TensorBuffer* buf = tensorflow::TensorCApi::Buffer(src); |
| buf->Ref(); |
| c_outputs[i] = new TF_Tensor{static_cast<TF_DataType>(src.dtype()), |
| src.shape(), buf}; |
| } else { |
| c_outputs[i] = tensorflow::TF_Tensor_EncodeStrings(src); |
| } |
| } |
| } |
| |
| extern "C" { |
| |
| void TF_Run(TF_Session* s, const TF_Buffer* run_options, |
| // Input tensors |
| const char** c_input_names, TF_Tensor** c_inputs, int ninputs, |
| // Output tensors |
| const char** c_output_names, TF_Tensor** c_outputs, int noutputs, |
| // Target nodes |
| const char** c_target_oper_names, int ntargets, |
| TF_Buffer* run_metadata, TF_Status* status) { |
| TF_Run_Setup(noutputs, c_outputs, status); |
| std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); |
| if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; |
| for (int i = 0; i < ninputs; ++i) { |
| input_pairs[i].first = c_input_names[i]; |
| } |
| std::vector<tensorflow::string> output_names(noutputs); |
| for (int i = 0; i < noutputs; ++i) { |
| output_names[i] = c_output_names[i]; |
| } |
| std::vector<tensorflow::string> target_oper_names(ntargets); |
| for (int i = 0; i < ntargets; ++i) { |
| target_oper_names[i] = c_target_oper_names[i]; |
| } |
| TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names, |
| c_outputs, target_oper_names, run_metadata, status); |
| } |
| |
| void TF_PRunSetup(TF_Session* s, |
| // Input names |
| const char** c_input_names, int ninputs, |
| // Output names |
| const char** c_output_names, int noutputs, |
| // Target nodes |
| const char** c_target_oper_names, int ntargets, |
| const char** handle, TF_Status* status) { |
| status->status = Status::OK(); |
| |
| std::vector<tensorflow::string> input_names(ninputs); |
| std::vector<tensorflow::string> output_names(noutputs); |
| std::vector<tensorflow::string> target_oper_names(ntargets); |
| for (int i = 0; i < ninputs; ++i) { |
| input_names[i] = c_input_names[i]; |
| } |
| for (int i = 0; i < noutputs; ++i) { |
| output_names[i] = c_output_names[i]; |
| } |
| for (int i = 0; i < ntargets; ++i) { |
| target_oper_names[i] = c_target_oper_names[i]; |
| } |
| tensorflow::string new_handle; |
| Status result; |
| result = s->session->PRunSetup(input_names, output_names, target_oper_names, |
| &new_handle); |
| if (result.ok()) { |
| char* buf = new char[new_handle.size() + 1]; |
| memcpy(buf, new_handle.c_str(), new_handle.size() + 1); |
| *handle = buf; |
| } else { |
| status->status = result; |
| } |
| } |
| |
| void TF_PRun(TF_Session* s, const char* handle, |
| // Input tensors |
| const char** c_input_names, TF_Tensor** c_inputs, int ninputs, |
| // Output tensors |
| const char** c_output_names, TF_Tensor** c_outputs, int noutputs, |
| // Target nodes |
| const char** c_target_oper_names, int ntargets, |
| TF_Status* status) { |
| TF_Run_Setup(noutputs, c_outputs, status); |
| std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); |
| if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; |
| for (int i = 0; i < ninputs; ++i) { |
| input_pairs[i].first = c_input_names[i]; |
| } |
| |
| std::vector<tensorflow::string> output_names(noutputs); |
| for (int i = 0; i < noutputs; ++i) { |
| output_names[i] = c_output_names[i]; |
| } |
| std::vector<tensorflow::string> target_oper_names(ntargets); |
| for (int i = 0; i < ntargets; ++i) { |
| target_oper_names[i] = c_target_oper_names[i]; |
| } |
| TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names, |
| c_outputs, target_oper_names, nullptr, status); |
| } |
| |
| struct TF_Library { |
| void* lib_handle; |
| TF_Buffer op_list; |
| }; |
| |
| TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { |
| TF_Library* lib_handle = new TF_Library; |
| status->status = tensorflow::LoadLibrary( |
| library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, |
| &lib_handle->op_list.length); |
| if (!status->status.ok()) { |
| delete lib_handle; |
| return nullptr; |
| } |
| return lib_handle; |
| } |
| |
| TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } |
| |
| void TF_DeleteLibraryHandle(TF_Library* lib_handle) { |
| free(const_cast<void*>(lib_handle->op_list.data)); |
| delete lib_handle; |
| } |
| |
| TF_Buffer* TF_GetAllOpList() { |
| std::vector<tensorflow::OpDef> op_defs; |
| tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs); |
| tensorflow::OpList op_list; |
| for (const auto& op : op_defs) { |
| *(op_list.add_op()) = op; |
| } |
| TF_Buffer* ret = TF_NewBuffer(); |
| MessageToBuffer(op_list, ret); |
| return ret; |
| } |
| |
| } // end extern "C" |
| |
| // -------------------------------------------------------------------------- |
| // New Graph and Session API |
| |
| // Structures ----------------------------------------------------------------- |
| |
| extern "C" { |
| |
| struct TF_Graph { |
| TF_Graph() |
| : graph(OpRegistry::Global()), |
| refiner(graph.op_registry()), |
| num_sessions(0), |
| delete_requested(false) {} |
| mutex mu; |
| Graph graph GUARDED_BY(mu); |
| |
| // Runs shape inference. |
| tensorflow::ShapeRefiner refiner GUARDED_BY(mu); |
| |
| // Maps from name of an operation to the Node* in 'graph'. |
| std::unordered_map<tensorflow::string, Node*> name_map GUARDED_BY(mu); |
| |
| // TF_Graph may only / must be deleted when |
| // num_sessions == 0 && delete_requested == true |
| |
| // num_sessions incremented by TF_NewSessionWithGraph, and decremented by |
| // TF_DeleteSessionWithGraph. |
| int num_sessions GUARDED_BY(mu); |
| bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph |
| }; |
| |
| struct TF_OperationDescription { |
| TF_OperationDescription(TF_Graph* g, const char* op_type, |
| const char* node_name) |
| : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} |
| |
| NodeBuilder node_builder; |
| TF_Graph* graph; |
| std::vector<tensorflow::string> colocation_constraints; |
| }; |
| |
| struct TF_Operation { |
| Node node; |
| }; |
| |
| struct TF_SessionWithGraph { |
| TF_SessionWithGraph(Session* s, TF_Graph* g) |
| : session(s), graph(g), last_num_graph_nodes(0) {} |
| Session* session; |
| TF_Graph* graph; |
| mutex mu; |
| int last_num_graph_nodes; |
| }; |
| |
| } // end extern "C" |
| |
| // Helper functions ----------------------------------------------------------- |
| |
| namespace { |
| |
| TF_Operation* ToOperation(Node* node) { |
| return static_cast<TF_Operation*>(static_cast<void*>(node)); |
| } |
| |
| tensorflow::string PortName(const TF_Port& port) { |
| return tensorflow::strings::StrCat(port.oper->node.name(), ":", port.index); |
| } |
| |
| const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, |
| const char* attr_name, |
| TF_Status* status) { |
| const tensorflow::AttrValue* attr = |
| tensorflow::AttrSlice(oper->node.def()).Find(attr_name); |
| if (attr == nullptr) { |
| status->status = |
| InvalidArgument("Operation has no attr named '", attr_name, "'."); |
| } |
| return attr; |
| } |
| |
| } // namespace |
| |
| // Shape functions ----------------------------------------------------------- |
| |
| void TF_GraphSetTensorShape(TF_Graph* graph, TF_Port port, const int64_t* dims, |
| const int num_dims, TF_Status* status) { |
| Node* node = &port.oper->node; |
| |
| mutex_lock l(graph->mu); |
| // Set the shape. |
| tensorflow::shape_inference::InferenceContext* ic = |
| graph->refiner.GetContext(node); |
| if (ic == nullptr) { |
| status->status = |
| InvalidArgument("Node ", node->name(), " was not found in the graph"); |
| return; |
| } |
| |
| std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec; |
| for (int i = 0; i < num_dims; ++i) { |
| dim_vec.push_back(ic->MakeDim(dims[i])); |
| } |
| |
| tensorflow::shape_inference::ShapeHandle new_shape = ic->MakeShape(dim_vec); |
| status->status = graph->refiner.SetShape(node, port.index, new_shape); |
| } |
| |
| int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Port port, TF_Status* status) { |
| Node* node = &port.oper->node; |
| |
| mutex_lock l(graph->mu); |
| tensorflow::shape_inference::InferenceContext* ic = |
| graph->refiner.GetContext(node); |
| if (ic == nullptr) { |
| status->status = |
| InvalidArgument("Node ", node->name(), " was not found in the graph"); |
| return -1; |
| } |
| |
| tensorflow::shape_inference::ShapeHandle shape = ic->output(port.index); |
| |
| // Unknown rank means the number of dimensions is -1. |
| if (!ic->RankKnown(shape)) { |
| return -1; |
| } |
| |
| return ic->Rank(shape); |
| } |
| |
| void TF_GraphGetTensorShape(TF_Graph* graph, TF_Port port, int64_t* dims, |
| int num_dims, TF_Status* status) { |
| Node* node = &port.oper->node; |
| |
| mutex_lock l(graph->mu); |
| tensorflow::shape_inference::InferenceContext* ic = |
| graph->refiner.GetContext(node); |
| if (ic == nullptr) { |
| status->status = |
| InvalidArgument("Node ", node->name(), " was not found in the graph"); |
| return; |
| } |
| |
| tensorflow::shape_inference::ShapeHandle shape = ic->output(port.index); |
| |
| int rank = -1; |
| if (ic->RankKnown(shape)) { |
| rank = ic->Rank(shape); |
| } |
| |
| if (num_dims != rank) { |
| status->status = InvalidArgument("Expected rank is ", num_dims, |
| " but actual rank is ", rank); |
| return; |
| } |
| |
| if (num_dims == 0) { |
| // Output shape is a scalar. |
| return; |
| } |
| |
| // Rank is greater than 0, so fill in the values, if known, and |
| // -1 for unknown values. |
| for (int i = 0; i < num_dims; ++i) { |
| auto dim = ic->Dim(shape, i); |
| tensorflow::int64 value = -1; |
| if (ic->ValueKnown(dim)) { |
| value = ic->Value(dim); |
| } |
| dims[i] = value; |
| } |
| } |
| |
| // TF_OperationDescription functions ------------------------------------------ |
| |
| extern "C" { |
| |
| TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type, |
| const char* oper_name) { |
| mutex_lock l(graph->mu); |
| return new TF_OperationDescription(graph, op_type, oper_name); |
| } |
| |
| void TF_SetDevice(TF_OperationDescription* desc, const char* device) { |
| desc->node_builder.Device(device); |
| } |
| |
| void TF_AddInput(TF_OperationDescription* desc, TF_Port input) { |
| desc->node_builder.Input(&input.oper->node, input.index); |
| } |
| |
| void TF_AddInputList(TF_OperationDescription* desc, const TF_Port* inputs, |
| int num_inputs) { |
| std::vector<NodeBuilder::NodeOut> input_list; |
| input_list.reserve(num_inputs); |
| for (int i = 0; i < num_inputs; ++i) { |
| input_list.emplace_back(&inputs[i].oper->node, inputs[i].index); |
| } |
| desc->node_builder.Input(input_list); |
| } |
| |
| void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) { |
| desc->node_builder.ControlInput(&input->node); |
| } |
| |
| void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) { |
| desc->colocation_constraints.emplace_back(tensorflow::strings::StrCat( |
| tensorflow::kColocationGroupPrefix, op->node.name())); |
| } |
| |
| void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name, |
| const void* value, int length) { |
| tensorflow::StringPiece s(static_cast<const char*>(value), length); |
| desc->node_builder.Attr(attr_name, s); |
| } |
| |
| void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name, |
| const void* const* values, const int* lengths, |
| int num_values) { |
| std::vector<tensorflow::StringPiece> v; |
| v.reserve(num_values); |
| for (int i = 0; i < num_values; ++i) { |
| v.emplace_back(static_cast<const char*>(values[i]), lengths[i]); |
| } |
| desc->node_builder.Attr(attr_name, v); |
| } |
| |
| void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, |
| int64_t value) { |
| static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
| "64-bit int types should match in size"); |
| desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value)); |
| } |
| |
| void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name, |
| const int64_t* values, int num_values) { |
| static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
| "64-bit int types should match in size"); |
| desc->node_builder.Attr( |
| attr_name, |
| ArraySlice<const tensorflow::int64>( |
| reinterpret_cast<const tensorflow::int64*>(values), num_values)); |
| } |
| |
| void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name, |
| float value) { |
| desc->node_builder.Attr(attr_name, value); |
| } |
| |
| void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name, |
| const float* values, int num_values) { |
| desc->node_builder.Attr(attr_name, |
| ArraySlice<const float>(values, num_values)); |
| } |
| |
| void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, |
| unsigned char value) { |
| desc->node_builder.Attr(attr_name, static_cast<bool>(value)); |
| } |
| |
| void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name, |
| const unsigned char* values, int num_values) { |
| std::unique_ptr<bool[]> b(new bool[num_values]); |
| for (int i = 0; i < num_values; ++i) { |
| b[i] = values[i]; |
| } |
| desc->node_builder.Attr(attr_name, |
| ArraySlice<const bool>(b.get(), num_values)); |
| } |
| |
| void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, |
| TF_DataType value) { |
| desc->node_builder.Attr(attr_name, static_cast<DataType>(value)); |
| } |
| |
| void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, |
| const TF_DataType* values, int num_values) { |
| desc->node_builder.Attr( |
| attr_name, ArraySlice<const DataType>( |
| reinterpret_cast<const DataType*>(values), num_values)); |
| } |
| |
| void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, |
| const int64_t* dims, int num_dims) { |
| PartialTensorShape shape; |
| if (num_dims >= 0) { |
| static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
| "64-bit int types should match in size"); |
| shape = PartialTensorShape(ArraySlice<tensorflow::int64>( |
| reinterpret_cast<const tensorflow::int64*>(dims), num_dims)); |
| } |
| desc->node_builder.Attr(attr_name, shape); |
| } |
| |
| void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name, |
| const int64_t* const* dims, const int* num_dims, |
| int num_shapes) { |
| std::vector<PartialTensorShape> shapes; |
| shapes.reserve(num_shapes); |
| for (int i = 0; i < num_shapes; ++i) { |
| if (num_dims[i] < 0) { |
| shapes.emplace_back(); |
| } else { |
| static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), |
| "64-bit int types should match in size"); |
| shapes.emplace_back(ArraySlice<tensorflow::int64>( |
| reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i])); |
| } |
| } |
| desc->node_builder.Attr(attr_name, shapes); |
| } |
| |
| void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, |
| const char* attr_name, const void* proto, |
| int proto_len, TF_Status* status) { |
| TensorShapeProto shape; |
| if (shape.ParseFromArray(proto, proto_len)) { |
| desc->node_builder.Attr(attr_name, shape); |
| status->status = Status::OK(); |
| } else { |
| status->status = InvalidArgument("Unparseable TensorShapeProto"); |
| } |
| } |
| |
| void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, |
| const char* attr_name, |
| const void* const* protos, |
| const int* proto_lens, int num_shapes, |
| TF_Status* status) { |
| std::vector<TensorShapeProto> shapes; |
| shapes.resize(num_shapes); |
| for (int i = 0; i < num_shapes; ++i) { |
| if (!shapes[i].ParseFromArray(protos[i], proto_lens[i])) { |
| status->status = |
| InvalidArgument("Unparseable TensorShapeProto at index ", i); |
| return; |
| } |
| } |
| desc->node_builder.Attr(attr_name, shapes); |
| status->status = Status::OK(); |
| } |
| |
| void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name, |
| TF_Tensor* value, TF_Status* status) { |
| status->status = Status::OK(); |
| Tensor t; |
| bool ok = true; |
| |
| if (value->dtype != TF_STRING) { |
| t = tensorflow::TensorCApi::MakeTensor(value->dtype, value->shape, |
| value->buffer); |
| } else { |
| // TF_STRING tensors require copying since Tensor class expects |
| // a sequence of string objects. |
| ok = tensorflow::TF_Tensor_DecodeStrings(value, &t, status); |
| } |
| |
| TF_DeleteTensor(value); |
| if (ok) desc->node_builder.Attr(attr_name, t); |
| } |
| |
| void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name, |
| TF_Tensor* const* values, int num_values, |
| TF_Status* status) { |
| status->status = Status::OK(); |
| std::vector<Tensor> t; |
| t.reserve(num_values); |
| bool ok = true; |
| |
| for (int i = 0; i < num_values; ++i) { |
| if (ok) { |
| if (values[i]->dtype != TF_STRING) { |
| t.emplace_back(tensorflow::TensorCApi::MakeTensor( |
| values[i]->dtype, values[i]->shape, values[i]->buffer)); |
| } else { |
| t.emplace_back(::tensorflow::DT_STRING); |
| // TF_STRING tensors require copying since Tensor class expects |
| // a sequence of string objects. |
| ok = tensorflow::TF_Tensor_DecodeStrings(values[i], &t.back(), status); |
| } |
| } |
| // We always delete value[i], even when there is an error, |
| // as promised in the API. |
| TF_DeleteTensor(values[i]); |
| } |
| |
| if (ok) desc->node_builder.Attr(attr_name, t); |
| } |
| |
| void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, |
| const void* proto, size_t proto_len, |
| TF_Status* status) { |
| tensorflow::AttrValue attr_value; |
| if (attr_value.ParseFromArray(proto, proto_len)) { |
| desc->node_builder.Attr(attr_name, attr_value); |
| status->status = Status::OK(); |
| } else { |
| status->status = InvalidArgument("Unparseable AttrValue proto"); |
| } |
| } |
| |
| TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, |
| TF_Status* status) { |
| Node* ret = nullptr; |
| mutex_lock l(desc->graph->mu); |
| |
| if (desc->graph->name_map.count(desc->node_builder.node_name())) { |
| status->status = InvalidArgument("Duplicate node name in graph: '", |
| desc->node_builder.node_name(), "'"); |
| } else { |
| std::sort(desc->colocation_constraints.begin(), |
| desc->colocation_constraints.end()); |
| desc->node_builder.Attr(tensorflow::kColocationAttrName, |
| desc->colocation_constraints); |
| status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); |
| |
| if (status->status.ok()) { |
| // Run shape inference function for newly added node. |
| // |
| // TODO(b/28152992): Enable returning the result of this |
| // code-path once we have converted all python shape functions |
| // to call their C++ versions. |
| desc->graph->refiner.AddNode(ret); |
| |
| // Add the node to the name-to-node mapping. |
| desc->graph->name_map[ret->name()] = ret; |
| } |
| } |
| |
| delete desc; |
| |
| return ToOperation(ret); |
| } |
| |
| // TF_Operation functions |
| // ---------------------------------------------------------- |
| |
| const char* TF_OperationName(TF_Operation* oper) { |
| return oper->node.name().c_str(); |
| } |
| |
| const char* TF_OperationOpType(TF_Operation* oper) { |
| return oper->node.type_string().c_str(); |
| } |
| |
| const char* TF_OperationDevice(TF_Operation* oper) { |
| return oper->node.def().device().c_str(); |
| } |
| |
| int TF_OperationNumOutputs(TF_Operation* oper) { |
| return oper->node.num_outputs(); |
| } |
| |
| TF_DataType TF_OperationOutputType(TF_Port oper_out) { |
| return static_cast<TF_DataType>( |
| oper_out.oper->node.output_type(oper_out.index)); |
| } |
| |
| int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, |
| TF_Status* status) { |
| NameRangeMap name_ranges; |
| status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), |
| nullptr, &name_ranges); |
| if (!status->status.ok()) return -1; |
| auto iter = name_ranges.find(arg_name); |
| if (iter == name_ranges.end()) { |
| status->status = InvalidArgument("Input arg '", arg_name, "' not found"); |
| return -1; |
| } |
| return iter->second.second - iter->second.first; |
| } |
| |
| int TF_OperationNumInputs(TF_Operation* oper) { |
| return oper->node.num_inputs(); |
| } |
| |
| TF_DataType TF_OperationInputType(TF_Port oper_in) { |
| return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index)); |
| } |
| |
| int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, |
| TF_Status* status) { |
| NameRangeMap name_ranges; |
| status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), |
| &name_ranges, nullptr); |
| if (!status->status.ok()) return -1; |
| auto iter = name_ranges.find(arg_name); |
| if (iter == name_ranges.end()) { |
| status->status = InvalidArgument("Input arg '", arg_name, "' not found"); |
| return -1; |
| } |
| return iter->second.second - iter->second.first; |
| } |
| |
| TF_Port TF_OperationInput(TF_Port oper_in) { |
| const tensorflow::Edge* edge; |
| Status s = oper_in.oper->node.input_edge(oper_in.index, &edge); |
| if (!s.ok()) { |
| return {nullptr, -1}; |
| } |
| |
| return {ToOperation(edge->src()), edge->src_output()}; |
| } |
| |
| int TF_OperationOutputNumConsumers(TF_Port oper_out) { |
| int count = 0; |
| for (const auto* edge : oper_out.oper->node.out_edges()) { |
| if (edge->src_output() == oper_out.index) { |
| ++count; |
| } |
| } |
| return count; |
| } |
| |
| int TF_OperationOutputConsumers(TF_Port oper_out, TF_Port* consumers, |
| int max_consumers) { |
| int count = 0; |
| for (const auto* edge : oper_out.oper->node.out_edges()) { |
| if (edge->src_output() == oper_out.index) { |
| if (count < max_consumers) { |
| consumers[count] = {ToOperation(edge->dst()), edge->dst_input()}; |
| } |
| ++count; |
| } |
| } |
| return count; |
| } |
| |
| int TF_OperationNumControlInputs(TF_Operation* oper) { |
| return oper->node.in_edges().size() - oper->node.num_inputs(); |
| } |
| |
| int TF_OperationGetControlInputs(TF_Operation* oper, |
| TF_Operation** control_inputs, |
| int max_control_inputs) { |
| int count = 0; |
| for (const auto* edge : oper->node.in_edges()) { |
| if (edge->IsControlEdge()) { |
| if (count < max_control_inputs) { |
| control_inputs[count] = ToOperation(edge->src()); |
| } |
| ++count; |
| } |
| } |
| return count; |
| } |
| |
| int TF_OperationNumControlOutputs(TF_Operation* oper) { |
| int count = 0; |
| for (const auto* edge : oper->node.out_edges()) { |
| if (edge->IsControlEdge()) { |
| ++count; |
| } |
| } |
| return count; |
| } |
| |
| int TF_OperationGetControlOutputs(TF_Operation* oper, |
| TF_Operation** control_outputs, |
| int max_control_outputs) { |
| int count = 0; |
| for (const auto* edge : oper->node.out_edges()) { |
| if (edge->IsControlEdge()) { |
| if (count < max_control_outputs) { |
| control_outputs[count] = ToOperation(edge->dst()); |
| } |
| ++count; |
| } |
| } |
| return count; |
| } |
| |
| TF_Attr_Metadata TF_OperationGetAttrMetadata(TF_Operation* oper, |
| const char* attr_name, |
| TF_Status* status) { |
| TF_Attr_Metadata metadata; |
| const auto* attr = GetAttrValue(oper, attr_name, status); |
| if (!status->status.ok()) return metadata; |
| switch (attr->value_case()) { |
| #define SINGLE_CASE(kK, attr_type, size_expr) \ |
| case tensorflow::AttrValue::kK: \ |
| metadata.is_list = 0; \ |
| metadata.list_size = -1; \ |
| metadata.type = attr_type; \ |
| metadata.total_size = size_expr; \ |
| break; |
| |
| SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); |
| SINGLE_CASE(kI, TF_ATTR_INT, -1); |
| SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); |
| SINGLE_CASE(kB, TF_ATTR_BOOL, -1); |
| SINGLE_CASE(kType, TF_ATTR_TYPE, -1); |
| SINGLE_CASE(kShape, TF_ATTR_SHAPE, |
| attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); |
| SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); |
| #undef SINGLE_CASE |
| |
| case tensorflow::AttrValue::kList: |
| metadata.is_list = 1; |
| metadata.list_size = 0; |
| metadata.total_size = -1; |
| #define LIST_CASE(field, attr_type, ...) \ |
| if (attr->list().field##_size() > 0) { \ |
| metadata.type = attr_type; \ |
| metadata.list_size = attr->list().field##_size(); \ |
| __VA_ARGS__; \ |
| break; \ |
| } |
| |
| LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0; |
| for (int i = 0; i < attr->list().s_size(); |
| ++i) { metadata.total_size += attr->list().s(i).size(); }); |
| LIST_CASE(i, TF_ATTR_INT); |
| LIST_CASE(f, TF_ATTR_FLOAT); |
| LIST_CASE(b, TF_ATTR_BOOL); |
| LIST_CASE(type, TF_ATTR_TYPE); |
| LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0; |
| for (int i = 0; i < attr->list().shape_size(); ++i) { |
| const auto& s = attr->list().shape(i); |
| metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); |
| }); |
| LIST_CASE(tensor, TF_ATTR_TENSOR); |
| #undef LIST_CASE |
| // All lists empty, determine the type from the OpDef. |
| if (metadata.list_size == 0) { |
| for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { |
| const auto& a = oper->node.op_def().attr(i); |
| if (a.name().compare(attr_name) != 0) continue; |
| const tensorflow::string& typestr = a.type(); |
| if (typestr == "list(string)") { |
| metadata.type = TF_ATTR_STRING; |
| } else if (typestr == "list(int)") { |
| metadata.type = TF_ATTR_INT; |
| } else if (typestr == "list(float)") { |
| metadata.type = TF_ATTR_FLOAT; |
| } else if (typestr == "list(bool)") { |
| metadata.type = TF_ATTR_BOOL; |
| } else if (typestr == "list(type)") { |
| metadata.type = TF_ATTR_TYPE; |
| } else if (typestr == "list(shape)") { |
| metadata.type = TF_ATTR_SHAPE; |
| } else if (typestr == "list(tensor)") { |
| metadata.type = TF_ATTR_TENSOR; |
| } else { |
| status->status = InvalidArgument( |
| "Attribute '", attr_name, |
| "' has an empty value of an unrecognized type '", typestr, "'"); |
| return metadata; |
| } |
| } |
| } |
| break; |
| |
| case tensorflow::AttrValue::kPlaceholder: |
| metadata.is_list = 0; |
| metadata.list_size = -1; |
| metadata.type = TF_ATTR_PLACEHOLDER; |
| metadata.total_size = -1; |
| break; |
| |
| case tensorflow::AttrValue::kFunc: |
| metadata.is_list = 0; |
| metadata.list_size = -1; |
| metadata.type = TF_ATTR_FUNC; |
| metadata.total_size = -1; |
| break; |
| |
| case tensorflow::AttrValue::VALUE_NOT_SET: |
| status->status = |
| InvalidArgument("Attribute '", attr_name, "' has no value set"); |
| break; |
| } |
| return metadata; |
| } |
| |
| void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, |
| void* value, int max_length, TF_Status* status) { |
| const auto* attr = GetAttrValue(oper, attr_name, status); |
| if (!status->status.ok()) return; |
| if (attr->value_case() != tensorflow::AttrValue::kS) { |
| status->status = |
| InvalidArgument("Attribute '", attr_name, "' is not a string"); |
| return; |
| } |
| if (max_length <= 0) { |
| return; |
| } |
| const auto& s = attr->s(); |
| std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length)); |
| } |
| |
| void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, |
| void** values, int* lengths, int max_values, |
| void* storage, size_t storage_size, |
| TF_Status* status) { |
| const auto* attr = GetAttrValue(oper, attr_name, status); |
| if (!status->status.ok()) return; |
| if (attr->value_case() != tensorflow::AttrValue::kList) { |
| status->status = |
| InvalidArgument("Value for '", attr_name, "' is not a list"); |
| return; |
| } |
| const auto len = std::min(max_values, attr->list().s_size()); |
| char* p = static_cast<char*>(storage); |
| for (int i = 0; i < len; ++i) { |
| const tensorflow::string& s = attr->list().s(i); |
| values[i] = p; |
| lengths[i] = s.size(); |
| if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) { |
| status->status = InvalidArgument( |
| "Not enough storage to hold the requested list of strings"); |
| return; |
| } |
| memcpy(values[i], s.data(), s.size()); |
| p += s.size(); |
| } |
| } |
| |
| #define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ |
| void func(TF_Operation* oper, const char* attr_name, c_type* value, \ |
| TF_Status* status) { \ |
| cpp_type v; \ |
| status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &v); \ |
| *value = static_cast<c_type>(v); \ |
| } \ |
| void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ |
| int max_values, TF_Status* status) { \ |
| const auto* attr = GetAttrValue(oper, attr_name, status); \ |
| if (!status->status.ok()) return; \ |
| if (attr->value_case() != tensorflow::AttrValue::kList) { \ |
| status->status = \ |
| InvalidArgument("Value for '", attr_name, "' is not a list."); \ |
| return; \ |
| } \ |
| const auto len = std::min(max_values, attr->list().list_field##_size()); \ |
| for (int i = 0; i < len; ++i) { \ |
| values[i] = static_cast<c_type>(attr->list().list_field(i)); \ |
| } \ |
| } |
| DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); |
| DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); |
| DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b); |
| DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); |
| #undef DEFINE_GETATTR |
| |
| void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, |
| int64_t* value, int num_dims, TF_Status* status) { |
| PartialTensorShape shape; |
| status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shape); |
| if (!status->status.ok()) return; |
| auto len = std::min(shape.dims(), num_dims); |
| for (int i = 0; i < len; ++i) { |
| value[i] = shape.dim_size(i); |
| } |
| } |
| |
| void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, |
| int64_t** values, int* num_dims, |
| int max_values, int64_t* storage, |
| int storage_size, TF_Status* status) { |
| std::vector<PartialTensorShape> shapes; |
| status->status = |
| tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shapes); |
| if (!status->status.ok()) return; |
| auto len = std::min(static_cast<int>(shapes.size()), max_values); |
| int64_t* p = storage; |
| int storage_left = storage_size; |
| for (int i = 0; i < len; ++i) { |
| // shapes[i].dims() == -1 for shapes with an unknown rank. |
| int64_t n = shapes[i].dims(); |
| num_dims[i] = n; |
| values[i] = p; |
| if (n < 0) { |
| continue; |
| } |
| if (storage_left < n) { |
| status->status = InvalidArgument( |
| "Not enough storage to hold the requested list of shapes"); |
| return; |
| } |
| storage_left -= n; |
| for (int j = 0; j < n; ++j, ++p) { |
| *p = shapes[i].dim_size(j); |
| } |
| } |
| } |
| |
| void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, |
| const char* attr_name, |
| TF_Buffer* value, TF_Status* status) { |
| const auto* attr = GetAttrValue(oper, attr_name, status); |
| if (!status->status.ok()) return; |
| if (attr->value_case() != tensorflow::AttrValue::kShape) { |
| status->status = |
| InvalidArgument("Value for '", attr_name, "' is not a shape."); |
| return; |
| } |
| status->status = MessageToBuffer(attr->shape(), value); |
| } |
| |
| void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, |
| const char* attr_name, |
| TF_Buffer** values, int max_values, |
| TF_Status* status) { |
| const auto* attr = GetAttrValue(oper, attr_name, status); |
| if (!status->status.ok()) return; |
| if (attr->value_case() != tensorflow::AttrValue::kList) { |
| status->status = |
| InvalidArgument("Value for '", attr_name, "' is not a list"); |
| return; |
| } |
| const auto len = std::min(max_values, attr->list().shape_size()); |
| for (int i = 0; i < len; ++i) { |
| values[i] = TF_NewBuffer(); |
| status->status = MessageToBuffer(attr->list().shape(i), values[i]); |
| if (!status->status.ok()) { |
| // Delete everything allocated to far, the operation has failed. |
| for (int j = 0; j <= i; ++j) { |
| TF_DeleteBuffer(values[j]); |
| } |
| return; |
| } |
| } |
| } |
| |
| void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, |
| TF_Tensor** value, TF_Status* status) { |
| *value = nullptr; |
| Tensor t; |
| status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &t); |
| if (!status->status.ok()) return; |
| *value = new TF_Tensor{static_cast<TF_DataType>(t.dtype()), t.shape(), |
| tensorflow::TensorCApi::Buffer(t)}; |
| (*value)->buffer->Ref(); |
| } |
| |
| void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, |
| TF_Tensor** values, int max_values, |
| TF_Status* status) { |
| std::vector<Tensor> ts; |
| status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &ts); |
| if (!status->status.ok()) return; |
| const auto len = std::min(max_values, static_cast<int>(ts.size())); |
| for (int i = 0; i < len; ++i) { |
| const Tensor& t = ts[i]; |
| values[i] = new TF_Tensor{static_cast<TF_DataType>(t.dtype()), t.shape(), |
| tensorflow::TensorCApi::Buffer(t)}; |
| values[i]->buffer->Ref(); |
| } |
| } |
| |
| void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name, |
| TF_Buffer* output_attr_value, |
| TF_Status* status) { |
| const auto* attr = GetAttrValue(oper, attr_name, status); |
| if (!status->status.ok()) return; |
| status->status = MessageToBuffer(*attr, output_attr_value); |
| } |
| |
| void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, |
| TF_Status* status) { |
| status->status = MessageToBuffer(oper->node.def(), output_node_def); |
| } |
| |
| // TF_Graph functions --------------------------------------------------------- |
| |
| TF_Graph* TF_NewGraph() { return new TF_Graph; } |
| |
| void TF_DeleteGraph(TF_Graph* g) { |
| g->mu.lock(); |
| g->delete_requested = true; |
| const bool del = g->num_sessions == 0; |
| g->mu.unlock(); |
| if (del) delete g; |
| } |
| |
| TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) { |
| mutex_lock l(graph->mu); |
| auto iter = graph->name_map.find(oper_name); |
| if (iter == graph->name_map.end()) { |
| return nullptr; |
| } else { |
| return ToOperation(iter->second); |
| } |
| } |
| |
| TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) { |
| if (*pos == 0) { |
| // Advance past the first sentinal nodes in every graph (the source & sink). |
| *pos += 2; |
| } else { |
| // Advance to the next node. |
| *pos += 1; |
| } |
| |
| mutex_lock l(graph->mu); |
| while (*pos < graph->graph.num_node_ids()) { |
| Node* node = graph->graph.FindNodeId(*pos); |
| // FindNodeId() returns nullptr for nodes that have been deleted. |
| // We aren't currently allowing nodes to be deleted, but it is safer |
| // to still check. |
| if (node != nullptr) return ToOperation(node); |
| *pos += 1; |
| } |
| |
| // No more nodes. |
| return nullptr; |
| } |
| |
| void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, |
| TF_Status* status) { |
| GraphDef def; |
| { |
| mutex_lock l(graph->mu); |
| graph->graph.ToGraphDef(&def); |
| } |
| status->status = MessageToBuffer(def, output_graph_def); |
| } |
| |
| struct TF_ImportGraphDefOptions { |
| tensorflow::ImportGraphDefOptions opts; |
| }; |
| |
| TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { |
| return new TF_ImportGraphDefOptions; |
| } |
| void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) { |
| delete opts; |
| } |
| void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, |
| const char* prefix) { |
| opts->opts.prefix = prefix; |
| } |
| |
| void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, |
| const TF_ImportGraphDefOptions* opts, |
| TF_Status* status) { |
| GraphDef def; |
| if (!def.ParseFromArray(graph_def->data, graph_def->length)) { |
| status->status = InvalidArgument("Invalid GraphDef"); |
| return; |
| } |
| mutex_lock l(graph->mu); |
| const int last_node_id = graph->graph.num_node_ids(); |
| status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, |
| &graph->refiner); |
| if (!status->status.ok()) return; |
| for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { |
| auto* node = graph->graph.FindNodeId(i); |
| if (node != nullptr) graph->name_map[node->name()] = node; |
| } |
| } |
| |
| // TF_SessionWithGraph functions ---------------------------------------------- |
| |
| TF_SessionWithGraph* TF_NewSessionWithGraph(TF_Graph* graph, |
| const TF_SessionOptions* opt, |
| TF_Status* status) { |
| Session* session; |
| status->status = NewSession(opt->options, &session); |
| if (status->status.ok()) { |
| if (graph != nullptr) { |
| mutex_lock l(graph->mu); |
| graph->num_sessions += 1; |
| } |
| return new TF_SessionWithGraph(session, graph); |
| } else { |
| DCHECK_EQ(nullptr, session); |
| return NULL; |
| } |
| } |
| |
| void TF_CloseSessionWithGraph(TF_SessionWithGraph* s, TF_Status* status) { |
| status->status = s->session->Close(); |
| } |
| |
| void TF_DeleteSessionWithGraph(TF_SessionWithGraph* s, TF_Status* status) { |
| status->status = Status::OK(); |
| TF_Graph* const graph = s->graph; |
| if (graph != nullptr) { |
| graph->mu.lock(); |
| graph->num_sessions -= 1; |
| const bool del = graph->delete_requested && graph->num_sessions == 0; |
| graph->mu.unlock(); |
| if (del) delete graph; |
| } |
| delete s->session; |
| delete s; |
| } |
| |
| // TODO(josh11b,mrry): Change Session to be able to use a Graph* |
| // directly, instead of requiring us to serialize to a GraphDef and |
| // call Session::Extend(). |
| static bool ExtendSessionGraphHelper(TF_SessionWithGraph* session, |
| TF_Status* status) { |
| if (session->graph != nullptr) { |
| mutex_lock session_lock(session->mu); |
| session->graph->mu.lock(); |
| const Graph& graph = session->graph->graph; |
| const auto num_nodes = graph.num_node_ids(); |
| if (session->last_num_graph_nodes < num_nodes) { |
| GraphDef graph_def; |
| graph_def.mutable_versions()->CopyFrom(graph.versions()); |
| // Fill graph_def with nodes with ids in the range |
| // [session->last_num_graph_nodes, num_nodes), that is the nodes |
| // added since the last TF_SessionRun() call. |
| for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { |
| Node* const node = graph.FindNodeId(id); |
| if (node != nullptr && node->IsOp()) { |
| NodeDef* const node_def = graph_def.add_node(); |
| *node_def = node->def(); |
| } |
| } |
| session->graph->mu.unlock(); |
| // TODO(josh11b): Also send the function library if needed. |
| status->status = session->session->Extend(graph_def); |
| if (!status->status.ok()) { |
| // Contract is we always delete input_values[i]. |
| return false; |
| } |
| // Note: session->session is not modified if Extend() fails, so |
| // we only set last_num_graph_nodes if it succeeds. |
| session->last_num_graph_nodes = num_nodes; |
| } else { |
| session->graph->mu.unlock(); |
| } |
| } |
| return true; |
| } |
| |
| void TF_SessionRun(TF_SessionWithGraph* session, const TF_Buffer* run_options, |
| const TF_Port* inputs, TF_Tensor* const* input_values, |
| int ninputs, const TF_Port* outputs, |
| TF_Tensor** output_values, int noutputs, |
| const TF_Operation* const* target_opers, int ntargets, |
| TF_Buffer* run_metadata, TF_Status* status) { |
| // TODO(josh11b,mrry): Change Session to be able to use a Graph* |
| // directly, instead of requiring us to serialize to a GraphDef and |
| // call Session::Extend(). |
| if (!ExtendSessionGraphHelper(session, status)) { |
| for (int i = 0; i < ninputs; ++i) { |
| TF_DeleteTensor(input_values[i]); |
| } |
| return; |
| } |
| |
| TF_Run_Setup(noutputs, output_values, status); |
| |
| // Convert from TF_Port and TF_Tensor to a string and Tensor. |
| std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); |
| if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; |
| for (int i = 0; i < ninputs; ++i) { |
| input_pairs[i].first = PortName(inputs[i]); |
| } |
| |
| // Convert from TF_Port to string names. |
| std::vector<tensorflow::string> output_names(noutputs); |
| for (int i = 0; i < noutputs; ++i) { |
| output_names[i] = PortName(outputs[i]); |
| } |
| |
| // Convert from TF_Operation* to string names. |
| std::vector<tensorflow::string> target_names(ntargets); |
| for (int i = 0; i < ntargets; ++i) { |
| target_names[i] = target_opers[i]->node.name(); |
| } |
| |
| // Actually run. |
| TF_Run_Helper(session->session, nullptr, run_options, input_pairs, |
| output_names, output_values, target_names, run_metadata, |
| status); |
| } |
| |
| void TF_SessionPRunSetup(TF_SessionWithGraph* session, const TF_Port* inputs, |
| int ninputs, const TF_Port* outputs, int noutputs, |
| const TF_Operation* const* target_opers, int ntargets, |
| const char** handle, TF_Status* status) { |
| if (!ExtendSessionGraphHelper(session, status)) { |
| return; |
| } |
| |
| std::vector<tensorflow::string> input_names(ninputs); |
| for (int i = 0; i < ninputs; ++i) { |
| input_names[i] = PortName(inputs[i]); |
| } |
| |
| std::vector<tensorflow::string> output_names(noutputs); |
| for (int i = 0; i < noutputs; ++i) { |
| output_names[i] = PortName(outputs[i]); |
| } |
| |
| std::vector<tensorflow::string> target_names(ntargets); |
| for (int i = 0; i < ntargets; ++i) { |
| target_names[i] = target_opers[i]->node.name(); |
| } |
| |
| tensorflow::string new_handle; |
| status->status = session->session->PRunSetup(input_names, output_names, |
| target_names, &new_handle); |
| if (status->status.ok()) { |
| char* buf = new char[new_handle.size() + 1]; |
| memcpy(buf, new_handle.c_str(), new_handle.size() + 1); |
| *handle = buf; |
| } |
| } |
| |
| void TF_SessionPRun(TF_SessionWithGraph* session, const char* handle, |
| const TF_Port* inputs, TF_Tensor* const* input_values, |
| int ninputs, const TF_Port* outputs, |
| TF_Tensor** output_values, int noutputs, |
| const TF_Operation* const* target_opers, int ntargets, |
| TF_Status* status) { |
| // TODO(josh11b,mrry): Change Session to be able to use a Graph* |
| // directly, instead of requiring us to serialize to a GraphDef and |
| // call Session::Extend(). |
| if (!ExtendSessionGraphHelper(session, status)) { |
| for (int i = 0; i < ninputs; ++i) { |
| TF_DeleteTensor(input_values[i]); |
| } |
| return; |
| } |
| |
| TF_Run_Setup(noutputs, output_values, status); |
| |
| // Convert from TF_Port and TF_Tensor to a string and Tensor. |
| std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); |
| if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; |
| for (int i = 0; i < ninputs; ++i) { |
| input_pairs[i].first = PortName(inputs[i]); |
| } |
| |
| // Convert from TF_Port to string names. |
| std::vector<tensorflow::string> output_names(noutputs); |
| for (int i = 0; i < noutputs; ++i) { |
| output_names[i] = PortName(outputs[i]); |
| } |
| |
| // Convert from TF_Operation* to string names. |
| std::vector<tensorflow::string> target_names(ntargets); |
| for (int i = 0; i < ntargets; ++i) { |
| target_names[i] = target_opers[i]->node.name(); |
| } |
| |
| TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names, |
| output_values, target_names, nullptr, status); |
| } |
| |
| } // end extern "C" |