문제 링크




문제

트리(tree)는 사이클이 없는 무방향 그래프이다. 트리에서는 어떤 두 노드를 선택해도 둘 사이에 경로가 항상 하나만 존재하게 된다. 트리에서 어떤 두 노드를 선택해서 양쪽으로 쫙 당길 때, 가장 길게 늘어나는 경우가 있을 것이다. 이럴 때 트리의 모든 노드들은 이 두 노드를 지름의 끝 점으로 하는 원 안에 들어가게 된다.

사진1

이런 두 노드 사이의 경로의 길이를 트리의 지름이라고 한다. 정확히 정의하자면 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이를 말한다.

입력으로 루트가 있는 트리를 가중치가 있는 간선들로 줄 때, 트리의 지름을 구해서 출력하는 프로그램을 작성하시오. 아래와 같은 트리가 주어진다면 트리의 지름은 45가 된다.

사진2

트리의 노드는 1부터 n까지 번호가 매겨져 있다.

입력

파일의 첫 번째 줄은 노드의 개수 n(1 ≤ n ≤ 10,000)이다. 둘째 줄부터 n-1개의 줄에 각 간선에 대한 정보가 들어온다. 간선에 대한 정보는 세 개의 정수로 이루어져 있다. 첫 번째 정수는 간선이 연결하는 두 노드 중 부모 노드의 번호를 나타내고, 두 번째 정수는 자식 노드를, 세 번째 정수는 간선의 가중치를 나타낸다. 간선에 대한 정보는 부모 노드의 번호가 작은 것이 먼저 입력되고, 부모 노드의 번호가 같으면 자식 노드의 번호가 작은 것이 먼저 입력된다. 루트 노드의 번호는 항상 1이라고 가정하며, 간선의 가중치는 100보다 크지 않은 양의 정수이다.

출력

첫째 줄에 트리의 지름을 출력한다.




접근 과정

  1. 처음 문제를 봤을 때 이 문제가 끝단에 위치한 노드간 이동시 최장거리를 구하는 문제로 생각하였고, 입력 받은 트리에서 끝단에 위치한(간선이 하나 밖에 없는) 노드를 시작점으로 깊이 우선 탐색(DFS)을 진행하고 그 최장 거리를 반환하도록 코드를 작성하였다.
  2. 결과적으로 한번에 통과하였다. 하지만 중복된 구현이 있어 코드가 지저분한 느낌이 있기도 했고 무엇보다 본인의 실행 시간은 400ms 였는데 다른 사람의 답안 중 10ms 안으로 해결된 경우가 보여 다른 풀이 방법을 고민하기로 했다.
  3. 처음에는 메모이제이션을 통해 노드간 이동시 최장거리를 기록할까 했지만 최대 노드 수가 1만개이므로 메모리가 부족할 것 같았다. 실제로 채점 현황에서도 메모리 초과로 실패한 경우가 종종 보였다.
  4. 다음으로 떠오른 방법은 분할정복법이었다. 문제에서 트리의 루트 노드를 1로 지정해주었으니 루트 노드를 기준으로 트리를 분할해가며 문제를 해결할 수 있을 것 같았다.
  5. 전체 문제를 루트 노드를 시작으로 각 자식 노드의 문제로 분할한다. 각 노드는 자식 노드별로 끝단 까지 가중치를 합했을 때 최대값을 구하고 이를 더하여 결과값과 비교하다보면 최종 결과값을 얻을 수 있다. 이를 solution() 함수로 구현할 때 트리의 자식 노드가 2개 이하라는 보장이 없으므로 그에 대한 처리를 해주어야 했으며 결과적으로 8ms 실행시간으로 문제를 해결할 수 있었다.


소스 코드

DFS를 사용한 풀이

#include <iostream>
#include <vector>
using namespace std;

struct edge_t {
    int node;
    int weight;
};

vector<edge_t> tree[10001];
int bVisited[10001];

int dfs(int cur, int totalWeight) {
    if (totalWeight != 0 && tree[cur].size() == 1) {    // is leaf
        return totalWeight;
    }

    int max = 0;
    bVisited[cur] = true;
    for (auto iter = tree[cur].begin(); iter != tree[cur].end(); ++iter) {
        if (!bVisited[iter->node]) {
            int distance = dfs(iter->node, totalWeight + iter->weight);
            if (distance > max) {
                max = distance;
            }
        }
    }
    bVisited[cur] = false;
    return max;
}

int main() {
    int n;
    vector<int> leafs;

    cin >> n;
    for (int i = 1; i < n; ++i) {
        int parent, child, weight;
        cin >> parent >> child >> weight;
        tree[parent].push_back({child, weight});
        tree[child].push_back({parent, weight});
    }

    for (int i = 0; i < n; ++i) {
        if (tree[i].size() == 1) {
            leafs.push_back(i);
        }
    }

    int max = 0;
    for (auto iter = leafs.begin(); iter != leafs.end(); ++iter) {
        int distance = dfs(*iter, 0);
        if (distance > max) {
            max = distance;
        }
    }

    cout << max << endl;
    return 0;
}


Divide and Conquer를 사용한 풀이

#include <iostream>
#include <vector>
using namespace std;

struct edge_t {
    int leaf;
    int weight;
};

struct node_t {
    int parent;
    vector<edge_t> childs;
};

vector<node_t> tree;
int maxDistance;

int solution(int cur) {
    int sum = 0;
    int maxWeights[2] = {0, 0};
    vector<edge_t>& childs = tree[cur].childs;

    for (unsigned int i = 0; i < childs.size(); ++i) {
        if (childs[i].leaf != 0) {
            int weight = solution(childs[i].leaf) + childs[i].weight;

            if (weight > maxWeights[1]) {
                if (weight > maxWeights[0]) {
                    maxWeights[1] = maxWeights[0];
                    maxWeights[0] = weight;
                } else {
                    maxWeights[1] = weight;
                }

                sum = maxWeights[0] + maxWeights[1];
                if (sum > maxDistance) {
                    maxDistance = sum;
                }
            }
        }
    }
    return maxWeights[0];
}

int main() {
    int n;
    
    cin >> n;
    tree.resize(n + 1);

    for (int i = 1; i < n; ++i) {
        int parent, child, weight;
        cin >> parent >> child >> weight;
        tree[parent].childs.push_back({child, weight});
        tree[child].parent = parent;
    }

    solution(1);
    cout << maxDistance << endl;
    return 0;
}