Skip to content

Commit 77c3e5b

Browse files
authored
Added A* algorithm (TheAlgorithms#1913)
* a* algorithm * changes after build error * intent changes * fix after review * ImportMissmatchError * Build failed fix * doctest changes * doctest changes
1 parent eef6393 commit 77c3e5b

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed

machine_learning/astar.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import numpy as np
2+
3+
'''
4+
The A* algorithm combines features of uniform-cost search and pure
5+
heuristic search to efficiently compute optimal solutions.
6+
A* algorithm is a best-first search algorithm in which the cost
7+
associated with a node is f(n) = g(n) + h(n),
8+
where g(n) is the cost of the path from the initial state to node n and
9+
h(n) is the heuristic estimate or the cost or a path
10+
from node n to a goal.A* algorithm introduces a heuristic into a
11+
regular graph-searching algorithm,
12+
essentially planning ahead at each step so a more optimal decision
13+
is made.A* also known as the algorithm with brains
14+
'''
15+
16+
17+
class Cell(object):
18+
'''
19+
Class cell represents a cell in the world which have the property
20+
position : The position of the represented by tupleof x and y
21+
co-ordinates initially set to (0,0)
22+
parent : This contains the parent cell object which we visited
23+
before arrinving this cell
24+
g,h,f : The parameters for constructing the heuristic function
25+
which can be any function. for simplicity used line
26+
distance
27+
'''
28+
def __init__(self):
29+
self.position = (0, 0)
30+
self.parent = None
31+
32+
self.g = 0
33+
self.h = 0
34+
self.f = 0
35+
'''
36+
overrides equals method because otherwise cell assign will give
37+
wrong results
38+
'''
39+
def __eq__(self, cell):
40+
return self.position == cell.position
41+
42+
def showcell(self):
43+
print(self.position)
44+
45+
46+
class Gridworld(object):
47+
48+
'''
49+
Gridworld class represents the external world here a grid M*M
50+
matrix
51+
w : create a numpy array with the given world_size default is 5
52+
'''
53+
54+
def __init__(self, world_size=(5, 5)):
55+
self.w = np.zeros(world_size)
56+
self.world_x_limit = world_size[0]
57+
self.world_y_limit = world_size[1]
58+
59+
def show(self):
60+
print(self.w)
61+
62+
'''
63+
get_neighbours
64+
As the name suggests this function will return the neighbours of
65+
the a particular cell
66+
'''
67+
def get_neigbours(self, cell):
68+
neughbour_cord = [
69+
(-1, -1), (-1, 0), (-1, 1), (0, -1),
70+
(0, 1), (1, -1), (1, 0), (1, 1)]
71+
current_x = cell.position[0]
72+
current_y = cell.position[1]
73+
neighbours = []
74+
for n in neughbour_cord:
75+
x = current_x + n[0]
76+
y = current_y + n[1]
77+
if (
78+
(x >= 0 and x < self.world_x_limit) and
79+
(y >= 0 and y < self.world_y_limit)):
80+
c = Cell()
81+
c.position = (x, y)
82+
c.parent = cell
83+
neighbours.append(c)
84+
return neighbours
85+
86+
'''
87+
Implementation of a start algorithm
88+
world : Object of the world object
89+
start : Object of the cell as start position
90+
stop : Object of the cell as goal position
91+
'''
92+
93+
94+
def astar(world, start, goal):
95+
'''
96+
>>> p = Gridworld()
97+
>>> start = Cell()
98+
>>> start.position = (0,0)
99+
>>> goal = Cell()
100+
>>> goal.position = (4,4)
101+
>>> astar(p, start, goal)
102+
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
103+
'''
104+
_open = []
105+
_closed = []
106+
_open.append(start)
107+
108+
while _open:
109+
min_f = np.argmin([n.f for n in _open])
110+
current = _open[min_f]
111+
_closed.append(_open.pop(min_f))
112+
if current == goal:
113+
break
114+
for n in world.get_neigbours(current):
115+
for c in _closed:
116+
if c == n:
117+
continue
118+
n.g = current.g + 1
119+
x1, y1 = n.position
120+
x2, y2 = goal.position
121+
n.h = (y2 - y1)**2 + (x2 - x1)**2
122+
n.f = n.h + n.g
123+
124+
for c in _open:
125+
if c == n and c.f < n.f:
126+
continue
127+
_open.append(n)
128+
path = []
129+
while current.parent is not None:
130+
path.append(current.position)
131+
current = current.parent
132+
path.append(current.position)
133+
path = path[::-1]
134+
return path
135+
136+
if __name__ == '__main__':
137+
'''
138+
sample run
139+
'''
140+
# object for the world
141+
p = Gridworld()
142+
# stat position and Goal
143+
start = Cell()
144+
start.position = (0, 0)
145+
goal = Cell()
146+
goal.position = (4, 4)
147+
print("path from {} to {} ".format(start.position, goal.position))
148+
s = astar(p, start, goal)
149+
# Just for visual Purpose
150+
for i in s:
151+
p.w[i] = 1
152+
print(p.w)

0 commit comments

Comments
 (0)