얼렁뚱땅 JAVA 알고리즘

[JAVA] 구간 합, 차를 빠르게 구할 수 있는 세그먼트 트리

MOSTAR 2023. 4. 14. 00:18

문제가 쉽다. != 내가 실력이 늘었다.

그저 시간을 빡세게 잡기 때문에 내가 생각한 방법으로는 통과될 수 없는 것이다.

구간 합, 차와 같은 문제를 실제로 만나면 나는 Prefix Sum으로 주로 문제 해결을 하고자 한다.

 

그런데, PrefixSum은 특정 구간의 합을 구할 때 O(1)이라는 장점은 분명하게 존재하지만,

단점은 중간 값을 수정한다면 O(N)의 시간이 걸린다는 단점이 존재한다.

 

그래서 저걸로 풀면 시간초과 뜰 때가 있다.

여기서 배워야 하는 것이 세그먼트 트리다.

세그먼트 트리의 갱신은 O(NlogN)으로 시간을 줄일 수 있다.

몰라 ? 일단 외워.

 

import java.util.*;
import java.io.*;

class SegmentTree {
	long [] tree;
	int treeSize;
	
	SegmentTree(int n) {
		int h = (int)Math.ceil(Math.log(n)/Math.log(2));
		this.treeSize = (int)Math.pow(2, h+1);
		tree = new long[treeSize];
	}
	
	long init(long [] arr, int node, int start, int end) {
		
		// 리프노드 인거
		if(start == end) {
			return tree[node] = arr[start];
		} 
		
		return tree[node] = init(arr, node*2, start, (start+end)/2)
				+ init(arr,node*2+1, (start+end)/2+1, end);
	}
	
	void update(int node, int start, int end, int idx, long diff) {
		if(idx < start || end < idx) {
			return;
		}
		
		tree[node] += diff;
		
		if(start != end) {
			update(node*2, start, (start+end)/2, idx, diff);
			update(node*2+1, (start+end)/2+1, end, idx, diff);
		}
	}
	
	long sum(int node, int start, int end, int left, int right) {
		
		if(left > end || right < start) {
			return 0;
		}
		
		if(left <= start && end <= right) {
			return tree[node];
		}
		
		return sum(node*2, start, (start+end)/2, left, right) +
				sum(node*2+1, (start+end)/2+1, end, left, right);
	}

}


public class 세그먼트트리 {

	public static void main(String[] args) throws IOException{
		// TODO Auto-generated method stub
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		String [] tmp = br.readLine().strip().split(" ");
		int n = Integer.parseInt(tmp[0]);
		int times = Integer.parseInt(tmp[1])+Integer.parseInt(tmp[2]);
		
		long [] arr = new long[n+1];
		for(int i=1;i<=n;i++) {
			arr[i] = Integer.parseInt(br.readLine().strip());
		}
		
		SegmentTree tree = new SegmentTree(n);
		tree.init(arr, 1, 1, n);
		
		for(int i=0;i<times;i++) {
			tmp = br.readLine().strip().split(" ");
			int type = Integer.parseInt(tmp[0].strip());
			if(type == 1) {
				int update_node = Integer.parseInt(tmp[1]);
				int update_num = Integer.parseInt(tmp[2]);
				long diff = update_num - arr[update_node];
				tree.update(1, 1, n, update_node, diff);
				
			} else {
				int left = Integer.parseInt(tmp[1]);
				int right = Integer.parseInt(tmp[2]);
				long sum_value = tree.sum(1, 1, n, left, right);
				System.out.println(sum_value);
			}
		}
	}

}