blob: d574da084e05b088372cb346d6b849410ec2d7ba [file] [log] [blame]
// Copyright 2022 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H
#define GRPC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H
// Scaffolding to allow the per-call part of a filter to be authored in a
// promise-style. Most of this will be removed once the promises conversion is
// completed.
#include <grpc/support/port_platform.h>
#include <stdint.h>
#include <stdlib.h>
#include <atomic>
#include <new>
#include <type_traits>
#include <utility>
#include "absl/container/inlined_vector.h"
#include "absl/meta/type_traits.h"
#include <grpc/impl/codegen/grpc_types.h>
#include <grpc/support/log.h>
#include "src/core/lib/channel/call_finalization.h"
#include "src/core/lib/channel/channel_args.h"
#include "src/core/lib/channel/channel_fwd.h"
#include "src/core/lib/channel/channel_stack.h"
#include "src/core/lib/channel/context.h"
#include "src/core/lib/gprpp/debug_location.h"
#include "src/core/lib/gprpp/time.h"
#include "src/core/lib/iomgr/call_combiner.h"
#include "src/core/lib/iomgr/closure.h"
#include "src/core/lib/iomgr/error.h"
#include "src/core/lib/iomgr/exec_ctx.h"
#include "src/core/lib/iomgr/polling_entity.h"
#include "src/core/lib/promise/activity.h"
#include "src/core/lib/promise/arena_promise.h"
#include "src/core/lib/promise/context.h"
#include "src/core/lib/promise/latch.h"
#include "src/core/lib/promise/poll.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/transport/error_utils.h"
#include "src/core/lib/transport/metadata_batch.h"
#include "src/core/lib/transport/transport.h"
namespace grpc_core {
class ChannelFilter {
public:
class Args {
public:
Args() : Args(nullptr, nullptr) {}
explicit Args(grpc_channel_stack* channel_stack,
grpc_channel_element* channel_element)
: channel_stack_(channel_stack), channel_element_(channel_element) {}
grpc_channel_stack* channel_stack() const { return channel_stack_; }
grpc_channel_element* uninitialized_channel_element() {
return channel_element_;
}
private:
friend class ChannelFilter;
grpc_channel_stack* channel_stack_;
grpc_channel_element* channel_element_;
};
// Perform post-initialization step (if any).
virtual void PostInit() {}
// Construct a promise for one call.
virtual ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) = 0;
// Start a legacy transport op
// Return true if the op was handled, false if it should be passed to the
// next filter.
// TODO(ctiller): design a new API for this - we probably don't want big op
// structures going forward.
virtual bool StartTransportOp(grpc_transport_op*) { return false; }
// Perform a legacy get info call
// Return true if the op was handled, false if it should be passed to the
// next filter.
// TODO(ctiller): design a new API for this
virtual bool GetChannelInfo(const grpc_channel_info*) { return false; }
virtual ~ChannelFilter() = default;
};
// Designator for whether a filter is client side or server side.
// Please don't use this outside calls to MakePromiseBasedFilter - it's
// intended to be deleted once the promise conversion is complete.
enum class FilterEndpoint {
kClient,
kServer,
};
// Flags for MakePromiseBasedFilter.
static constexpr uint8_t kFilterExaminesServerInitialMetadata = 1;
static constexpr uint8_t kFilterIsLast = 2;
namespace promise_filter_detail {
// Proxy channel filter for initialization failure, since we must leave a
// valid filter in place.
class InvalidChannelFilter : public ChannelFilter {
public:
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs, NextPromiseFactory) override {
abort();
}
};
// Call data shared between all implementations of promise-based filters.
class BaseCallData : public Activity, private Wakeable {
public:
BaseCallData(grpc_call_element* elem, const grpc_call_element_args* args,
uint8_t flags);
~BaseCallData() override;
void set_pollent(grpc_polling_entity* pollent) {
GPR_ASSERT(nullptr ==
pollent_.exchange(pollent, std::memory_order_release));
}
// Activity implementation (partial).
void Orphan() final;
Waker MakeNonOwningWaker() final;
Waker MakeOwningWaker() final;
void Finalize(const grpc_call_final_info* final_info) {
finalization_.Run(final_info);
}
protected:
class ScopedContext
: public promise_detail::Context<Arena>,
public promise_detail::Context<grpc_call_context_element>,
public promise_detail::Context<grpc_polling_entity>,
public promise_detail::Context<CallFinalization> {
public:
explicit ScopedContext(BaseCallData* call_data)
: promise_detail::Context<Arena>(call_data->arena_),
promise_detail::Context<grpc_call_context_element>(
call_data->context_),
promise_detail::Context<grpc_polling_entity>(
call_data->pollent_.load(std::memory_order_acquire)),
promise_detail::Context<CallFinalization>(&call_data->finalization_) {
}
};
class Flusher {
public:
explicit Flusher(BaseCallData* call);
// Calls closures, schedules batches, relinquishes call combiner.
~Flusher();
void Resume(grpc_transport_stream_op_batch* batch) {
release_.push_back(batch);
}
void Cancel(grpc_transport_stream_op_batch* batch,
grpc_error_handle error) {
grpc_transport_stream_op_batch_queue_finish_with_failure(batch, error,
&call_closures_);
}
void Complete(grpc_transport_stream_op_batch* batch) {
call_closures_.Add(batch->on_complete, GRPC_ERROR_NONE,
"Flusher::Complete");
}
void AddClosure(grpc_closure* closure, grpc_error_handle error,
const char* reason) {
call_closures_.Add(closure, error, reason);
}
private:
absl::InlinedVector<grpc_transport_stream_op_batch*, 1> release_;
CallCombinerClosureList call_closures_;
BaseCallData* const call_;
};
// Smart pointer like wrapper around a batch.
// Creation makes a ref count of one capture.
// Copying increments.
// Must be moved from or resumed or cancelled before destruction.
class CapturedBatch final {
public:
CapturedBatch();
explicit CapturedBatch(grpc_transport_stream_op_batch* batch);
~CapturedBatch();
CapturedBatch(const CapturedBatch&);
CapturedBatch& operator=(const CapturedBatch&);
CapturedBatch(CapturedBatch&&) noexcept;
CapturedBatch& operator=(CapturedBatch&&) noexcept;
grpc_transport_stream_op_batch* operator->() { return batch_; }
bool is_captured() const { return batch_ != nullptr; }
// Resume processing this batch (releases one ref, passes it down the
// stack)
void ResumeWith(Flusher* releaser);
// Cancel this batch immediately (releases all refs)
void CancelWith(grpc_error_handle error, Flusher* releaser);
// Complete this batch (pass it up) assuming refs drop to zero
void CompleteWith(Flusher* releaser);
void Swap(CapturedBatch* other) { std::swap(batch_, other->batch_); }
private:
grpc_transport_stream_op_batch* batch_;
};
static MetadataHandle<grpc_metadata_batch> WrapMetadata(
grpc_metadata_batch* p) {
return MetadataHandle<grpc_metadata_batch>(p);
}
static grpc_metadata_batch* UnwrapMetadata(
MetadataHandle<grpc_metadata_batch> p) {
return p.Unwrap();
}
Arena* arena() { return arena_; }
grpc_call_element* elem() const { return elem_; }
CallCombiner* call_combiner() const { return call_combiner_; }
Timestamp deadline() const { return deadline_; }
grpc_call_stack* call_stack() const { return call_stack_; }
Latch<ServerMetadata*>* server_initial_metadata_latch() const {
return server_initial_metadata_latch_;
}
bool is_last() const {
return grpc_call_stack_element(call_stack_, call_stack_->count - 1) ==
elem_;
}
private:
// Wakeable implementation.
void Wakeup() final;
void Drop() final;
virtual void OnWakeup() = 0;
grpc_call_stack* const call_stack_;
grpc_call_element* const elem_;
Arena* const arena_;
CallCombiner* const call_combiner_;
const Timestamp deadline_;
CallFinalization finalization_;
grpc_call_context_element* const context_;
std::atomic<grpc_polling_entity*> pollent_{nullptr};
Latch<ServerMetadata*>* server_initial_metadata_latch_ = nullptr;
};
class ClientCallData : public BaseCallData {
public:
ClientCallData(grpc_call_element* elem, const grpc_call_element_args* args,
uint8_t flags);
~ClientCallData() override;
// Activity implementation.
void ForceImmediateRepoll() final;
// Handle one grpc_transport_stream_op_batch
void StartBatch(grpc_transport_stream_op_batch* batch);
private:
// At what stage is our handling of send initial metadata?
enum class SendInitialState {
// Start state: no op seen
kInitial,
// We've seen the op, and started the promise in response to it, but have
// not yet sent the op to the next filter.
kQueued,
// We've sent the op to the next filter.
kForwarded,
// We were cancelled.
kCancelled
};
// At what stage is our handling of recv trailing metadata?
enum class RecvTrailingState {
// Start state: no op seen
kInitial,
// We saw the op, and since it was bundled with send initial metadata, we
// queued it until the send initial metadata can be sent to the next
// filter.
kQueued,
// We've forwarded the op to the next filter.
kForwarded,
// The op has completed from below, but we haven't yet forwarded it up (the
// promise gets to interject and mutate it).
kComplete,
// We've called the recv_metadata_ready callback from the original
// recv_trailing_metadata op that was presented to us.
kResponded,
// We've been cancelled and handled that locally.
// (i.e. whilst the recv_trailing_metadata op is queued in this filter).
kCancelled
};
struct RecvInitialMetadata;
class PollContext;
// Handle cancellation.
void Cancel(grpc_error_handle error);
// Begin running the promise - which will ultimately take some initial
// metadata and return some trailing metadata.
void StartPromise(Flusher* flusher);
// Interject our callback into the op batch for recv trailing metadata
// ready. Stash a pointer to the trailing metadata that will be filled in,
// so we can manipulate it later.
void HookRecvTrailingMetadata(CapturedBatch batch);
// Construct a promise that will "call" the next filter.
// Effectively:
// - put the modified initial metadata into the batch to be sent down.
// - return a wrapper around PollTrailingMetadata as the promise.
ArenaPromise<ServerMetadataHandle> MakeNextPromise(CallArgs call_args);
// Wrapper to make it look like we're calling the next filter as a promise.
// First poll: send the send_initial_metadata op down the stack.
// All polls: await receiving the trailing metadata, then return it to the
// application.
Poll<ServerMetadataHandle> PollTrailingMetadata();
static void RecvTrailingMetadataReadyCallback(void* arg,
grpc_error_handle error);
void RecvTrailingMetadataReady(grpc_error_handle error);
void RecvInitialMetadataReady(grpc_error_handle error);
// Given an error, fill in ServerMetadataHandle to represent that error.
void SetStatusFromError(grpc_metadata_batch* metadata,
grpc_error_handle error);
// Wakeup and poll the promise if appropriate.
void WakeInsideCombiner(Flusher* flusher);
void OnWakeup() override;
// Contained promise
ArenaPromise<ServerMetadataHandle> promise_;
// Queued batch containing at least a send_initial_metadata op.
CapturedBatch send_initial_metadata_batch_;
// Pointer to where trailing metadata will be stored.
grpc_metadata_batch* recv_trailing_metadata_ = nullptr;
// State tracking recv initial metadata for filters that care about it.
RecvInitialMetadata* recv_initial_metadata_ = nullptr;
// Closure to call when we're done with the trailing metadata.
grpc_closure* original_recv_trailing_metadata_ready_ = nullptr;
// Our closure pointing to RecvTrailingMetadataReadyCallback.
grpc_closure recv_trailing_metadata_ready_;
// Error received during cancellation.
grpc_error_handle cancelled_error_ = GRPC_ERROR_NONE;
// State of the send_initial_metadata op.
SendInitialState send_initial_state_ = SendInitialState::kInitial;
// State of the recv_trailing_metadata op.
RecvTrailingState recv_trailing_state_ = RecvTrailingState::kInitial;
// Polling related data. Non-null if we're actively polling
PollContext* poll_ctx_ = nullptr;
};
class ServerCallData : public BaseCallData {
public:
ServerCallData(grpc_call_element* elem, const grpc_call_element_args* args,
uint8_t flags);
~ServerCallData() override;
// Activity implementation.
void ForceImmediateRepoll() final;
// Handle one grpc_transport_stream_op_batch
void StartBatch(grpc_transport_stream_op_batch* batch);
private:
// At what stage is our handling of recv initial metadata?
enum class RecvInitialState {
// Start state: no op seen
kInitial,
// Op seen, and forwarded to the next filter.
// Now waiting for the callback.
kForwarded,
// The op has completed from below, but we haven't yet forwarded it up
// (the promise gets to interject and mutate it).
kComplete,
// We've sent the response to the next filter up.
kResponded,
};
// At what stage is our handling of send trailing metadata?
enum class SendTrailingState {
// Start state: no op seen
kInitial,
// We saw the op, and are waiting for the promise to complete
// to forward it.
kQueued,
// We've forwarded the op to the next filter.
kForwarded,
// We were cancelled.
kCancelled
};
class PollContext;
struct SendInitialMetadata;
// Handle cancellation.
void Cancel(grpc_error_handle error, Flusher* flusher);
// Construct a promise that will "call" the next filter.
// Effectively:
// - put the modified initial metadata into the batch being sent up.
// - return a wrapper around PollTrailingMetadata as the promise.
ArenaPromise<ServerMetadataHandle> MakeNextPromise(CallArgs call_args);
// Wrapper to make it look like we're calling the next filter as a promise.
// All polls: await sending the trailing metadata, then foward it down the
// stack.
Poll<ServerMetadataHandle> PollTrailingMetadata();
static void RecvInitialMetadataReadyCallback(void* arg,
grpc_error_handle error);
void RecvInitialMetadataReady(grpc_error_handle error);
// Wakeup and poll the promise if appropriate.
void WakeInsideCombiner(Flusher* flusher);
void OnWakeup() override;
// Contained promise
ArenaPromise<ServerMetadataHandle> promise_;
// Pointer to where initial metadata will be stored.
grpc_metadata_batch* recv_initial_metadata_ = nullptr;
// State for sending initial metadata.
SendInitialMetadata* send_initial_metadata_ = nullptr;
// Closure to call when we're done with the trailing metadata.
grpc_closure* original_recv_initial_metadata_ready_ = nullptr;
// Our closure pointing to RecvInitialMetadataReadyCallback.
grpc_closure recv_initial_metadata_ready_;
// Error received during cancellation.
grpc_error_handle cancelled_error_ = GRPC_ERROR_NONE;
// Trailing metadata batch
CapturedBatch send_trailing_metadata_batch_;
// State of the send_initial_metadata op.
RecvInitialState recv_initial_state_ = RecvInitialState::kInitial;
// State of the recv_trailing_metadata op.
SendTrailingState send_trailing_state_ = SendTrailingState::kInitial;
// Current poll context (or nullptr if not polling).
PollContext* poll_ctx_ = nullptr;
// Whether to forward the recv_initial_metadata op at the end of promise
// wakeup.
bool forward_recv_initial_metadata_callback_ = false;
};
// Specific call data per channel filter.
// Note that we further specialize for clients and servers since their
// implementations are very different.
template <class ChannelFilter, FilterEndpoint endpoint>
class CallData;
// Client implementation of call data.
template <class ChannelFilter>
class CallData<ChannelFilter, FilterEndpoint::kClient> : public ClientCallData {
public:
using ClientCallData::ClientCallData;
};
// Server implementation of call data.
template <class ChannelFilter>
class CallData<ChannelFilter, FilterEndpoint::kServer> : public ServerCallData {
public:
using ServerCallData::ServerCallData;
};
} // namespace promise_filter_detail
// F implements ChannelFilter and :
// class SomeChannelFilter : public ChannelFilter {
// public:
// static absl::StatusOr<SomeChannelFilter> Create(
// ChannelArgs channel_args, ChannelFilter::Args filter_args);
// };
template <typename F, FilterEndpoint kEndpoint, uint8_t kFlags = 0>
absl::enable_if_t<std::is_base_of<ChannelFilter, F>::value, grpc_channel_filter>
MakePromiseBasedFilter(const char* name) {
using CallData = promise_filter_detail::CallData<F, kEndpoint>;
return grpc_channel_filter{
// start_transport_stream_op_batch
[](grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
static_cast<CallData*>(elem->call_data)->StartBatch(batch);
},
// make_call_promise
[](grpc_channel_element* elem, CallArgs call_args,
NextPromiseFactory next_promise_factory) {
return static_cast<ChannelFilter*>(elem->channel_data)
->MakeCallPromise(std::move(call_args),
std::move(next_promise_factory));
},
// start_transport_op
[](grpc_channel_element* elem, grpc_transport_op* op) {
if (!static_cast<ChannelFilter*>(elem->channel_data)
->StartTransportOp(op)) {
grpc_channel_next_op(elem, op);
}
},
// sizeof_call_data
sizeof(CallData),
// init_call_elem
[](grpc_call_element* elem, const grpc_call_element_args* args) {
new (elem->call_data) CallData(elem, args, kFlags);
return GRPC_ERROR_NONE;
},
// set_pollset_or_pollset_set
[](grpc_call_element* elem, grpc_polling_entity* pollent) {
static_cast<CallData*>(elem->call_data)->set_pollent(pollent);
},
// destroy_call_elem
[](grpc_call_element* elem, const grpc_call_final_info* final_info,
grpc_closure* then_schedule_closure) {
auto* cd = static_cast<CallData*>(elem->call_data);
cd->Finalize(final_info);
cd->~CallData();
if ((kFlags & kFilterIsLast) != 0) {
ExecCtx::Run(DEBUG_LOCATION, then_schedule_closure, GRPC_ERROR_NONE);
} else {
GPR_ASSERT(then_schedule_closure == nullptr);
}
},
// sizeof_channel_data
sizeof(F),
// init_channel_elem
[](grpc_channel_element* elem, grpc_channel_element_args* args) {
GPR_ASSERT(args->is_last == ((kFlags & kFilterIsLast) != 0));
auto status = F::Create(ChannelArgs::FromC(args->channel_args),
ChannelFilter::Args(args->channel_stack, elem));
if (!status.ok()) {
static_assert(
sizeof(promise_filter_detail::InvalidChannelFilter) <= sizeof(F),
"InvalidChannelFilter must fit in F");
new (elem->channel_data)
promise_filter_detail::InvalidChannelFilter();
return absl_status_to_grpc_error(status.status());
}
new (elem->channel_data) F(std::move(*status));
return GRPC_ERROR_NONE;
},
// post_init_channel_elem
[](grpc_channel_stack*, grpc_channel_element* elem) {
static_cast<ChannelFilter*>(elem->channel_data)->PostInit();
},
// destroy_channel_elem
[](grpc_channel_element* elem) {
static_cast<ChannelFilter*>(elem->channel_data)->~ChannelFilter();
},
// get_channel_info
[](grpc_channel_element* elem, const grpc_channel_info* info) {
if (!static_cast<ChannelFilter*>(elem->channel_data)
->GetChannelInfo(info)) {
grpc_channel_next_get_info(elem, info);
}
},
// name
name,
};
}
} // namespace grpc_core
#endif // GRPC_CORE_LIB_CHANNEL_PROMISE_BASED_FILTER_H