1+ /* *
2+ * @file
3+ * @brief Simple implementation of the [A* search
4+ * algorithm](https://en.wikipedia.org/wiki/A*_search_algorithm)
5+ *
6+ * @details A* is an informed search algorithm, which leverages a heuristic
7+ * function to estimate the cost from the current node to the goal. This
8+ * enables the algorithm to prioritise traversing edges that move you closer to
9+ * the goal. This can significantly reduce the time spent searching but, unlike
10+ * Dijkstra, doesn't compute all possible routes.
11+ *
12+ * @author [Jordan Hembrow](http://github.com/JordanHembrow5)
13+ *
14+ */
15+ #include < cassert> // / For assert
16+ #include < cmath> // / For std::hypot, std::abs
17+ #include < iostream> // / For IO operations
18+ #include < queue> // / For std::priority_queue
19+ #include < vector> // / For std::vector
20+
21+ /* *
22+ * @namespace graph
23+ * @brief Graph Algorithms
24+ */
25+
26+ namespace graph {
27+
28+ class Node {
29+ public:
30+ Node (size_t idx, std::pair<int , int > pos, std::vector<size_t > conn = {}) {
31+ this ->_idx = idx;
32+ this ->_pos = pos;
33+ if (!conn.empty ()) {
34+ for (size_t c : conn) {
35+ this ->_connections .push_back (c);
36+ }
37+ }
38+ }
39+
40+ void add_connection (size_t conn) { this ->_connections .push_back (conn); }
41+ size_t get_idx () { return this ->_idx ; }
42+ std::vector<size_t > get_connections () { return this ->_connections ; }
43+ std::pair<int , int > get_pos () { return this ->_pos ; }
44+
45+ private:
46+ size_t _idx;
47+ std::pair<int , int > _pos;
48+ std::vector<size_t > _connections;
49+ };
50+
51+ double heuristic_cost (std::pair<int , int > curr_pos,
52+ std::pair<int , int > end_pos) {
53+ return std::hypot (curr_pos.first - end_pos.first ,
54+ curr_pos.second - end_pos.second );
55+ }
56+
57+ double traverse_cost (std::pair<int , int > curr_pos,
58+ std::pair<int , int > next_pos) {
59+ return std::hypot (curr_pos.first - next_pos.first ,
60+ curr_pos.second - next_pos.second );
61+ }
62+
63+ double a_star (std::vector<Node> graph, size_t start_idx, size_t finish_idx) {
64+ if (start_idx == finish_idx) {
65+ return 0.0 ;
66+ }
67+
68+ // stores all the info required for our priority queue
69+ typedef struct {
70+ double heur_cost = 0.0 ;
71+ double curr_weight = 0.0 ;
72+ size_t node_idx = 0 ;
73+ } queue_info;
74+
75+ // Ensures our priority queue is sorted with the smallest cost at the top
76+ typedef struct {
77+ bool operator ()(const queue_info l, const queue_info r) const {
78+ return (l.heur_cost + l.curr_weight ) >
79+ (r.heur_cost + r.curr_weight );
80+ }
81+ } custom_less;
82+
83+ std::priority_queue<queue_info, std::vector<queue_info>, custom_less> pq;
84+
85+ // Start at the start point, with a total weight of zero
86+ queue_info q_info;
87+ q_info.node_idx = start_idx;
88+ pq.push (q_info);
89+ std::pair<int , int > end_pos = graph[finish_idx].get_pos ();
90+
91+ while (!pq.empty ()) {
92+ double curr_weight = pq.top ().curr_weight ;
93+ Node curr_node = graph[pq.top ().node_idx ];
94+ pq.pop (); // remove current node now we are exploring it
95+ for (const size_t &N_idx : curr_node.get_connections ()) {
96+ double cost = curr_weight + traverse_cost (curr_node.get_pos (),
97+ graph[N_idx].get_pos ());
98+
99+ if (N_idx == finish_idx) { // We found the finish
100+ return cost;
101+ }
102+
103+ queue_info q = {heuristic_cost (graph[N_idx].get_pos (), end_pos),
104+ cost, N_idx};
105+ pq.push (q);
106+ }
107+ }
108+ std::cout << " End point is not reachable from start point!" << std::endl;
109+ return -1 ;
110+ }
111+
112+ } // namespace graph
113+
114+ bool double_eq (double a, double b) { return std::abs (a - b) < 1e-4 ; }
115+
116+ void test () {
117+ std::vector<graph::Node> graph;
118+ graph::Node n0 (0 , {0 , 0 }, {1 , 6 }), n1 (1 , {5 , 0 }, {2 }), n2 (2 , {5 , 5 }, {3 }),
119+ n3 (3 , {10 , 5 }, {4 }), n4 (4 , {10 , 10 }, {5 }), n5 (5 , {11 , 11 }),
120+ n6 (6 , {0 , 11 }, {7 }), n7 (7 , {16 , 11 }, {5 });
121+
122+ graph.push_back (n0);
123+ graph.push_back (n1);
124+ graph.push_back (n2);
125+ graph.push_back (n3);
126+ graph.push_back (n4);
127+ graph.push_back (n5);
128+ graph.push_back (n6);
129+ graph.push_back (n7);
130+
131+ double shortest_dist = graph::a_star (graph, 0 , 5 );
132+ std::cout << " Test 1:\n "
133+ << " Shortest distance: " << shortest_dist << std::endl;
134+ assert (double_eq (shortest_dist, 21.4142 ));
135+
136+ shortest_dist = graph::a_star (graph, 1 , 1 );
137+ std::cout << " Test 2:\n "
138+ << " Shortest distance: " << shortest_dist << std::endl;
139+ assert (double_eq (shortest_dist, 0.0 ));
140+ std::cout << " \n Test is working correctly\n " ;
141+ }
142+
143+ int main () {
144+ test ();
145+ return 0 ;
146+ }
0 commit comments