Algorithm

[Algorithm] 누적합(prefix sum) 알고리즘 with Java

eunmiee 2023. 9. 20. 03:41

안녕하세요 !

두근두근 첫번째 포스팅 주제는 바로 "누적합 알고리즘" 입니다.

최근 알고리즘 수업 TA 를 진행하면서 누적합 알고리즘에 대한 설명을 학생들에게 직접 알려주어야 하는 기회가 생겼습니다..!

그래서 저도 누적합 알고리즘에 대해 공부를 하게 되었는데요.

공부하며 이해한 내용들을 정리해보며 오늘의 포스팅 주제로 담아보겠습니다 !


00 시작하며


누적합 알고리즘이 뭔가요?

누적합 알고리즘은 말 그대로 누적된 합을 찾는 알고리즘인데요.

그래서 그게 어디에 해당하는 누적합인데?

배열이 주어졌을 때, 해당 인덱스 범위 내에서의 원소들의 합을 빠르게 계산하는 알고리즘입니다.

예를 들면, [1,2,3,4,5]라는 배열이 주어졌을 때,

이를 누적합 알고리즘을 사용하여 계산하면 [1,3,6,10,15]라는 결과값을 출력하게 되는거죠.

 

 

01 1차원 누적합


가장 먼저 1차원 누적합에 대해 이해를 해본 후 2차원 누적합까지 접근해보도록 하겠습니다.

여러분들은 만약 1차원 배열에서 인덱스 3 ~ 5번 범위안에 있는 값의 합을 구하려면 어떻게 접근하실건가요?

저는 이 알고리즘에 대한 개념이 없을 때에는 단순히 for문을 돌리며 모든 인덱스를 탐색하다가 해당 범위안의 인덱스 부분만 if문으로 조건을 걸어 계산을 했었답니다.. 하하😅

나중에 문제풀이를 하면서 부끄럽지만 이부분에 대한 코드도 보여드리도록 하겠습니다..

당연히 이렇게 접근을 한다면 많은 시간이 걸리게 되겠죠? 아마도 시간복잡도로 따진다면 O(N)이 될거 같네요..

 

그렇다면 어떻게 접근을 해야 더 빠르게 구간합을 찾을 수 있을까요?

 

바로 누적합 알고리즘을 사용하면 되는데요!

누적합 알고리즘은 말그대로 "누적 합"을 구하여 구간의 합을 구할 수 있답니다.

 

우선 0번째 인덱스부터 시작하는 모든 구간에 대해 합을 누적시키며 값을 저장해보도록 하겠습니다.

 

예를 들어 설명을 해보도록 할게요.

 

Q. 누적합은 어떻게 구하나요?


누적합 알고리즘을 사용하여 0부터 시작하는 모든 인덱스 범위 내의 원소들의 합을 구해보며 이해해봅시다.

현재 배열에는 6, 4, 3, 2, 9라는 값이 순차적으로 들어있습니다.

우선 현재 누적합을 구하기 위해서는 추가적으로 누적합을 저장할 배열이 하나 더 필요합니다. prefix라고 우선 이름을 붙여두도록 할게요.

prefix 배열의 가장 첫번째 인덱스에는 아직은 누적합이 0이기 때문에 array 배열의 첫번째 인덱스 값을 그대로 가져옵니다.

 

자, 이렇게 첫번째 0~0번째 인덱스까지의 누적합을 구했습니다.

다음으로 0 ~ 1번째 인덱스까지의 누적합을 구해볼게요.

 

Q. 여기서 생각해야할 점은 여러분들은 어떻게 누적합을 구할 건가요?

 

흠.. 그냥 array[0] + array[1] 로 구하면 되는거 아닌가요?

물론 그런 방법도 있겠지만 그렇게 접근을 하게 된다면 인덱스의 범위가 점점 커져갈수록 우리는 계속해서 array[0] + array[1] + array[2] + ... + array[n] 이런식으로 구하게 될 것이예요. 결국에는 모든 array에 대한 index를 모두 돌면서 누적합을 찾아야하는 것이죠.. 그렇게 되면 인덱스의 범위가 커질수록 점점 속도가 느려지겠죠?

 

그래서 다른 방법으로 접근을 해볼거예요.

누적합에서는 현재 prefix[0]에 이미 0번 인덱스까지의 누적합을 가지고 있으니 prefix[0] + array[1]로 구하면 이후에 인덱스의 범위가 커지더라도 prefix[n-1] + array[n]으로 구할 수 있으니 더 간결하고 빠르게 계산을 할 수 있겠죠?

 

 

 

이러한 방법을 사용하면서 인덱스 0 ~ 2, 0 ~ 3, 0 ~ 4번째 인덱스의 누적합을 구해보도록 할게요!

이렇게해서 누적합을 구해보았습니다. 이제 0부터 시작하는 모든 인덱스 범위에 대한 누적합은 모두 구한거겠죠?

 

 

그러면 여기서부터 더 적용해나가보며 구간합까지 구해볼까요?

 

Q. 조금 더 나아가서 구간 A ~ B까지의 합은 어떻게 구할까요?


