A* 算法

定义

A* 搜索算法(A*search algorithm\text{A*search algorithm})是一种在图形平面上,对于有多个节点的路径求出最低通过成本的算法。它属于图遍历(英文:Graph traversal\text{Graph traversal})和最佳优先搜索算法(英文:Best-first search\text{Best-first search}),亦是 BFS 的优化,用到了启发式搜索的思维。

启发式搜索(英文:heuristic search\text{heuristic search})是一种在普通搜索算法的基础上引入了启发式函数的搜索算法。

启发式函数的作用是基于已有的信息对搜索的每一个分支选择都做估价,进而选择分支。简单来说,启发式搜索就是对取和不取都做分析,从中选取更优解或删去无效解。

过程

确定了起点 stst 和终点 eded,对于每个点 xx,计算出从 stst 开始的距离函数 g(x)g(x),到 eded 的预估代价距离函数 h(x)h(x) 和实际代价函数 h(x)h^*(x),不难推出每个点上的估价函数为:

f(x)=g(x)+h(x)f(x) = g(x) + h(x)

A* 算法的正确性在保证对于任意一点 xx,都有 hhh \le h^*

类似 Dijikstra 算法,A* 算法每一次从优先队列中取出一个 ff 值最小的元素并以此更新相邻的所有状态。

值得注意的是:当 h=0h = 0 时,A* 算法退化为 Dijkstra 算法,当 h=0h = 0 且边权为 11 时,会演变为 BFS。

K 短路

题目

按顺序求一个有向图上从结点 stst 到结点 eded 的所有路径最小的前任意多(不妨设为 kk)个。

解题思路

不难发现,这是很容易转化成 A* 算法的。

初始状态设为 stst,重点为 eded,距离函数 g(x)g(x) 表示从 ststxx 点的实际距离,估价函数 h(x)h(x) 表示从当前点到节点 eded 的估计距离。

通过一次反向的建图跑一边 Dijkstra,计算出终点 eded 到所有点的最短路,然后将初始状态依次加入到队列当中,每次取出 ff 值最小的一项,计算出相邻的所有点并全部加入到队列当中。当第 kk 次走到节点 eded 时,便是所得到的答案。

优化:当我们第 k+1k + 1 走到此节点时,直接跳过该状态。

参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Problem: 第K短路
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/180/
// Memory Limit: 64 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>

using namespace std;

// #define int long long

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;

const int N = 1e4 + 9;
const int INF = 0x3f3f3f3f;

int n, m, st, ed, k;
int dist[N], cnt[N];

struct Edge
{
int u, w; // 出点,权值
bool operator < (const Edge &u) const { return w > u.w; }
}; // 存边用的结构体

vector<Edge> g[N]; // 正向建图
vector<Edge> rg[N]; // 反向建图

struct Node
{
int f, g, idx; // 估价值 + 真实值,真实值,编号
bool operator<(const Node &u) const { return f > u.f; }
};

void dijkstra() // 求出终点到起点时扩展过的所有点的路(被用来当作估价函数,因为其值不会小于实际距离)
{
priority_queue<Edge> heap;
bitset<N> vis;
memset(dist, 0x3f, sizeof dist);

heap.push({ed, 0});
dist[ed] = 0;

while (!heap.empty())
{
auto t = heap.top(); heap.pop();

int u = t.u;
if (vis[u]) continue;
vis[u] = true;

for (auto y : rg[u])
{
int v = y.u, w = y.w;
if (dist[v] > dist[u] + w)
{
dist[v] = dist[u] + w;
heap.push({v, dist[v]});
}
}
}
}

int astar()
{
priority_queue<Node> heap;
heap.push({dist[st], 0, st}); // 初始点到终点的f值为dist[st] + 0,起点到自身的距离为 0

while (!heap.empty())
{
auto t = heap.top(); heap.pop();

int idx = t.idx, distance = t.g; // 取出编号和st点到该点的实际距离
cnt[idx] ++; // 小优化
if (cnt[ed] == k) return distance; // 如果终点遍历了 k 次,直接return掉

for (auto y : g[idx])
{
int v = y.u, w = y.w;
if (cnt[v] < k)
heap.push({distance + w + dist[v], distance + w, v}); // 该状态仍可扩展
}
}
return -1; // 无解时直接 return
}

signed main()
{
// ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= m; i ++)
{
int u, v, w; cin >> u >> v >> w;
g[u].push_back({v, w}); // 正向建图
rg[v].push_back({u, w}); // 反向建图
}

cin >> st >> ed >> k;
if (st == ed) k ++; // 题目要求,至少要有一条路

dijkstra();

// cout << dist[st] << '\n';
// if (dist[st] == INF) cout << -1 << '\n';
cout << astar() << '\n'; // 输出 k 短路

return 0;
}