#pragma once
#include <iostream>
#include <stdexcept>
template <typename T>
class RedBlackTree {
private:
enum Color {RED, BLACK};
struct Node {
T data;
Color color;
Node* left;
Node* right;
Node* parent;
Node(const T& item)
: data(item)
, color(BLACK)
, left(nullptr)
, right(nullptr)
, parent(nullptr) {}
Node(const T& item, Color colour, Node* left_child, Node* right_child, Node* node_parent)
: data(item)
, color(colour)
, left(left_child)
, right(right_child)
, parent(node_parent) {}
};
Node* m_root;
Node* copy(Node* node) const {
if (node == nullptr)
return nullptr;
else
return new Node(node->data, node->color, copy(node->left), copy(node->right), node->parent);
}
void display(Node* root, std::ostream& out, unsigned int counter = 0) const {
if (m_root == nullptr) {
out << "<empty>\n";
} else if (root == nullptr) {
out << "";
} else {
display(root->right, out, counter + 1);
for (unsigned int i = 0; i < counter; i++) out << " ";
char c = root->color ? 'b' : 'r';
out << c << root->data << "\n";
display(root->left, out, counter + 1);
}
}
void left_rotate(Node* x) {
Node* y = x->right;
x->right = y->left;
if (y->left != nullptr) {
y->left->parent = x;
}
y->parent = x->parent;
if (x->parent == nullptr) {
m_root = y;
} else if (x == x->parent->left) {
x->parent->left = y;
} else {
x->parent->right = y;
}
y->left = x;
x->parent = y;
}
void right_rotate(Node* x) {
Node* y = x->left;
x->left = y->right;
if (y->right != nullptr) {
y->right->parent = x;
}
y->parent = x->parent;
if (x->parent == nullptr) {
m_root = y;
} else if (x == x->parent->right) {
x->parent->right = y;
} else {
x->parent->left = y;
}
y->right = x;
x->parent = y;
}
public:
RedBlackTree() : m_root(nullptr) {}
RedBlackTree(const RedBlackTree& tree) : m_root(nullptr) {
m_root = copy(tree.m_root);
}
~RedBlackTree() {
make_empty(m_root);
}
RedBlackTree& operator=(const RedBlackTree& rhs) {
if (this == &rhs)
return *this;
make_empty(m_root);
m_root = copy(rhs.m_root);
return *this;
}
Color color(const Node* node) const {
if (node == nullptr)
return BLACK;
else
return node->color;
}
const Node* get_root() const {
return m_root;
}
bool contains(const T& item) const {
if (m_root == nullptr) {
return false;
} else {
Node* temp = m_root;
while (temp != nullptr) {
if (temp->data == item)
return true;
else if (temp->data < item)
temp = temp->right;
else
temp = temp->left;
}
return false;
}
}
const T& find_max() const {
if (m_root == nullptr)
throw std::invalid_argument("tree is empty");
Node* temp = m_root;
while (temp->right != nullptr) temp = temp->right;
return temp->data;
}
const T& find_min() const {
if (m_root == nullptr)
throw std::invalid_argument("tree is empty");
Node* temp = m_root;
while (temp->left != nullptr) temp = temp->left;
return temp->data;
}
void insert(const T& item) {
if (contains(item)) {
return;
} else {
Node* new_node = new Node(item);
if (m_root == nullptr) {
m_root = new_node;
} else {
Node* temp = m_root;
Node* running_parent;
new_node->color = RED;
while (temp != nullptr) {
running_parent = temp;
if (item < temp->data)
temp = temp->left;
else if (item > temp->data)
temp = temp->right;
}
new_node->parent = running_parent;
if (item > running_parent->data)
running_parent->right = new_node;
else
running_parent->left = new_node;
if (new_node->parent->parent == nullptr)
return;
insert_fix(new_node);
}
}
}
void insert_fix(Node* node) {
Node* uncle;
while (node->parent->color == RED) {
if (node->parent == node->parent->parent->right) {
uncle = node->parent->parent->left;
if (uncle->color == RED) {
uncle->color = BLACK;
node->parent->color = BLACK;
node->parent->parent->color = RED;
node = node->parent->parent;
} else {
if (node == node->parent->left) {
node = node->parent;
right_rotate(node);
}
node->parent->color = BLACK;
node->parent->parent->color = RED;
left_rotate(node->parent->parent);
}
} else {
uncle = node->parent->parent->right;
if (uncle->color == RED) {
uncle->color = BLACK;
node->parent->color = BLACK;
node->parent->parent->color = RED;
node = node->parent->parent;
} else {
if (node == node->parent->right) {
node = node->parent;
left_rotate(node);
}
node->parent->color = BLACK;
node->parent->parent->color = RED;
right_rotate(node->parent->parent);
}
}
if (node == m_root)
break;
}
m_root->color = BLACK;
}
void remove(const T& item);
void make_empty(Node* &node) {
if (node != nullptr) {
make_empty(node->left);
make_empty(node->right);
delete node;
}
node = nullptr;
}
void print_tree(std::ostream& out=std::cout) const {
Node* temp = m_root;
display(temp, out);
}
};