diff --git a/clean_JaxGCRL/README.md b/clean_JaxGCRL/README.md
new file mode 100644
index 0000000..9ca5669
--- /dev/null
+++ b/clean_JaxGCRL/README.md
@@ -0,0 +1 @@
+# Clean-JaxGCRL
\ No newline at end of file
diff --git a/clean_JaxGCRL/buffer.py b/clean_JaxGCRL/buffer.py
new file mode 100644
index 0000000..7f213ab
--- /dev/null
+++ b/clean_JaxGCRL/buffer.py
@@ -0,0 +1,234 @@
+import jax
+import flax
+import functools
+import jax.numpy as jnp
+
+from jax import flatten_util
+from brax.training.types import PRNGKey
+
+@flax.struct.dataclass
+class ReplayBufferState:
+ """Contains data related to a replay buffer."""
+
+ data: jnp.ndarray
+ insert_position: jnp.ndarray
+ sample_position: jnp.ndarray
+ key: PRNGKey
+
+class TrajectoryUniformSamplingQueue():
+ """
+ Base class for limited-size FIFO reply buffers.
+
+ Implements an `insert()` method which behaves like a limited-size queue.
+ I.e. it adds samples to the end of the queue and, if necessary, removes the
+ oldest samples form the queue in order to keep the maximum size within the
+ specified limit.
+
+ Derived classes must implement the `sample()` method.
+ """
+ def __init__(
+ self,
+ max_replay_size: int,
+ dummy_data_sample,
+ sample_batch_size: int,
+ num_envs: int,
+ episode_length: int,
+ ):
+
+ self._flatten_fn = jax.vmap(jax.vmap(lambda x: flatten_util.ravel_pytree(x)[0]))
+ dummy_flatten, self._unflatten_fn = flatten_util.ravel_pytree(dummy_data_sample)
+ self._unflatten_fn = jax.vmap(jax.vmap(self._unflatten_fn))
+ data_size = len(dummy_flatten)
+
+ self._data_shape = (max_replay_size, num_envs, data_size)
+ self._data_dtype = dummy_flatten.dtype
+ self._sample_batch_size = sample_batch_size
+ self._size = 0
+ self.num_envs = num_envs
+ self.episode_length = episode_length
+
+ def init(self, key):
+ return ReplayBufferState(
+ data=jnp.zeros(self._data_shape, self._data_dtype),
+ sample_position=jnp.zeros((), jnp.int32),
+ insert_position=jnp.zeros((), jnp.int32),
+ key=key,
+ )
+
+ def insert(self, buffer_state, samples):
+ """Insert data into the replay buffer."""
+ self.check_can_insert(buffer_state, samples, 1)
+ return self.insert_internal(buffer_state, samples)
+
+ def check_can_insert(self, buffer_state, samples, shards):
+ """Checks whether insert operation can be performed."""
+ assert isinstance(shards, int), "This method should not be JITed."
+ insert_size = jax.tree_util.tree_flatten(samples)[0][0].shape[0] // shards
+ if self._data_shape[0] < insert_size:
+ raise ValueError(
+ "Trying to insert a batch of samples larger than the maximum replay"
+ f" size. num_samples: {insert_size}, max replay size"
+ f" {self._data_shape[0]}"
+ )
+ self._size = min(self._data_shape[0], self._size + insert_size)
+
+ def check_can_sample(self, buffer_state, shards):
+ """Checks whether sampling can be performed. Do not JIT this method."""
+ pass
+
+ def insert_internal(
+ self, buffer_state, samples
+ ):
+ """Insert data in the replay buffer.
+
+ Args:
+ buffer_state: Buffer state
+ samples: Sample to insert with a leading batch size.
+
+ Returns:
+ New buffer state.
+ """
+ if buffer_state.data.shape != self._data_shape:
+ raise ValueError(
+ f"buffer_state.data.shape ({buffer_state.data.shape}) "
+ f"doesn't match the expected value ({self._data_shape})"
+ )
+
+ update = self._flatten_fn(samples) #Updates has shape (unroll_len, num_envs, self._data_shape[-1])
+ data = buffer_state.data #shape = (max_replay_size, num_envs, data_size)
+
+ # If needed, roll the buffer to make sure there's enough space to fit
+ # `update` after the current position.
+ position = buffer_state.insert_position
+ roll = jnp.minimum(0, len(data) - position - len(update))
+ data = jax.lax.cond(roll, lambda: jnp.roll(data, roll, axis=0), lambda: data)
+ position = position + roll
+
+ # Update the buffer and the control numbers.
+ data = jax.lax.dynamic_update_slice_in_dim(data, update, position, axis=0)
+ position = (position + len(update)) % (len(data) + 1) # so whenever roll happens, position becomes len(data), else it is increased by len(update), what is the use of doing % (len(data) + 1)??
+ sample_position = jnp.maximum(0, buffer_state.sample_position + roll) #what is the use of this line? sample_position always remains 0 as roll can never be positive
+
+ return buffer_state.replace(
+ data=data,
+ insert_position=position,
+ sample_position=sample_position,
+ )
+
+ def sample(self, buffer_state):
+ """Sample a batch of data."""
+ self.check_can_sample(buffer_state, 1)
+ return self.sample_internal(buffer_state)
+
+ def sample_internal(self, buffer_state):
+ if buffer_state.data.shape != self._data_shape:
+ raise ValueError(
+ f"Data shape expected by the replay buffer ({self._data_shape}) does "
+ f"not match the shape of the buffer state ({buffer_state.data.shape})"
+ )
+ key, sample_key, shuffle_key = jax.random.split(buffer_state.key, 3)
+ # Note: this is the number of envs to sample but it can be modified if there is OOM
+ shape = self.num_envs
+
+ # Sampling envs idxs
+ envs_idxs = jax.random.choice(sample_key, jnp.arange(self.num_envs), shape=(shape,), replace=False)
+
+ @functools.partial(jax.jit, static_argnames=("rows", "cols"))
+ def create_matrix(rows, cols, min_val, max_val, rng_key):
+ rng_key, subkey = jax.random.split(rng_key)
+ start_values = jax.random.randint(subkey, shape=(rows,), minval=min_val, maxval=max_val)
+ row_indices = jnp.arange(cols)
+ matrix = start_values[:, jnp.newaxis] + row_indices
+ return matrix
+
+ @jax.jit
+ def create_batch(arr_2d, indices):
+ return jnp.take(arr_2d, indices, axis=0, mode="wrap")
+
+ create_batch_vmaped = jax.vmap(create_batch, in_axes=(1, 0))
+
+ matrix = create_matrix(
+ shape,
+ self.episode_length,
+ buffer_state.sample_position,
+ buffer_state.insert_position - self.episode_length,
+ sample_key,
+ )
+
+ '''
+ The function create_batch will be called for every envs_idxs of buffer_state.data and every row of matrix.
+ Because every row of matrix has consecutive indices of self.episode_length, for every
+ envs_idx of envs_idxs, we will sample a random self.episode_length length sequence from
+ buffer_state.data[:, envs_idx, :]. But I don't think the code ensures that this sequence
+ won't be across episodes?
+
+ flatten_crl_fn takes care of this
+ '''
+ batch = create_batch_vmaped(buffer_state.data[:, envs_idxs, :], matrix)
+ transitions = self._unflatten_fn(batch)
+ return buffer_state.replace(key=key), transitions
+
+ @staticmethod
+ @functools.partial(jax.jit, static_argnames=("buffer_config"))
+ def flatten_crl_fn(buffer_config, transition, sample_key):
+
+ gamma, obs_dim, goal_start_idx, goal_end_idx = buffer_config
+
+ # Because it's vmaped transition.obs.shape is of shape (episode_len, obs_dim)
+ seq_len = transition.observation.shape[0]
+ arrangement = jnp.arange(seq_len)
+ is_future_mask = jnp.array(arrangement[:, None] < arrangement[None], dtype=jnp.float32) # upper triangular matrix of shape seq_len, seq_len where all non-zero entries are 1
+ discount = gamma ** jnp.array(arrangement[None] - arrangement[:, None], dtype=jnp.float32)
+ probs = is_future_mask * discount
+
+ # probs is an upper triangular matrix of shape seq_len, seq_len of the form:
+ # [[0. , 0.99 , 0.98010004, 0.970299 , 0.960596 ],
+ # [0. , 0. , 0.99 , 0.98010004, 0.970299 ],
+ # [0. , 0. , 0. , 0.99 , 0.98010004],
+ # [0. , 0. , 0. , 0. , 0.99 ],
+ # [0. , 0. , 0. , 0. , 0. ]]
+ # assuming seq_len = 5
+ # the same result can be obtained using probs = is_future_mask * (gamma ** jnp.cumsum(is_future_mask, axis=-1))
+
+ single_trajectories = jnp.concatenate(
+ [transition.extras["state_extras"]["seed"][:, jnp.newaxis].T] * seq_len, axis=0
+ )
+ # array of seq_len x seq_len where a row is an array of seeds that correspond to the episode index from which that time-step was collected
+ # timesteps collected from the same episode will have the same seed. All rows of the single_trajectories are same.
+
+ probs = probs * jnp.equal(single_trajectories, single_trajectories.T) + jnp.eye(seq_len) * 1e-5
+ #ith row of probs will be non zero only for time indices that
+ # 1) are greater than i
+ # 2) have the same seed as the ith time index
+
+ goal_index = jax.random.categorical(sample_key, jnp.log(probs))
+ future_state = jnp.take(transition.observation, goal_index[:-1], axis=0) #the last goal_index cannot be considered as there is no future.
+ future_action = jnp.take(transition.action, goal_index[:-1], axis=0)
+ goal = future_state[:, goal_start_idx : goal_end_idx]
+ future_state = future_state[:, : obs_dim]
+ state = transition.observation[:-1, : obs_dim] #all states are considered
+ new_obs = jnp.concatenate([state, goal], axis=1)
+
+ extras = {
+ "policy_extras": {},
+ "state_extras": {
+ "truncation": jnp.squeeze(transition.extras["state_extras"]["truncation"][:-1]),
+ "seed": jnp.squeeze(transition.extras["state_extras"]["seed"][:-1]),
+ },
+ "state": state,
+ "future_state": future_state,
+ "future_action": future_action,
+ }
+
+ return transition._replace(
+ observation=jnp.squeeze(new_obs), #this has shape (num_envs, episode_length-1, obs_size)
+ action=jnp.squeeze(transition.action[:-1]),
+ reward=jnp.squeeze(transition.reward[:-1]),
+ discount=jnp.squeeze(transition.discount[:-1]),
+ extras=extras,
+ )
+
+ def size(self, buffer_state: ReplayBufferState) -> int:
+ return (
+ buffer_state.insert_position - buffer_state.sample_position
+ )
\ No newline at end of file
diff --git a/clean_JaxGCRL/envs/ant.py b/clean_JaxGCRL/envs/ant.py
new file mode 100644
index 0000000..8adc950
--- /dev/null
+++ b/clean_JaxGCRL/envs/ant.py
@@ -0,0 +1,185 @@
+import os
+from typing import Tuple
+
+from brax import base
+from brax import math
+from brax.envs.base import PipelineEnv, State
+from brax.io import mjcf
+import jax
+from jax import numpy as jp
+import mujoco
+
+class Ant(PipelineEnv):
+ def __init__(
+ self,
+ ctrl_cost_weight=0.5,
+ use_contact_forces=False,
+ contact_cost_weight=5e-4,
+ healthy_reward=1.0,
+ terminate_when_unhealthy=True,
+ healthy_z_range=(0.2, 1.0),
+ contact_force_range=(-1.0, 1.0),
+ reset_noise_scale=0.1,
+ exclude_current_positions_from_observation=True,
+ backend="generalized",
+ **kwargs,
+ ):
+ path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets', "ant.xml")
+ sys = mjcf.load(path)
+
+ n_frames = 5
+
+ if backend in ["spring", "positional"]:
+ sys = sys.replace(dt=0.005)
+ n_frames = 10
+
+ if backend == "mjx":
+ sys = sys.tree_replace(
+ {
+ "opt.solver": mujoco.mjtSolver.mjSOL_NEWTON,
+ "opt.disableflags": mujoco.mjtDisableBit.mjDSBL_EULERDAMP,
+ "opt.iterations": 1,
+ "opt.ls_iterations": 4,
+ }
+ )
+
+ if backend == "positional":
+ # TODO: does the same actuator strength work as in spring
+ sys = sys.replace(
+ actuator=sys.actuator.replace(
+ gear=200 * jp.ones_like(sys.actuator.gear)
+ )
+ )
+
+ kwargs["n_frames"] = kwargs.get("n_frames", n_frames)
+
+ super().__init__(sys=sys, backend=backend, **kwargs)
+
+ self._ctrl_cost_weight = ctrl_cost_weight
+ self._use_contact_forces = use_contact_forces
+ self._contact_cost_weight = contact_cost_weight
+ self._healthy_reward = healthy_reward
+ self._terminate_when_unhealthy = terminate_when_unhealthy
+ self._healthy_z_range = healthy_z_range
+ self._contact_force_range = contact_force_range
+ self._reset_noise_scale = reset_noise_scale
+ self._exclude_current_positions_from_observation = (
+ exclude_current_positions_from_observation
+ )
+
+ if self._use_contact_forces:
+ raise NotImplementedError("use_contact_forces not implemented.")
+
+ def reset(self, rng: jax.Array) -> State:
+ """Resets the environment to an initial state."""
+
+ rng, rng1, rng2 = jax.random.split(rng, 3)
+
+ low, hi = -self._reset_noise_scale, self._reset_noise_scale
+ q = self.sys.init_q + jax.random.uniform(
+ rng1, (self.sys.q_size(),), minval=low, maxval=hi
+ )
+ qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))
+
+ # set the target q, qd
+ _, target = self._random_target(rng)
+ q = q.at[-2:].set(target)
+ qd = qd.at[-2:].set(0)
+
+ pipeline_state = self.pipeline_init(q, qd)
+ obs = self._get_obs(pipeline_state)
+
+ reward, done, zero = jp.zeros(3)
+ metrics = {
+ "reward_forward": zero,
+ "reward_survive": zero,
+ "reward_ctrl": zero,
+ "reward_contact": zero,
+ "x_position": zero,
+ "y_position": zero,
+ "distance_from_origin": zero,
+ "x_velocity": zero,
+ "y_velocity": zero,
+ "forward_reward": zero,
+ "dist": zero,
+ "success": zero,
+ "success_easy": zero
+ }
+ info = {"seed": 0}
+ state = State(pipeline_state, obs, reward, done, metrics)
+ state.info.update(info)
+ return state
+
+ # Todo rename seed to traj_id
+ def step(self, state: State, action: jax.Array) -> State:
+ """Run one timestep of the environment's dynamics."""
+ pipeline_state0 = state.pipeline_state
+ pipeline_state = self.pipeline_step(pipeline_state0, action)
+
+ if "steps" in state.info.keys():
+ seed = state.info["seed"] + jp.where(state.info["steps"], 0, 1)
+ else:
+ seed = state.info["seed"]
+ info = {"seed": seed}
+
+ velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt
+ forward_reward = velocity[0]
+
+ min_z, max_z = self._healthy_z_range
+ is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0)
+ is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy)
+ if self._terminate_when_unhealthy:
+ healthy_reward = self._healthy_reward
+ else:
+ healthy_reward = self._healthy_reward * is_healthy
+ ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
+ contact_cost = 0.0
+
+ obs = self._get_obs(pipeline_state)
+ reward = forward_reward + healthy_reward - ctrl_cost - contact_cost
+ done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
+
+ dist = jp.linalg.norm(obs[:2] - obs[-2:])
+ success = jp.array(dist < 0.5, dtype=float)
+ success_easy = jp.array(dist < 2., dtype=float)
+
+ state.metrics.update(
+ reward_forward=forward_reward,
+ reward_survive=healthy_reward,
+ reward_ctrl=-ctrl_cost,
+ reward_contact=-contact_cost,
+ x_position=pipeline_state.x.pos[0, 0],
+ y_position=pipeline_state.x.pos[0, 1],
+ distance_from_origin=math.safe_norm(pipeline_state.x.pos[0]),
+ x_velocity=velocity[0],
+ y_velocity=velocity[1],
+ dist=dist,
+ success=success,
+ success_easy=success_easy
+ )
+ state.info.update(info)
+ return state.replace(
+ pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
+ )
+
+ def _get_obs(self, pipeline_state: base.State) -> jax.Array:
+ """Observe ant body position and velocities."""
+ # remove target q, qd
+ qpos = pipeline_state.q[:-2]
+ qvel = pipeline_state.qd[:-2]
+
+ target_pos = pipeline_state.x.pos[-1][:2]
+
+ if self._exclude_current_positions_from_observation:
+ qpos = qpos[2:]
+
+ return jp.concatenate([qpos] + [qvel] + [target_pos])
+
+ def _random_target(self, rng: jax.Array) -> Tuple[jax.Array, jax.Array]:
+ """Returns a target location in a random circle slightly above xy plane."""
+ rng, rng1, rng2 = jax.random.split(rng, 3)
+ dist = 10
+ ang = jp.pi * 2.0 * jax.random.uniform(rng2)
+ target_x = dist * jp.cos(ang)
+ target_y = dist * jp.sin(ang)
+ return rng, jp.array([target_x, target_y])
\ No newline at end of file
diff --git a/clean_JaxGCRL/envs/ant_maze.py b/clean_JaxGCRL/envs/ant_maze.py
new file mode 100644
index 0000000..fe4011c
--- /dev/null
+++ b/clean_JaxGCRL/envs/ant_maze.py
@@ -0,0 +1,303 @@
+import os
+from typing import Tuple
+
+from brax import base
+from brax import math
+from brax.envs.base import PipelineEnv, State
+from brax.io import mjcf
+import jax
+from jax import numpy as jp
+import mujoco
+import xml.etree.ElementTree as ET
+
+# This is based on original Ant environment from Brax
+# https://github.com/google/brax/blob/main/brax/envs/ant.py
+# Maze creation dapted from: https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/locomotion/maze_env.py
+
+RESET = R = 'r'
+GOAL = G = 'g'
+
+
+U_MAZE = [[1, 1, 1, 1, 1],
+ [1, R, G, G, 1],
+ [1, 1, 1, G, 1],
+ [1, G, G, G, 1],
+ [1, 1, 1, 1, 1]]
+
+U_MAZE_EVAL = [[1, 1, 1, 1, 1],
+ [1, R, 0, 0, 1],
+ [1, 1, 1, 0, 1],
+ [1, G, G, G, 1],
+ [1, 1, 1, 1, 1]]
+
+BIG_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1],
+ [1, R, G, 1, 1, G, G, 1],
+ [1, G, G, 1, G, G, G, 1],
+ [1, 1, G, G, G, 1, 1, 1],
+ [1, G, G, 1, G, G, G, 1],
+ [1, G, 1, G, G, 1, G, 1],
+ [1, G, G, G, 1, G, G, 1],
+ [1, 1, 1, 1, 1, 1, 1, 1]]
+
+BIG_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1],
+ [1, R, 0, 1, 1, G, G, 1],
+ [1, 0, 0, 1, 0, G, G, 1],
+ [1, 1, 0, 0, 0, 1, 1, 1],
+ [1, 0, 0, 1, 0, 0, 0, 1],
+ [1, 0, 1, G, 0, 1, G, 1],
+ [1, 0, G, G, 1, G, G, 1],
+ [1, 1, 1, 1, 1, 1, 1, 1]]
+
+HARDEST_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ [1, R, G, G, G, 1, G, G, G, G, G, 1],
+ [1, G, 1, 1, G, 1, G, 1, G, 1, G, 1],
+ [1, G, G, G, G, G, G, 1, G, G, G, 1],
+ [1, G, 1, 1, 1, 1, G, 1, 1, 1, G, 1],
+ [1, G, G, 1, G, 1, G, G, G, G, G, 1],
+ [1, 1, G, 1, G, 1, G, 1, G, 1, 1, 1],
+ [1, G, G, 1, G, G, G, 1, G, G, G, 1],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
+
+MAZE_HEIGHT = 0.5
+
+def find_robot(structure, size_scaling):
+ for i in range(len(structure)):
+ for j in range(len(structure[0])):
+ if structure[i][j] == RESET:
+ return i * size_scaling, j * size_scaling
+
+def find_goals(structure, size_scaling):
+ goals = []
+ for i in range(len(structure)):
+ for j in range(len(structure[0])):
+ if structure[i][j] == GOAL:
+ goals.append([i * size_scaling, j * size_scaling])
+
+ return jp.array(goals)
+
+# Create a xml with maze and a list of possible goal positions
+def make_maze(maze_layout_name, maze_size_scaling):
+ if maze_layout_name == "u_maze":
+ maze_layout = U_MAZE
+ elif maze_layout_name == "u_maze_eval":
+ maze_layout = U_MAZE_EVAL
+ elif maze_layout_name == "big_maze":
+ maze_layout = BIG_MAZE
+ elif maze_layout_name == "big_maze_eval":
+ maze_layout = BIG_MAZE_EVAL
+ elif maze_layout_name == "hardest_maze":
+ maze_layout = HARDEST_MAZE
+ else:
+ raise ValueError(f"Unknown maze layout: {maze_layout_name}")
+
+ xml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets', "ant_maze.xml")
+
+ robot_x, robot_y = find_robot(maze_layout, maze_size_scaling)
+ possible_goals = find_goals(maze_layout, maze_size_scaling)
+
+ tree = ET.parse(xml_path)
+ worldbody = tree.find(".//worldbody")
+
+ for i in range(len(maze_layout)):
+ for j in range(len(maze_layout[0])):
+ struct = maze_layout[i][j]
+ if struct == 1:
+ ET.SubElement(
+ worldbody, "geom",
+ name="block_%d_%d" % (i, j),
+ pos="%f %f %f" % (i * maze_size_scaling,
+ j * maze_size_scaling,
+ MAZE_HEIGHT / 2 * maze_size_scaling),
+ size="%f %f %f" % (0.5 * maze_size_scaling,
+ 0.5 * maze_size_scaling,
+ MAZE_HEIGHT / 2 * maze_size_scaling),
+ type="box",
+ material="",
+ contype="1",
+ conaffinity="1",
+ rgba="0.7 0.5 0.3 1.0",
+ )
+
+
+ torso = tree.find(".//numeric[@name='init_qpos']")
+ data = torso.get("data")
+ torso.set("data", f"{robot_x} {robot_y} " + data)
+
+ tree = tree.getroot()
+ xml_string = ET.tostring(tree)
+
+ return xml_string, possible_goals
+
+class AntMaze(PipelineEnv):
+ def __init__(
+ self,
+ ctrl_cost_weight=0.5,
+ use_contact_forces=False,
+ contact_cost_weight=5e-4,
+ healthy_reward=1.0,
+ terminate_when_unhealthy=True,
+ healthy_z_range=(0.2, 1.0),
+ contact_force_range=(-1.0, 1.0),
+ reset_noise_scale=0.1,
+ exclude_current_positions_from_observation=True,
+ backend="generalized",
+ maze_layout_name="u_maze",
+ maze_size_scaling=4.0,
+ **kwargs,
+ ):
+ xml_string, possible_goals = make_maze(maze_layout_name, maze_size_scaling)
+
+ sys = mjcf.loads(xml_string)
+ self.possible_goals = possible_goals
+
+ n_frames = 5
+
+ if backend in ["spring", "positional"]:
+ sys = sys.replace(dt=0.005)
+ n_frames = 10
+
+ if backend == "mjx":
+ sys = sys.tree_replace(
+ {
+ "opt.solver": mujoco.mjtSolver.mjSOL_NEWTON,
+ "opt.disableflags": mujoco.mjtDisableBit.mjDSBL_EULERDAMP,
+ "opt.iterations": 1,
+ "opt.ls_iterations": 4,
+ }
+ )
+
+ if backend == "positional":
+ # TODO: does the same actuator strength work as in spring
+ sys = sys.replace(
+ actuator=sys.actuator.replace(
+ gear=200 * jp.ones_like(sys.actuator.gear)
+ )
+ )
+
+ kwargs["n_frames"] = kwargs.get("n_frames", n_frames)
+
+ super().__init__(sys=sys, backend=backend, **kwargs)
+
+ self._ctrl_cost_weight = ctrl_cost_weight
+ self._use_contact_forces = use_contact_forces
+ self._contact_cost_weight = contact_cost_weight
+ self._healthy_reward = healthy_reward
+ self._terminate_when_unhealthy = terminate_when_unhealthy
+ self._healthy_z_range = healthy_z_range
+ self._contact_force_range = contact_force_range
+ self._reset_noise_scale = reset_noise_scale
+ self._exclude_current_positions_from_observation = (
+ exclude_current_positions_from_observation
+ )
+
+ if self._use_contact_forces:
+ raise NotImplementedError("use_contact_forces not implemented.")
+
+ def reset(self, rng: jax.Array) -> State:
+ """Resets the environment to an initial state."""
+
+ rng, rng1, rng2 = jax.random.split(rng, 3)
+
+ low, hi = -self._reset_noise_scale, self._reset_noise_scale
+ q = self.sys.init_q + jax.random.uniform(
+ rng1, (self.sys.q_size(),), minval=low, maxval=hi
+ )
+ qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))
+
+ # set the target q, qd
+ _, target = self._random_target(rng)
+ q = q.at[-2:].set(target)
+ qd = qd.at[-2:].set(0)
+
+ pipeline_state = self.pipeline_init(q, qd)
+ obs = self._get_obs(pipeline_state)
+
+ reward, done, zero = jp.zeros(3)
+ metrics = {
+ "reward_forward": zero,
+ "reward_survive": zero,
+ "reward_ctrl": zero,
+ "reward_contact": zero,
+ "x_position": zero,
+ "y_position": zero,
+ "distance_from_origin": zero,
+ "x_velocity": zero,
+ "y_velocity": zero,
+ "forward_reward": zero,
+ "dist": zero,
+ "success": zero,
+ "success_easy": zero
+ }
+ info = {"seed": 0}
+ state = State(pipeline_state, obs, reward, done, metrics)
+ state.info.update(info)
+ return state
+
+ # Todo rename seed to traj_id
+ def step(self, state: State, action: jax.Array) -> State:
+ """Run one timestep of the environment's dynamics."""
+ pipeline_state0 = state.pipeline_state
+ pipeline_state = self.pipeline_step(pipeline_state0, action)
+
+ if "steps" in state.info.keys():
+ seed = state.info["seed"] + jp.where(state.info["steps"], 0, 1)
+ else:
+ seed = state.info["seed"]
+ info = {"seed": seed}
+
+ velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt
+ forward_reward = velocity[0]
+
+ min_z, max_z = self._healthy_z_range
+ is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0)
+ is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy)
+ if self._terminate_when_unhealthy:
+ healthy_reward = self._healthy_reward
+ else:
+ healthy_reward = self._healthy_reward * is_healthy
+ ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
+ contact_cost = 0.0
+
+ obs = self._get_obs(pipeline_state)
+ done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
+
+ dist = jp.linalg.norm(obs[:2] - obs[-2:])
+ success = jp.array(dist < 0.5, dtype=float)
+ success_easy = jp.array(dist < 2., dtype=float)
+ reward = -dist + healthy_reward - ctrl_cost - contact_cost
+ state.metrics.update(
+ reward_forward=forward_reward,
+ reward_survive=healthy_reward,
+ reward_ctrl=-ctrl_cost,
+ reward_contact=-contact_cost,
+ x_position=pipeline_state.x.pos[0, 0],
+ y_position=pipeline_state.x.pos[0, 1],
+ distance_from_origin=math.safe_norm(pipeline_state.x.pos[0]),
+ x_velocity=velocity[0],
+ y_velocity=velocity[1],
+ forward_reward=forward_reward,
+ dist=dist,
+ success=success,
+ success_easy=success_easy
+ )
+ state.info.update(info)
+ return state.replace(
+ pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
+ )
+
+ def _get_obs(self, pipeline_state: base.State) -> jax.Array:
+ """Observe ant body position and velocities."""
+ qpos = pipeline_state.q[:-2]
+ qvel = pipeline_state.qd[:-2]
+
+ target_pos = pipeline_state.x.pos[-1][:2]
+
+ if self._exclude_current_positions_from_observation:
+ qpos = qpos[2:]
+
+ return jp.concatenate([qpos] + [qvel] + [target_pos])
+
+ def _random_target(self, rng: jax.Array) -> Tuple[jax.Array, jax.Array]:
+ """Returns a random target location chosen from possibilities specified in the maze layout."""
+ idx = jax.random.randint(rng, (1,), 0, len(self.possible_goals))
+ return rng, jp.array(self.possible_goals[idx])[0]
\ No newline at end of file
diff --git a/clean_JaxGCRL/envs/assets/ant.xml b/clean_JaxGCRL/envs/assets/ant.xml
new file mode 100644
index 0000000..ad22093
--- /dev/null
+++ b/clean_JaxGCRL/envs/assets/ant.xml
@@ -0,0 +1,103 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/clean_JaxGCRL/envs/assets/ant_maze.xml b/clean_JaxGCRL/envs/assets/ant_maze.xml
new file mode 100644
index 0000000..0b44a5e
--- /dev/null
+++ b/clean_JaxGCRL/envs/assets/ant_maze.xml
@@ -0,0 +1,105 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/clean_JaxGCRL/evaluator.py b/clean_JaxGCRL/evaluator.py
new file mode 100644
index 0000000..6656134
--- /dev/null
+++ b/clean_JaxGCRL/evaluator.py
@@ -0,0 +1,85 @@
+import jax
+import time
+import numpy as np
+import jax.numpy as jnp
+import flax.linen as nn
+
+from brax import envs
+from envs.ant import Ant
+from typing import NamedTuple
+from collections import namedtuple
+
+def generate_unroll(actor_step, training_state, env, env_state, unroll_length, extra_fields=()):
+ """Collect trajectories of given unroll_length."""
+
+ @jax.jit
+ def f(carry, unused_t):
+ state = carry
+ nstate, transition = actor_step(training_state, env, state, extra_fields=extra_fields)
+ return nstate, transition
+
+ final_state, data = jax.lax.scan(f, env_state, (), length=unroll_length)
+ return final_state, data
+
+class CrlEvaluator():
+
+ def __init__(self, actor_step, eval_env, num_eval_envs, episode_length, key):
+
+ self._key = key
+ self._eval_walltime = 0.
+
+ eval_env = envs.training.EvalWrapper(eval_env)
+
+ def generate_eval_unroll(training_state, key):
+ reset_keys = jax.random.split(key, num_eval_envs)
+ eval_first_state = eval_env.reset(reset_keys)
+ return generate_unroll(
+ actor_step,
+ training_state,
+ eval_env,
+ eval_first_state,
+ unroll_length=episode_length)[0]
+
+ self._generate_eval_unroll = jax.jit(generate_eval_unroll)
+ self._steps_per_unroll = episode_length * num_eval_envs
+
+ def run_evaluation(self, training_state, training_metrics, aggregate_episodes = True):
+ """Run one epoch of evaluation."""
+ self._key, unroll_key = jax.random.split(self._key)
+
+ t = time.time()
+ eval_state = self._generate_eval_unroll(training_state, unroll_key)
+ eval_metrics = eval_state.info["eval_metrics"]
+ eval_metrics.active_episodes.block_until_ready()
+ epoch_eval_time = time.time() - t
+ metrics = {}
+ aggregating_fns = [
+ (np.mean, ""),
+ # (np.std, "_std"),
+ # (np.max, "_max"),
+ # (np.min, "_min"),
+ ]
+
+ for (fn, suffix) in aggregating_fns:
+ metrics.update(
+ {
+ f"eval/episode_{name}{suffix}": (
+ fn(eval_metrics.episode_metrics[name]) if aggregate_episodes else eval_metrics.episode_metrics[name]
+ )
+ for name in ['reward', 'success', 'success_easy', 'dist', 'distance_from_origin']
+ }
+ )
+
+ # We check in how many env there was at least one step where there was success
+ if "success" in eval_metrics.episode_metrics:
+ metrics["eval/episode_success_any"] = np.mean(
+ eval_metrics.episode_metrics["success"] > 0.0
+ )
+
+ metrics["eval/avg_episode_length"] = np.mean(eval_metrics.episode_steps)
+ metrics["eval/epoch_eval_time"] = epoch_eval_time
+ metrics["eval/sps"] = self._steps_per_unroll / epoch_eval_time
+ self._eval_walltime = self._eval_walltime + epoch_eval_time
+ metrics = {"eval/walltime": self._eval_walltime, **training_metrics, **metrics}
+
+ return metrics
\ No newline at end of file
diff --git a/clean_JaxGCRL/train_crl_jax_brax.py b/clean_JaxGCRL/train_crl_jax_brax.py
new file mode 100644
index 0000000..55e84a3
--- /dev/null
+++ b/clean_JaxGCRL/train_crl_jax_brax.py
@@ -0,0 +1,646 @@
+import os
+import jax
+import flax
+import tyro
+import time
+import optax
+import wandb
+import pickle
+import random
+import wandb_osh
+import numpy as np
+import flax.linen as nn
+import jax.numpy as jnp
+
+from brax import envs
+from etils import epath
+from dataclasses import dataclass
+from collections import namedtuple
+from typing import NamedTuple, Any
+from wandb_osh.hooks import TriggerWandbSyncHook
+from flax.training.train_state import TrainState
+from flax.linen.initializers import variance_scaling
+
+from evaluator import CrlEvaluator
+from buffer import TrajectoryUniformSamplingQueue
+
+@dataclass
+class Args:
+ exp_name: str = os.path.basename(__file__)[: -len(".py")]
+ seed: int = 1
+ torch_deterministic: bool = True
+ cuda: bool = True
+ track: bool = False
+ wandb_project_name: str = "exploration"
+ wandb_entity: str = 'raj19'
+ wandb_mode: str = 'online'
+ wandb_dir: str = '.'
+ wandb_group: str = '.'
+ capture_video: bool = False
+ checkpoint: bool = False
+
+ #environment specific arguments
+ env_id: str = "ant"
+ episode_length: int = 1000
+ # to be filled in runtime
+ obs_dim: int = 0
+ goal_start_idx: int = 0
+ goal_end_idx: int = 0
+
+ # Algorithm specific arguments
+ total_env_steps: int = 50000000
+ num_epochs: int = 50
+ num_envs: int = 1024
+ num_eval_envs: int = 128
+ actor_lr: float = 3e-4
+ critic_lr: float = 3e-4
+ alpha_lr: float = 3e-4
+ batch_size: int = 256
+ gamma: float = 0.99
+ logsumexp_penalty_coeff: float = 0.1
+
+ max_replay_size: int = 10000
+ min_replay_size: int = 1000
+
+ unroll_length: int = 62
+
+ # to be filled in runtime
+ env_steps_per_actor_step : int = 0
+ """number of env steps per actor step (computed in runtime)"""
+ num_prefill_env_steps : int = 0
+ """number of env steps to fill the buffer before starting training (computed in runtime)"""
+ num_prefill_actor_steps : int = 0
+ """number of actor steps to fill the buffer before starting training (computed in runtime)"""
+ num_training_steps_per_epoch : int = 0
+ """the number of training steps per epoch(computed in runtime)"""
+
+class SA_encoder(nn.Module):
+ norm_type = "layer_norm"
+ @nn.compact
+ def __call__(self, s: jnp.ndarray, a: jnp.ndarray):
+
+ lecun_unfirom = variance_scaling(1/3, "fan_in", "uniform")
+ bias_init = nn.initializers.zeros
+
+ if self.norm_type == "layer_norm":
+ normalize = lambda x: nn.LayerNorm()(x)
+ else:
+ normalize = lambda x: x
+
+ x = jnp.concatenate([s, a], axis=-1)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(64, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ return x
+
+class G_encoder(nn.Module):
+ norm_type = "layer_norm"
+ @nn.compact
+ def __call__(self, g: jnp.ndarray):
+
+ lecun_unfirom = variance_scaling(1/3, "fan_in", "uniform")
+ bias_init = nn.initializers.zeros
+
+ if self.norm_type == "layer_norm":
+ normalize = lambda x: nn.LayerNorm()(x)
+ else:
+ normalize = lambda x: x
+
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(g)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(64, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ return x
+
+class Actor(nn.Module):
+ action_size: int
+ norm_type = "layer_norm"
+
+ LOG_STD_MAX = 2
+ LOG_STD_MIN = -5
+
+ @nn.compact
+ def __call__(self, x):
+ if self.norm_type == "layer_norm":
+ normalize = lambda x: nn.LayerNorm()(x)
+ else:
+ normalize = lambda x: x
+
+ lecun_unfirom = variance_scaling(1/3, "fan_in", "uniform")
+ bias_init = nn.initializers.zeros
+
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+ x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ x = normalize(x)
+ x = nn.swish(x)
+
+ mean = nn.Dense(self.action_size, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+ log_std = nn.Dense(self.action_size, kernel_init=lecun_unfirom, bias_init=bias_init)(x)
+
+ log_std = nn.tanh(log_std)
+ log_std = self.LOG_STD_MIN + 0.5 * (self.LOG_STD_MAX - self.LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats
+
+ return mean, log_std
+
+@flax.struct.dataclass
+class TrainingState:
+ """Contains training state for the learner"""
+ env_steps: jnp.ndarray
+ gradient_steps: jnp.ndarray
+ actor_state: TrainState
+ critic_state: TrainState
+ alpha_state: TrainState
+
+class Transition(NamedTuple):
+ """Container for a transition"""
+ observation: jnp.ndarray
+ action: jnp.ndarray
+ reward: jnp.ndarray
+ discount: jnp.ndarray
+ extras: jnp.ndarray = ()
+
+def load_params(path: str):
+ with epath.Path(path).open('rb') as fin:
+ buf = fin.read()
+ return pickle.loads(buf)
+
+def save_params(path: str, params: Any):
+ """Saves parameters in flax format."""
+ with epath.Path(path).open('wb') as fout:
+ fout.write(pickle.dumps(params))
+
+if __name__ == "__main__":
+
+ args = tyro.cli(Args)
+
+ args.env_steps_per_actor_step = args.num_envs * args.unroll_length
+ args.num_prefill_env_steps = args.min_replay_size * args.num_envs
+ args.num_prefill_actor_steps = np.ceil(args.min_replay_size / args.unroll_length)
+ args.num_training_steps_per_epoch = (args.total_env_steps - args.num_prefill_env_steps) // (args.num_epochs * args.env_steps_per_actor_step)
+
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
+
+ if args.track:
+
+ if args.wandb_group == '.':
+ args.wandb_group = None
+
+ wandb.init(
+ project=args.wandb_project_name,
+ entity=args.wandb_entity,
+ mode=args.wandb_mode,
+ group=args.wandb_group,
+ dir=args.wandb_dir,
+ config=vars(args),
+ name=run_name,
+ monitor_gym=True,
+ save_code=True,
+ )
+
+ if args.wandb_mode == 'offline':
+ wandb_osh.set_log_level("ERROR")
+ trigger_sync = TriggerWandbSyncHook()
+
+ if args.checkpoint:
+ from pathlib import Path
+ save_path = Path(args.wandb_dir) / Path(run_name)
+ os.mkdir(path=save_path)
+
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ key = jax.random.PRNGKey(args.seed)
+ key, buffer_key, env_key, eval_env_key, actor_key, sa_key, g_key = jax.random.split(key, 7)
+
+ # Environment setup
+ if args.env_id == "ant":
+ from envs.ant import Ant
+ env = Ant(
+ backend="spring",
+ exclude_current_positions_from_observation=False,
+ terminate_when_unhealthy=True,
+ )
+
+ args.obs_dim = 29
+ args.goal_start_idx = 0
+ args.goal_end_idx = 2
+
+ elif "maze" in args.env_id:
+ from envs.ant_maze import AntMaze
+ env = AntMaze(
+ backend="spring",
+ exclude_current_positions_from_observation=False,
+ terminate_when_unhealthy=True,
+ maze_layout_name=args.env_id[4:]
+ )
+
+ args.obs_dim = 29
+ args.goal_start_idx = 0
+ args.goal_end_idx = 2
+
+ else:
+ raise NotImplementedError
+
+ env = envs.training.wrap(
+ env,
+ episode_length=args.episode_length,
+ )
+
+ obs_size = env.observation_size
+ action_size = env.action_size
+ env_keys = jax.random.split(env_key, args.num_envs)
+ env_state = jax.jit(env.reset)(env_keys)
+ env.step = jax.jit(env.step)
+
+ # Network setup
+ # Actor
+ actor = Actor(action_size=action_size)
+ actor_state = TrainState.create(
+ apply_fn=actor.apply,
+ params=actor.init(actor_key, np.ones([1, obs_size])),
+ tx=optax.adam(learning_rate=args.actor_lr)
+ )
+
+ # Critic
+ sa_encoder = SA_encoder()
+ sa_encoder_params = sa_encoder.init(sa_key, np.ones([1, args.obs_dim]), np.ones([1, action_size]))
+ g_encoder = G_encoder()
+ g_encoder_params = g_encoder.init(g_key, np.ones([1, args.goal_end_idx - args.goal_start_idx]))
+ c = jnp.asarray(0.0, dtype=jnp.float32)
+ critic_state = TrainState.create(
+ apply_fn=None,
+ params={"sa_encoder": sa_encoder_params, "g_encoder": g_encoder_params},
+ tx=optax.adam(learning_rate=args.critic_lr),
+ )
+
+ # Entropy coefficient
+ target_entropy = -0.5 * action_size
+ log_alpha = jnp.asarray(0.0, dtype=jnp.float32)
+ alpha_state = TrainState.create(
+ apply_fn=None,
+ params={"log_alpha": log_alpha},
+ tx=optax.adam(learning_rate=args.alpha_lr),
+ )
+
+ # Trainstate
+ training_state = TrainingState(
+ env_steps=jnp.zeros(()),
+ gradient_steps=jnp.zeros(()),
+ actor_state=actor_state,
+ critic_state=critic_state,
+ alpha_state=alpha_state,
+ )
+
+ #Replay Buffer
+ dummy_obs = jnp.zeros((obs_size,))
+ dummy_action = jnp.zeros((action_size,))
+
+ dummy_transition = Transition(
+ observation=dummy_obs,
+ action=dummy_action,
+ reward=0.0,
+ discount=0.0,
+ extras={
+ "state_extras": {
+ "truncation": 0.0,
+ "seed": 0.0,
+ }
+ },
+ )
+
+ def jit_wrap(buffer):
+ buffer.insert_internal = jax.jit(buffer.insert_internal)
+ buffer.sample_internal = jax.jit(buffer.sample_internal)
+ return buffer
+
+ replay_buffer = jit_wrap(
+ TrajectoryUniformSamplingQueue(
+ max_replay_size=args.max_replay_size,
+ dummy_data_sample=dummy_transition,
+ sample_batch_size=args.batch_size,
+ num_envs=args.num_envs,
+ episode_length=args.episode_length,
+ )
+ )
+ buffer_state = jax.jit(replay_buffer.init)(buffer_key)
+
+ def deterministic_actor_step(training_state, env, env_state, extra_fields):
+ means, _ = actor.apply(training_state.actor_state.params, env_state.obs)
+ actions = nn.tanh( means )
+
+ nstate = env.step(env_state, actions)
+ state_extras = {x: nstate.info[x] for x in extra_fields}
+
+ return nstate, Transition(
+ observation=env_state.obs,
+ action=actions,
+ reward=nstate.reward,
+ discount=1-nstate.done,
+ extras={"state_extras": state_extras},
+ )
+
+ def actor_step(actor_state, env, env_state, key, extra_fields):
+ means, log_stds = actor.apply(actor_state.params, env_state.obs)
+ stds = jnp.exp(log_stds)
+ actions = nn.tanh( means + stds * jax.random.normal(key, shape=means.shape, dtype=means.dtype) )
+
+ nstate = env.step(env_state, actions)
+ state_extras = {x: nstate.info[x] for x in extra_fields}
+
+ return nstate, Transition(
+ observation=env_state.obs,
+ action=actions,
+ reward=nstate.reward,
+ discount=1-nstate.done,
+ extras={"state_extras": state_extras},
+ )
+
+ @jax.jit
+ def get_experience(actor_state, env_state, buffer_state, key):
+ @jax.jit
+ def f(carry, unused_t):
+ env_state, current_key = carry
+ current_key, next_key = jax.random.split(current_key)
+ env_state, transition = actor_step(actor_state, env, env_state, current_key, extra_fields=("truncation", "seed"))
+ return (env_state, next_key), transition
+
+ (env_state, _), data = jax.lax.scan(f, (env_state, key), (), length=args.unroll_length)
+
+ buffer_state = replay_buffer.insert(buffer_state, data)
+ return env_state, buffer_state
+
+ def prefill_replay_buffer(training_state, env_state, buffer_state, key):
+ @jax.jit
+ def f(carry, unused):
+ del unused
+ training_state, env_state, buffer_state, key = carry
+ key, new_key = jax.random.split(key)
+ env_state, buffer_state = get_experience(
+ training_state.actor_state,
+ env_state,
+ buffer_state,
+ key,
+
+ )
+ training_state = training_state.replace(
+ env_steps=training_state.env_steps + args.env_steps_per_actor_step,
+ )
+ return (training_state, env_state, buffer_state, new_key), ()
+
+ return jax.lax.scan(f, (training_state, env_state, buffer_state, key), (), length=args.num_prefill_actor_steps)[0]
+
+ @jax.jit
+ def update_actor_and_alpha(transitions, training_state, key):
+ def actor_loss(actor_params, critic_params, log_alpha, transitions, key):
+ obs = transitions.observation # expected_shape = batch_size, obs_size + goal_size
+ state = obs[:, :args.obs_dim]
+ future_state = transitions.extras["future_state"]
+ goal = future_state[:, args.goal_start_idx : args.goal_end_idx]
+ observation = jnp.concatenate([state, goal], axis=1)
+
+ means, log_stds = actor.apply(actor_params, observation)
+ stds = jnp.exp(log_stds)
+ x_ts = means + stds * jax.random.normal(key, shape=means.shape, dtype=means.dtype)
+ action = nn.tanh(x_ts)
+ log_prob = jax.scipy.stats.norm.logpdf(x_ts, loc=means, scale=stds)
+ log_prob -= jnp.log((1 - jnp.square(action)) + 1e-6)
+ log_prob = log_prob.sum(-1) # dimension = B
+
+ sa_encoder_params, g_encoder_params = critic_params["sa_encoder"], critic_params["g_encoder"]
+ sa_repr = sa_encoder.apply(sa_encoder_params, state, action)
+ g_repr = g_encoder.apply(g_encoder_params, goal)
+
+ qf_pi = -jnp.sqrt(jnp.sum((sa_repr - g_repr) ** 2, axis=-1))
+
+ actor_loss = jnp.mean( jnp.exp(log_alpha) * log_prob - (qf_pi) )
+
+ return actor_loss, log_prob
+
+ def alpha_loss(alpha_params, log_prob):
+ alpha = jnp.exp(alpha_params["log_alpha"])
+ alpha_loss = alpha * jnp.mean(jax.lax.stop_gradient(-log_prob - target_entropy))
+ return jnp.mean(alpha_loss)
+
+ (actorloss, log_prob), actor_grad = jax.value_and_grad(actor_loss, has_aux=True)(training_state.actor_state.params, training_state.critic_state.params, training_state.alpha_state.params['log_alpha'], transitions, key)
+ new_actor_state = training_state.actor_state.apply_gradients(grads=actor_grad)
+
+ alphaloss, alpha_grad = jax.value_and_grad(alpha_loss)(training_state.alpha_state.params, log_prob)
+ new_alpha_state = training_state.alpha_state.apply_gradients(grads=alpha_grad)
+
+ training_state = training_state.replace(actor_state=new_actor_state, alpha_state=new_alpha_state)
+
+ metrics = {
+ "sample_entropy": -log_prob,
+ "actor_loss": actorloss,
+ "alph_aloss": alphaloss,
+ "log_alpha": training_state.alpha_state.params["log_alpha"],
+ }
+
+ return training_state, metrics
+
+ @jax.jit
+ def update_critic(transitions, training_state, key):
+ def critic_loss(critic_params, transitions, key):
+ sa_encoder_params, g_encoder_params = critic_params["sa_encoder"], critic_params["g_encoder"]
+
+ obs = transitions.observation[:, :args.obs_dim]
+ action = transitions.action
+
+ sa_repr = sa_encoder.apply(sa_encoder_params, obs, action)
+ g_repr = g_encoder.apply(g_encoder_params, transitions.observation[:, args.obs_dim:])
+
+ # InfoNCE
+ logits = -jnp.sqrt(jnp.sum((sa_repr[:, None, :] - g_repr[None, :, :]) ** 2, axis=-1)) # shape = BxB
+ critic_loss = -jnp.mean(jnp.diag(logits) - jax.nn.logsumexp(logits, axis=1))
+
+ # logsumexp regularisation
+ logsumexp = jax.nn.logsumexp(logits + 1e-6, axis=1)
+ critic_loss += args.logsumexp_penalty_coeff * jnp.mean(logsumexp**2)
+
+ I = jnp.eye(logits.shape[0])
+ correct = jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1)
+ logits_pos = jnp.sum(logits * I) / jnp.sum(I)
+ logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I)
+
+ return critic_loss, (logsumexp, I, correct, logits_pos, logits_neg)
+
+ (loss, (logsumexp, I, correct, logits_pos, logits_neg)), grad = jax.value_and_grad(critic_loss, has_aux=True)(training_state.critic_state.params, transitions, key)
+ new_critic_state = training_state.critic_state.apply_gradients(grads=grad)
+ training_state = training_state.replace(critic_state = new_critic_state)
+
+ metrics = {
+ "categorical_accuracy": jnp.mean(correct),
+ "logits_pos": logits_pos,
+ "logits_neg": logits_neg,
+ "logsumexp": logsumexp.mean(),
+ "critic_loss": loss,
+ }
+
+ return training_state, metrics
+
+ @jax.jit
+ def sgd_step(carry, transitions):
+ training_state, key = carry
+ key, critic_key, actor_key, = jax.random.split(key, 3)
+
+ training_state, actor_metrics = update_actor_and_alpha(transitions, training_state, actor_key)
+
+ training_state, critic_metrics = update_critic(transitions, training_state, critic_key)
+
+ training_state = training_state.replace(gradient_steps = training_state.gradient_steps + 1)
+
+ metrics = {}
+ metrics.update(actor_metrics)
+ metrics.update(critic_metrics)
+
+ return (training_state, key,), metrics
+
+ @jax.jit
+ def training_step(training_state, env_state, buffer_state, key):
+ experience_key1, experience_key2, sampling_key, training_key = jax.random.split(key, 4)
+
+ # update buffer
+ env_state, buffer_state = get_experience(
+ training_state.actor_state,
+ env_state,
+ buffer_state,
+ experience_key1,
+ )
+
+ training_state = training_state.replace(
+ env_steps=training_state.env_steps + args.env_steps_per_actor_step,
+ )
+
+ # sample actor-step worth of transitions
+ buffer_state, transitions = replay_buffer.sample(buffer_state)
+
+ # process transitions for training
+ batch_keys = jax.random.split(sampling_key, transitions.observation.shape[0])
+ transitions = jax.vmap(TrajectoryUniformSamplingQueue.flatten_crl_fn, in_axes=(None, 0, 0))(
+ (args.gamma, args.obs_dim, args.goal_start_idx, args.goal_end_idx), transitions, batch_keys
+ )
+
+ transitions = jax.tree_util.tree_map(
+ lambda x: jnp.reshape(x, (-1,) + x.shape[2:], order="F"),
+ transitions,
+ )
+ permutation = jax.random.permutation(experience_key2, len(transitions.observation))
+ transitions = jax.tree_util.tree_map(lambda x: x[permutation], transitions)
+ transitions = jax.tree_util.tree_map(
+ lambda x: jnp.reshape(x, (-1, args.batch_size) + x.shape[1:]),
+ transitions,
+ )
+
+ # take actor-step worth of training-step
+ (training_state, _,), metrics = jax.lax.scan(sgd_step, (training_state, training_key), transitions)
+
+ return (training_state, env_state, buffer_state,), metrics
+
+ @jax.jit
+ def training_epoch(
+ training_state,
+ env_state,
+ buffer_state,
+ key,
+ ):
+ @jax.jit
+ def f(carry, unused_t):
+ ts, es, bs, k = carry
+ k, train_key = jax.random.split(k, 2)
+ (ts, es, bs,), metrics = training_step(ts, es, bs, train_key)
+ return (ts, es, bs, k), metrics
+
+ (training_state, env_state, buffer_state, key), metrics = jax.lax.scan(f, (training_state, env_state, buffer_state, key), (), length=args.num_training_steps_per_epoch)
+
+ metrics["buffer_current_size"] = replay_buffer.size(buffer_state)
+ return training_state, env_state, buffer_state, metrics
+
+ key, prefill_key = jax.random.split(key, 2)
+
+ training_state, env_state, buffer_state, _ = prefill_replay_buffer(
+ training_state, env_state, buffer_state, prefill_key
+ )
+
+ '''Setting up evaluator'''
+ evaluator = CrlEvaluator(
+ deterministic_actor_step,
+ env,
+ num_eval_envs=args.num_eval_envs,
+ episode_length=args.episode_length,
+ key=eval_env_key,
+ )
+
+ training_walltime = 0
+ print('starting training....')
+ for ne in range(args.num_epochs):
+
+ t = time.time()
+
+ key, epoch_key = jax.random.split(key)
+ training_state, env_state, buffer_state, metrics = training_epoch(training_state, env_state, buffer_state, epoch_key)
+
+ metrics = jax.tree_util.tree_map(jnp.mean, metrics)
+ metrics = jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)
+
+ epoch_training_time = time.time() - t
+ training_walltime += epoch_training_time
+
+ sps = (args.env_steps_per_actor_step * args.num_training_steps_per_epoch) / epoch_training_time
+ metrics = {
+ "training/sps": sps,
+ "training/walltime": training_walltime,
+ "training/envsteps": training_state.env_steps.item(),
+ **{f"training/{name}": value for name, value in metrics.items()},
+ }
+
+ metrics = evaluator.run_evaluation(training_state, metrics)
+
+ print(metrics)
+
+ if args.checkpoint:
+ # Save current policy and critic params.
+ params = (training_state.alpha_state.params, training_state.actor_state.params, training_state.critic_state.params)
+ path = f"{save_path}/step_{int(training_state.env_steps)}.pkl"
+ save_params(path, params)
+
+ if args.track:
+ wandb.log(metrics, step=ne)
+
+ if args.wandb_mode == 'offline':
+ trigger_sync()
+
+ if args.checkpoint:
+ # Save current policy and critic params.
+ params = (training_state.alpha_state.params, training_state.actor_state.params, training_state.critic_state.params)
+ path = f"{save_path}/final.pkl"
+ save_params(path, params)
+
+# (50000000 - 1024 x 1000) / 50 x 1024 x 62 = 15 #number of actor steps per epoch (which is equal to the number of training steps)
+# 1024 x 999 / 256 = 4000 #number of gradient steps per actor step
+# 1024 x 62 / 4000 = 16 #ratio of env steps per gradient step
\ No newline at end of file