[알고리즘] 세그먼트 트리 (indexed tree)

2020. 1. 9. 00:26Algorithm

인덱스 트리, 세그먼트라 부르는 알고리즘이다.

주로 데이터 삽입, 삭제보다는 값 갱신이 자주 있고, 구간 합을 구하는 데 사용하는 알고리즘이다.

 

주로 원노트에 그림과 함께 정리하는데 그것을 갖고와서 설명하겠다.

추천 문제

사탕상자 2243

구간 합 구하기 2042

 

사탕상자 코드

더보기
import java.io.*;
import java.util.*;

public class Main {

	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st;
		int N = Integer.parseInt(br.readLine());
		int a, b, c;
//		int Max = 4;
		int Max = 1000001;
		Candy candy = new Candy(Max);
		// IndexedTree tree= new IndexedTree(1000001);

		for (int i = 0; i < N; i++) {
			st = new StringTokenizer(br.readLine());
			a = Integer.parseInt(st.nextToken());
			if (a == 1) {// 이떄만 출력하기
				// 사탕을 꺼내는 경우,
				// 사탕상자에서 한개꺼내기
				// b 꺼낼사탕의 순위
				b = Integer.parseInt(st.nextToken());
				int index = candy.find(1, 1, Max, b);
				System.out.println(index);
				candy.update(1, 1, Max, index, -1);
			} else if (a == 2) {// 사탕을 넣는 경우
				// b는 사탕의 맛
				b = Integer.parseInt(st.nextToken());
				// c 사탕의 갯수
				// c==1 넣는 경우, c==-1 빼는 경우
				c = Integer.parseInt(st.nextToken());
				candy.update(1, 1, Max, b, c);
			}
//			System.out.println(candy);
		}

	}

	static class Candy {
		long[] flavor;

		@Override
		public String toString() {
			return "Candy [flavor=" + Arrays.toString(flavor) + "]";
		}

		public Candy(int N) {
			int h = (int) Math.ceil(Math.log(N) / Math.log(2) + 1);
			int size = (int) Math.pow(2, h);
			flavor = new long[size];
		}

		// a ==2 일때 작동
		void update(int node, int start, int end, int target, int value) {
			if (start > target || target > end) {
				return;
			} else {
				flavor[node] += value;
				if (start == end) {
					return;
				} else {
					int mid = (start + end) / 2;
					update(node * 2, start, mid, target, value);
					update(node * 2 + 1, mid + 1, end, target, value);
				}
			}
		}

		// a==1일때 target번째 사탕의 맛을 출력한다
		int find(int node, int start, int end, long rank) {
			// 갯수를 찾아 내려가서 그게 무슨 맛인지
			if (start == end) {
				return start;
			}

			// 자식들 사이즈 비교해서 내려가기
			// 아 왼쪽자식에 없으면 부모-왼쪽자식!
			int mid = (start + end) / 2;
			long leftchild = flavor[2 * node];
			long rightchild = flavor[2 * node + 1];
			if (rank <= leftchild) {
				return find(2 * node, start, mid, rank);
			} else {
				return find(2 * node + 1, mid + 1, end, rank - leftchild);
			}
			// return find(2 * node, start, mid, rank);
		}

	}
}

구간 합 구하기 코드

더보기
import java.io.*;
import java.util.*;

public class Main {

	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		int N = Integer.parseInt(st.nextToken()); // N개의 수
		int M = Integer.parseInt(st.nextToken()); // 수의 변경이 일어나는 회수
		int K = Integer.parseInt(st.nextToken()); // 구간의 합을 구하는 회수
		arr = new int[N + 1];
		for (int i = 1; i <= N; i++) {
			arr[i] = Integer.parseInt(br.readLine());
		}
		SegmentTree tree = new SegmentTree(N);
		tree.makeTree(1, N, 1);
		for (int i = 0; i < M + K; i++) {
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			int c = Integer.parseInt(st.nextToken());
			if (a == 1) {
				int diff = c - arr[b];
				tree.update(1, 1, N, b, diff);
				arr[b] = c;
			} else if (a == 2) {
				System.out.println(tree.query(1, 1, N, b, c));
			}
		}
		// a=1 인 경우 b번째 수를 c로 바꾼다
		// a=2인 경우 b~c까지의 합을 구한다.

	}

	static int[] arr;

	static class SegmentTree {
		long[] tree;

		@Override
		public String toString() {
			return "SegmentTree [tree=" + Arrays.toString(tree) + "]";
		}

		public SegmentTree(int N) {
			int h = (int) Math.ceil(Math.log(N) / Math.log(2)) + 1;
			int size = (int) Math.pow(2, h);
			tree = new long[size];
		}

		public long makeTree(int left, int right, int node) {
			if (left == right) {// 끝까지 왔으니 값을 돌려보내준다
				return tree[node] = arr[left];
			}
			int mid = (left + right) / 2;
			tree[node] += makeTree(left, mid, node * 2);// 왼쪽으로
			tree[node] += makeTree(mid + 1, right, node * 2 + 1);// 오른쪽으로

			return tree[node];
		}

		// left부터 right까지 더하겠다
		public long query(int node, int left, int right, int targetL, int targetR) {
			// 먼저 범위 넘어가면 쳐내기
			if (targetR < left || right < targetL) {// 잘못된 범위
				return 0;
			} else if (targetL <= left && right <= targetR) {
				return tree[node]; // targetL~~targetR까지의 구간 합
			} else {
				// 1~4라면 1~2 + 3~4 인것
				int mid = (left + right) / 2;
				return query(node * 2, left, mid, targetL, targetR)
						+ query(node * 2 + 1, mid + 1, right, targetL, targetR);
			}
		}

		public void update(int node, int left, int right, int targetIndex, long diff) {
			if (targetIndex < left || right < targetIndex) {// 잘못된 범위
				return;
			} else {
				tree[node] += diff; // 차이를 더해준다
				if (left == right) {// 끝까지 내려왔다
					return;
				} else {
					int mid = (left + right) / 2;
					update(node * 2, left, mid, targetIndex, diff);
					update(node * 2 + 1, mid + 1, right, targetIndex, diff);
				}
			}
		}// 바꾸면 맨처음 값 배열도 갱신해주기
	}
}

'Algorithm' 카테고리의 다른 글

[BOJ/17135] - 캐슬 디펜스  (0) 2020.06.06
[BOJ/17136] - 색종이 붙이기  (0) 2020.06.06
[BOJ/2146] - 다리 만들기  (0) 2020.06.06
[BOJ/17822] - 원판돌리기  (0) 2020.06.05
[백준] - 2805/나무자르기  (0) 2020.01.08