PS/BOJ

백준 15480번: LCA와 쿼리 (C++)

도비(Doby) 2022. 6. 4. 09:55

https://www.acmicpc.net/problem/15480

 

15480번: LCA와 쿼리

첫째 줄에 정점의 개수 N(1 ≤ N ≤ 100,000)이 주어진다. 둘째 줄부터 N-1개의 줄에는 트리 T의 간선 정보 u와 v가 주어지다. u와 v는 트리의 간선을 나타내는 두 정점이다. 다음 줄에는 쿼리의 개수 M(

www.acmicpc.net


Solved By: LCA

 

쿼리마다 루트를 r로 만들어서 parent와 level을 초기화시키는 것은 시간 초과가 납니다.

3가지 케이스로 나누어서 풀 수 있습니다.

 

[root node를 1로 잡은 트리를 만들었을 때]

  1. u, v가 r의 자식 노드인 경우
  2. u 혹은 v 중 하나만 r의 자식 노드인 경우
  3. 둘 다 r의 자식 노드가 아닌 경우

1번의 경우 lca(u, v)가 답이 됩니다. 2번의 경우는 lca(u, r) 혹은 lca(v, u)가 답이 됩니다.

3번의 경우는 2가지로 나뉩니다.

  1. 1의 left child {u, r}, right child {v}인 경우
  2. 1의 left child {u, v, r}인 경우
  3. 1의 left child {r}, right child {u, v}인 경우

이 3가지 경우를 보았을 때, 모두 lca(u, v), lca(v, r), lca(u, r) 중 가장 level이 높은 노드가 답이 된다는 특징이 있습니다.

3번에서 나온 특징을 1번과 2번에 적용해보아도 답이 된다는 특징을 가집니다.

 

즉, 쿼리에 대한 답은 다음과 같이 볼 수 있습니다.

$$Answer  Node = max(level[LCA(u, v)], level[LCA(u, r)], level[LCA(v, r)])$$

 

#include <iostream>
#include <vector>
#define MAX 100001
#define LOG_MAX 17
using namespace std;

int parent[MAX][LOG_MAX];
vector<int> adj[MAX];
int level[MAX];

int n, m;

void dfs(int now, int par){
    for(int i = 0; i < adj[now].size(); i++){
        int next = adj[now][i];
        
        if(next == par) continue;
        
        level[next] = level[now] + 1;
        parent[next][0] = now;
        
        dfs(next, now);
    }
}

void swap(int* a, int* b){
    int* temp = a;
    a = b;
    b = temp;
}

int lca(int a, int b){
    if(level[a] < level[b]) swap(a, b);
    
    int diff = level[a] - level[b];
    
    for(int i = LOG_MAX - 1; i >= 0; i--){
        if(diff >= 1 << i){
            diff -= 1 << i;
            a = parent[a][i];
        }
    }
    
    if(a != b){
        for(int i = LOG_MAX - 1; i >= 0; i--){
            if(parent[a][i] != 0 && parent[a][i] != parent[b][i]){
                a = parent[a][i];
                b = parent[b][i];
            }
        }
        
        a = parent[a][0];
    }
    
    return a;
}

int main(){
    cin >> n;
    for(int i = 0; i < n - 1; i++){
        int a, b; cin >> a >> b;
        adj[a].push_back(b);
        adj[b].push_back(a);
    }
    
    dfs(1, -1);
    
    for(int j = 1; j < LOG_MAX; j++){
        for(int i = 1; i <= n; i++){
            parent[i][j] = parent[parent[i][j - 1]][j - 1];
        }
    }
    
    vector<int> res;
    cin >> m;
    
    for(int i = 0; i < m; i++){
        int r, u, v; cin >> r >> u >> v;
        int lcaUV = lca(u, v), lcaUR = lca(u, r), lcaVR = lca(v, r);
        int ans = level[lcaUV] > level[lcaUR] ? lcaUV : lcaUR;
        ans = level[ans] > level[lcaVR] ? ans : lcaVR;
        res.push_back(ans);
    }
    
    for(int i = 0; i < res.size(); i++){
        cout << res[i] << '\n';
    }
    
    return 0;
}