#include <iostream>
using namespace std;
template<class T>
class Node {
public:
T value;
Node<T>* left, * right;
int height;
Node() {
value = 0;
left = right = nullptr;
}
Node(T value) {
this->value = value;
left = right = nullptr;
height = 1;
}
};
template<class T>
class BST {
int size;
Node<T>* root;
public:
int getBalanceFactor(Node<T>* node) {
if (node == nullptr) {
return;
}
return (node->left->height) - (node->right->height);
}
Node<T>* rightRotate(Node<T>*node , T value) {
Node<T>* x = node->left;
Node<T>* y = x->right;
x->right = node;
node->left = y;
node->height = 1 + max(node->left->height, node->right->left);
x->height = 1 + max(x->left->height, x->right->left);
return x;
}
Node<T>* leftRotate(Node<T>* node, T value) {
Node<T>* x = node->right;
Node<T>* y = x->left;
x->left = node;
node->right = y;
node->height = 1 + max(node->left->height, node->right->left);
x->height = 1 + max(x->left->height, x->right->left);
return x;
}
Node<T>* insert(Node<T>* node , T value) { // O(log n)
if (node == nullptr) {
return node(value);
}
if (value < node->value) {
node->left = insert(node->left,value);
}
else if (value > node->value) {
node->right = insert(node->right,value);
}
else{
return node;
}
node->height = 1 + max(node->left->height, node->right->height);
int balanceFactor = getBalanceFactor(node);
if (balanceFactor > 1 && value < node->left->value) {
return rightRotate(node);
}
if (balanceFactor < -1 && value > node->right->value) {
return leftRotate(node);
}
if (balanceFactor > 1 && value > node->left->value) {
node->left = leftRotate(node->left);
return rightRotate(node);
}
if (balanceFactor < -1 && value < node->right->value) {
node->right = rightRotate(node->right);
return leftRotate(node);
}
return node;
}
Node<T>* remove(Node<T>* node, T value) { //O(log n)
if (node == nullptr) {
return node;
}
if (value > node->value) {
node->right = remove(node->right, value);
}
else if (value < node->value) {
node->left = remove(node->left, value);
}
else {
if (node->left != nullptr) {
Node<T>* tmp = node->left;
delete node;
return tmp;
}
else if (node->right!=nullptr) {
Node<T>* tmp = node->right;
delete node;
return tmp;
}
else {
Node<T>* tmp = node->right;
while (tmp != nullptr) {
tmp = tmp->left;
}
node->value = tmp->value;
node->right = remove(node->right,value);
}
}
node->height = 1 + max(node->left->height, node->right->height);
int balanceFactor = getBalanceFactor(node);
if (balanceFactor > 1 && value < node->left->value) {
return rightRotate(node);
}
if (balanceFactor < -1 && value > node->right->value) {
return leftRotate(node);
}
if (balanceFactor > 1 && value > node->left->value) {
node->left = leftRotate(node->left);
return rightRotate(node);
}
if (balanceFactor < -1 && value < node->right->value) {
node->right = rightRotate(node->right);
return leftRotate(node);
}
return node;
}
void clear() {
while (isEmpty() == false)
remove(root->value);
}
~BST() {
clear();
}
};
int main() {