ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 세그먼트 트리 (Segment Tree)
    알고리즘/Algorithm Theory 2024. 3. 12. 13:48

    - 일반적인 구간 합 계산

    int arr[5] = {3, 2, 1, 5, 4};

     

    arr에서의 2번째에서 4번째까지의 합 : 2 + 1 + 5 = 8

    배열의 길이가 N 일 때, 일반적인 구간 합 계산으로는 O(N)의 시간 복잡도를 가진다.

    만약 N의 길이가 조금만 커져도 해당 연산을 반복하게 되면,  매우 오랜 시간을 필요로 하게 된다..

     

    해당 경우, 세그먼트 트리를 사용하여, O(logN)의 시간 복잡도로 구간 합을 구할 수 있다.

     

     

     

    - 세그먼트 트리

    해당 배열을 세그먼트 트리로 나타내면 아래와 같다.

     

    트리는 노드는 9 개, 높이는 3인 이진트리 형태를 가진다.

    루트 노드 : 총 합(전체 배열의 합)

    리프 노드 : 배열의 각 원소 값

    트리의 높이 : math.ceil(log2(N)) //ceil은 올림 함수

     

     

     

    - 코드

    #include <iostream>
    #include <vector>
    #include <map>
    #include <unordered_map>
    #include <cstring>
    #include <queue>
    #include <stack>
    using namespace std;
    vector<int> tree;
    int N = 5;
    int arr[5] = { 3, 2, 1, 5, 4 };
    int MakeSegmentTree(int node, int start, int end) 
    {
    	if (start == end)
    	{
    		tree[node] = arr[start];	// 리프 노드에 도달하면 
    		return tree[node];
    	}
    	int mid = (start + end) / 2;
    	int leftResult = MakeSegmentTree(node * 2, start, mid);
    	int rightResult = MakeSegmentTree(node * 2 + 1, mid + 1, end);
    	tree[node] = leftResult + rightResult;
    	return tree[node];
    }
    
    //노드 번호, start ~ end 해당 노드가 포함하고 있는 범위, left ~ right 구하고자 하는 범위
    int Sum(int node, int start, int end, int left, int right)
    {
    	// 부분합이 범위를 벗어나는 경우
    	if (left > end || right < start) return 0;
    	// 부분합이 범위 안에 속하는 경우
    	if (left <= start && end <= right) return tree[node];
    	// 부분합이 범위에 걸치는 경우
    	int mid = (start + end) / 2;
    	int leftResult = Sum(node * 2, start, mid, left, right);
    	int rightResult = Sum(node * 2 + 1, mid + 1, end, left, right);
    	return leftResult + rightResult;
    }
    
    void UpdateSegmentTree(int node, int start, int end, int index, int diff)
    {
    	// 해당 인덱스가 범위를 벗어나는 경우
    	if (index < start || index > end) return;
    	
    	tree[node] = tree[node] + diff;
    	if (start != end)
    	{
    		int mid = (start + end) / 2;
    		UpdateSegmentTree(node * 2, start, mid, index, diff);
    		UpdateSegmentTree(node * 2 + 1, mid + 1, end, index, diff);
    	}
    }
    
    int main()
    {
    	ios::sync_with_stdio(false);
    	cin.tie(0);
    	cout.tie(0);
    	int N = 5;
    	int height = ceil(log2(N));
    	int size = (1 << (height + 1));
    	tree.resize(size);
    
    	MakeSegmentTree(1, 0, N - 1);
    	cout << "value : " << Sum(1, 1, N, 1, 5) << "\n";
    	int index = 0;
    	int newValue = 20;
    	int diff = newValue - arr[index];
    	arr[index] = newValue;
    	UpdateSegmentTree(1, 0, N - 1, index, diff);
    	cout << "newValue : " << Sum(1, 1, N, 1, 5);
    	return 0;
    }
Designed by Tistory.