#pragma once

#include <iostream>
#include <stdexcept>

template <typename T>
class AVLTree {
    private:
        struct Node {
            T data;
            Node* left;
            Node* right;
            int height;

            Node(const T& item)
            : data(item)
            , left(nullptr)
            , right(nullptr)
            , height(0) {}

            Node(const T& item, Node* _left, Node* _right, int _height=0)
            : data(item)
            , left(_left)
            , right(_right)
            , height(_height) {}
        };

        Node* m_root;

        void balance(Node* &node) {
            if (node == nullptr) {
                return;
            }

            if (height(node->left) - height(node->right) > 1) {
                if (height(node->left->left) >= height(node->left->right))
                    lc_rotate(node);
                else
                    dlc_rotate(node);
            } else if (height(node->right) - height(node->left) > 1) {
                if (height(node->right->right) >= height(node->right->left))
                    rc_rotate(node);
                else
                    drc_rotate(node);
            }

            node->height = std::max(height(node->left), height(node->right)) + 1;
        }

        Node* copy(Node* node) const {
            if (node == nullptr)
                return nullptr;
            else
                return new Node(node->data, copy(node->left), copy(node->right), node->height);
        }

        void display(Node* root, std::ostream& out, unsigned int counter = 0) const {
            if (m_root == nullptr) {
                out << "<empty>\n";
            } else if (root == nullptr) {
                out << "";
            } else {
                display(root->right, out, counter + 1);
                for (unsigned int i = 0; i < counter; i++) out << "  ";
                out << root->data << "\n";
                display(root->left, out, counter + 1);
            }
        }

        void dlc_rotate(Node* &node) {
            rc_rotate(node->left);
            lc_rotate(node);
        }

        void drc_rotate(Node* &node) {
            lc_rotate(node->right);
            rc_rotate(node);
        }

        void insert(const T& item, Node* &node) {
            if (node == nullptr) {
                node = new Node(item, nullptr, nullptr);
            } else if (item < node->data) {
                insert(item, node->left);
            } else if (item > node->data) {
                insert(item, node->right);
            }

            balance(node);
        }

        int height(Node* node) const {
            return node == nullptr ? -1 : node->height;
        }

        void lc_rotate(Node* &node) {
            Node* left_child = node->left;
            node->left = left_child->right;
            left_child->right = node;
            node->height = std::max(height(node->left), height(node->right)) + 1;
            left_child->height = std::max(height(left_child->left), node->height) + 1;
            node = left_child;
        }

        void rc_rotate(Node* &node) {
            Node* right_child = node->right;
            node->right = right_child->left;
            right_child->left = node;
            node->height = std::max(height(node->left), height(node->right)) + 1;
            right_child->height = std::max(height(right_child->right), node->height) + 1;
            node = right_child;
        }

        T& right_subtree_min(Node* root) const {
            while (root->left != nullptr) root = root->left;

            return root->data;
        }

        void remove(const T& item, Node* &root) {
            if (root == nullptr) {
                return;
            }

            if (item < root->data) {
                remove(item, root->left);
            } else if (item > root->data) {
                remove(item, root->right);
            } else if (root->left != nullptr && root->right != nullptr) {
                root->data = right_subtree_min(root->right);
                remove(root->data, root->right);
            } else {
                Node* node_to_remove = root;
                root = root->left != nullptr ? root->left : root->right;
                delete node_to_remove;
            }

            balance(root);
        }



    public:
        AVLTree() : m_root(nullptr) {}

        AVLTree(const AVLTree& tree) : m_root(nullptr) {
            m_root = copy(tree.m_root);
        }

        ~AVLTree() {
            make_empty(m_root);
        }

        AVLTree& operator=(const AVLTree& rhs) {
            if (&rhs == this)
                return *this;
            
            make_empty(m_root);
            m_root = copy(rhs.m_root);
            return *this;
        }

        bool contains(const T& item) const {
            Node* temp = m_root;
            while (temp != nullptr) {
                if (temp->data == item)
                    return true;
                else if (temp->data < item)
                    temp = temp->right;
                else
                    temp = temp->left;
            }

            return false;
        }
        
        void insert(const T& item) {
            if (contains(item)) {
                ;
            } else {
                insert(item, m_root);
            }
        }

        void remove(const T& item) {
            if (!contains(item)) {
                ;
            } else {
                remove(item, m_root);
            }
        }

        const T& find_min() const {
            if (m_root == nullptr)
                throw std::invalid_argument("tree is empty");
            
            Node* temp = m_root;
            while (temp->left != nullptr) temp = temp->left;

            return temp->data;
        }

        const T& find_max() const {
            if (m_root == nullptr)
                throw std::invalid_argument("tree is empty");

            Node* temp = m_root;
            while (temp->right != nullptr) temp = temp->right;

            return temp->data;
        }

        void print_tree(std::ostream& out=std::cout) const {
            Node* temp = m_root;
            display(temp, out);
        }

        const T& root_data() const {
            return m_root->data;
        }

        void make_empty(Node* &node) {
            if (node != nullptr) {
                make_empty(node->left);
                make_empty(node->right);
                delete node;
            }

            node = nullptr;
        }
};