参考《算法竞赛进阶指南》、AcWing题库

A*

在探讨 AA^{\ast} 算法之前, 我们先来回顾一下优先队列 BFS 算法。该算法维护了一个优先队列 (二叉堆), 不断从堆中取出 “当前代价最小” 的状态 (堆顶) 进行扩展。每个状态第一次从堆中被取出时, 就得到了从初态到该状态的最小代价。

如果给定一个 “目标状态”, 需要求出从初态到目标状态的最小代价, 那么优先队列 BFS 的这个 “优先策略” 显然是不完善的。一个状态的当前代价最小, 只能说明从起始状态到该状态的代价很小, 而在未来的搜索中, 从该状态到目标状态可能会花费很大的代价。另外一些状态虽然当前代价略大, 但是未来到目标状态的代价可能会很小, 于是从起始状态到目标状态的总代价反而更优。优先队列 BFS 会优先选择前者的分支, 导致求出最优解的搜索量增大。比如在优先队列 BFS 的示意图中, 产生最优解的搜索路径 (5+2+1)(5+2+1) 的后半部分就很晩才得以扩展。

为了提高搜索效率, 我们很自然地想到, 可以对未来可能产生的代价进行预估。详细地讲, 我们设计一个 “估价函数”, 以任意 “状态” 为输入, 计算出从该状态到目标状态所需代价的估计值。在搜索中, 仍然维护一个堆, 不断从堆中取出 “当前代价+未来估价” 最小的状态进行扩展。

为了保证第一次从堆中取出目标状态时得到的就是最优解, 我们设计的估价函数需要满足一个基本准则:

  • 设当前状态 state 到目标状态所需代价的估计值为 f\mathrm{f} (state)。
  • 设在未来的搜索中, 实际求出的从当前状态 state 到目标状态的最小代价为 g\mathrm{g} (state)。
  • 对于任意的 state, 应该有 f\mathrm{f} (state) g\leq \mathrm{g} (state)。

也就是说, 估价函数的估值不能大于未来实际代价, 估价比实际代价更优。为什么要遵守这个准则呢? 我们不妨看看, 如果某些估值大于未来实际代价, 那么将发生什么情况。

假设我们的估价函数 f\mathrm{f} 在每个状态上的值如下页图所示:

根据 “每次取出当前代价+未来估价最小的状态” 的策略, 会得到如下过程:

我们可以看到, 本来在最优解搜索路径上的状态被错误地估计了较大的代价, 被压在堆中无法取出, 从而导致非最优解搜索路径上的状态不断扩展, 直至在目标状态上产生错误的答案。

如果我们设计估价函数时遵守上述准则, 保证估值不大于未来实际代价, 那么即使估价不太准确, 导致非最优解搜索路径上的状态 ss 先被扩展, 但是随着 “当前代价” 的不断累加, 在目标状态被取出之前的某个时刻:

  1. 根据 ss 并非最优, ss 的 “当前代价” 就会大于从起始状态到目标状态的最小代价。
  2. 对于最优解搜索路径上的状态 tt, 因为 f(t)g(t)\mathrm{f}(t) \leq \mathrm{g}(t), 所以 tt 的 “当前代价” 加上 f(t)\mathrm{f}(t) 必定小于等于 tt 的 “当前代价” 加上 g(t)\mathrm{g}(t), 而后者的含义就是从起始状态到目标状态的最小代价。

结合以上两点, 可知 “ tt 的当前代价加上 f(t)\mathrm{f}(t) ” 小于 ss 的当前代价。因此, tt 就会被从堆中取出进行扩展, 最终更新到目标状态上, 产生最优解。

这种带有估价函数的优先队列 BFS 就称为 AA^{\ast} 算法。只要保证对于任意状态 state, 都有 f(state)g(state)\mathrm{f}( state ) \leq \mathrm{g} (state) , AA^{\ast} 算法就一定能目标状态第一次从堆中被取出时得到最优解。 (~~并且在搜索过程中每个状态只需要被扩展一次 (之后再被取出就可以直接忽略)~~中间点可能被扩展(入队)多次,中间点第一次出队不一定是最短距离,只有终点第一次出队可保证是最短距离)。估价 f\mathrm{f} (state) 越准确、越接近 g\mathrm{g} (state), A\mathrm{A}^{\ast} 算法的效率就越高。如果估价始终为 0 , 就等价于普通的优先队列 BFS。

