#pragma once
#include <iostream>
#include <stdexcept>
template <typename T>
class AVLTree {
private:
struct Node {
T data;
Node* left;
Node* right;
int height;
Node(const T& item)
: data(item)
, left(nullptr)
, right(nullptr)
, height(0) {}
Node(const T& item, Node* _left, Node* _right, int _height=0)
: data(item)
, left(_left)
, right(_right)
, height(_height) {}
};
Node* m_root;
void balance(Node* &node) {
if (node == nullptr) {
return;
}
if (height(node->left) - height(node->right) > 1) {
if (height(node->left->left) >= height(node->left->right))
lc_rotate(node);
else
dlc_rotate(node);
} else if (height(node->right) - height(node->left) > 1) {
if (height(node->right->right) >= height(node->right->left))
rc_rotate(node);
else
drc_rotate(node);
}
node->height = std::max(height(node->left), height(node->right)) + 1;
}
Node* copy(Node* node) const {
if (node == nullptr)
return nullptr;
else
return new Node(node->data, copy(node->left), copy(node->right), node->height);
}
void display(Node* root, std::ostream& out, unsigned int counter = 0) const {
if (m_root == nullptr) {
out << "<empty>\n";
} else if (root == nullptr) {
out << "";
} else {
display(root->right, out, counter + 1);
for (unsigned int i = 0; i < counter; i++) out << " ";
out << root->data << "\n";
display(root->left, out, counter + 1);
}
}
void dlc_rotate(Node* &node) {
rc_rotate(node->left);
lc_rotate(node);
}
void drc_rotate(Node* &node) {
lc_rotate(node->right);
rc_rotate(node);
}
void insert(const T& item, Node* &node) {
if (node == nullptr) {
node = new Node(item, nullptr, nullptr);
} else if (item < node->data) {
insert(item, node->left);
} else if (item > node->data) {
insert(item, node->right);
}
balance(node);
}
int height(Node* node) const {
return node == nullptr ? -1 : node->height;
}
void lc_rotate(Node* &node) {
Node* left_child = node->left;
node->left = left_child->right;
left_child->right = node;
node->height = std::max(height(node->left), height(node->right)) + 1;
left_child->height = std::max(height(left_child->left), node->height) + 1;
node = left_child;
}
void rc_rotate(Node* &node) {
Node* right_child = node->right;
node->right = right_child->left;
right_child->left = node;
node->height = std::max(height(node->left), height(node->right)) + 1;
right_child->height = std::max(height(right_child->right), node->height) + 1;
node = right_child;
}
T& right_subtree_min(Node* root) const {
while (root->left != nullptr) root = root->left;
return root->data;
}
void remove(const T& item, Node* &root) {
if (root == nullptr) {
return;
}
if (item < root->data) {
remove(item, root->left);
} else if (item > root->data) {
remove(item, root->right);
} else if (root->left != nullptr && root->right != nullptr) {
root->data = right_subtree_min(root->right);
remove(root->data, root->right);
} else {
Node* node_to_remove = root;
root = root->left != nullptr ? root->left : root->right;
delete node_to_remove;
}
balance(root);
}
public:
AVLTree() : m_root(nullptr) {}
AVLTree(const AVLTree& tree) : m_root(nullptr) {
m_root = copy(tree.m_root);
}
~AVLTree() {
make_empty(m_root);
}
AVLTree& operator=(const AVLTree& rhs) {
if (&rhs == this)
return *this;
make_empty(m_root);
m_root = copy(rhs.m_root);
return *this;
}
bool contains(const T& item) const {
Node* temp = m_root;
while (temp != nullptr) {
if (temp->data == item)
return true;
else if (temp->data < item)
temp = temp->right;
else
temp = temp->left;
}
return false;
}
void insert(const T& item) {
if (contains(item)) {
;
} else {
insert(item, m_root);
}
}
void remove(const T& item) {
if (!contains(item)) {
;
} else {
remove(item, m_root);
}
}
const T& find_min() const {
if (m_root == nullptr)
throw std::invalid_argument("tree is empty");
Node* temp = m_root;
while (temp->left != nullptr) temp = temp->left;
return temp->data;
}
const T& find_max() const {
if (m_root == nullptr)
throw std::invalid_argument("tree is empty");
Node* temp = m_root;
while (temp->right != nullptr) temp = temp->right;
return temp->data;
}
void print_tree(std::ostream& out=std::cout) const {
Node* temp = m_root;
display(temp, out);
}
const T& root_data() const {
return m_root->data;
}
void make_empty(Node* &node) {
if (node != nullptr) {
make_empty(node->left);
make_empty(node->right);
delete node;
}
node = nullptr;
}
};