만약 위의 배열에서 1 ~ 3번까지 구간의 합을 구하고 싶다면 어떻게 찾을 수 있을까요?

 

물론 여기서도 마찬가지로 for문으로 구현하여 해당하는 i번째에 조건문을 걸어 구현하는 방법도 있겠죠? 하지만 그렇게 구현하게 되면 배열의 값이 커질수록 많은 시간이 소요되게 됩니다.

 

그렇다면 방금 누적합 알고리즘을 활용해서 접근해볼까요?

위의 그림처럼 array[1] ~ array[3]까지의 구간합을 구하려면 array[3]번까지의 누적합에서 array[0]번까지의 누적합을 빼주면 됩니다.

 

그래서 우리는 누적합을 이용하여 한가지 공식을 찾을 수 있게 되는데 이는 다음과 같습니다.

a ~ b번째까지의 구간합 = prefix[b] - prefix[a-1]

만약 a = 1, b = 3이라고 할 때, prefix 배열을 사용하여 구간합을 구하면 prefix[3] - prefix[0] = 15 - 6 = 9가 나오게 되는 것이죠!

하지만 위의 그림대로라면 a-1번째에 의해 a가 0이라면 인덱스 범위에 대한 에러가 발생하므로, 코드로 구현할 때에는 다음과 같이 설계해야겠죠.

 

자, 이렇게 해서 저희는 1차원 누적합 알고리즘에 대해 이해하고, 구간합을 구하는 것까지 알아보았습니다.

 

 

이제 코드로 한번 살펴볼까요?

 

👩🏻‍💻 코드로 적용해보기


관련된 문제로 백준에 있는 구간 합 구하기4 문제를 풀어보며 복습해봅시다.

 

함께 풀어보기

https://www.acmicpc.net/problem/11659

 

11659번: 구간 합 구하기 4

첫째 줄에 수의 개수 N과 합을 구해야 하는 횟수 M이 주어진다. 둘째 줄에는 N개의 수가 주어진다. 수는 1,000보다 작거나 같은 자연수이다. 셋째 줄부터 M개의 줄에는 합을 구해야 하는 구간 i와 j

www.acmicpc.net


java code

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {

    static int[] index;
    static int[] prefix;
    
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken());

        index = new int[n + 1];
        prefix = new int[n + 1];
        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            index[i] = Integer.parseInt(st.nextToken());
        }

        for (int i = 1; i <= n; i++) {
            prefix[i] = prefix[i - 1] + index[i];
        }

        for (int t = 0; t < m; t++) {
            st = new StringTokenizer(br.readLine());
            int i = Integer.parseInt(st.nextToken());
            int j = Integer.parseInt(st.nextToken());

            int sum = prefix[j] - prefix[i - 1];
            System.out.println(sum);
        }

    }
}

 


 

그렇다면, 여기서 2차원 배열에서 누적합 알고리즘은 어떻게 동작하고, 구간합은 어떻게 구하는걸까?

 


 

02 2차원 누적합


1차원에서 나아가 2차원 배열에서는 누적합 알고리즘을 어떻게 활용할까요?

* 위의 array 배열은 앞에 0번째 인덱스가 숨어있다고 가정하고 가장 왼쪽 위에 있는 1은 (1,1)로 생각해주세요!

1차원 배열에서 누적합을 구했던 것과 같이 적용을 하여 위의 그림처럼 구해볼 수 있습니다. 그렇다면 그 다음 (1, 1) ~ (2, 2)까지의 누적합을 구하고 싶을 때, 어떻게 찾으면 될까요?

 

그림을 통해 접근을 해보도록 합시다.

 

우선 array[1][1] ~ array[1][2]의 누적합은 prefix[1][3]에 저장이 되어 있을 것입니다. 마찬가지로 array[1][1] ~ array[2][1] 까지의 누적합은 prefix[2][1]에 저장이 되어 있습니다.

 

그렇다면 array 배열의 (1, 1) ~ (2, 2)까지의 누적합을 찾고 싶다면 어떻게 접근할 수 있을까요?

우리는 이미 array[1][1] ~ array[1][2], array[1][1] ~ array[2][1]까지의 누적합을 알고 있습니다. 그 누적합에서 공통된 array[1][1]의 값을 빼준다면 array[1][1] + array[1][2] + array[2][1]의 누적합을 알 수 있겠죠?

여기에 현재 찾고자하는 array[2][2]까지 더하면 (1, 1)부터 (2,2)까지의 구간합을 구할 수 있게 됩니다.

 

즉, 이를 수식으로 나타낸다면 array 배열에서 (1, 1) ~ (a, b)의 누적합을 구하는 공식은

prefix[a][b] = prefix[a-1][b] + prefix[a][b-1] - prefix[a-1][b-1] + array[a][b]

로 나타낼 수 있습니다.

 

이를 통해 구한 array 배열에서 (1, 1)에서부터 시작하는 모든 원소의 누적합은 다음과 같습니다.

 

Q. (x1, y1) ~ (x2, y2)까지의 구간합은 어떻게 구할까?