A\mathrm{A}^{*} 算法提高搜索效率的关键, 就在于能否设计出一个优秀的估价函数。估价函数在满足上述设计准则的前提下, 还应该尽可能反映未来实际代价的变化趋势和相对大小关系, 这样搜索才会较快地逼近最优解。接下来我们通过两道例题来具体地感受一下 AA^{*} 算法。


178. 第K短路

题目描述

给定一张 NN 个点(编号 1,2N1,2…N),MM 条边的有向图,求从起点 SS 到终点 TT 的第 KK 短路的长度,路径允许重复经过点或边。

注意: 每条最短路中至少要包含一条边。

输入格式

第一行包含两个整数 NNMM

接下来 MM 行,每行包含三个整数 A,BA,BLL,表示点 AA 与点 BB 之间存在有向边,且边长为 LL

最后一行包含三个整数 S,TS,TKK,分别表示起点 SS,终点 TT 和第 KK 短路。

输出格式

输出占一行,包含一个整数,表示第 KK 短路的长度,如果第 KK 短路不存在,则输出 1-1

数据范围

1S,TN10001 \le S,T \le N \le 1000,
0M1050 \le M \le 10^5,
1K10001 \le K \le 1000,
1L1001 \le L \le 100

输入样例:

1
2
3
4
2 2
1 2 5
2 1 4
1 2 2

输出样例:

1
14 

算法分析

一个比较直接的想法是使用优先队列 BFS 进行求解。优先队列 (堆) 中保存一些二元组 (x,dist)(x, d i s t), 其中 xx 为节点编号, dist 表示从 SS 沿着某条路径到 xx 的距离。

起初, 堆中只有 (S,0)(S, 0) 。我们不断从堆中取出 dist 值最小的二元组 (x,dist)(x, d i s t), 然后沿着从 xx 出发的每条边 (x,y)(x, y) 进行扩展, 把新的二元组 (y,dist+length(x,y))(y, d i s t+l e n g t h(x, y)) 插入到堆中 (无论堆中是否已经存在一个节点编号为 yy 的二元组)。

上一节我们已经讲到, 在优先队列 BFS 中, 某个状态第一次从堆中被取出时, 就得到了从初态到它的最小代价。读者用数学归纳法很容易得到一个推论: 对于任意正整数 ii 和任意节点 xx, 当第 ii 次从堆中取出包含节点 xx 的二元组时, 对应的 dist 值就是从 SSxx 的第 ii 短路。(所以, 当扩展到的节点 yy 已经被取出 KK 次时, 就没有必要再插入堆中了。注意:如果使用 AA^{\ast} 算法,加了估价函数后,中间的第一次出队不一定是最短距离,故不能只用取出 KK 次进行优化)。最后当节点 TTKK 次被取出时, 就得到了 SSTT 的第 KK 短路。

使用优先队列 BFS\mathrm{BFS} 在最坏情况下的复杂度为 O(K(N+M)log(N+M))O(K *(N+M) * \log (N+M)) 。这道题目给定了起点和终点, 求长度最短 (代价最小) 的路径, 可以考虑使用 A\mathrm{A}^{*} 算法提高搜索效率。

根据估价函数的设计准则, 在第 KK 短路中从 xxTT 的估计距离 f(x)\mathrm{f}(x) 应该不大于第 KK 短路中从 xxTT 的实际距离 g(x)\mathrm{g}(x) 。于是, 我们可以把估价函数 f(x)\mathrm{f}(x) 定为从 xxT\mathrm{T} 的最短路长度, 这样不但能保证 f(x)g(x)\mathrm{f}(x) \leq \mathrm{g}(x), 还能顺应 g(x)\mathrm{g}(x) 的实际变化趋势。 最终我们得到了以下 AA^{*} 算法:

  1. 预处理出各个节点 xx 到终点 TT 的最短路长度 f(x)\mathrm{f}(x) 一一这等价于在反向图上以 TT 为起点求解单源最短路径问题, 可以在 O((N+M)log(N+M))\mathrm{O}((N+M) \log (N+M)) 的时间内完成。
  2. 建立一个二叉堆, 存储一些二元组 (x,dist+f(x))(x, dist +\mathrm{f}(x)), 其中 xx 为节点编号, dist 表示从 SSxx 当前走过的距离。起初堆中只有 (S,0+f(0))(S, 0+\mathrm{f}(0))
  3. 从二叉堆中取出 dist+f(x)dist +\mathrm{f}(x) 值最小的二元组 (x,dist+f(x))(x, dist +\mathrm{f}(x)), 然后沿着从 xx 出发的每条边 (x,y)(x, y) 进行扩展。如果节点 yy 被取出的次数尚未达到 KK , 就把新的二元组 (y,dist+length(x,y)+f(y))(y,dist + length (x, y)+\mathrm{f}(y)) 插入堆中。
  4. 重复第 232 \sim 3 步, 直至第 KK 次取出包含终点 TT 的二元组, 此时二元组中的 dist 值就是从 SSTT 的第 KK 短路。

