PS/BOJ
백준 15480번: LCA와 쿼리 (C++)
도비(Doby)
2022. 6. 4. 09:55
https://www.acmicpc.net/problem/15480
15480번: LCA와 쿼리
첫째 줄에 정점의 개수 N(1 ≤ N ≤ 100,000)이 주어진다. 둘째 줄부터 N-1개의 줄에는 트리 T의 간선 정보 u와 v가 주어지다. u와 v는 트리의 간선을 나타내는 두 정점이다. 다음 줄에는 쿼리의 개수 M(
www.acmicpc.net
Solved By: LCA
쿼리마다 루트를 r로 만들어서 parent와 level을 초기화시키는 것은 시간 초과가 납니다.
3가지 케이스로 나누어서 풀 수 있습니다.
[root node를 1로 잡은 트리를 만들었을 때]
- u, v가 r의 자식 노드인 경우
- u 혹은 v 중 하나만 r의 자식 노드인 경우
- 둘 다 r의 자식 노드가 아닌 경우
1번의 경우 lca(u, v)가 답이 됩니다. 2번의 경우는 lca(u, r) 혹은 lca(v, u)가 답이 됩니다.
3번의 경우는 2가지로 나뉩니다.
- 1의 left child {u, r}, right child {v}인 경우
- 1의 left child {u, v, r}인 경우
- 1의 left child {r}, right child {u, v}인 경우
이 3가지 경우를 보았을 때, 모두 lca(u, v), lca(v, r), lca(u, r) 중 가장 level이 높은 노드가 답이 된다는 특징이 있습니다.
3번에서 나온 특징을 1번과 2번에 적용해보아도 답이 된다는 특징을 가집니다.
즉, 쿼리에 대한 답은 다음과 같이 볼 수 있습니다.
$$Answer Node = max(level[LCA(u, v)], level[LCA(u, r)], level[LCA(v, r)])$$
#include <iostream>
#include <vector>
#define MAX 100001
#define LOG_MAX 17
using namespace std;
int parent[MAX][LOG_MAX];
vector<int> adj[MAX];
int level[MAX];
int n, m;
void dfs(int now, int par){
for(int i = 0; i < adj[now].size(); i++){
int next = adj[now][i];
if(next == par) continue;
level[next] = level[now] + 1;
parent[next][0] = now;
dfs(next, now);
}
}
void swap(int* a, int* b){
int* temp = a;
a = b;
b = temp;
}
int lca(int a, int b){
if(level[a] < level[b]) swap(a, b);
int diff = level[a] - level[b];
for(int i = LOG_MAX - 1; i >= 0; i--){
if(diff >= 1 << i){
diff -= 1 << i;
a = parent[a][i];
}
}
if(a != b){
for(int i = LOG_MAX - 1; i >= 0; i--){
if(parent[a][i] != 0 && parent[a][i] != parent[b][i]){
a = parent[a][i];
b = parent[b][i];
}
}
a = parent[a][0];
}
return a;
}
int main(){
cin >> n;
for(int i = 0; i < n - 1; i++){
int a, b; cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
dfs(1, -1);
for(int j = 1; j < LOG_MAX; j++){
for(int i = 1; i <= n; i++){
parent[i][j] = parent[parent[i][j - 1]][j - 1];
}
}
vector<int> res;
cin >> m;
for(int i = 0; i < m; i++){
int r, u, v; cin >> r >> u >> v;
int lcaUV = lca(u, v), lcaUR = lca(u, r), lcaVR = lca(v, r);
int ans = level[lcaUV] > level[lcaUR] ? lcaUV : lcaUR;
ans = level[ans] > level[lcaVR] ? ans : lcaVR;
res.push_back(ans);
}
for(int i = 0; i < res.size(); i++){
cout << res[i] << '\n';
}
return 0;
}