문제 링크

문제 해설이 아닌 풀이하면서 거쳐간 생각들을 기록하는 글입니다.


문제

정수로 이루어진 크기가 같은 배열 A, B, C, D가 있다.

A[a], B[b], C[c], D[d]의 합이 0인 (a, b, c, d) 쌍의 개수를 구하는 프로그램을 작성하시오.

입력

첫째 줄에 배열의 크기 n (1 ≤ n ≤ 4000)이 주어진다. 다음 n개 줄에는 A, B, C, D에 포함되는 정수가 공백으로 구분되어져서 주어진다. 배열에 들어있는 정수의 절댓값은 최대 \(2^{28}\)이다.

출력

합이 0이 되는 쌍의 개수를 출력한다.




접근 과정

  1. 최대 4000개씩 되는 배열 4개에서 하나씩 골라 쌍을 만들어야 한다. 다만 이 작업을 단순하게 4중 포문에 돌려 확인하면 어우… 제한 시간이 1000초여도 모자랄 것이다.

  2. 처음 생각했던 방법은 평균 O(1)의 탐색 알고리즘을 가지는 해쉬맵에 C와 D 배열의 합을 계산하여 저장한뒤(최대 1600만개) A와 B 배열 순서쌍의 합과 일치하는게 있는지 확인하는 방법이었다. 이러면 시간복잡도가 \(O(N^2 + N^2 \times 1) = O(N^2)\)이 되므로 이론상 시간이 차고도 넘친다. 메모리 제한도 1GB면 충분하지 않을까? 라고 기대했지만 unordered_map이 생각 이상으로 메모리를 많이 먹어서 메모리 초과로 실패했다. 다른 사람들의 코드를 살펴보고 안 사실은 해쉬맵을 사용한 코드들은 직접 최적화한 해쉬맵을 구현하여 메모리 초과를 피해가고 있었다.

  3. 제한 시간이 10초이니 \(O(N^3)\)은 힘들꺼 같고 \(O(N^2 logN)\) 정도의 알고리즘이면 가능할 것 같았다. A, B 순서쌍으로 \(N^2\)을 잡고 C, D에 적당히 전처리를 한 뒤 logN으로 탐색이 가능한 이진탐색을 통해 합이 0이될 수 있는 경우를 찾는 아이디어였다. 그런데 생각보다 구현이 쉽지 않았다. 이진탐색으로 값을 찾는 것 자체는 쉽게 구현할 수 있었으나, 배열 내에 동일한 C, D 순서쌍의 합이 여럿 존재할 경우 단순한 이진 탐색으론 이를 모두 확인할 수 없던 것이다.

  4. 적당히 고민해도 풀리지 않아 그동안 공부를 미뤄왔던 c++ 이진탐색 함수인 lower_bound()와 upper_bound() 살펴보았다. 이진탐색 알고리즘 문제의 다른 사람들 코드를 살펴보면 종종 볼 수 있던 함수인데 미루고미루다 이번에서야 확인보았다. 함수의 정의와 돌아가는 예제를 보니 아! 소리가 나오면서 문제를 어떻게 해결할지 감을 잡을 수 있었다. 탐색 기준이 다른 두 번의 이진 탐색으로 배열 내의 중복된 합의 범위를 도출할 수 있고 해당 범위가 A, B 합과 일치하는 C, D 순서쌍 합의 갯수가 된다. 이를 코드에 적용하여 통과할 수 있었다.

  5. 이후로 다른 사람들의 코드를 열람해보니 투 포인터 알고리즘을 적용한 코드가 보여 비슷하게 작성해보았다. 함수를 사용하지 않아 이진 탐색에 비해 코드는 조금 복잡해졌다. 속도는 더 새로 작성한 코드가 더 빠르지만 시간 복잡도의 병목점인 순서쌍 합을 저장한 배열을 정렬하는 절차가 그대로였기 때문에 극적인 차이가 나지는 않았다. (약 700ms 정도 빨라짐)


소스 코드

투포인터

#include <algorithm>
#include <iostream>

int main() {
    std::ios_base::sync_with_stdio(false); std::cin.tie(nullptr);

    int n;
    int squareN;
    int input[4000][4];
    int* sumAB;
    int* sumCD;

    std::cin >> n;
    for (int i = 0; i < n; ++i) {
        for (int k = 0; k < 4; ++k) {
            std::cin >> input[i][k];
        }
    }

    squareN = n * n;
    sumAB = new int[squareN];
    sumCD = new int[squareN];

    int front = 0;
    int back = 0;
    for (int i = 0; i < squareN; ++i) {
        sumAB[i] = input[front][0] + input[back][1];
        sumCD[i] = input[front][2] + input[back][3];

        ++back;
        if (back == n) {
            back = 0;
            ++front;
        }
    }

    std::sort(sumAB, sumAB + squareN);
    std::sort(sumCD, sumCD + squareN);

    long long ans = 0;
    int abCur = 0;
    int cdCur = squareN - 1;

    while (abCur < squareN && cdCur >= 0) {
        int total = sumAB[abCur] + sumCD[cdCur];

        if (total == 0) {
            int ab = sumAB[abCur++];
            int cd = sumCD[cdCur--];
            long long abCount = 1;
            long long cdCount = 1;

            while (abCur < squareN && sumAB[abCur] == ab) { 
                ++abCount;
                ++abCur;
            }

            while (cdCur >= 0 && sumCD[cdCur] == cd) {
                ++cdCount;
                --cdCur;
            }

            ans += abCount * cdCount;
        } else if (total < 0) {
            ++abCur;
        } else {
            --cdCur;
        }
    }

    std::cout << ans;

    delete[] sumAB;
    delete[] sumCD;
    return 0;
}

이진탐색

#include <algorithm>
#include <iostream>

int main() {
    std::ios_base::sync_with_stdio(false); std::cin.tie(nullptr);

    int n;
    int squareN;
    int input[4000][4];
    int* sumAB;
    int* sumCD;

    std::cin >> n;
    for (int i = 0; i < n; ++i) {
        for (int k = 0; k < 4; ++k) {
            std::cin >> input[i][k];
        }
    }

    squareN = n * n;
    sumAB = new int[squareN];
    sumCD = new int[squareN];

    /* ----- 여기까진 위 코드와 동일 ----- */

    int front = 0;
    int back = 0;
    for (int i = 0; i < squareN; ++i) {
        sumAB[i] = -(input[front][0] + input[back][1]);
        sumCD[i] = input[front][2] + input[back][3];

        ++back;
        if (back == n) {
            back = 0;
            ++front;
        }
    }

    std::sort(sumAB, sumAB + squareN);
    std::sort(sumCD, sumCD + squareN);

    long long ans = 0;
    for (int i = 0; i < squareN; ++i) {
        int* l = std::lower_bound(sumCD, sumCD + squareN, sumAB[i]);
        if (*l == sumAB[i]) {
            int* r = std::upper_bound(sumCD, sumCD + squareN, sumAB[i]);
            ans += r - l;
        }
    }

    std::cout << ans;

    delete[] sumAB;
    delete[] sumCD;
    return 0;
}