AA^{*} 算法的复杂度上界与优先队列 BFSB F S 相同。不过, 因为估价函数的作用, 图中很多节点访问次数都远小于 KK, 上述 AA^{*} 算法已经能够比较快速地求出结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
//Author:XuHt
#include <queue>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
using namespace std;
const int N = 1006;
int n, m, st, ed, k, f[N], cnt[N];
bool v[N];
vector<pair<int, int> > e[N], fe[N];
priority_queue<pair<int, int> > pq;

void dijkstra() {
memset(f, 0x3f, sizeof(f));
memset(v, 0, sizeof(v));
f[ed] = 0;
pq.push(make_pair(0, ed));
while (pq.size()) {
int x = pq.top().second;
pq.pop();
if (v[x]) continue;
v[x] = 1;
for (unsigned int i = 0; i < fe[x].size(); i++) {
int y = fe[x][i].first, z = fe[x][i].second;
if (f[y] > f[x] + z) {
f[y] = f[x] + z;
pq.push(make_pair(-f[y], y));
}
}
}
}

void A_star() {
if (st == ed) ++k;
pq.push(make_pair(-f[st], st));
memset(cnt, 0, sizeof(cnt));
while (pq.size()) {
int x = pq.top().second;
int dist = -pq.top().first - f[x];
pq.pop();
++cnt[x];
if (cnt[ed] == k) {
cout << dist << endl;
return;
}
for (unsigned int i = 0; i < e[x].size(); i++) {
int y = e[x][i].first, z = e[x][i].second;
if (cnt[y] != k) pq.push(make_pair(-f[y] - dist - z, y));
}
}
cout << "-1" << endl;
}

int main() {
cin >> n >> m;
for (int i = 1; i <= m; i++) {
int x, y, z;
scanf("%d %d %d", &x, &y, &z);
e[x].push_back(make_pair(y, z));
fe[y].push_back(make_pair(x, z));
}
cin >> st >> ed >> k;
dijkstra();
A_star();
return 0;
}

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;

using PII = pair<int, int>;
using PIII = pair<int, PII>;
const int N = 1010, M = 2e5 + 10, INF = 0x3f3f3f3f;

struct Node{
int deva = INF, dreal, id;
bool operator < (const Node &x) const {
return deva + dreal > x.deva + x.dreal;
}
}node[M];



int n, m, st, ed, k;
int h[N], rh[N], e[M], w[M], ne[M], idx;
int dist[N];
bool used[N];

void add(int h[], int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}


void dijkstra() {
priority_queue<PII, vector<PII>, greater<PII>> heap;

node[ed].deva = 0;
heap.push({0, ed});

while (heap.size()) {
auto now = heap.top(); heap.pop();

int ver = now.second;
if (used[ver]) continue;
used[ver] = true;

for (int i = rh[ver]; ~i; i = ne[i]) {
int adj = e[i];
if (node[adj].deva > node[ver].deva + w[i]) {
node[adj].deva = node[ver].deva + w[i];
heap.push({node[adj].deva, adj});
}
}
}
}

int astar() {
if (node[st].deva == INF) return -1;

priority_queue<Node> heap;
heap.push({node[st].deva, 0, st});

int cnt = 0;
while (heap.size()) {
auto now = heap.top(); heap.pop();

if (now.id == ed) ++cnt;
if (cnt == k) return now.dreal;

for (int i = h[now.id]; ~i; i = ne[i]) {
int adj = e[i];
heap.push({node[adj].deva, now.dreal + w[i], adj});
}
}
return -1;
}


int main() {
cin >> n >> m;

memset(h, -1, sizeof h);
memset(rh, -1, sizeof rh);

for (int i = 1; i <= m; ++i) {
int a, b, c;
cin >> a >> b >> c;
add(h, a, b, c);
add(rh, b, a, c);
}

cin >> st >> ed >> k;
if (st == ed) ++k;

dijkstra();

cout << astar() << endl;

}

179. 八数码

题目描述

