blob: d00f50c0bc8e59315020257755d6df2f2349101d [file] [log] [blame]
// Copyright 2021 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_CORE_LIB_AVL_AVL_H
#define GRPC_CORE_LIB_AVL_AVL_H
#include <grpc/support/port_platform.h>
#include <stdlib.h>
#include <memory>
#include <utility>
#include "absl/container/inlined_vector.h"
namespace grpc_core {
template <class K, class V = void>
class AVL {
public:
AVL() {}
AVL Add(K key, V value) const {
return AVL(AddKey(root_, std::move(key), std::move(value)));
}
template <typename SomethingLikeK>
AVL Remove(const SomethingLikeK& key) const {
return AVL(RemoveKey(root_, key));
}
template <typename SomethingLikeK>
const V* Lookup(const SomethingLikeK& key) const {
NodePtr n = Get(root_, key);
return n ? &n->kv.second : nullptr;
}
const std::pair<K, V>* LookupBelow(const K& key) const {
NodePtr n = GetBelow(root_, *key);
return n ? &n->kv : nullptr;
}
bool Empty() const { return root_ == nullptr; }
template <class F>
void ForEach(F&& f) const {
ForEachImpl(root_.get(), std::forward<F>(f));
}
bool SameIdentity(const AVL& avl) const { return root_ == avl.root_; }
bool operator==(const AVL& other) const {
Iterator a(root_);
Iterator b(other.root_);
for (;;) {
Node* p = a.current();
Node* q = b.current();
if (p == nullptr) return q == nullptr;
if (q == nullptr) return false;
if (p->kv != q->kv) return false;
a.MoveNext();
b.MoveNext();
}
}
bool operator<(const AVL& other) const {
Iterator a(root_);
Iterator b(other.root_);
for (;;) {
Node* p = a.current();
Node* q = b.current();
if (p == nullptr) return q != nullptr;
if (q == nullptr) return false;
if (p->kv < q->kv) return true;
if (p->kv != q->kv) return false;
a.MoveNext();
b.MoveNext();
}
}
private:
struct Node;
typedef std::shared_ptr<Node> NodePtr;
struct Node : public std::enable_shared_from_this<Node> {
Node(K k, V v, NodePtr l, NodePtr r, long h)
: kv(std::move(k), std::move(v)),
left(std::move(l)),
right(std::move(r)),
height(h) {}
const std::pair<K, V> kv;
const NodePtr left;
const NodePtr right;
const long height;
};
NodePtr root_;
class Iterator {
public:
explicit Iterator(const NodePtr& root) {
auto* n = root.get();
while (n != nullptr) {
stack_.push_back(n);
n = n->left.get();
}
}
Node* current() const { return stack_.empty() ? nullptr : stack_.back(); }
void MoveNext() {
auto* n = stack_.back();
stack_.pop_back();
if (n->right != nullptr) {
n = n->right.get();
while (n != nullptr) {
stack_.push_back(n);
n = n->left.get();
}
}
}
private:
absl::InlinedVector<Node*, 8> stack_;
};
explicit AVL(NodePtr root) : root_(std::move(root)) {}
template <class F>
static void ForEachImpl(const Node* n, F&& f) {
if (n == nullptr) return;
ForEachImpl(n->left.get(), std::forward<F>(f));
f(const_cast<const K&>(n->kv.first), const_cast<const V&>(n->kv.second));
ForEachImpl(n->right.get(), std::forward<F>(f));
}
static long Height(const NodePtr& n) { return n ? n->height : 0; }
static NodePtr MakeNode(K key, V value, const NodePtr& left,
const NodePtr& right) {
return std::make_shared<Node>(std::move(key), std::move(value), left, right,
1 + std::max(Height(left), Height(right)));
}
template <typename SomethingLikeK>
static NodePtr Get(const NodePtr& node, const SomethingLikeK& key) {
if (node == nullptr) {
return nullptr;
}
if (node->kv.first > key) {
return Get(node->left, key);
} else if (node->kv.first < key) {
return Get(node->right, key);
} else {
return node;
}
}
static NodePtr GetBelow(const NodePtr& node, const K& key) {
if (!node) return nullptr;
if (node->kv.first > key) {
return GetBelow(node->left, key);
} else if (node->kv.first < key) {
NodePtr n = GetBelow(node->right, key);
if (n == nullptr) n = node;
return n;
} else {
return node;
}
}
static NodePtr RotateLeft(K key, V value, const NodePtr& left,
const NodePtr& right) {
return MakeNode(
right->kv.first, right->kv.second,
MakeNode(std::move(key), std::move(value), left, right->left),
right->right);
}
static NodePtr RotateRight(K key, V value, const NodePtr& left,
const NodePtr& right) {
return MakeNode(
left->kv.first, left->kv.second, left->left,
MakeNode(std::move(key), std::move(value), left->right, right));
}
static NodePtr RotateLeftRight(K key, V value, const NodePtr& left,
const NodePtr& right) {
/* rotate_right(..., rotate_left(left), right) */
return MakeNode(
left->right->kv.first, left->right->kv.second,
MakeNode(left->kv.first, left->kv.second, left->left,
left->right->left),
MakeNode(std::move(key), std::move(value), left->right->right, right));
}
static NodePtr RotateRightLeft(K key, V value, const NodePtr& left,
const NodePtr& right) {
/* rotate_left(..., left, rotate_right(right)) */
return MakeNode(
right->left->kv.first, right->left->kv.second,
MakeNode(std::move(key), std::move(value), left, right->left->left),
MakeNode(right->kv.first, right->kv.second, right->left->right,
right->right));
}
static NodePtr Rebalance(K key, V value, const NodePtr& left,
const NodePtr& right) {
switch (Height(left) - Height(right)) {
case 2:
if (Height(left->left) - Height(left->right) == -1) {
return RotateLeftRight(std::move(key), std::move(value), left, right);
} else {
return RotateRight(std::move(key), std::move(value), left, right);
}
case -2:
if (Height(right->left) - Height(right->right) == 1) {
return RotateRightLeft(std::move(key), std::move(value), left, right);
} else {
return RotateLeft(std::move(key), std::move(value), left, right);
}
default:
return MakeNode(key, value, left, right);
}
}
static NodePtr AddKey(const NodePtr& node, K key, V value) {
if (!node) {
return MakeNode(std::move(key), std::move(value), nullptr, nullptr);
}
if (node->kv.first < key) {
return Rebalance(node->kv.first, node->kv.second, node->left,
AddKey(node->right, std::move(key), std::move(value)));
}
if (key < node->kv.first) {
return Rebalance(node->kv.first, node->kv.second,
AddKey(node->left, std::move(key), std::move(value)),
node->right);
}
return MakeNode(std::move(key), std::move(value), node->left, node->right);
}
static NodePtr InOrderHead(NodePtr node) {
while (node->left != nullptr) {
node = node->left;
}
return node;
}
static NodePtr InOrderTail(NodePtr node) {
while (node->right != nullptr) {
node = node->right;
}
return node;
}
template <typename SomethingLikeK>
static NodePtr RemoveKey(const NodePtr& node, const SomethingLikeK& key) {
if (node == nullptr) {
return nullptr;
}
if (key < node->kv.first) {
return Rebalance(node->kv.first, node->kv.second,
RemoveKey(node->left, key), node->right);
} else if (node->kv.first < key) {
return Rebalance(node->kv.first, node->kv.second, node->left,
RemoveKey(node->right, key));
} else {
if (node->left == nullptr) {
return node->right;
} else if (node->right == nullptr) {
return node->left;
} else if (node->left->height < node->right->height) {
NodePtr h = InOrderHead(node->right);
return Rebalance(h->kv.first, h->kv.second, node->left,
RemoveKey(node->right, h->kv.first));
} else {
NodePtr h = InOrderTail(node->left);
return Rebalance(h->kv.first, h->kv.second,
RemoveKey(node->left, h->kv.first), node->right);
}
}
abort();
}
};
template <class K>
class AVL<K, void> {
public:
AVL() {}
AVL Add(K key) const { return AVL(AddKey(root_, std::move(key))); }
AVL Remove(const K& key) const { return AVL(RemoveKey(root_, key)); }
bool Lookup(const K& key) const { return Get(root_, key) != nullptr; }
bool Empty() const { return root_ == nullptr; }
template <class F>
void ForEach(F&& f) const {
ForEachImpl(root_.get(), std::forward<F>(f));
}
bool SameIdentity(AVL avl) const { return root_ == avl.root_; }
private:
struct Node;
typedef std::shared_ptr<Node> NodePtr;
struct Node : public std::enable_shared_from_this<Node> {
Node(K k, NodePtr l, NodePtr r, long h)
: key(std::move(k)),
left(std::move(l)),
right(std::move(r)),
height(h) {}
const K key;
const NodePtr left;
const NodePtr right;
const long height;
};
NodePtr root_;
explicit AVL(NodePtr root) : root_(std::move(root)) {}
template <class F>
static void ForEachImpl(const Node* n, F&& f) {
if (n == nullptr) return;
ForEachImpl(n->left.get(), std::forward<F>(f));
f(const_cast<const K&>(n->key));
ForEachImpl(n->right.get(), std::forward<F>(f));
}
static long Height(const NodePtr& n) { return n ? n->height : 0; }
static NodePtr MakeNode(K key, const NodePtr& left, const NodePtr& right) {
return std::make_shared<Node>(std::move(key), left, right,
1 + std::max(Height(left), Height(right)));
}
static NodePtr Get(const NodePtr& node, const K& key) {
if (node == nullptr) {
return nullptr;
}
if (node->key > key) {
return Get(node->left, key);
} else if (node->key < key) {
return Get(node->right, key);
} else {
return node;
}
}
static NodePtr RotateLeft(K key, const NodePtr& left, const NodePtr& right) {
return MakeNode(right->key, MakeNode(std::move(key), left, right->left),
right->right);
}
static NodePtr RotateRight(K key, const NodePtr& left, const NodePtr& right) {
return MakeNode(left->key, left->left,
MakeNode(std::move(key), left->right, right));
}
static NodePtr RotateLeftRight(K key, const NodePtr& left,
const NodePtr& right) {
/* rotate_right(..., rotate_left(left), right) */
return MakeNode(left->right->key,
MakeNode(left->key, left->left, left->right->left),
MakeNode(std::move(key), left->right->right, right));
}
static NodePtr RotateRightLeft(K key, const NodePtr& left,
const NodePtr& right) {
/* rotate_left(..., left, rotate_right(right)) */
return MakeNode(right->left->key,
MakeNode(std::move(key), left, right->left->left),
MakeNode(right->key, right->left->right, right->right));
}
static NodePtr Rebalance(K key, const NodePtr& left, const NodePtr& right) {
switch (Height(left) - Height(right)) {
case 2:
if (Height(left->left) - Height(left->right) == -1) {
return RotateLeftRight(std::move(key), left, right);
} else {
return RotateRight(std::move(key), left, right);
}
case -2:
if (Height(right->left) - Height(right->right) == 1) {
return RotateRightLeft(std::move(key), left, right);
} else {
return RotateLeft(std::move(key), left, right);
}
default:
return MakeNode(key, left, right);
}
}
static NodePtr AddKey(const NodePtr& node, K key) {
if (!node) {
return MakeNode(std::move(key), nullptr, nullptr);
}
if (node->key < key) {
return Rebalance(node->key, node->left,
AddKey(node->right, std::move(key)));
}
if (key < node->key) {
return Rebalance(node->key, AddKey(node->left, std::move(key)),
node->right);
}
return MakeNode(std::move(key), node->left, node->right);
}
static NodePtr InOrderHead(NodePtr node) {
while (node->left != nullptr) {
node = node->left;
}
return node;
}
static NodePtr InOrderTail(NodePtr node) {
while (node->right != nullptr) {
node = node->right;
}
return node;
}
static NodePtr RemoveKey(const NodePtr& node, const K& key) {
if (node == nullptr) {
return nullptr;
}
if (key < node->key) {
return Rebalance(node->key, RemoveKey(node->left, key), node->right);
} else if (node->key < key) {
return Rebalance(node->key, node->left, RemoveKey(node->right, key));
} else {
if (node->left == nullptr) {
return node->right;
} else if (node->right == nullptr) {
return node->left;
} else if (node->left->height < node->right->height) {
NodePtr h = InOrderHead(node->right);
return Rebalance(h->key, node->left, RemoveKey(node->right, h->key));
} else {
NodePtr h = InOrderTail(node->left);
return Rebalance(h->key, RemoveKey(node->left, h->key), node->right);
}
}
abort();
}
};
} // namespace grpc_core
#endif // GRPC_CORE_LIB_AVL_AVL_H