이제 우리는 앞서 누적합 알고리즘과 그것을 활용하여 구간합까지 구하는 법을 알았기 때문에 간단히 적용만 하면 됩니다 !

 

예시를 들어 보겠습니다. 만약 (2, 2) ~ (3, 3)의 구간합을 구하고자 할 때, 어떻게 구할 수 있을까요?

 

이렇게 (3, 3)까지의 모든 누적합에서 (1, 3)까지의 누적과 (3, 1)까지의 누적합을 빼준 뒤, 겹치는 부분인 (1, 1)을 더해주면 해당 구간의 합을 구할 수 있겠죠?

 

이를 누적합 표에 표시를 해두고, 공식을 유도해 봅시다.

 

(2, 2) ~ (4, 4)까지의 구간합은 위의 그림을 통해 적용해 보면 prefix[4][4] - prefix[4][2] - prefix[2][4] + prefix[1][1] = 18 - 3 - 6 + 1 = 10으로 구간 합을 확인할 수 있죠 :)

 

이를 적용해 2차원 배열의 구간합 공식은 다음과 같습니다.

(x1, y1) ~ (x2, y2)까지의 구간합 = prefix[x2][y2] - prefix[x2][y1] - prefix[x1][y2] + prefix[x1-1][y1-1]

 

이번에는 지금까지 알아봤던 2차원 배열에서의 누적합 알고리즘을 활용하여 코드로 적용해보도록 하겠습니다.

 

👩🏻‍💻 코드로 적용해보기


함께 풀어보기

https://www.acmicpc.net/problem/11660

 

11660번: 구간 합 구하기 5

첫째 줄에 표의 크기 N과 합을 구해야 하는 횟수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져 있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네

www.acmicpc.net


번외적으로 아까 앞서서 말했던 것처럼 저도 이 문제를 처음 접했을 땐 for문을 이용해서 구현을 했었는데요.. 그 결과는 다음과 같이 "시간초과"가 나버렸죠.. 허허🥲

시간초과...

참고하자면 저 당시의 코드는 이렇습니다...

import java.io.*;
import java.util.StringTokenizer;

public class Main {

    static int[][] graph;
    static Point[] point;

    static class Point {
        int x1;
        int y1;
        int x2;
        int y2;

        public Point(int x1, int y1, int x2, int y2) {
            this.x1 = x1;
            this.y1 = y1;
            this.x2 = x2;
            this.y2 = y2;
        }
    }
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        int n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken());

        graph = new int[n][n];
        point = new Point[m];
        int sum = 0;


        for (int i = 0; i < n; i++) {
            st = new StringTokenizer(br.readLine());
            for (int j = 0; j < n; j++) {
                graph[i][j] = Integer.parseInt(st.nextToken());
            }
        }

        for (int i = 0; i < m; i++) {
            st = new StringTokenizer(br.readLine());
            int x1 = Integer.parseInt(st.nextToken());
            int y1 = Integer.parseInt(st.nextToken());
            int x2 = Integer.parseInt(st.nextToken());
            int y2 = Integer.parseInt(st.nextToken());
            point[i] = new Point(x1, y1, x2, y2);
        }

        for (int t = 0; t < m; t++) {
            sum = 0;
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    if (point[t].x1 - 1 <= i && i <= point[t].x2 - 1 && point[t].y1 - 1 <= j && j <= point[t].y2 - 1) {
                        sum += graph[i][j];
                    }
                }
            }
            bw.write(sum);
        }
    }
}

 

그렇지만 ... ! 누적합 알고리즘을 이해한 저는 더이상 이렇게 풀지 않죠 !!

 

성공한 코드는 다음과 같습니다.

야호!

java code

import java.io.*;
import java.util.StringTokenizer;

public class Main {

    static int[][] graph;
    static int[][] prefix;



    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken());

        graph = new int[n + 1][n + 1];
        prefix = new int[n + 1][n + 1];


        for (int i = 1; i < n + 1; i++) {
            st = new StringTokenizer(br.readLine());
            for (int j = 1; j < n + 1; j++) {
                graph[i][j] = Integer.parseInt(st.nextToken());
            }
        }

        for (int i = 1; i < prefix.length; i++) {
            for (int j = 1; j < prefix.length ; j++) {
                prefix[i][j] = prefix[i][j - 1] - prefix[i - 1][j - 1] + prefix[i - 1][j] + graph[i][j];
            }
        }


        for (int i = 0; i < m; i++) {
            st = new StringTokenizer(br.readLine());
            int x1 = Integer.parseInt(st.nextToken());
            int y1 = Integer.parseInt(st.nextToken());
            int x2 = Integer.parseInt(st.nextToken());
            int y2 = Integer.parseInt(st.nextToken());

            int sum = prefix[x2][y2] - prefix[x1-1][y2] - prefix[x2][y1-1] + prefix[x1-1][y1-1];
            System.out.println(sum);
        }

    }
}

 

이상으로 누적합 알고리즘에 대한 포스팅을 마치겠습니다.

부족하지만 긴 글 읽어주셔서 감사합니다 :)