在一个 3×33×3 的网格中,181 \sim 888 个数字和一个 X 恰好不重不漏地分布在这 3×33×3 的网格中。

例如:

1
2
3
1 2 3
X 4 6
7 5 8

在游戏过程中,可以把 X 与其上、下、左、右四个方向之一的数字交换(如果存在)。

我们的目的是通过交换,使得网格变为如下排列(称为正确排列):

1
2
3
1 2 3
4 5 6
7 8 X

例如,示例中图形就可以通过让 X 先后与右、下、右三个方向的数字交换成功得到正确排列。

交换过程如下:

1
2
3
1 2 3   1 2 3   1 2 3   1 2 3
X 4 6 4 X 6 4 5 6 4 5 6
7 5 8 7 5 8 7 X 8 7 8 X

X 与上下左右方向数字交换的行动记录为 udlr

现在,给你一个初始网格,请你通过最少的移动次数,得到正确排列。

输入格式

输入占一行,将 3×33×3 的初始网格描绘出来。

例如,如果初始网格如下所示:

1
2
3
1 2 3 
x 4 6
7 5 8

则输入为:1 2 3 x 4 6 7 5 8

输出格式

输出占一行,包含一个字符串,表示得到正确排列的完整行动记录。

如果答案不唯一,输出任意一种合法方案即可。

如果不存在解决方案,则输出 unsolvable

输入样例:

1
2  3  4  1  5  x  7  6  8 

输出样例

1
ullddrurdllurdruldr 

算法分析

先进行可解性判定。我们在排序-逆序对中已经提到过, 把除空格之外的所有数字排成一个序列, 求出该序列的逆序对数。如果初态和终态的逆序对数奇偶性相同, 那么这两个状态互相可达, 否则一定不可达。

若问题有解, 我们就采用 AA^{*} 算法搜索一种移动步数最少的方案。经过观察可以发现, 每次移动只能把一个数字与空格交换位置, 这样至多把一个数字向它在目标状态中的位置移近一步。即使每一步移动都是有意义的, 从任何一个状态到目标状态的移动步数也不可能小于所有数字当前位置与目标位置的曼哈顿距离之和。

于是, 对于任意状态 state, 我们可以把估价函数设计为所有数字在 state 中的位置与目标状态 end 中的位置的曼哈顿距离之和, 即:

f(state)=num=19(state_xnumend_xnum+state_ynumend_ynum)f(state)=\sum_{num=1}^{9}\left(\mid state\_x_{n u m}-end\_ x_{num}|+| state\_y_{num} - end\_y_{num} \mid\right)

其中 state_xnumstate\_x_{num} 表示在状态 state 下数字 num 的行号, state_ynumstate\_y_{num} 为列号。 我们不断从堆中取出“从初态到当前状态 state 已经移动的步数+f(state)” 最小的状态进行扩展, 当终态第一次被从堆中取出时, 就得到了答案。

AA^{*} 算法中, 为了保证效率, 每个状态只需要在第一次被取出时扩展一次。本题中的状态是一个八数码, 并非一个简单的节点编号, 所以需要使用 Hash 来记录每个八数码是否已经被取出并扩展过一次。我们可以选择取模配合开散列处理冲突, 或 STL map 等 Hash 方法。另外, 有一种名为 “康托展开” 的 Hash 方法, 可以对全排列进行编码和解码, 把 1N1 \sim N 排成的序列与 1N!1 \sim N ! 之间的整数建立一一映射关系, 非常适合八数码的 Hash, 感兴趣的读者可以自行查阅相关资料。

下面的参考程序实现了一个最基本的 AA^{*} 算法。虽然它没有用 “逆序对” 的结论判断是否有解, 也没有用高效的线性 Hash 方法 (直接采用了 STL map), 但足以在规定的时限内完成求解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<map>
using namespace std;
// state:八数码的状态(3*3九宫格压缩为一个整数)
// dist:当前代价 + 估价
struct rec{int state,dist;
rec(){}
rec(int s,int d){state=s,dist=d;}
};
int a[3][3];

map<int,int> d,f,go;
priority_queue<rec> q;
const int dx[4]={-1,0,0,1},dy[4]={0,-1,1,0};
char dir[4]={'u','l','r','d'};

bool operator <(rec a,rec b) {
return a.dist>b.dist;
}

// 把3*3的九宫格压缩为一个整数(9进制)
int calc(int a[3][3]) {
int val=0;
for(int i=0;i<3;i++)
for(int j=0;j<3;j++) {
val=val*9+a[i][j];
}
return val;
}

