Skip to content

Commit

Permalink
Add Gaussian random walk for CTRM (omron-sinicx#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
yonetaniryo authored Dec 12, 2022
1 parent 097271f commit cc927c9
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 170 deletions.
1 change: 1 addition & 0 deletions scripts/config/sampler/ctrm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ params:
num_samples: 50
num_rw_samples: 15
max_T: 64
rw_type: "uniform"
prob_rw_decay_high: 25.0
prob_rw_decay_low: 5.0
prob_rw_after_goal: 1.0
Expand Down
303 changes: 145 additions & 158 deletions src/jaxmapp/roadmap/learned_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class CTRMSampler(DefaultSampler):
15 # number of trajectories to sample with high random-walk decay
)
max_T: int = 64 # maximum number of timesteps for each traj
rw_type: str = "uniform" # types of random walk | should be either "uniform" or "normal"
prob_rw_decay_high: float = (
25.0 # parameter for invoking random walk in learned trajectories
)
Expand Down Expand Up @@ -120,6 +121,7 @@ def set_model_and_params(
self.default_num_neighbors = model.num_neighbors
self.inference_fn = get_inference_fn(model)
self.params = params
self.sample_next = self.build_sample_next()

def sample_trajectories(
self, key: PRNGKey, num_samples: int, instance: Instance
Expand Down Expand Up @@ -161,17 +163,12 @@ def sample_trajectories(
has_reached_goals,
]
for t in range(self.max_T):
loop_carry = sample_next(
loop_carry = self.sample_next(
t,
loop_carry,
self.params,
self.inference_fn,
prob_rw_decay=self.prob_rw_decay_high
if trial_id > self.num_rw_samples
else self.prob_rw_decay_low,
prob_rw_after_goal=self.prob_rw_after_goal,
num_rw_attempts=self.num_rw_attempts,
max_speed_discount=self.max_speed_discount,
)
has_reached_goals = loop_carry[-1]
if jnp.all(has_reached_goals):
Expand All @@ -186,160 +183,150 @@ def sample_trajectories(
return pos_carry


@partial(
jax.jit,
static_argnames=(
"inference_fn",
"num_rw_attempts",
"max_speed_discount",
),
)
def sample_next(
t: int,
loop_carry: list[Array],
params: dict,
inference_fn: Callable,
prob_rw_decay: float,
prob_rw_after_goal: float,
num_rw_attempts: int,
max_speed_discount: float,
) -> list[Array]:
"""
Compiled function for sampling next vertex using the trained model
def build_sample_next(self):

Args:
t (int): current timestep
loop_carry (list[Array]): loop_carry
params (dict): model parameters
inference_fn (Callable): inference function of the model
prob_rw_decay (float): parameter for decaying the probability of random walk
prob_rw_after_goal (float): the probability of random walk after reaching the goal
num_rw_attempts (int): number of resampling with the random walk
max_speed_discount (float): parameter to make the movement valid within a single time step
def sample_next(
t: int,
loop_carry: list[Array],
prob_rw_decay: float,
) -> list[Array]:
"""
Compiled function for sampling next vertex using the trained model
Returns:
list[Array]: updated loop carry
"""
Args:
t (int): current timestep
loop_carry (list[Array]): loop_carry
prob_rw_decay (float): parameter for decaying the probability of random walk
# extract elements from loop_carry and update pos_carry
(
key,
current_pos,
previous_pos,
goals,
max_speeds,
rads,
occupancy,
cost_map,
sdf,
pos_carry,
trial_id,
makespan,
has_reached_goals,
) = loop_carry

# determine random walk probability
key0, key1, key = jax.random.split(key, 3)
prob_random_walk = prob_random_walk = jnp.exp(-prob_rw_decay * t / makespan)
has_reached_goals = jax.vmap(valid_linear_move, in_axes=(0, 0, 0, 0, None))(
current_pos,
goals,
max_speeds,
rads,
sdf,
)
prob_random_walk = jax.vmap(
lambda x: jax.lax.cond(
x, lambda _: prob_rw_after_goal, lambda _: prob_random_walk, None
)
)(has_reached_goals)

# generate next motion candidates
next_motion_learned = inference_fn(
params,
key0,
current_pos,
previous_pos,
goals,
max_speeds,
rads,
occupancy,
cost_map,
)
Returns:
list[Array]: updated loop carry
"""

# clip next motion to ensure validity
next_motion_mag = next_motion_learned[:, 2]
next_motion_mag = jnp.minimum(next_motion_mag, max_speeds * max_speed_discount)
next_motion_learned = next_motion_learned.at[:, 2].set(next_motion_mag)
next_motion_dir = next_motion_learned[:, :2]
vec_mag = jnp.expand_dims(jnp.linalg.norm(next_motion_dir, axis=-1), -1)
vec_mag_avoid_zero = jnp.where(vec_mag == 0, 1, vec_mag)
next_motion_dir = next_motion_dir / vec_mag_avoid_zero
next_motion_learned = next_motion_learned.at[:, :2].set(next_motion_dir)

def sample_uniform_i(key, max_speed):
random_vals = jax.random.uniform(key, shape=(2 * num_rw_attempts,))
mag = max_speed * random_vals[:num_rw_attempts] * max_speed_discount
theta = jnp.pi * 2 * random_vals[num_rw_attempts:]
next_motion_ = jnp.vstack((jnp.sin(theta), jnp.cos(theta), mag)).T
return next_motion_

next_motion_random = jax.vmap(sample_uniform_i, in_axes=(None, 0), out_axes=1)(
key, max_speeds
)
next_motion_learned = jax.vmap(
lambda x, nmr, nml: jax.lax.cond(
jax.random.uniform(key1) < x, lambda _: nmr, lambda _: nml, None
)
)(prob_random_walk, next_motion_random[0], next_motion_learned)
next_motion_learned = jnp.expand_dims(next_motion_learned, 0)
# extract elements from loop_carry and update pos_carry
(
key,
current_pos,
previous_pos,
goals,
max_speeds,
rads,
occupancy,
cost_map,
sdf,
pos_carry,
trial_id,
makespan,
has_reached_goals,
) = loop_carry

next_motion = jnp.concatenate(
(next_motion_learned, next_motion_random, jnp.zeros_like(next_motion_learned)),
axis=0,
)
next_pos_cands = (next_motion[:, :, :2] * next_motion[:, :, 2:3]) + current_pos
# determine next position
validity = jax.vmap(
jax.vmap(valid_linear_move, in_axes=(0, 0, 0, 0, None)),
in_axes=(None, 0, None, None, None),
)(
current_pos,
next_pos_cands,
max_speeds,
rads,
sdf,
).T
selected_id = jax.vmap(
lambda x: num_rw_attempts
+ 1
- jax.lax.fori_loop(
0,
num_rw_attempts + 2,
lambda i, x: jax.lax.cond(x[0][i], lambda _: [x[0], i], lambda _: x, None),
[x, 0],
)[1]
)(jnp.fliplr(validity))
next_pos = jax.vmap(lambda x, y: x[y], in_axes=(1, 0))(next_pos_cands, selected_id)

# update positions and pack everything back into loop_carry
previous_pos = current_pos
current_pos = next_pos
pos_carry = pos_carry.at[trial_id, t].set(current_pos)
loop_carry = [
key,
current_pos,
previous_pos,
goals,
max_speeds,
rads,
occupancy,
cost_map,
sdf,
pos_carry,
trial_id,
makespan,
has_reached_goals,
]

return loop_carry
# determine random walk probability
key0, key1, key = jax.random.split(key, 3)
prob_random_walk = prob_random_walk = jnp.exp(-prob_rw_decay * t / makespan)
has_reached_goals = jax.vmap(valid_linear_move, in_axes=(0, 0, 0, 0, None))(
current_pos,
goals,
max_speeds,
rads,
sdf,
)
prob_random_walk = jax.vmap(
lambda x: jax.lax.cond(
x, lambda _: self.prob_rw_after_goal, lambda _: prob_random_walk, None
)
)(has_reached_goals)

# generate next motion candidates
next_motion_learned = self.inference_fn(
self.params,
key0,
current_pos,
previous_pos,
goals,
max_speeds,
rads,
occupancy,
cost_map,
)

# clip next motion to ensure validity
next_motion_mag = next_motion_learned[:, 2]
next_motion_mag = jnp.minimum(next_motion_mag, max_speeds * self.max_speed_discount)
next_motion_learned = next_motion_learned.at[:, 2].set(next_motion_mag)
next_motion_dir = next_motion_learned[:, :2]
vec_mag = jnp.expand_dims(jnp.linalg.norm(next_motion_dir, axis=-1), -1)
vec_mag_avoid_zero = jnp.where(vec_mag == 0, 1, vec_mag)
next_motion_dir = next_motion_dir / vec_mag_avoid_zero
next_motion_learned = next_motion_learned.at[:, :2].set(next_motion_dir)

def sample_uniform_i(key, max_speed):

if self.rw_type == "uniform":
random_vals = jax.random.uniform(key, shape=(2 * self.num_rw_attempts,))
else:
random_vals = jnp.clip(jnp.abs(jax.random.normal(key, shape=(2 * self.num_rw_attempts,))), a_min=0., a_max=1.)
mag = max_speed * random_vals[:self.num_rw_attempts] * self.max_speed_discount
theta = jnp.pi * 2 * random_vals[self.num_rw_attempts:]
next_motion_ = jnp.vstack((jnp.sin(theta), jnp.cos(theta), mag)).T
return next_motion_

next_motion_random = jax.vmap(sample_uniform_i, in_axes=(None, 0), out_axes=1)(
key, max_speeds
)
next_motion_learned = jax.vmap(
lambda x, nmr, nml: jax.lax.cond(
jax.random.uniform(key1) < x, lambda _: nmr, lambda _: nml, None
)
)(prob_random_walk, next_motion_random[0], next_motion_learned)
next_motion_learned = jnp.expand_dims(next_motion_learned, 0)

next_motion = jnp.concatenate(
(next_motion_learned, next_motion_random, jnp.zeros_like(next_motion_learned)),
axis=0,
)
next_pos_cands = (next_motion[:, :, :2] * next_motion[:, :, 2:3]) + current_pos
# determine next position
validity = jax.vmap(
jax.vmap(valid_linear_move, in_axes=(0, 0, 0, 0, None)),
in_axes=(None, 0, None, None, None),
)(
current_pos,
next_pos_cands,
max_speeds,
rads,
sdf,
).T
selected_id = jax.vmap(
lambda x: self.num_rw_attempts
+ 1
- jax.lax.fori_loop(
0,
self.num_rw_attempts + 2,
lambda i, x: jax.lax.cond(x[0][i], lambda _: [x[0], i], lambda _: x, None),
[x, 0],
)[1]
)(jnp.fliplr(validity))
next_pos = jax.vmap(lambda x, y: x[y], in_axes=(1, 0))(next_pos_cands, selected_id)

# update positions and pack everything back into loop_carry
previous_pos = current_pos
current_pos = next_pos
pos_carry = pos_carry.at[trial_id, t].set(current_pos)
loop_carry = [
key,
current_pos,
previous_pos,
goals,
max_speeds,
rads,
occupancy,
cost_map,
sdf,
pos_carry,
trial_id,
makespan,
has_reached_goals,
]

return loop_carry

return jax.jit(sample_next)
25 changes: 13 additions & 12 deletions tutorials/1. Quickstart.ipynb

Large diffs are not rendered by default.

0 comments on commit cc927c9

Please sign in to comment.