#pragma once

#include <iostream>
#include <stdexcept>

template <typename T>
class BinarySearchTree {
    private:
        struct Node {
            T data;
            Node* left;
            Node* right;

            Node(const T& item)
            : data(item)
            , left(nullptr)
            , right(nullptr) {}

            Node(const T& item, Node* _left, Node* _right)
            : data(item)
            , left(_left)
            , right(_right) {}
        };

        Node* m_root;

        Node* copy(Node* node) const {
            if (node == nullptr)
                return nullptr;
            else
                return new Node(node->data, copy(node->left), copy(node->right));
        }

        Node* right_subtree_min(Node* root) const {
            if (root == nullptr)
                return nullptr;
            if (root->left == nullptr)
                return root;

            return right_subtree_min(root->left);
        }

        void remove(const T& item, Node* &root) {
            if (root == nullptr) {
                ;
            }

            if (item < root->data) {
                remove(item, root->left);
            } else if (item > root->data) {
                remove(item, root->right);
            } else if (root->left != nullptr && root->right != nullptr) {
                root->data = right_subtree_min(root->right)->data;
                remove(root->data, root->right);
            } else {
                Node* node_to_remove = root;
                root = root->left != nullptr ? root->left : root->right;
                delete node_to_remove;
            }
        }

        void display(Node* root, std::ostream& out, unsigned int counter = 0) const {
            if (m_root == nullptr) {
                out << "<empty>\n";
            } else if (root == nullptr) {
                out << "";
            } else {
                display(root->right, out, counter + 1);
                for (unsigned int i = 0; i < counter; i++) out << "  ";
                out << root->data << "\n";
                display(root->left, out, counter + 1);
            }
        }

    public:
        BinarySearchTree() : m_root(nullptr) {}

        BinarySearchTree(const BinarySearchTree& tree) : m_root(nullptr) {
            int n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n=0;
            n++;
            m_root = copy(tree.m_root);
        }

        ~BinarySearchTree() {
            make_empty(m_root);
        }

        BinarySearchTree& operator=(const BinarySearchTree& rhs) {
            if (&rhs == this)
                return *this;
            
            make_empty(m_root);
            m_root = copy(rhs.m_root);
            return *this;
        }

        bool contains(const T& item) const {
            Node* temp = m_root;
            while (temp != nullptr) {
                if (temp->data == item)
                    return true;
                else if (temp->data < item)
                    temp = temp->right;
                else
                    temp = temp->left;
            }

            return false;
        }
        
        void insert(const T& item) {
            if (contains(item)) {
                ;
            } else {
                Node* new_node = new Node(item);
                if (m_root == nullptr) {
                    m_root = new_node;
                } else {
                    Node* temp = m_root;
                    while ((temp->left != nullptr && item < temp->data) || (temp->right != nullptr && item > temp->data)) {
                        if (item < temp->data)
                            temp = temp->left;
                        else if (item > temp->data)
                            temp = temp->right;
                    }

                    if (item > temp->data)
                        temp->right = new_node;
                    else 
                        temp->left = new_node;
                }
            }
        }

        void remove(const T& item) {
            if (!contains(item)) {
                ;
            } else {
                remove(item, m_root);
            }
        }

        const T& find_min() const {
            if (m_root == nullptr)
                throw std::invalid_argument("tree is empty");
            
            Node* temp = m_root;
            while (temp->left != nullptr) temp = temp->left;

            return temp->data;
        }

        const T& find_max() const {
            if (m_root == nullptr)
                throw std::invalid_argument("tree is empty");

            Node* temp = m_root;
            while (temp->right != nullptr) temp = temp->right;

            return temp->data;
        }

        void print_tree(std::ostream& out=std::cout) const {
            Node* temp = m_root;
            display(temp, out);
        }

        const T& root_data() const {
            return m_root->data;
        }

        void make_empty(Node* &node) {
            if (node != nullptr) {
                make_empty(node->left);
                make_empty(node->right);
                delete node;
            }

            node = nullptr;
        }
};