并查集概念引入
并查集用于解决一些元素分组的问题。它管理一系列不相交的集合,并支持两种操作:
合并:把两个不相交的集合合并为一个集合。
查询:查询两个元素是否在同一个集合中。
故事背景:大陆上有下面六位鼎鼎大名的忍者,且各自为王!
鸣人用嘴遁,轻而易举的让雷影和小李相信自己,所以就这样被鸣人收服了,以后就跟着鸣人混了!佐助用力量,直接收服了卡卡西,佩恩也十分欣赏佐助,所以他们两个就跟随佐助,一起搞事情!
此时卡卡西看雷影不爽,想用忍术收拾他,奈何发现自己好像打不过,于是向自己的老大佐助求助,让他去KO雷影,结果刚想出招,雷影果断喊出自己的老大鸣人,只见鸣人一招螺旋丸再加一招嘴遁,顺利收服了佐助,让他成为村里的情报人员。此时两大阵营合并为一个阵营,即佐助集合被合并到鸣人集合。
并查集思想
用树结构表示一个集合
代表元法:用树的根结点代表这个集合
树的存储结构:双亲表示法
程序设计
初始化
每个结点为一个集合
将每个结点的父结点设为其自身。
int p[N]; // p[i]表示结点i的父结点编号
void init(int n) //n:结点数量
{
for (int i = 1; i <= n; ++i)
p[i] = i;
}
查询
- 用树的根结点代表这个集合
// 返回x结点所在集合的根结点
int find(int x)
{
if(p[x] == x)// 如果x是根结点,返回x
return x;
// 如果不是,返回x父结点所在集合的根结点
return find(p[x]);
}
合并
将不属于同一个集合的两个集合进行合并。
// 合并x结点和y结点所在的集合
void merge(int x, int y)
{
// 将x的代表的父结点设为y的代表节点
int px = find(x), py = find(y);
if(px != py) p[px] = py;
}
路径压缩
形成长链问题
动画演示
递归记忆化存储写法
int find(int x)
{
if(x != p[x])
p[x] = find(p[x]);
return p[x];
}
非递归写法
int find(int x)
{
int k, t, r; // r:根结点
k = r = x;
while(r != p[r]) // 查找根结点
r = p[r]; // 找到根结点,用r记录
while(k != r) // 将从x到r的整条路径上的结点的父结点都设为r
{
t = p[k]; // 用t暂存k的父结点
p[k] = r; // p[k]指向根结点
k = t; // k指向暂存的父结点
}
return r; // 返回根结点
}
按秩合并
应该将谁的父结点设为谁?
答:合并后树的深度应尽可能小
秩(rank):树的高度
按秩合并:把高度较小的树的根结点连接到高度较大的树的根结点上。
int rk[N]; // rk[i]:以i结点为根结点的树的高度
void merge(int x, int y) //合并i结点和j结点所在的集合
{
int px = find(x), py = find(y); // 先找到两个根结点
if(px == py) return; // 根结点相同,就不用合并了
if(rk[x] < rk[y]) // 高度低的树作为高度高的树的子树
p[px] = py;
else if(rk[x] > rk[y])
p[py] = px;
else // 如果高度相同,则新的树的高度+1
{
p[px] = py;
rk[y]++;
}
}
两种优化的比较
小试牛刀
家庭问题
参考程序
#include <iostream>
using namespace std;
const int N = 110;
int n, m;
int p[N]; // p[i] i 所在集合的 代表
int ans1 = 0, ans2 = 0;
int cnt[N];
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
int main(){
cin >> n >> m;
for(int i = 1; i <= n; i++) p[i] = i;
for(int i = 1; i <= m; i++){
int x, y;
cin >> x >> y;
int px = find(x), py = find(y);
if(px != py) p[px] = py;
}
for(int i = 1; i <= n; i++)
if(p[i] == i)
ans1++; // 团体的个数
for(int i = 1; i <= n; i++)
cnt[find(i)]++;
for(int i = 1; i <= n; i++)
ans2 = max(ans2, cnt[i]);
cout << ans1 << " " << ans2 << endl;
return 0;
}
亲戚
参考程序
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 1e5 + 10;
int n, m;
int p[N], s[N];
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
void merge(int a, int b){
int pa = find(a), pb = find(b);
if(pa != pb){
p[pb] = pa;
s[pa] += s[pb];
}
}
int main() {
cin >> n >> m;
for(int i = 1; i <= n; i++){
p[i] = i;
s[i] = 1;
}
while(m--){
char op[2];
int a, b;
cin >> op;
if(*op == 'M'){
cin >> a >> b;
merge(a, b);
}
else{
cin >> a;
cout << s[find(a)] << endl;
}
}
return 0;
}
团伙
参考程序
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 1010;
int n, m;
int p[N], e[N];
int res = 0;
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
void merge(int x, int y){
if(find(x) != find(y)) p[find(x)] = y;
}
int main() {
cin >> n >> m;
for(int i = 1; i <= n; i++) p[i] = i;
while(m--){
int k, x, y;
cin >> k >> x >> y;
if(!k) merge(x, y);
else{
if(!e[x]) e[x] = y;
if(!e[y]) e[y] = x;
merge(p[y], e[x]);
merge(p[x], e[y]);
}
}
for(int i = 1; i <= n; i++)
if(p[i] == i)
res++;
cout << res << endl;
return 0;
}
格子游戏
参考程序
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 40010;
int n, m;
int p[N];
int get(int x, int y) {
return x * n + y;
}
// 路径压缩
int find(int x) {
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
int main() {
cin >> n >> m;
for (int i = 0; i < n * n; i++) p[i] = i;
int ans = 0;
for (int i = 1; i <= m; i++) {
int x, y;
char d;
cin >> x >> y >> d;
x--, y--;
int a = get(x, y);
int b;
if (d == 'D') b = get(x + 1, y);
else b = get(x, y + 1);
int pa = find(a), pb = find(b);
if (pa == pb) {
ans = i;
break;
}
else p[pb] = pa;
}
if (ans) cout << ans << endl;
else puts("draw");
return 0;
}
打击犯罪
参考程序
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
const int N = 1010;
int n, m;
int p[N], s[N];
vector<int> q[N];
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
void merge(int x, int y){
p[y] = x;
s[x] += s[y];
}
int main() {
cin >> n;
for(int i = 1; i <= n; i++){
int cnt;
cin >> cnt;
while(cnt--){
int x;
cin >> x;
q[i].push_back(x);
}
}
for(int i = 1; i <= n; i++){
p[i] = i;
s[i] = 1;
}
for(int i = n; i >= 1; i--){
for(int j = 0; j < q[i].size(); j++){
if(q[i][j] > i){
if(find(i) != find(q[i][j])) merge(find(i), find(q[i][j]));
else continue;
if(s[i] > n / 2){
cout << i << endl;
return 0;
}
}
}
}
return 0;
}
搭配购买
参考程序
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 1e4 + 10;
int n, m, V;
int p[N], v[N], w[N], f[N];
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
void merge(int a, int b){
int pa = find(a), pb = find(b);
if(pa != pb){
p[pb] = pa;
v[pa] += v[pb];
w[pa] += w[pb];
}
}
int main() {
cin >> n >> m >> V;
for(int i = 1; i <= n; i++) p[i] = i;
for(int i = 1; i <= n; i++) cin >> v[i] >> w[i];
while(m--){
int a, b;
cin >> a >> b;
merge(a, b);
}
for(int i = 1; i <= n; i++){
if(p[i] != i) continue;
for(int j = V; j >= v[i]; j--)
f[j] = max(f[j], f[j - v[i]] + w[i]);
}
cout << f[V] << endl;
return 0;
}