레드블랙 트리는 스스로 균형을 잡는 이진트리 라는 점에서
앞에서 살펴본 AVL과 비슷합니다.
또한, 레드블랙 트리는 실제로 키값을 2개 가지진 않지만 2-3트리의 성질도 가지고 있습니다.
AVL트리와 2-3트리의 성질을 다 가지고 있는 것이 레드블랙트리입니다.
아래 사이트에서 레드블랙트리가 어떻게 작동하는지 직관적으로 확인할 수 있습니다.
https://www.cs.usfca.edu/~galles/visualization/RedBlack.html
레드블랙트리의 특징
1. 이름에서 알 수 있듯, 노드는 빨간색 혹은 검정색의 색을 가질 수 있습니다.
- 색이 있는 이유 -> 빨간색 노드 + 검은색 노드 한 쌍이 마치 2-3노드에서 키값이 두개인 노드의 역할과 비슷합니다.
2-3노드의 기능을 가져오면서 구현은 편하게 하기 위해 이진트리의 형태를 한 것이라고 보면 됩니다.
- 여기서는 enum을 활용하여 빨간색을 kRed, 검은색을 kBlack 이라 표현하겠습니다.
enum Color { kRed, kBlack };
2. 루트 노드의 색은 항상 검정색이어야합니다.
- 빨간색 노드는 노드가 2노드(자식이2개) 인지 3노드(자식이 3개) 인지 구분할 떄 사용되는 것이기 때문입니다.
(실제로는 2-3노드처럼 작용하지만 구현은 이진트리의 형태로 하기에 빨간 노드를 사용하여 구분)
3. 균형이 맞추어져 있어야 합니다. ( 균형은 검정 노드로 판단. 빨간색의 노드는 균형에 영향을 미치지 않습니다. )
- 빨간 노드는 위의 2번에서 말했듯 노드의 종류를 구분하기 위한 것이므로 높이에 영향을 주지 않습니다.
아래 그림 또한 균형이 무너지지 않은 상태입니다.
4. 빨간색 노드는 연속될 수 없습니다. ( 어느 한 노드가 빨간색 노드라면 그 부모와 자식은 반드시 검정색 노드여야 합니다.)
- 2-3트리의 관점을 볼 때 빨간 노드는 부모 노드(반드시 검정)가 3노드라는 것을 나타내주는 노드이므로
빨간색 노드가 연속될 수는 없습니다.
5. 빨간 노드는 왼쪽으로 기울어져있다.
- 위에서 빨간 노드는 3노드인 것을 알려주기 위한 노드라고 설명드렸습니다.
여기에 추가로, 2-3트리의 관점에서 봤을 때 두개의 키값 중 작은 것을 빨간 노드로 표현했다고 생각하시면 됩니다.
레드블랙트리의 구현
먼저, 문자열 트리를 구현하기 위해 string -> Key로 사용하겠습니다.
using Key = std::string;
using Value = int;
class Node)
노드의 색을 표현하는 color와 높이를 나타내는 size (균형 확인하기 위함) 를 추가했습니다.
class Node {
public:
Key key;
Value val;
Color color;
int size;
Node* left;
Node* right;
Node(Key key, Value val, int N, Color color)
: key(key),
val(val),
size(N),
color(color),
left(nullptr),
right(nullptr) {}
private:
};
class RedBlackTree
AVL클래스에서 추가되거나 변경된 점들만 살펴보겠습니다.
- 노드의 색을 구별하기 위해 빨간색인지 아닌지 판별하는 함수가 추가되었습니다.
bool IsRed(Node* x) {
if (x == nullptr) return false;
return x->color == Color::kRed;
}
- search와 구조는 같지만 값이 있는지 없는지 판별하기 위해 ( 삭제할 때 사용 ) Contains 함수가 추가되었습니다.
bool Contains(Key key) { return Contains(root, key); }
bool Contains(Node* x, Key key) {
if (x == nullptr) return false;
if (key < x->key)
return Contains(x->left, key);
else if (key > x->key)
return Contains(x->right, key);
else
return true;
return false;
}
- 트리에서 가장 작은 값과 가장 큰 값을 찾아주는 Min, Max 함수가 추가되었습니다.
Key Min() { return Min(root)->key; }
Node* Min(Node* x) {
if (x->left == nullptr)
return x;
else
return Min(x->left);
return nullptr;
}
Key Max() { return Max(root)->key; }
Node* Max(Node* x) {
if (x->right == nullptr)
return x;
else
return Max(x->right);
return nullptr;
}
- 트리를 회전시키는 RotateLeft & RotateRight 함수가 변경되었습니다.
AVL과 비슷하지만 색을 구분해주는 부분, size를 바꿔주는 부분이 생겼습니다.
회전할 때 자식 노드였던 것이 부모 노드의 위치로 올라오면 색, 사이즈는 바뀌어야 합니다.(조건에 따라)
Node* RotateLeft(Node* h) {
Node* x = h->right;
h->right = x->left;
x->left = h;
x->color = h->color;
h->color = Color::kRed;
x->size = h->size;
h->size = 1 + Size(h->left) + Size(h->right);
return x;
}
Node* RotateRight(Node* h) {
Node* x = h->left;
h->left = x->right;
x->right = h;
x->color = h->color;
h->color = Color::kRed;
x->size = h->size;
h->size = 1 + Size(h->left) + Size(h->right);
return x;
}
- 노드의 색을 반전시켜주는 FlipColors 함수가 추가되었습니다.
색을 인자로 받는 함수는 색을 반전시켜주고
노드를 인자로 받는 함수는 그 노드와 자식들까지 색을 반전시켜줍니다. (왜인지는 아래에서 설명)
void FlipColors(Color& color) {
if (color == Color::kBlack)
color = Color::kRed;
else
color = Color::kBlack;
}
void FlipColors(Node* h) {
FlipColors(h->color);
FlipColors(h->left->color);
FlipColors(h->right->color);
}
- 트리의 균형이 잘 맞는지 확인하는 Balance 함수가 변경되었습니다.
여기서는 균형의 확인이 아닌 균형을 맞추는 용도로 사용됩니다.
(균형을 확인하기 위해 Node에 size를 추가해주었기에 별도로 확인하는 함수 구현할 필요 없음)
Node* Balance(Node* h) {
assert(h != nullptr);
if (IsRed(h->left) && IsRed(h->right)) { //두 자식 모두 빨간 노드이면 부모 노드까지 반전
FlipColors(h);
}
if (IsRed(h->left) && IsRed(h->left->left)) { //빨간 노드가 연속되지 않도록
h = RotateRight(h);
FlipColors(h);
}
if (!IsRed(h->left) && IsRed(h->right)) { //빨간 노드는 왼쪽으로 가야함
h = RotateLeft(h);
}
h->size = 1 + Size(h->left) + Size(h->right); //높이 재조정
return h; //균형을 맞춘 후 헤드 반환
}
-데이터를 삽입하는 Insert함수는 거의 동일합니다.
AVL트리같이 데이터를 삽입하고 위의 Balance함수를 이용해 균형을 맞춰줍니다.
균형을 맞추고 난 뒤 루트 노드의 색을 검정색으로 바꿔줍니다.
void Insert(Key key, Value val) {
root = Insert(root, key, val);
root->color = Color::kBlack;
}
Node* Insert(Node* h, Key key, Value val) {
if (h == nullptr) return new Node(key, val, 1, Color::kRed);
if (key < h->key) {
h->left = Insert(h->left, key, val);
} else if (key > h->key) {
h->right = Insert(h->right, key, val);
} else {
h->val = val;
}
return Balance(h);
}
- 핵심이라고도 할 수 있는 MoveRedLeft, MoveRedRight 함수입니다.
- 함수의 이름대로 빨간 노드를 왼쪽으로 , 오른쪽으로 옮기는 함수입니다.
그 이유는 아래의 삭제 부분에서 설명하겠습니다.
Node* MoveRedLeft(Node* h) {
FlipColors(h);
if (IsRed(h->right->left)) {
h->right = RotateRight(h->right);
h = RotateLeft(h);
FlipColors(h);
}
return h;
}
Node* MoveRedRight(Node* h) {
FlipColors(h);
if (IsRed(h->left->left)) {
h = RotateRight(h);
FlipColors(h);
}
return h;
}
-가장 작은 데이터를 찾아 삭제하는 DeleteMin 함수입니다.
void DeleteMin() {
if (!IsRed(root->left) && !IsRed(root->right)) root->color = Color::kRed;
root = DeleteMin(root);
if (!IsEmpty()) root->color = Color::kBlack;
}
Node* DeleteMin(Node* h) { //왼쪽 자식이 없다는 것은
if (h->left == nullptr) { //가장 작다는 의미이므로 바로 삭제
delete h;
return nullptr;
}
//가장 작은 노드를 삭제해야하는데 왼쪽 자식, 그 왼쪽 자식까지 검정색인상황
//검정 노드를 삭제하면 트리의 균형이 깨질 수 있으므로 빨간 노드를 왼쪽으로 옮김
//옮긴 뒤 빨간 노드를 삭제하면 균형에 변화가 없음
//따라서 MoveRedLeft가 중요하다고 볼 수 있다.
if (!IsRed(h->left) && !IsRed(h->left->left)) {
h = MoveRedLeft(h);
}
h->left = DeleteMin(h->left);
return Balance(h);
}
-가장 큰 데이터를 찾아 삭제하는 DeleteMax 함수입니다.
void DeleteMax() {
if (!IsRed(root->left) && !IsRed(root->right)) root->color = Color::kRed;
root = DeleteMax(root);
if (!IsEmpty()) root->color = Color::kBlack;
}
Node* DeleteMax(Node* h) {
if (IsRed(h->left)) h = RotateRight(h);
if (h->right == nullptr) { //오른쪽이 없다 = 가장 크다
delete h; //바로 삭제
return nullptr;
}
//DeleteMin에서도 설명했듯 균형을 위해
//빨간노드를 오른쪽으로.
if (!IsRed(h->right) && !IsRed(h->right->left)) {
h = MoveRedRight(h);
Print2D(h);
}
h->right = DeleteMax(h->right);
return Balance(h);
}
-마지막으로 특정 Key값을 찾아 삭제하는 Delete함수입니다.
- Key값을 기준으로 왼쪽 혹은 오른쪽으로 내려가면서 Rotate, MoveRed를 이용해 균형을 맞춰줍니다.
void Delete(Key key) {
if (!Contains(key)) return;
if (!IsRed(root->left) && !IsRed(root->right)) root->color = Color::kRed;
root = Delete(root, key);
if (!IsEmpty()) root->color = Color::kBlack;
}
Node* Delete(Node* h, Key key) {
if (key < h->key) {
if (!IsRed(h->left) && !IsRed(h->left->left)) {
h = MoveRedLeft(h);
}
h->left = Delete(h->left, key);
} else {
if (IsRed(h->left)) {
h = RotateRight(h);
}
if (key == h->key && h->right == nullptr) {
delete h;
return nullptr;
}
if (!IsRed(h->right) && !IsRed(h->right->left)) {
h = MoveRedRight(h);
}
if (key == h->key) {
h->key = Min(h->right)->key;
h->val = Min(h->right)->val;
h->right = DeleteMin(h->right);
} else {
h->right = Delete(h->right, key);
}
}
return Balance(h);
}
여기까지 레드블랙트리에 대해 알아보았습니다.
아래는 전체 클래스 코드입니다.
디버깅용 Print해주는 함수도 있습니다.
#include <cassert>
#include <iostream>
#include <string>
#include <vector>
using Key = std::string;
using Value = int;
enum Color { kRed, kBlack };
class Node {
public:
Key key;
Value val;
Color color;
int size;
Node* left;
Node* right;
Node(Key key, Value val, int N, Color color)
: key(key),
val(val),
size(N),
color(color),
left(nullptr),
right(nullptr) {}
private:
};
class RedBlackBST {
public:
Node* root = nullptr;
bool IsEmpty() { return root == nullptr; }
bool IsRed(Node* x) {
if (x == nullptr) return false;
return x->color == Color::kRed;
}
int Size() { return Size(root); }
int Size(Node* x) {
if (x == nullptr) return 0;
return x->size;
}
Value Search(Key key) { return Search(root, key); }
Value Search(Node* x, Key key) {
if (x == nullptr) return -1;
if (key < x->key)
return Search(x->left, key);
else if (key > x->key)
return Search(x->right, key);
else
return x->val;
return -1;
}
bool Contains(Key key) { return Contains(root, key); }
bool Contains(Node* x, Key key) {
if (x == nullptr) return false;
if (key < x->key)
return Contains(x->left, key);
else if (key > x->key)
return Contains(x->right, key);
else
return true;
return false;
}
Key Min() { return Min(root)->key; }
Node* Min(Node* x) {
if (x->left == nullptr)
return x;
else
return Min(x->left);
return nullptr;
}
Key Max() { return Max(root)->key; }
Node* Max(Node* x) {
if (x->right == nullptr)
return x;
else
return Max(x->right);
return nullptr;
}
Node* RotateLeft(Node* h) {
Node* x = h->right;
h->right = x->left;
x->left = h;
x->color = h->color;
h->color = Color::kRed;
x->size = h->size;
h->size = 1 + Size(h->left) + Size(h->right);
return x;
}
Node* RotateRight(Node* h) {
Node* x = h->left;
h->left = x->right;
x->right = h;
x->color = h->color;
h->color = Color::kRed;
x->size = h->size;
h->size = 1 + Size(h->left) + Size(h->right);
return x;
}
void FlipColors(Color& color) {
if (color == Color::kBlack)
color = Color::kRed;
else
color = Color::kBlack;
}
void FlipColors(Node* h) {
FlipColors(h->color);
FlipColors(h->left->color);
FlipColors(h->right->color);
}
Node* Balance(Node* h) {
assert(h != nullptr);
if (IsRed(h->left) && IsRed(h->right)) {
FlipColors(h);
}
if (IsRed(h->left) && IsRed(h->left->left)) {
h = RotateRight(h);
FlipColors(h);
}
if (!IsRed(h->left) && IsRed(h->right)) {
h = RotateLeft(h);
}
h->size = 1 + Size(h->left) + Size(h->right);
return h;
}
void Insert(Key key, Value val) {
root = Insert(root, key, val);
root->color = Color::kBlack;
}
Node* Insert(Node* h, Key key, Value val) {
if (h == nullptr) return new Node(key, val, 1, Color::kRed);
if (key < h->key) {
h->left = Insert(h->left, key, val);
} else if (key > h->key) {
h->right = Insert(h->right, key, val);
} else {
h->val = val;
}
return Balance(h);
}
Node* MoveRedLeft(Node* h) {
std::cout << "MoveRedLeft() " << h->key << '\n';
FlipColors(h);
if (IsRed(h->right->left)) {
h->right = RotateRight(h->right);
h = RotateLeft(h);
FlipColors(h);
}
return h;
}
Node* MoveRedRight(Node* h) {
std::cout << "MoveRedRight() " << h->key << '\n';
FlipColors(h);
if (IsRed(h->left->left)) {
h = RotateRight(h);
FlipColors(h);
}
return h;
}
void DeleteMin() {
if (!IsRed(root->left) && !IsRed(root->right)) root->color = Color::kRed;
root = DeleteMin(root);
if (!IsEmpty()) root->color = Color::kBlack;
}
Node* DeleteMin(Node* h) {
std::cout << "DeleteMin() " << h->key << '\n';
if (h->left == nullptr) {
std::cout << "Delete node " << h->key << '\n';
delete h;
return nullptr;
}
if (!IsRed(h->left) && !IsRed(h->left->left)) {
h = MoveRedLeft(h);
Print2D(h);
}
h->left = DeleteMin(h->left);
return Balance(h);
}
void DeleteMax() {
if (!IsRed(root->left) && !IsRed(root->right)) root->color = Color::kRed;
root = DeleteMax(root);
if (!IsEmpty()) root->color = Color::kBlack;
}
Node* DeleteMax(Node* h) {
std::cout << "DeleteMax() " << h->key << '\n';
if (IsRed(h->left)) h = RotateRight(h);
if (h->right == nullptr) {
std::cout << "Delete node " << h->key << '\n';
delete h;
return nullptr;
}
if (!IsRed(h->right) && !IsRed(h->right->left)) {
h = MoveRedRight(h);
Print2D(h);
}
h->right = DeleteMax(h->right);
return Balance(h);
}
void Delete(Key key) {
if (!Contains(key)) return;
if (!IsRed(root->left) && !IsRed(root->right)) root->color = Color::kRed;
root = Delete(root, key);
if (!IsEmpty()) root->color = Color::kBlack;
}
Node* Delete(Node* h, Key key) {
if (key < h->key) {
if (!IsRed(h->left) && !IsRed(h->left->left)) {
h = MoveRedLeft(h);
}
h->left = Delete(h->left, key);
} else {
if (IsRed(h->left)) {
h = RotateRight(h);
}
if (key == h->key && h->right == nullptr) {
std::cout << "Delete node " << h->key << '\n';
delete h;
return nullptr;
}
if (!IsRed(h->right) && !IsRed(h->right->left)) {
h = MoveRedRight(h);
}
if (key == h->key) {
h->key = Min(h->right)->key;
h->val = Min(h->right)->val;
h->right = DeleteMin(h->right);
} else {
h->right = Delete(h->right, key);
}
}
return Balance(h);
}
int Height() { return Height(root); }
int Height(Node* x) {
if (!x) return -1;
return 1 + std::max(Height(x->left), Height(x->right));
}
std::vector<std::string> screen;
void PrintLine(int x, std::string s, std::string& line) {
for (const auto c : s) line[x++] = c;
}
void Print2D() { Print2D(root); }
void Print2D(Node* root) {
if (!root)
std::cout << "Empty" << '\n';
else {
int h = Height(root) + 1, w = 4 * int(pow(2, h - 1));
screen.clear();
screen.resize(h * 2, std::string(w, ' '));
Print2D(root, w / 2 - 2, 0, h - 1);
for (const auto& l : screen) std::cout << l << '\n';
}
}
void Print2D(Node* n, int x, int level, int s) {
// cout << x << " " << level << " " << s << endl;
PrintLine(x, (IsRed(n) ? "*" : " ") + n->key, screen[2 * level]);
x -= int(pow(2, s));
if (n->left) {
PrintLine(x, " /", screen[2 * level + 1]);
Print2D(n->left, x, level + 1, s - 1);
}
if (n->right) {
PrintLine(x + 2 * int(pow(2, s)), "\\", screen[2 * level + 1]);
Print2D(n->right, x + 2 * int(pow(2, s)), level + 1, s - 1);
}
}
private:
};
'DataStructure' 카테고리의 다른 글
std::map , std::unordered_map (0) | 2024.09.26 |
---|---|
AVL vs RedBlack (0) | 2024.09.25 |
2-3Tree (2-3트리) (0) | 2024.09.25 |
AVL트리를 이용한 영어사전 만들기 (2) | 2024.09.25 |
AVL Tree (0) | 2024.09.24 |