// 从一个9进制数复原出3*3的九宫格,以及空格位置
pair<int,int> recover(int val,int a[3][3]) {
int x,y;
for(int i=2;i>=0;i--)
for(int j=2;j>=0;j--) {
a[i][j]=val%9;
val/=9;
if(a[i][j]==0) x=i,y=j;
}
return make_pair(x,y);
}

// 计算估价函数
int value(int a[3][3]) {
int val=0;
for(int i=0;i<3;i++)
for(int j=0;j<3;j++) {
if(a[i][j]==0) continue;
int x=(a[i][j]-1)/3;
int y=(a[i][j]-1)%3;
val+=abs(i-x)+abs(j-y);
}
return val;
}

// A*算法
int astar(int sx,int sy,int e) {
d.clear(); f.clear(); go.clear();
while(q.size()) q.pop();
int start=calc(a);
d[start]=0;
q.push(rec(start,0+value(a)));
while(q.size()) {
// 取出堆顶
int now=q.top().state; q.pop();
// 第一次取出目标状态时,得到答案
if(now==e) return d[now];
int a[3][3];
// 复原九宫格
pair<int,int> space=recover(now,a);
int x=space.first,y=space.second;
// 枚举空格的移动方向(上下左右)
for(int i=0;i<4;i++) {
int nx=x+dx[i], ny=y+dy[i];
if (nx<0||nx>2||ny<0||ny>2) continue;
swap(a[x][y],a[nx][ny]);
int next=calc(a);
// next状态没有访问过,或者能被更新
if(d.find(next)==d.end()||d[next]>d[now]+1) {
d[next]=d[now]+1;
// f和go记录移动的路线,以便输出方案
f[next]=now;
go[next]=i;
// 入堆
q.push(rec(next,d[next]+value(a)));
}
swap(a[x][y],a[nx][ny]);
}
}
return -1;
}

void print(int e) {
if(f.find(e)==f.end()) return;
print(f[e]);
putchar(dir[go[e]]);
}

int main() {
int end=0;
for(int i=1;i<=8;i++) end=end*9+i;
end*=9;
int x,y;
for(int i=0;i<3;i++)
for(int j=0;j<3;j++) {
char str[2];
scanf("%s",str);
if(str[0]=='x') a[i][j]=0,x=i,y=j;
else a[i][j]=str[0]-'0';
}
int ans=astar(x,y,end);
if(ans==-1) puts("unsolvable"); else print(end);
}

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#include <unordered_map>
using namespace std;

const int dx[4] = {-1, 0, 1, 0}, dy[4] = {0, 1, 0, -1};
const char dir[4] = {'u', 'r', 'd', 'l'};

string st, ed = "12345678x";
int n = 3;

struct Node{
int d, pos;
string s, op;
bool operator<(const Node&x) const {
return d > x.d;
}
};

int calc(string &s) {
int md = 0;
for (int i = 0; i < s.size(); ++i) {
if (s[i] == 'x') continue;
int num = s[i] - '1';
md += abs(num / 3 - i / 3) + abs(num % 3 - i % 3);
}
return md;
}

bool valid(int x, int y) {
return x >= 0 && x < n && y >= 0 && y < n;
}

string astar(string &st, string &ed) {
unordered_map<string, int> dist;
priority_queue<Node> q;

int pos = 0;
while (st[pos] != 'x') ++pos;
dist[st] = 0;
q.push({calc(st), pos, st, ""});

while (q.size()) {
auto now = q.top(); q.pop();

if (now.s == ed) return now.op;

for (int i = 0; i < 4; ++i) {
auto s = now.s;
auto pos = now.pos;
int x = pos / 3 + dx[i], y = pos % 3 + dy[i], npos = x * 3 + y;
if (!valid(x, y)) continue;
swap(s[pos], s[npos]);
if (!dist.count(s) || dist[s] > dist[now.s] + 1) {
dist[s] = dist[now.s] + 1;
q.push({dist[s] + calc(s), npos, s, now.op + dir[i]});
}
}
}
}

int main() {
string seq;
for (int i = 0; i < 9; ++i) {
char c;
cin >> c;
st += c;
if (c != 'x') seq += c;
}

int cnt = 0;
for (int i = 0; i < 8; ++i)
for (int j = i + 1; j < 8; ++j)
if (seq[i] > seq[j]) ++cnt;

if (cnt % 2) cout << "unsolvable";
else cout << astar(st, ed);
}