Skip to content

Commit bde2d11

Browse files
feat: Added A star (A*) algorithm
1 parent 77b9f39 commit bde2d11

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

graph/a_star.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/**
2+
* @file
3+
* @brief Simple implementation of the [A* search
4+
* algorithm](https://en.wikipedia.org/wiki/A*_search_algorithm)
5+
*
6+
* @details A* is an informed search algorithm, which leverages a heuristic
7+
* function to estimate the cost from the current node to the goal. This
8+
* enables the algorithm to prioritise traversing edges that move you closer to
9+
* the goal. This can significantly reduce the time spent searching but, unlike
10+
* Dijkstra, doesn't compute all possible routes.
11+
*
12+
* @author [Jordan Hembrow](http://github.com/JordanHembrow5)
13+
*
14+
*/
15+
#include <cassert> /// For assert
16+
#include <cmath> /// For std::hypot, std::abs
17+
#include <iostream> /// For IO operations
18+
#include <queue> /// For std::priority_queue
19+
#include <vector> /// For std::vector
20+
21+
/**
22+
* @namespace graph
23+
* @brief Graph Algorithms
24+
*/
25+
26+
namespace graph {
27+
28+
class Node {
29+
public:
30+
Node(size_t idx, std::pair<int, int> pos, std::vector<size_t> conn = {}) {
31+
this->_idx = idx;
32+
this->_pos = pos;
33+
if (!conn.empty()) {
34+
for (size_t c : conn) {
35+
this->_connections.push_back(c);
36+
}
37+
}
38+
}
39+
40+
void add_connection(size_t conn) { this->_connections.push_back(conn); }
41+
size_t get_idx() { return this->_idx; }
42+
std::vector<size_t> get_connections() { return this->_connections; }
43+
std::pair<int, int> get_pos() { return this->_pos; }
44+
45+
private:
46+
size_t _idx;
47+
std::pair<int, int> _pos;
48+
std::vector<size_t> _connections;
49+
};
50+
51+
double heuristic_cost(std::pair<int, int> curr_pos,
52+
std::pair<int, int> end_pos) {
53+
return std::hypot(curr_pos.first - end_pos.first,
54+
curr_pos.second - end_pos.second);
55+
}
56+
57+
double traverse_cost(std::pair<int, int> curr_pos,
58+
std::pair<int, int> next_pos) {
59+
return std::hypot(curr_pos.first - next_pos.first,
60+
curr_pos.second - next_pos.second);
61+
}
62+
63+
double a_star(std::vector<Node> graph, size_t start_idx, size_t finish_idx) {
64+
if (start_idx == finish_idx) {
65+
return 0.0;
66+
}
67+
68+
// stores all the info required for our priority queue
69+
typedef struct {
70+
double heur_cost = 0.0;
71+
double curr_weight = 0.0;
72+
size_t node_idx = 0;
73+
} queue_info;
74+
75+
// Ensures our priority queue is sorted with the smallest cost at the top
76+
typedef struct {
77+
bool operator()(const queue_info l, const queue_info r) const {
78+
return (l.heur_cost + l.curr_weight) >
79+
(r.heur_cost + r.curr_weight);
80+
}
81+
} custom_less;
82+
83+
std::priority_queue<queue_info, std::vector<queue_info>, custom_less> pq;
84+
85+
// Start at the start point, with a total weight of zero
86+
queue_info q_info;
87+
q_info.node_idx = start_idx;
88+
pq.push(q_info);
89+
std::pair<int, int> end_pos = graph[finish_idx].get_pos();
90+
91+
while (!pq.empty()) {
92+
double curr_weight = pq.top().curr_weight;
93+
Node curr_node = graph[pq.top().node_idx];
94+
pq.pop(); // remove current node now we are exploring it
95+
for (const size_t &N_idx : curr_node.get_connections()) {
96+
double cost = curr_weight + traverse_cost(curr_node.get_pos(),
97+
graph[N_idx].get_pos());
98+
99+
if (N_idx == finish_idx) { // We found the finish
100+
return cost;
101+
}
102+
103+
queue_info q = {heuristic_cost(graph[N_idx].get_pos(), end_pos),
104+
cost, N_idx};
105+
pq.push(q);
106+
}
107+
}
108+
std::cout << "End point is not reachable from start point!" << std::endl;
109+
return -1;
110+
}
111+
112+
} // namespace graph
113+
114+
bool double_eq(double a, double b) { return std::abs(a - b) < 1e-4; }
115+
116+
void test() {
117+
std::vector<graph::Node> graph;
118+
graph::Node n0(0, {0, 0}, {1, 6}), n1(1, {5, 0}, {2}), n2(2, {5, 5}, {3}),
119+
n3(3, {10, 5}, {4}), n4(4, {10, 10}, {5}), n5(5, {11, 11}),
120+
n6(6, {0, 11}, {7}), n7(7, {16, 11}, {5});
121+
122+
graph.push_back(n0);
123+
graph.push_back(n1);
124+
graph.push_back(n2);
125+
graph.push_back(n3);
126+
graph.push_back(n4);
127+
graph.push_back(n5);
128+
graph.push_back(n6);
129+
graph.push_back(n7);
130+
131+
double shortest_dist = graph::a_star(graph, 0, 5);
132+
std::cout << "Test 1:\n"
133+
<< " Shortest distance: " << shortest_dist << std::endl;
134+
assert(double_eq(shortest_dist, 21.4142));
135+
136+
shortest_dist = graph::a_star(graph, 1, 1);
137+
std::cout << "Test 2:\n"
138+
<< " Shortest distance: " << shortest_dist << std::endl;
139+
assert(double_eq(shortest_dist, 0.0));
140+
std::cout << "\nTest is working correctly\n";
141+
}
142+
143+
int main() {
144+
test();
145+
return 0;
146+
}

0 commit comments

Comments
 (0)