diff --git a/clean_JaxGCRL/README.md b/clean_JaxGCRL/README.md new file mode 100644 index 0000000..9ca5669 --- /dev/null +++ b/clean_JaxGCRL/README.md @@ -0,0 +1 @@ +# Clean-JaxGCRL \ No newline at end of file diff --git a/clean_JaxGCRL/buffer.py b/clean_JaxGCRL/buffer.py new file mode 100644 index 0000000..7f213ab --- /dev/null +++ b/clean_JaxGCRL/buffer.py @@ -0,0 +1,234 @@ +import jax +import flax +import functools +import jax.numpy as jnp + +from jax import flatten_util +from brax.training.types import PRNGKey + +@flax.struct.dataclass +class ReplayBufferState: + """Contains data related to a replay buffer.""" + + data: jnp.ndarray + insert_position: jnp.ndarray + sample_position: jnp.ndarray + key: PRNGKey + +class TrajectoryUniformSamplingQueue(): + """ + Base class for limited-size FIFO reply buffers. + + Implements an `insert()` method which behaves like a limited-size queue. + I.e. it adds samples to the end of the queue and, if necessary, removes the + oldest samples form the queue in order to keep the maximum size within the + specified limit. + + Derived classes must implement the `sample()` method. + """ + def __init__( + self, + max_replay_size: int, + dummy_data_sample, + sample_batch_size: int, + num_envs: int, + episode_length: int, + ): + + self._flatten_fn = jax.vmap(jax.vmap(lambda x: flatten_util.ravel_pytree(x)[0])) + dummy_flatten, self._unflatten_fn = flatten_util.ravel_pytree(dummy_data_sample) + self._unflatten_fn = jax.vmap(jax.vmap(self._unflatten_fn)) + data_size = len(dummy_flatten) + + self._data_shape = (max_replay_size, num_envs, data_size) + self._data_dtype = dummy_flatten.dtype + self._sample_batch_size = sample_batch_size + self._size = 0 + self.num_envs = num_envs + self.episode_length = episode_length + + def init(self, key): + return ReplayBufferState( + data=jnp.zeros(self._data_shape, self._data_dtype), + sample_position=jnp.zeros((), jnp.int32), + insert_position=jnp.zeros((), jnp.int32), + key=key, + ) + + def insert(self, buffer_state, samples): + """Insert data into the replay buffer.""" + self.check_can_insert(buffer_state, samples, 1) + return self.insert_internal(buffer_state, samples) + + def check_can_insert(self, buffer_state, samples, shards): + """Checks whether insert operation can be performed.""" + assert isinstance(shards, int), "This method should not be JITed." + insert_size = jax.tree_util.tree_flatten(samples)[0][0].shape[0] // shards + if self._data_shape[0] < insert_size: + raise ValueError( + "Trying to insert a batch of samples larger than the maximum replay" + f" size. num_samples: {insert_size}, max replay size" + f" {self._data_shape[0]}" + ) + self._size = min(self._data_shape[0], self._size + insert_size) + + def check_can_sample(self, buffer_state, shards): + """Checks whether sampling can be performed. Do not JIT this method.""" + pass + + def insert_internal( + self, buffer_state, samples + ): + """Insert data in the replay buffer. + + Args: + buffer_state: Buffer state + samples: Sample to insert with a leading batch size. + + Returns: + New buffer state. + """ + if buffer_state.data.shape != self._data_shape: + raise ValueError( + f"buffer_state.data.shape ({buffer_state.data.shape}) " + f"doesn't match the expected value ({self._data_shape})" + ) + + update = self._flatten_fn(samples) #Updates has shape (unroll_len, num_envs, self._data_shape[-1]) + data = buffer_state.data #shape = (max_replay_size, num_envs, data_size) + + # If needed, roll the buffer to make sure there's enough space to fit + # `update` after the current position. + position = buffer_state.insert_position + roll = jnp.minimum(0, len(data) - position - len(update)) + data = jax.lax.cond(roll, lambda: jnp.roll(data, roll, axis=0), lambda: data) + position = position + roll + + # Update the buffer and the control numbers. + data = jax.lax.dynamic_update_slice_in_dim(data, update, position, axis=0) + position = (position + len(update)) % (len(data) + 1) # so whenever roll happens, position becomes len(data), else it is increased by len(update), what is the use of doing % (len(data) + 1)?? + sample_position = jnp.maximum(0, buffer_state.sample_position + roll) #what is the use of this line? sample_position always remains 0 as roll can never be positive + + return buffer_state.replace( + data=data, + insert_position=position, + sample_position=sample_position, + ) + + def sample(self, buffer_state): + """Sample a batch of data.""" + self.check_can_sample(buffer_state, 1) + return self.sample_internal(buffer_state) + + def sample_internal(self, buffer_state): + if buffer_state.data.shape != self._data_shape: + raise ValueError( + f"Data shape expected by the replay buffer ({self._data_shape}) does " + f"not match the shape of the buffer state ({buffer_state.data.shape})" + ) + key, sample_key, shuffle_key = jax.random.split(buffer_state.key, 3) + # Note: this is the number of envs to sample but it can be modified if there is OOM + shape = self.num_envs + + # Sampling envs idxs + envs_idxs = jax.random.choice(sample_key, jnp.arange(self.num_envs), shape=(shape,), replace=False) + + @functools.partial(jax.jit, static_argnames=("rows", "cols")) + def create_matrix(rows, cols, min_val, max_val, rng_key): + rng_key, subkey = jax.random.split(rng_key) + start_values = jax.random.randint(subkey, shape=(rows,), minval=min_val, maxval=max_val) + row_indices = jnp.arange(cols) + matrix = start_values[:, jnp.newaxis] + row_indices + return matrix + + @jax.jit + def create_batch(arr_2d, indices): + return jnp.take(arr_2d, indices, axis=0, mode="wrap") + + create_batch_vmaped = jax.vmap(create_batch, in_axes=(1, 0)) + + matrix = create_matrix( + shape, + self.episode_length, + buffer_state.sample_position, + buffer_state.insert_position - self.episode_length, + sample_key, + ) + + ''' + The function create_batch will be called for every envs_idxs of buffer_state.data and every row of matrix. + Because every row of matrix has consecutive indices of self.episode_length, for every + envs_idx of envs_idxs, we will sample a random self.episode_length length sequence from + buffer_state.data[:, envs_idx, :]. But I don't think the code ensures that this sequence + won't be across episodes? + + flatten_crl_fn takes care of this + ''' + batch = create_batch_vmaped(buffer_state.data[:, envs_idxs, :], matrix) + transitions = self._unflatten_fn(batch) + return buffer_state.replace(key=key), transitions + + @staticmethod + @functools.partial(jax.jit, static_argnames=("buffer_config")) + def flatten_crl_fn(buffer_config, transition, sample_key): + + gamma, obs_dim, goal_start_idx, goal_end_idx = buffer_config + + # Because it's vmaped transition.obs.shape is of shape (episode_len, obs_dim) + seq_len = transition.observation.shape[0] + arrangement = jnp.arange(seq_len) + is_future_mask = jnp.array(arrangement[:, None] < arrangement[None], dtype=jnp.float32) # upper triangular matrix of shape seq_len, seq_len where all non-zero entries are 1 + discount = gamma ** jnp.array(arrangement[None] - arrangement[:, None], dtype=jnp.float32) + probs = is_future_mask * discount + + # probs is an upper triangular matrix of shape seq_len, seq_len of the form: + # [[0. , 0.99 , 0.98010004, 0.970299 , 0.960596 ], + # [0. , 0. , 0.99 , 0.98010004, 0.970299 ], + # [0. , 0. , 0. , 0.99 , 0.98010004], + # [0. , 0. , 0. , 0. , 0.99 ], + # [0. , 0. , 0. , 0. , 0. ]] + # assuming seq_len = 5 + # the same result can be obtained using probs = is_future_mask * (gamma ** jnp.cumsum(is_future_mask, axis=-1)) + + single_trajectories = jnp.concatenate( + [transition.extras["state_extras"]["seed"][:, jnp.newaxis].T] * seq_len, axis=0 + ) + # array of seq_len x seq_len where a row is an array of seeds that correspond to the episode index from which that time-step was collected + # timesteps collected from the same episode will have the same seed. All rows of the single_trajectories are same. + + probs = probs * jnp.equal(single_trajectories, single_trajectories.T) + jnp.eye(seq_len) * 1e-5 + #ith row of probs will be non zero only for time indices that + # 1) are greater than i + # 2) have the same seed as the ith time index + + goal_index = jax.random.categorical(sample_key, jnp.log(probs)) + future_state = jnp.take(transition.observation, goal_index[:-1], axis=0) #the last goal_index cannot be considered as there is no future. + future_action = jnp.take(transition.action, goal_index[:-1], axis=0) + goal = future_state[:, goal_start_idx : goal_end_idx] + future_state = future_state[:, : obs_dim] + state = transition.observation[:-1, : obs_dim] #all states are considered + new_obs = jnp.concatenate([state, goal], axis=1) + + extras = { + "policy_extras": {}, + "state_extras": { + "truncation": jnp.squeeze(transition.extras["state_extras"]["truncation"][:-1]), + "seed": jnp.squeeze(transition.extras["state_extras"]["seed"][:-1]), + }, + "state": state, + "future_state": future_state, + "future_action": future_action, + } + + return transition._replace( + observation=jnp.squeeze(new_obs), #this has shape (num_envs, episode_length-1, obs_size) + action=jnp.squeeze(transition.action[:-1]), + reward=jnp.squeeze(transition.reward[:-1]), + discount=jnp.squeeze(transition.discount[:-1]), + extras=extras, + ) + + def size(self, buffer_state: ReplayBufferState) -> int: + return ( + buffer_state.insert_position - buffer_state.sample_position + ) \ No newline at end of file diff --git a/clean_JaxGCRL/envs/ant.py b/clean_JaxGCRL/envs/ant.py new file mode 100644 index 0000000..8adc950 --- /dev/null +++ b/clean_JaxGCRL/envs/ant.py @@ -0,0 +1,185 @@ +import os +from typing import Tuple + +from brax import base +from brax import math +from brax.envs.base import PipelineEnv, State +from brax.io import mjcf +import jax +from jax import numpy as jp +import mujoco + +class Ant(PipelineEnv): + def __init__( + self, + ctrl_cost_weight=0.5, + use_contact_forces=False, + contact_cost_weight=5e-4, + healthy_reward=1.0, + terminate_when_unhealthy=True, + healthy_z_range=(0.2, 1.0), + contact_force_range=(-1.0, 1.0), + reset_noise_scale=0.1, + exclude_current_positions_from_observation=True, + backend="generalized", + **kwargs, + ): + path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets', "ant.xml") + sys = mjcf.load(path) + + n_frames = 5 + + if backend in ["spring", "positional"]: + sys = sys.replace(dt=0.005) + n_frames = 10 + + if backend == "mjx": + sys = sys.tree_replace( + { + "opt.solver": mujoco.mjtSolver.mjSOL_NEWTON, + "opt.disableflags": mujoco.mjtDisableBit.mjDSBL_EULERDAMP, + "opt.iterations": 1, + "opt.ls_iterations": 4, + } + ) + + if backend == "positional": + # TODO: does the same actuator strength work as in spring + sys = sys.replace( + actuator=sys.actuator.replace( + gear=200 * jp.ones_like(sys.actuator.gear) + ) + ) + + kwargs["n_frames"] = kwargs.get("n_frames", n_frames) + + super().__init__(sys=sys, backend=backend, **kwargs) + + self._ctrl_cost_weight = ctrl_cost_weight + self._use_contact_forces = use_contact_forces + self._contact_cost_weight = contact_cost_weight + self._healthy_reward = healthy_reward + self._terminate_when_unhealthy = terminate_when_unhealthy + self._healthy_z_range = healthy_z_range + self._contact_force_range = contact_force_range + self._reset_noise_scale = reset_noise_scale + self._exclude_current_positions_from_observation = ( + exclude_current_positions_from_observation + ) + + if self._use_contact_forces: + raise NotImplementedError("use_contact_forces not implemented.") + + def reset(self, rng: jax.Array) -> State: + """Resets the environment to an initial state.""" + + rng, rng1, rng2 = jax.random.split(rng, 3) + + low, hi = -self._reset_noise_scale, self._reset_noise_scale + q = self.sys.init_q + jax.random.uniform( + rng1, (self.sys.q_size(),), minval=low, maxval=hi + ) + qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),)) + + # set the target q, qd + _, target = self._random_target(rng) + q = q.at[-2:].set(target) + qd = qd.at[-2:].set(0) + + pipeline_state = self.pipeline_init(q, qd) + obs = self._get_obs(pipeline_state) + + reward, done, zero = jp.zeros(3) + metrics = { + "reward_forward": zero, + "reward_survive": zero, + "reward_ctrl": zero, + "reward_contact": zero, + "x_position": zero, + "y_position": zero, + "distance_from_origin": zero, + "x_velocity": zero, + "y_velocity": zero, + "forward_reward": zero, + "dist": zero, + "success": zero, + "success_easy": zero + } + info = {"seed": 0} + state = State(pipeline_state, obs, reward, done, metrics) + state.info.update(info) + return state + + # Todo rename seed to traj_id + def step(self, state: State, action: jax.Array) -> State: + """Run one timestep of the environment's dynamics.""" + pipeline_state0 = state.pipeline_state + pipeline_state = self.pipeline_step(pipeline_state0, action) + + if "steps" in state.info.keys(): + seed = state.info["seed"] + jp.where(state.info["steps"], 0, 1) + else: + seed = state.info["seed"] + info = {"seed": seed} + + velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt + forward_reward = velocity[0] + + min_z, max_z = self._healthy_z_range + is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0) + is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy) + if self._terminate_when_unhealthy: + healthy_reward = self._healthy_reward + else: + healthy_reward = self._healthy_reward * is_healthy + ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action)) + contact_cost = 0.0 + + obs = self._get_obs(pipeline_state) + reward = forward_reward + healthy_reward - ctrl_cost - contact_cost + done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0 + + dist = jp.linalg.norm(obs[:2] - obs[-2:]) + success = jp.array(dist < 0.5, dtype=float) + success_easy = jp.array(dist < 2., dtype=float) + + state.metrics.update( + reward_forward=forward_reward, + reward_survive=healthy_reward, + reward_ctrl=-ctrl_cost, + reward_contact=-contact_cost, + x_position=pipeline_state.x.pos[0, 0], + y_position=pipeline_state.x.pos[0, 1], + distance_from_origin=math.safe_norm(pipeline_state.x.pos[0]), + x_velocity=velocity[0], + y_velocity=velocity[1], + dist=dist, + success=success, + success_easy=success_easy + ) + state.info.update(info) + return state.replace( + pipeline_state=pipeline_state, obs=obs, reward=reward, done=done + ) + + def _get_obs(self, pipeline_state: base.State) -> jax.Array: + """Observe ant body position and velocities.""" + # remove target q, qd + qpos = pipeline_state.q[:-2] + qvel = pipeline_state.qd[:-2] + + target_pos = pipeline_state.x.pos[-1][:2] + + if self._exclude_current_positions_from_observation: + qpos = qpos[2:] + + return jp.concatenate([qpos] + [qvel] + [target_pos]) + + def _random_target(self, rng: jax.Array) -> Tuple[jax.Array, jax.Array]: + """Returns a target location in a random circle slightly above xy plane.""" + rng, rng1, rng2 = jax.random.split(rng, 3) + dist = 10 + ang = jp.pi * 2.0 * jax.random.uniform(rng2) + target_x = dist * jp.cos(ang) + target_y = dist * jp.sin(ang) + return rng, jp.array([target_x, target_y]) \ No newline at end of file diff --git a/clean_JaxGCRL/envs/ant_maze.py b/clean_JaxGCRL/envs/ant_maze.py new file mode 100644 index 0000000..fe4011c --- /dev/null +++ b/clean_JaxGCRL/envs/ant_maze.py @@ -0,0 +1,303 @@ +import os +from typing import Tuple + +from brax import base +from brax import math +from brax.envs.base import PipelineEnv, State +from brax.io import mjcf +import jax +from jax import numpy as jp +import mujoco +import xml.etree.ElementTree as ET + +# This is based on original Ant environment from Brax +# https://github.com/google/brax/blob/main/brax/envs/ant.py +# Maze creation dapted from: https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/locomotion/maze_env.py + +RESET = R = 'r' +GOAL = G = 'g' + + +U_MAZE = [[1, 1, 1, 1, 1], + [1, R, G, G, 1], + [1, 1, 1, G, 1], + [1, G, G, G, 1], + [1, 1, 1, 1, 1]] + +U_MAZE_EVAL = [[1, 1, 1, 1, 1], + [1, R, 0, 0, 1], + [1, 1, 1, 0, 1], + [1, G, G, G, 1], + [1, 1, 1, 1, 1]] + +BIG_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1], + [1, R, G, 1, 1, G, G, 1], + [1, G, G, 1, G, G, G, 1], + [1, 1, G, G, G, 1, 1, 1], + [1, G, G, 1, G, G, G, 1], + [1, G, 1, G, G, 1, G, 1], + [1, G, G, G, 1, G, G, 1], + [1, 1, 1, 1, 1, 1, 1, 1]] + +BIG_MAZE_EVAL = [[1, 1, 1, 1, 1, 1, 1, 1], + [1, R, 0, 1, 1, G, G, 1], + [1, 0, 0, 1, 0, G, G, 1], + [1, 1, 0, 0, 0, 1, 1, 1], + [1, 0, 0, 1, 0, 0, 0, 1], + [1, 0, 1, G, 0, 1, G, 1], + [1, 0, G, G, 1, G, G, 1], + [1, 1, 1, 1, 1, 1, 1, 1]] + +HARDEST_MAZE = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, R, G, G, G, 1, G, G, G, G, G, 1], + [1, G, 1, 1, G, 1, G, 1, G, 1, G, 1], + [1, G, G, G, G, G, G, 1, G, G, G, 1], + [1, G, 1, 1, 1, 1, G, 1, 1, 1, G, 1], + [1, G, G, 1, G, 1, G, G, G, G, G, 1], + [1, 1, G, 1, G, 1, G, 1, G, 1, 1, 1], + [1, G, G, 1, G, G, G, 1, G, G, G, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] + +MAZE_HEIGHT = 0.5 + +def find_robot(structure, size_scaling): + for i in range(len(structure)): + for j in range(len(structure[0])): + if structure[i][j] == RESET: + return i * size_scaling, j * size_scaling + +def find_goals(structure, size_scaling): + goals = [] + for i in range(len(structure)): + for j in range(len(structure[0])): + if structure[i][j] == GOAL: + goals.append([i * size_scaling, j * size_scaling]) + + return jp.array(goals) + +# Create a xml with maze and a list of possible goal positions +def make_maze(maze_layout_name, maze_size_scaling): + if maze_layout_name == "u_maze": + maze_layout = U_MAZE + elif maze_layout_name == "u_maze_eval": + maze_layout = U_MAZE_EVAL + elif maze_layout_name == "big_maze": + maze_layout = BIG_MAZE + elif maze_layout_name == "big_maze_eval": + maze_layout = BIG_MAZE_EVAL + elif maze_layout_name == "hardest_maze": + maze_layout = HARDEST_MAZE + else: + raise ValueError(f"Unknown maze layout: {maze_layout_name}") + + xml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'assets', "ant_maze.xml") + + robot_x, robot_y = find_robot(maze_layout, maze_size_scaling) + possible_goals = find_goals(maze_layout, maze_size_scaling) + + tree = ET.parse(xml_path) + worldbody = tree.find(".//worldbody") + + for i in range(len(maze_layout)): + for j in range(len(maze_layout[0])): + struct = maze_layout[i][j] + if struct == 1: + ET.SubElement( + worldbody, "geom", + name="block_%d_%d" % (i, j), + pos="%f %f %f" % (i * maze_size_scaling, + j * maze_size_scaling, + MAZE_HEIGHT / 2 * maze_size_scaling), + size="%f %f %f" % (0.5 * maze_size_scaling, + 0.5 * maze_size_scaling, + MAZE_HEIGHT / 2 * maze_size_scaling), + type="box", + material="", + contype="1", + conaffinity="1", + rgba="0.7 0.5 0.3 1.0", + ) + + + torso = tree.find(".//numeric[@name='init_qpos']") + data = torso.get("data") + torso.set("data", f"{robot_x} {robot_y} " + data) + + tree = tree.getroot() + xml_string = ET.tostring(tree) + + return xml_string, possible_goals + +class AntMaze(PipelineEnv): + def __init__( + self, + ctrl_cost_weight=0.5, + use_contact_forces=False, + contact_cost_weight=5e-4, + healthy_reward=1.0, + terminate_when_unhealthy=True, + healthy_z_range=(0.2, 1.0), + contact_force_range=(-1.0, 1.0), + reset_noise_scale=0.1, + exclude_current_positions_from_observation=True, + backend="generalized", + maze_layout_name="u_maze", + maze_size_scaling=4.0, + **kwargs, + ): + xml_string, possible_goals = make_maze(maze_layout_name, maze_size_scaling) + + sys = mjcf.loads(xml_string) + self.possible_goals = possible_goals + + n_frames = 5 + + if backend in ["spring", "positional"]: + sys = sys.replace(dt=0.005) + n_frames = 10 + + if backend == "mjx": + sys = sys.tree_replace( + { + "opt.solver": mujoco.mjtSolver.mjSOL_NEWTON, + "opt.disableflags": mujoco.mjtDisableBit.mjDSBL_EULERDAMP, + "opt.iterations": 1, + "opt.ls_iterations": 4, + } + ) + + if backend == "positional": + # TODO: does the same actuator strength work as in spring + sys = sys.replace( + actuator=sys.actuator.replace( + gear=200 * jp.ones_like(sys.actuator.gear) + ) + ) + + kwargs["n_frames"] = kwargs.get("n_frames", n_frames) + + super().__init__(sys=sys, backend=backend, **kwargs) + + self._ctrl_cost_weight = ctrl_cost_weight + self._use_contact_forces = use_contact_forces + self._contact_cost_weight = contact_cost_weight + self._healthy_reward = healthy_reward + self._terminate_when_unhealthy = terminate_when_unhealthy + self._healthy_z_range = healthy_z_range + self._contact_force_range = contact_force_range + self._reset_noise_scale = reset_noise_scale + self._exclude_current_positions_from_observation = ( + exclude_current_positions_from_observation + ) + + if self._use_contact_forces: + raise NotImplementedError("use_contact_forces not implemented.") + + def reset(self, rng: jax.Array) -> State: + """Resets the environment to an initial state.""" + + rng, rng1, rng2 = jax.random.split(rng, 3) + + low, hi = -self._reset_noise_scale, self._reset_noise_scale + q = self.sys.init_q + jax.random.uniform( + rng1, (self.sys.q_size(),), minval=low, maxval=hi + ) + qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),)) + + # set the target q, qd + _, target = self._random_target(rng) + q = q.at[-2:].set(target) + qd = qd.at[-2:].set(0) + + pipeline_state = self.pipeline_init(q, qd) + obs = self._get_obs(pipeline_state) + + reward, done, zero = jp.zeros(3) + metrics = { + "reward_forward": zero, + "reward_survive": zero, + "reward_ctrl": zero, + "reward_contact": zero, + "x_position": zero, + "y_position": zero, + "distance_from_origin": zero, + "x_velocity": zero, + "y_velocity": zero, + "forward_reward": zero, + "dist": zero, + "success": zero, + "success_easy": zero + } + info = {"seed": 0} + state = State(pipeline_state, obs, reward, done, metrics) + state.info.update(info) + return state + + # Todo rename seed to traj_id + def step(self, state: State, action: jax.Array) -> State: + """Run one timestep of the environment's dynamics.""" + pipeline_state0 = state.pipeline_state + pipeline_state = self.pipeline_step(pipeline_state0, action) + + if "steps" in state.info.keys(): + seed = state.info["seed"] + jp.where(state.info["steps"], 0, 1) + else: + seed = state.info["seed"] + info = {"seed": seed} + + velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt + forward_reward = velocity[0] + + min_z, max_z = self._healthy_z_range + is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0) + is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy) + if self._terminate_when_unhealthy: + healthy_reward = self._healthy_reward + else: + healthy_reward = self._healthy_reward * is_healthy + ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action)) + contact_cost = 0.0 + + obs = self._get_obs(pipeline_state) + done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0 + + dist = jp.linalg.norm(obs[:2] - obs[-2:]) + success = jp.array(dist < 0.5, dtype=float) + success_easy = jp.array(dist < 2., dtype=float) + reward = -dist + healthy_reward - ctrl_cost - contact_cost + state.metrics.update( + reward_forward=forward_reward, + reward_survive=healthy_reward, + reward_ctrl=-ctrl_cost, + reward_contact=-contact_cost, + x_position=pipeline_state.x.pos[0, 0], + y_position=pipeline_state.x.pos[0, 1], + distance_from_origin=math.safe_norm(pipeline_state.x.pos[0]), + x_velocity=velocity[0], + y_velocity=velocity[1], + forward_reward=forward_reward, + dist=dist, + success=success, + success_easy=success_easy + ) + state.info.update(info) + return state.replace( + pipeline_state=pipeline_state, obs=obs, reward=reward, done=done + ) + + def _get_obs(self, pipeline_state: base.State) -> jax.Array: + """Observe ant body position and velocities.""" + qpos = pipeline_state.q[:-2] + qvel = pipeline_state.qd[:-2] + + target_pos = pipeline_state.x.pos[-1][:2] + + if self._exclude_current_positions_from_observation: + qpos = qpos[2:] + + return jp.concatenate([qpos] + [qvel] + [target_pos]) + + def _random_target(self, rng: jax.Array) -> Tuple[jax.Array, jax.Array]: + """Returns a random target location chosen from possibilities specified in the maze layout.""" + idx = jax.random.randint(rng, (1,), 0, len(self.possible_goals)) + return rng, jp.array(self.possible_goals[idx])[0] \ No newline at end of file diff --git a/clean_JaxGCRL/envs/assets/ant.xml b/clean_JaxGCRL/envs/assets/ant.xml new file mode 100644 index 0000000..ad22093 --- /dev/null +++ b/clean_JaxGCRL/envs/assets/ant.xml @@ -0,0 +1,103 @@ + + + \ No newline at end of file diff --git a/clean_JaxGCRL/envs/assets/ant_maze.xml b/clean_JaxGCRL/envs/assets/ant_maze.xml new file mode 100644 index 0000000..0b44a5e --- /dev/null +++ b/clean_JaxGCRL/envs/assets/ant_maze.xml @@ -0,0 +1,105 @@ + + + \ No newline at end of file diff --git a/clean_JaxGCRL/evaluator.py b/clean_JaxGCRL/evaluator.py new file mode 100644 index 0000000..6656134 --- /dev/null +++ b/clean_JaxGCRL/evaluator.py @@ -0,0 +1,85 @@ +import jax +import time +import numpy as np +import jax.numpy as jnp +import flax.linen as nn + +from brax import envs +from envs.ant import Ant +from typing import NamedTuple +from collections import namedtuple + +def generate_unroll(actor_step, training_state, env, env_state, unroll_length, extra_fields=()): + """Collect trajectories of given unroll_length.""" + + @jax.jit + def f(carry, unused_t): + state = carry + nstate, transition = actor_step(training_state, env, state, extra_fields=extra_fields) + return nstate, transition + + final_state, data = jax.lax.scan(f, env_state, (), length=unroll_length) + return final_state, data + +class CrlEvaluator(): + + def __init__(self, actor_step, eval_env, num_eval_envs, episode_length, key): + + self._key = key + self._eval_walltime = 0. + + eval_env = envs.training.EvalWrapper(eval_env) + + def generate_eval_unroll(training_state, key): + reset_keys = jax.random.split(key, num_eval_envs) + eval_first_state = eval_env.reset(reset_keys) + return generate_unroll( + actor_step, + training_state, + eval_env, + eval_first_state, + unroll_length=episode_length)[0] + + self._generate_eval_unroll = jax.jit(generate_eval_unroll) + self._steps_per_unroll = episode_length * num_eval_envs + + def run_evaluation(self, training_state, training_metrics, aggregate_episodes = True): + """Run one epoch of evaluation.""" + self._key, unroll_key = jax.random.split(self._key) + + t = time.time() + eval_state = self._generate_eval_unroll(training_state, unroll_key) + eval_metrics = eval_state.info["eval_metrics"] + eval_metrics.active_episodes.block_until_ready() + epoch_eval_time = time.time() - t + metrics = {} + aggregating_fns = [ + (np.mean, ""), + # (np.std, "_std"), + # (np.max, "_max"), + # (np.min, "_min"), + ] + + for (fn, suffix) in aggregating_fns: + metrics.update( + { + f"eval/episode_{name}{suffix}": ( + fn(eval_metrics.episode_metrics[name]) if aggregate_episodes else eval_metrics.episode_metrics[name] + ) + for name in ['reward', 'success', 'success_easy', 'dist', 'distance_from_origin'] + } + ) + + # We check in how many env there was at least one step where there was success + if "success" in eval_metrics.episode_metrics: + metrics["eval/episode_success_any"] = np.mean( + eval_metrics.episode_metrics["success"] > 0.0 + ) + + metrics["eval/avg_episode_length"] = np.mean(eval_metrics.episode_steps) + metrics["eval/epoch_eval_time"] = epoch_eval_time + metrics["eval/sps"] = self._steps_per_unroll / epoch_eval_time + self._eval_walltime = self._eval_walltime + epoch_eval_time + metrics = {"eval/walltime": self._eval_walltime, **training_metrics, **metrics} + + return metrics \ No newline at end of file diff --git a/clean_JaxGCRL/train_crl_jax_brax.py b/clean_JaxGCRL/train_crl_jax_brax.py new file mode 100644 index 0000000..55e84a3 --- /dev/null +++ b/clean_JaxGCRL/train_crl_jax_brax.py @@ -0,0 +1,646 @@ +import os +import jax +import flax +import tyro +import time +import optax +import wandb +import pickle +import random +import wandb_osh +import numpy as np +import flax.linen as nn +import jax.numpy as jnp + +from brax import envs +from etils import epath +from dataclasses import dataclass +from collections import namedtuple +from typing import NamedTuple, Any +from wandb_osh.hooks import TriggerWandbSyncHook +from flax.training.train_state import TrainState +from flax.linen.initializers import variance_scaling + +from evaluator import CrlEvaluator +from buffer import TrajectoryUniformSamplingQueue + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + seed: int = 1 + torch_deterministic: bool = True + cuda: bool = True + track: bool = False + wandb_project_name: str = "exploration" + wandb_entity: str = 'raj19' + wandb_mode: str = 'online' + wandb_dir: str = '.' + wandb_group: str = '.' + capture_video: bool = False + checkpoint: bool = False + + #environment specific arguments + env_id: str = "ant" + episode_length: int = 1000 + # to be filled in runtime + obs_dim: int = 0 + goal_start_idx: int = 0 + goal_end_idx: int = 0 + + # Algorithm specific arguments + total_env_steps: int = 50000000 + num_epochs: int = 50 + num_envs: int = 1024 + num_eval_envs: int = 128 + actor_lr: float = 3e-4 + critic_lr: float = 3e-4 + alpha_lr: float = 3e-4 + batch_size: int = 256 + gamma: float = 0.99 + logsumexp_penalty_coeff: float = 0.1 + + max_replay_size: int = 10000 + min_replay_size: int = 1000 + + unroll_length: int = 62 + + # to be filled in runtime + env_steps_per_actor_step : int = 0 + """number of env steps per actor step (computed in runtime)""" + num_prefill_env_steps : int = 0 + """number of env steps to fill the buffer before starting training (computed in runtime)""" + num_prefill_actor_steps : int = 0 + """number of actor steps to fill the buffer before starting training (computed in runtime)""" + num_training_steps_per_epoch : int = 0 + """the number of training steps per epoch(computed in runtime)""" + +class SA_encoder(nn.Module): + norm_type = "layer_norm" + @nn.compact + def __call__(self, s: jnp.ndarray, a: jnp.ndarray): + + lecun_unfirom = variance_scaling(1/3, "fan_in", "uniform") + bias_init = nn.initializers.zeros + + if self.norm_type == "layer_norm": + normalize = lambda x: nn.LayerNorm()(x) + else: + normalize = lambda x: x + + x = jnp.concatenate([s, a], axis=-1) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(64, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + return x + +class G_encoder(nn.Module): + norm_type = "layer_norm" + @nn.compact + def __call__(self, g: jnp.ndarray): + + lecun_unfirom = variance_scaling(1/3, "fan_in", "uniform") + bias_init = nn.initializers.zeros + + if self.norm_type == "layer_norm": + normalize = lambda x: nn.LayerNorm()(x) + else: + normalize = lambda x: x + + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(g) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(64, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + return x + +class Actor(nn.Module): + action_size: int + norm_type = "layer_norm" + + LOG_STD_MAX = 2 + LOG_STD_MIN = -5 + + @nn.compact + def __call__(self, x): + if self.norm_type == "layer_norm": + normalize = lambda x: nn.LayerNorm()(x) + else: + normalize = lambda x: x + + lecun_unfirom = variance_scaling(1/3, "fan_in", "uniform") + bias_init = nn.initializers.zeros + + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + x = nn.Dense(1024, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + x = normalize(x) + x = nn.swish(x) + + mean = nn.Dense(self.action_size, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + log_std = nn.Dense(self.action_size, kernel_init=lecun_unfirom, bias_init=bias_init)(x) + + log_std = nn.tanh(log_std) + log_std = self.LOG_STD_MIN + 0.5 * (self.LOG_STD_MAX - self.LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats + + return mean, log_std + +@flax.struct.dataclass +class TrainingState: + """Contains training state for the learner""" + env_steps: jnp.ndarray + gradient_steps: jnp.ndarray + actor_state: TrainState + critic_state: TrainState + alpha_state: TrainState + +class Transition(NamedTuple): + """Container for a transition""" + observation: jnp.ndarray + action: jnp.ndarray + reward: jnp.ndarray + discount: jnp.ndarray + extras: jnp.ndarray = () + +def load_params(path: str): + with epath.Path(path).open('rb') as fin: + buf = fin.read() + return pickle.loads(buf) + +def save_params(path: str, params: Any): + """Saves parameters in flax format.""" + with epath.Path(path).open('wb') as fout: + fout.write(pickle.dumps(params)) + +if __name__ == "__main__": + + args = tyro.cli(Args) + + args.env_steps_per_actor_step = args.num_envs * args.unroll_length + args.num_prefill_env_steps = args.min_replay_size * args.num_envs + args.num_prefill_actor_steps = np.ceil(args.min_replay_size / args.unroll_length) + args.num_training_steps_per_epoch = (args.total_env_steps - args.num_prefill_env_steps) // (args.num_epochs * args.env_steps_per_actor_step) + + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + + if args.track: + + if args.wandb_group == '.': + args.wandb_group = None + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + mode=args.wandb_mode, + group=args.wandb_group, + dir=args.wandb_dir, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + + if args.wandb_mode == 'offline': + wandb_osh.set_log_level("ERROR") + trigger_sync = TriggerWandbSyncHook() + + if args.checkpoint: + from pathlib import Path + save_path = Path(args.wandb_dir) / Path(run_name) + os.mkdir(path=save_path) + + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, buffer_key, env_key, eval_env_key, actor_key, sa_key, g_key = jax.random.split(key, 7) + + # Environment setup + if args.env_id == "ant": + from envs.ant import Ant + env = Ant( + backend="spring", + exclude_current_positions_from_observation=False, + terminate_when_unhealthy=True, + ) + + args.obs_dim = 29 + args.goal_start_idx = 0 + args.goal_end_idx = 2 + + elif "maze" in args.env_id: + from envs.ant_maze import AntMaze + env = AntMaze( + backend="spring", + exclude_current_positions_from_observation=False, + terminate_when_unhealthy=True, + maze_layout_name=args.env_id[4:] + ) + + args.obs_dim = 29 + args.goal_start_idx = 0 + args.goal_end_idx = 2 + + else: + raise NotImplementedError + + env = envs.training.wrap( + env, + episode_length=args.episode_length, + ) + + obs_size = env.observation_size + action_size = env.action_size + env_keys = jax.random.split(env_key, args.num_envs) + env_state = jax.jit(env.reset)(env_keys) + env.step = jax.jit(env.step) + + # Network setup + # Actor + actor = Actor(action_size=action_size) + actor_state = TrainState.create( + apply_fn=actor.apply, + params=actor.init(actor_key, np.ones([1, obs_size])), + tx=optax.adam(learning_rate=args.actor_lr) + ) + + # Critic + sa_encoder = SA_encoder() + sa_encoder_params = sa_encoder.init(sa_key, np.ones([1, args.obs_dim]), np.ones([1, action_size])) + g_encoder = G_encoder() + g_encoder_params = g_encoder.init(g_key, np.ones([1, args.goal_end_idx - args.goal_start_idx])) + c = jnp.asarray(0.0, dtype=jnp.float32) + critic_state = TrainState.create( + apply_fn=None, + params={"sa_encoder": sa_encoder_params, "g_encoder": g_encoder_params}, + tx=optax.adam(learning_rate=args.critic_lr), + ) + + # Entropy coefficient + target_entropy = -0.5 * action_size + log_alpha = jnp.asarray(0.0, dtype=jnp.float32) + alpha_state = TrainState.create( + apply_fn=None, + params={"log_alpha": log_alpha}, + tx=optax.adam(learning_rate=args.alpha_lr), + ) + + # Trainstate + training_state = TrainingState( + env_steps=jnp.zeros(()), + gradient_steps=jnp.zeros(()), + actor_state=actor_state, + critic_state=critic_state, + alpha_state=alpha_state, + ) + + #Replay Buffer + dummy_obs = jnp.zeros((obs_size,)) + dummy_action = jnp.zeros((action_size,)) + + dummy_transition = Transition( + observation=dummy_obs, + action=dummy_action, + reward=0.0, + discount=0.0, + extras={ + "state_extras": { + "truncation": 0.0, + "seed": 0.0, + } + }, + ) + + def jit_wrap(buffer): + buffer.insert_internal = jax.jit(buffer.insert_internal) + buffer.sample_internal = jax.jit(buffer.sample_internal) + return buffer + + replay_buffer = jit_wrap( + TrajectoryUniformSamplingQueue( + max_replay_size=args.max_replay_size, + dummy_data_sample=dummy_transition, + sample_batch_size=args.batch_size, + num_envs=args.num_envs, + episode_length=args.episode_length, + ) + ) + buffer_state = jax.jit(replay_buffer.init)(buffer_key) + + def deterministic_actor_step(training_state, env, env_state, extra_fields): + means, _ = actor.apply(training_state.actor_state.params, env_state.obs) + actions = nn.tanh( means ) + + nstate = env.step(env_state, actions) + state_extras = {x: nstate.info[x] for x in extra_fields} + + return nstate, Transition( + observation=env_state.obs, + action=actions, + reward=nstate.reward, + discount=1-nstate.done, + extras={"state_extras": state_extras}, + ) + + def actor_step(actor_state, env, env_state, key, extra_fields): + means, log_stds = actor.apply(actor_state.params, env_state.obs) + stds = jnp.exp(log_stds) + actions = nn.tanh( means + stds * jax.random.normal(key, shape=means.shape, dtype=means.dtype) ) + + nstate = env.step(env_state, actions) + state_extras = {x: nstate.info[x] for x in extra_fields} + + return nstate, Transition( + observation=env_state.obs, + action=actions, + reward=nstate.reward, + discount=1-nstate.done, + extras={"state_extras": state_extras}, + ) + + @jax.jit + def get_experience(actor_state, env_state, buffer_state, key): + @jax.jit + def f(carry, unused_t): + env_state, current_key = carry + current_key, next_key = jax.random.split(current_key) + env_state, transition = actor_step(actor_state, env, env_state, current_key, extra_fields=("truncation", "seed")) + return (env_state, next_key), transition + + (env_state, _), data = jax.lax.scan(f, (env_state, key), (), length=args.unroll_length) + + buffer_state = replay_buffer.insert(buffer_state, data) + return env_state, buffer_state + + def prefill_replay_buffer(training_state, env_state, buffer_state, key): + @jax.jit + def f(carry, unused): + del unused + training_state, env_state, buffer_state, key = carry + key, new_key = jax.random.split(key) + env_state, buffer_state = get_experience( + training_state.actor_state, + env_state, + buffer_state, + key, + + ) + training_state = training_state.replace( + env_steps=training_state.env_steps + args.env_steps_per_actor_step, + ) + return (training_state, env_state, buffer_state, new_key), () + + return jax.lax.scan(f, (training_state, env_state, buffer_state, key), (), length=args.num_prefill_actor_steps)[0] + + @jax.jit + def update_actor_and_alpha(transitions, training_state, key): + def actor_loss(actor_params, critic_params, log_alpha, transitions, key): + obs = transitions.observation # expected_shape = batch_size, obs_size + goal_size + state = obs[:, :args.obs_dim] + future_state = transitions.extras["future_state"] + goal = future_state[:, args.goal_start_idx : args.goal_end_idx] + observation = jnp.concatenate([state, goal], axis=1) + + means, log_stds = actor.apply(actor_params, observation) + stds = jnp.exp(log_stds) + x_ts = means + stds * jax.random.normal(key, shape=means.shape, dtype=means.dtype) + action = nn.tanh(x_ts) + log_prob = jax.scipy.stats.norm.logpdf(x_ts, loc=means, scale=stds) + log_prob -= jnp.log((1 - jnp.square(action)) + 1e-6) + log_prob = log_prob.sum(-1) # dimension = B + + sa_encoder_params, g_encoder_params = critic_params["sa_encoder"], critic_params["g_encoder"] + sa_repr = sa_encoder.apply(sa_encoder_params, state, action) + g_repr = g_encoder.apply(g_encoder_params, goal) + + qf_pi = -jnp.sqrt(jnp.sum((sa_repr - g_repr) ** 2, axis=-1)) + + actor_loss = jnp.mean( jnp.exp(log_alpha) * log_prob - (qf_pi) ) + + return actor_loss, log_prob + + def alpha_loss(alpha_params, log_prob): + alpha = jnp.exp(alpha_params["log_alpha"]) + alpha_loss = alpha * jnp.mean(jax.lax.stop_gradient(-log_prob - target_entropy)) + return jnp.mean(alpha_loss) + + (actorloss, log_prob), actor_grad = jax.value_and_grad(actor_loss, has_aux=True)(training_state.actor_state.params, training_state.critic_state.params, training_state.alpha_state.params['log_alpha'], transitions, key) + new_actor_state = training_state.actor_state.apply_gradients(grads=actor_grad) + + alphaloss, alpha_grad = jax.value_and_grad(alpha_loss)(training_state.alpha_state.params, log_prob) + new_alpha_state = training_state.alpha_state.apply_gradients(grads=alpha_grad) + + training_state = training_state.replace(actor_state=new_actor_state, alpha_state=new_alpha_state) + + metrics = { + "sample_entropy": -log_prob, + "actor_loss": actorloss, + "alph_aloss": alphaloss, + "log_alpha": training_state.alpha_state.params["log_alpha"], + } + + return training_state, metrics + + @jax.jit + def update_critic(transitions, training_state, key): + def critic_loss(critic_params, transitions, key): + sa_encoder_params, g_encoder_params = critic_params["sa_encoder"], critic_params["g_encoder"] + + obs = transitions.observation[:, :args.obs_dim] + action = transitions.action + + sa_repr = sa_encoder.apply(sa_encoder_params, obs, action) + g_repr = g_encoder.apply(g_encoder_params, transitions.observation[:, args.obs_dim:]) + + # InfoNCE + logits = -jnp.sqrt(jnp.sum((sa_repr[:, None, :] - g_repr[None, :, :]) ** 2, axis=-1)) # shape = BxB + critic_loss = -jnp.mean(jnp.diag(logits) - jax.nn.logsumexp(logits, axis=1)) + + # logsumexp regularisation + logsumexp = jax.nn.logsumexp(logits + 1e-6, axis=1) + critic_loss += args.logsumexp_penalty_coeff * jnp.mean(logsumexp**2) + + I = jnp.eye(logits.shape[0]) + correct = jnp.argmax(logits, axis=1) == jnp.argmax(I, axis=1) + logits_pos = jnp.sum(logits * I) / jnp.sum(I) + logits_neg = jnp.sum(logits * (1 - I)) / jnp.sum(1 - I) + + return critic_loss, (logsumexp, I, correct, logits_pos, logits_neg) + + (loss, (logsumexp, I, correct, logits_pos, logits_neg)), grad = jax.value_and_grad(critic_loss, has_aux=True)(training_state.critic_state.params, transitions, key) + new_critic_state = training_state.critic_state.apply_gradients(grads=grad) + training_state = training_state.replace(critic_state = new_critic_state) + + metrics = { + "categorical_accuracy": jnp.mean(correct), + "logits_pos": logits_pos, + "logits_neg": logits_neg, + "logsumexp": logsumexp.mean(), + "critic_loss": loss, + } + + return training_state, metrics + + @jax.jit + def sgd_step(carry, transitions): + training_state, key = carry + key, critic_key, actor_key, = jax.random.split(key, 3) + + training_state, actor_metrics = update_actor_and_alpha(transitions, training_state, actor_key) + + training_state, critic_metrics = update_critic(transitions, training_state, critic_key) + + training_state = training_state.replace(gradient_steps = training_state.gradient_steps + 1) + + metrics = {} + metrics.update(actor_metrics) + metrics.update(critic_metrics) + + return (training_state, key,), metrics + + @jax.jit + def training_step(training_state, env_state, buffer_state, key): + experience_key1, experience_key2, sampling_key, training_key = jax.random.split(key, 4) + + # update buffer + env_state, buffer_state = get_experience( + training_state.actor_state, + env_state, + buffer_state, + experience_key1, + ) + + training_state = training_state.replace( + env_steps=training_state.env_steps + args.env_steps_per_actor_step, + ) + + # sample actor-step worth of transitions + buffer_state, transitions = replay_buffer.sample(buffer_state) + + # process transitions for training + batch_keys = jax.random.split(sampling_key, transitions.observation.shape[0]) + transitions = jax.vmap(TrajectoryUniformSamplingQueue.flatten_crl_fn, in_axes=(None, 0, 0))( + (args.gamma, args.obs_dim, args.goal_start_idx, args.goal_end_idx), transitions, batch_keys + ) + + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, (-1,) + x.shape[2:], order="F"), + transitions, + ) + permutation = jax.random.permutation(experience_key2, len(transitions.observation)) + transitions = jax.tree_util.tree_map(lambda x: x[permutation], transitions) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, (-1, args.batch_size) + x.shape[1:]), + transitions, + ) + + # take actor-step worth of training-step + (training_state, _,), metrics = jax.lax.scan(sgd_step, (training_state, training_key), transitions) + + return (training_state, env_state, buffer_state,), metrics + + @jax.jit + def training_epoch( + training_state, + env_state, + buffer_state, + key, + ): + @jax.jit + def f(carry, unused_t): + ts, es, bs, k = carry + k, train_key = jax.random.split(k, 2) + (ts, es, bs,), metrics = training_step(ts, es, bs, train_key) + return (ts, es, bs, k), metrics + + (training_state, env_state, buffer_state, key), metrics = jax.lax.scan(f, (training_state, env_state, buffer_state, key), (), length=args.num_training_steps_per_epoch) + + metrics["buffer_current_size"] = replay_buffer.size(buffer_state) + return training_state, env_state, buffer_state, metrics + + key, prefill_key = jax.random.split(key, 2) + + training_state, env_state, buffer_state, _ = prefill_replay_buffer( + training_state, env_state, buffer_state, prefill_key + ) + + '''Setting up evaluator''' + evaluator = CrlEvaluator( + deterministic_actor_step, + env, + num_eval_envs=args.num_eval_envs, + episode_length=args.episode_length, + key=eval_env_key, + ) + + training_walltime = 0 + print('starting training....') + for ne in range(args.num_epochs): + + t = time.time() + + key, epoch_key = jax.random.split(key) + training_state, env_state, buffer_state, metrics = training_epoch(training_state, env_state, buffer_state, epoch_key) + + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + metrics = jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics) + + epoch_training_time = time.time() - t + training_walltime += epoch_training_time + + sps = (args.env_steps_per_actor_step * args.num_training_steps_per_epoch) / epoch_training_time + metrics = { + "training/sps": sps, + "training/walltime": training_walltime, + "training/envsteps": training_state.env_steps.item(), + **{f"training/{name}": value for name, value in metrics.items()}, + } + + metrics = evaluator.run_evaluation(training_state, metrics) + + print(metrics) + + if args.checkpoint: + # Save current policy and critic params. + params = (training_state.alpha_state.params, training_state.actor_state.params, training_state.critic_state.params) + path = f"{save_path}/step_{int(training_state.env_steps)}.pkl" + save_params(path, params) + + if args.track: + wandb.log(metrics, step=ne) + + if args.wandb_mode == 'offline': + trigger_sync() + + if args.checkpoint: + # Save current policy and critic params. + params = (training_state.alpha_state.params, training_state.actor_state.params, training_state.critic_state.params) + path = f"{save_path}/final.pkl" + save_params(path, params) + +# (50000000 - 1024 x 1000) / 50 x 1024 x 62 = 15 #number of actor steps per epoch (which is equal to the number of training steps) +# 1024 x 999 / 256 = 4000 #number of gradient steps per actor step +# 1024 x 62 / 4000 = 16 #ratio of env steps per gradient step \ No newline at end of file