JAxtar is a project with a JAX-native implementation of parallelizable a A* & Q* solver for neural heuristic search research. This project is inspired by mctx from Google DeepMind. If MCTS can be implemented entirely in pure JAX, why not A*?
MCTS, or tree search, is used in many RL algorithmic techniques, starting with AlphaGo, but graph search (not tree search) doesn't seem to have received much attention. Nevertheless, there are puzzle‐solving algorithms that use neural heuristics like DeepcubeA with A* or Q*(graph search).
However, the most frustrating aspect of my brief research (MSc) in this area is the time it takes to pass information back and forth between the GPU and CPU. When using a neural heuristic to evaluate a single node, the communication between the CPU and GPU, rather than the computation itself, can consume between 50% and 80% of the total processing time. Because of this communication overhead, DeepcubeA batches multiple nodes concurrently, which appears to work quite well.
However, these issues indicate that a more fundamental solution is needed. This led me to search for ways to perform A* directly on the GPU, but I discovered that most implementations suffer from the following problems.
- Many are written in pure C and CUDA, which is not well-suited for machine learning research.
- Some are written in JAX or PyTorch, but these are often limited to 2D grid environments or connectivity matrices, and cannot scale to an infinite number of different states that cannot all be held in memory.
- The implementation itself is often dependent on the specific definition of the state or problem.
To address these challenges, I decided to develop code based on the following principles:
- Pure JAX implementation
- Specifically for machine learning research.
- JAX-native priority queue
- The A* algorithm necessitates a priority queue to process nodes based on the lowest combined cost and heuristic estimate.
- However, standard Python heaps use lists, which are not JIT-compilable in JAX. Thus, a JAX-iterable heap is necessary.
- Hashable state representation and a hashtable for JAX operations.
- This is crucial for tracking node status (open/closed) in A* and efficiently retrieving parent state information.
- Hashing is optional for simple, indexable states. But for complex or infinite state spaces, hashing becomes essential for efficient indexing and retrieval of unique states.
- Fully batched and parallelized operations
- GPUs provide massive parallelism but have slower cores than CPUs. Therefore, algorithms for GPUs must be highly parallelized to leverage their architecture.
- Puzzle-agnostic implementation
- The implementation should be general enough to handle any puzzle with a defined state and action space.
- This generality enables wider research and allows for formalizing 'strict' behaviors in future implementations.
This project features specially written components, including:
- a hash function builder to convert defined states into hash keys
- a hashtable for parallel lookup and insertion operations
- a priority queue that supports batching, push, and pop operations
- Implementations for puzzles such as Rubik's Cube, Slide Puzzle, Lights Out, and Sokoban
- Network heuristics and Q-functions designed for JIT-compilable integration with A* & Q* algorithm
- a fully JIT-compiled A* & Q* algorithm for puzzles
This project was quite challenging to develop, and it felt like performing acrobatics with JAX. However, I managed to create a fully functional version, and hopefully it will inspire you to discover something amazing as you delve into JAX.
We can find the optimal path using a jittable, batched A* search as shown below. This is not a super blazingly fast result, but it can be well integrated with heuristics using neural networks.
The following speed benchmarks were measured on an Nvidia RTX 5090 hosted at vast.ai.
You can easily test it yourself with the colab link below.
$ python main.py astar
Start state
┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃ E ┃ 2 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ B ┃ 9 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ F ┃ 4 ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ ┃ 8 ┃ 6 ┃ 1 ┃
┗━━━┻━━━┻━━━┻━━━┛
Target state
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃ ┃
┗━━━┻━━━┻━━━┻━━━┛
Heuristic: 33.00
Search Time: 0.45 seconds
Search states: 1.65M(3.64M states/s)
Cost: 49.0
Solution found
$ python main.py astar --vmap_size 20
Vmapped A* search, multiple initial state solution
Start states
┏━━━┳━━━┳━━━┳━━━┓ ┏━━━┳━━━┳━━━┳━━━┓ ... ┏━━━┳━━━┳━━━┳━━━┓ ┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃ E ┃ 2 ┃ 3 ┃ ┃ 5 ┃ E ┃ 2 ┃ 3 ┃ (batch : (20,)) ┃ 5 ┃ E ┃ 2 ┃ 3 ┃ ┃ 5 ┃ E ┃ 2 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ B ┃ 9 ┃ 7 ┃ ┃ D ┃ B ┃ 9 ┃ 7 ┃ ┃ D ┃ B ┃ 9 ┃ 7 ┃ ┃ D ┃ B ┃ 9 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ F ┃ 4 ┃ C ┃ ┃ A ┃ F ┃ 4 ┃ C ┃ ┃ A ┃ F ┃ 4 ┃ C ┃ ┃ A ┃ F ┃ 4 ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫
┃ ┃ 8 ┃ 6 ┃ 1 ┃ ┃ ┃ 8 ┃ 6 ┃ 1 ┃ ┃ ┃ 8 ┃ 6 ┃ 1 ┃ ┃ ┃ 8 ┃ 6 ┃ 1 ┃
┗━━━┻━━━┻━━━┻━━━┛ ┗━━━┻━━━┻━━━┻━━━┛ ┗━━━┻━━━┻━━━┻━━━┛ ┗━━━┻━━━┻━━━┻━━━┛
Target state
┏━━━┳━━━┳━━━┳━━━┓ ┏━━━┳━━━┳━━━┳━━━┓ ... ┏━━━┳━━━┳━━━┳━━━┓ ┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃ ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃ (batch : (20,)) ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃ ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃ ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃ ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃ ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃ ┃ 9 ┃ A ┃ B ┃ C ┃ ┃ 9 ┃ A ┃ B ┃ C ┃ ┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫ ┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃ ┃ ┃ D ┃ E ┃ F ┃ ┃ ┃ D ┃ E ┃ F ┃ ┃ ┃ D ┃ E ┃ F ┃ ┃
┗━━━┻━━━┻━━━┻━━━┛ ┗━━━┻━━━┻━━━┻━━━┛ ┗━━━┻━━━┻━━━┻━━━┛ ┗━━━┻━━━┻━━━┻━━━┛
vmap astar
# search_result, solved, solved_idx =jax.vmap(astar_fn, in_axes=(None, 0, 0, None))(inital_search_result, states, filled, target)
Search Time: 8.93 seconds (x19.6/20)
Search states: 33.1M (3.71M states/s) (x1.0 faster)
Solution found: 100.00%
# this means astart_fn is completely vmapable and jitable
$ python main.py astar -nn -h -p rubikscube -w 0.2
...
Heuristic: 14.51
Search Time: 0.98 seconds
Search states: 1.51M(1.54M states/s)
Cost: 22.0
Solution found
$ python main.py qstar -nn -h -p rubikscube -w 0.2
...
qvalues: 'l_cw': 16.9 | 'l_ccw': 17.5 | 'd_cw': 17.1 | 'd_ccw': 16.8 | 'f_cw': 17.4 | 'f_ccw': 17.9 | 'r_cw': 16.8 | 'r_ccw': 17.2 | 'b_cw': 17.3 | 'b_ccw': 16.3 | 'u_cw': 17.7 | 'u_ccw': 17.0
Search Time: 0.24 seconds
Search states: 1.46M(6.04M states/s)
Cost: 22.0
Solution found
Rubikscube | Slidepuzzle | Lightsout | Maze | Sokoban |
---|---|---|---|---|
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
These types of puzzles are not strictly the kind that are typically solved with A*, but after some simple testing, it turns out that, depending on how the problem is defined, they can be solved. Furthermore, this approach can be extended to TSP and countless other COP problems, provided that with a good heuristic. The training method will need to be investigated further.
Dotknot | TSP |
---|---|
![]() |
![]() |
![]() |
![]() |
Please use this citation to reference this project.
@software{kyuseokjung2024jaxtar,
title = {JA^{xtar}: GPU-accelerated Batched parallel A* & Q* solver in pure JAX!},
author = {Kyuseok Jung},
url = {https://github.com/tinker495/JAxtar},
year = {2024},
}