#pragma once

#include <iostream>
#include <stdexcept>

template <typename T>
class RedBlackTree {
    private:
        enum Color {RED, BLACK};

        struct Node {
            T data;
            Color color;
            Node* left;
            Node* right;
            Node* parent;

            Node(const T& item)
            : data(item)
            , color(BLACK)
            , left(nullptr)
            , right(nullptr)
            , parent(nullptr) {}

            Node(const T& item, Color colour, Node* left_child, Node* right_child, Node* node_parent)
            : data(item)
            , color(colour)
            , left(left_child)
            , right(right_child)
            , parent(node_parent) {}
        };

        Node* m_root;

        Node* copy(Node* node) const {
            if (node == nullptr)
                return nullptr;
            else
                return new Node(node->data, node->color, copy(node->left), copy(node->right), node->parent);
        }

        void display(Node* root, std::ostream& out, unsigned int counter = 0) const {
            if (m_root == nullptr) {
                out << "<empty>\n";
            } else if (root == nullptr) {
                out << "";
            } else {
                display(root->right, out, counter + 1);
                for (unsigned int i = 0; i < counter; i++) out << "  ";
                char c = root->color ? 'b' : 'r';
                out << c << root->data << "\n";
                display(root->left, out, counter + 1);
            }
        }

        void left_rotate(Node* x) {
            Node* y = x->right;
            x->right = y->left;
            if (y->left != nullptr) {
                y->left->parent = x;
            }
            y->parent = x->parent;
            if (x->parent == nullptr) {
                m_root = y;
            } else if (x == x->parent->left) {
                x->parent->left = y;
            } else {
                x->parent->right = y;
            }
            y->left = x;
            x->parent = y;
        }

        void right_rotate(Node* x) {
            Node* y = x->left;
            x->left = y->right;
            if (y->right != nullptr) {
                y->right->parent = x;
            }
            y->parent = x->parent;
            if (x->parent == nullptr) {
                m_root = y;
            } else if (x == x->parent->right) {
                x->parent->right = y;
            } else {
                x->parent->left = y;
            }
            y->right = x;
            x->parent = y;
        }

    public:
        RedBlackTree() : m_root(nullptr) {}
        
        RedBlackTree(const RedBlackTree& tree) : m_root(nullptr) {
            m_root = copy(tree.m_root);
        }

        ~RedBlackTree() {
            make_empty(m_root);
        }

        RedBlackTree& operator=(const RedBlackTree& rhs) {
            if (this == &rhs)
                return *this;

            make_empty(m_root);
            m_root = copy(rhs.m_root);
            return *this;
        }

        Color color(const Node* node) const {
            if (node == nullptr)
                return BLACK;
            else
                return node->color;
        }

        const Node* get_root() const {
            return m_root;
        }

        bool contains(const T& item) const {
            if (m_root == nullptr) {
                return false;
            } else {
                Node* temp = m_root;
                while (temp != nullptr) {
                    if (temp->data == item)
                        return true;
                    else if (temp->data < item)
                        temp = temp->right;
                    else
                        temp = temp->left;
                }

                return false;
            }
        }

        const T& find_max() const {
            if (m_root == nullptr)
                throw std::invalid_argument("tree is empty");
            
            Node* temp = m_root;
            while (temp->right != nullptr) temp = temp->right;

            return temp->data;
        }

        const T& find_min() const {
            if (m_root == nullptr)
                throw std::invalid_argument("tree is empty");
            
            Node* temp = m_root;
            while (temp->left != nullptr) temp = temp->left;

            return temp->data;
        }

        void insert(const T& item) {
            if (contains(item)) {
                return;
            } else {
                Node* new_node = new Node(item);
                if (m_root == nullptr) {
                    m_root = new_node;
                } else {
                    Node* temp = m_root;
                    Node* running_parent;
                    // the default node color is black but newly inserted nodes are red so we must update here
                    new_node->color = RED;
                    while (temp != nullptr) {
                        running_parent = temp;
                        if (item < temp->data)
                            temp = temp->left;
                        else if (item > temp->data)
                            temp = temp->right;
                    }

                    new_node->parent = running_parent;
                    if (item > running_parent->data)
                        running_parent->right = new_node;
                    else 
                        running_parent->left = new_node;
                    
                    if (new_node->parent->parent == nullptr)
                        return;
                    
                    insert_fix(new_node);
                }
            }
        }

        void insert_fix(Node* node) {
            Node* uncle;
            while (node->parent->color == RED) {
                if (node->parent == node->parent->parent->right) {
                    uncle = node->parent->parent->left;
                    if (uncle->color == RED) {
                        uncle->color = BLACK;
                        node->parent->color = BLACK;
                        node->parent->parent->color = RED;
                        node = node->parent->parent;
                    } else {
                        if (node == node->parent->left) {
                            node = node->parent;
                            right_rotate(node);
                        }
                        node->parent->color = BLACK;
                        node->parent->parent->color = RED;
                        left_rotate(node->parent->parent);
                    }
                } else {
                    uncle = node->parent->parent->right;

                    if (uncle->color == RED) {
                        uncle->color = BLACK;
                        node->parent->color = BLACK;
                        node->parent->parent->color = RED;
                        node = node->parent->parent;
                    } else {
                        if (node == node->parent->right) {
                            node = node->parent;
                            left_rotate(node);
                        }
                        node->parent->color = BLACK;
                        node->parent->parent->color = RED;
                        right_rotate(node->parent->parent);
                    }
                }
                if (node == m_root)
                    break;
            }
            m_root->color = BLACK;
        }

        void remove(const T& item);

        void make_empty(Node* &node) {
            if (node != nullptr) {
                make_empty(node->left);
                make_empty(node->right);
                delete node;
            }

            node = nullptr;
        }

        void print_tree(std::ostream& out=std::cout) const {
            Node* temp = m_root;
            display(temp, out);
        }
};