Skip to content

Commit

Permalink
feat: add RL python sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-ziqi committed Jun 29, 2024
1 parent 9d2ed69 commit c153746
Show file tree
Hide file tree
Showing 12 changed files with 689 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ logs
*fldlar*
.cache
*.json
# *gr1t1*
__pycache__
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,22 @@ Before running, copy the trained pt model file to `rl_sar/src/rl_sar/models/YOUR

### Simulation

Open a new terminal, launch the gazebo simulation environment
Open a terminal, launch the gazebo simulation environment

```bash
source devel/setup.bash
roslaunch rl_sar gazebo_<ROBOT>.launch
```

Where \<ROBOT\> can be `a1` or `gr1t1`.
Open a new terminal, launch the control program

```bash
source devel/setup.bash
(for cpp version) rosrun rl_sar rl_sim
(for python version) rosrun rl_sar rl_sim.py
```

Where \<ROBOT\> can be `a1` or `gr1t1` or `gr1t2`.

Control:
* Press **\<Enter\>** to toggle simulation start/stop.
Expand Down
12 changes: 10 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,22 @@ catkin build

### 仿真

新建终端,启动gazebo仿真环境
打开一个终端,启动gazebo仿真环境

```bash
source devel/setup.bash
roslaunch rl_sar gazebo_<ROBOT>.launch
```

其中 \<ROBOT\> 可以是 `a1``gr1t1`.
打开一个新终端,启动控制程序

```bash
source devel/setup.bash
(for cpp version) rosrun rl_sar rl_sim
(for python version) rosrun rl_sar rl_sim.py
```

其中 \<ROBOT\> 可以是 `a1``gr1t1``gr1t2`.

控制:

Expand Down
8 changes: 8 additions & 0 deletions src/rl_sar/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ find_package(catkin REQUIRED COMPONENTS
geometry_msgs
robot_msgs
robot_joint_controller
rospy
)

find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
Expand All @@ -37,6 +38,7 @@ include_directories(${YAML_CPP_INCLUDE_DIR})
catkin_package(
CATKIN_DEPENDS
robot_joint_controller
rospy
)

include_directories(library/unitree_legged_sdk_3.2/include)
Expand Down Expand Up @@ -78,3 +80,9 @@ target_link_libraries(rl_real_a1
${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp
)

catkin_install_python(PROGRAMS
scripts/rl_sim.py
scripts/rl_sdk.py
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
)
6 changes: 1 addition & 5 deletions src/rl_sar/launch/gazebo_a1.launch
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<launch>
<arg name="wname" default="stairs"/>
<arg name="rname" default="a1"/>
<param name="robot_name" type="str" value="a1"/>
<param name="robot_name" type="str" value="$(arg rname)"/>
<param name="use_history" type="bool" value="true"/>
<param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/>
<arg name="robot_path" value="(find $(arg rname)_description)"/>
Expand Down Expand Up @@ -37,8 +37,6 @@
<!-- Load joint controller configurations from YAML file to parameter server -->
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>

<!-- <rosparam param="/a1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->

<!-- load the controllers -->
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
Expand All @@ -53,6 +51,4 @@
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
</node>

<node pkg="rl_sar" type="rl_sim" name="rl_sim" output="screen"/>

</launch>
6 changes: 1 addition & 5 deletions src/rl_sar/launch/gazebo_gr1t1.launch
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<launch>
<arg name="wname" default="stairs"/>
<arg name="rname" default="gr1t1"/>
<param name="robot_name" type="str" value="gr1t1"/>
<param name="robot_name" type="str" value="$(arg rname)"/>
<param name="use_history" type="bool" value="false"/>
<param name="ros_namespace" type="str" value="/$(arg rname)_gazebo/"/>
<arg name="robot_path" value="(find $(arg rname)_description)"/>
Expand Down Expand Up @@ -33,8 +33,6 @@
<!-- Load joint controller configurations from YAML file to parameter server -->
<rosparam file="$(arg dollar)$(arg robot_path)/config/robot_control.yaml" command="load"/>

<!-- <rosparam param="/gr1t1_gazebo/joint_state_controller/publish_rate">5000</rosparam> -->

<!-- load the controllers -->
<node pkg="controller_manager" type="spawner" name="controller_spawner" respawn="false"
output="screen" ns="/$(arg rname)_gazebo" args="joint_state_controller
Expand All @@ -47,6 +45,4 @@
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
</node>

<node pkg="rl_sar" type="rl_sim" name="rl_sim" output="screen"/>

</launch>
23 changes: 10 additions & 13 deletions src/rl_sar/library/rl_sdk/rl_sdk.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
#include "rl_sdk.hpp"

/* You may need to override this ComputeObservation() function
torch::Tensor RL::ComputeObservation()
torch::Tensor RL_XXX::ComputeObservation()
{
torch::Tensor obs = torch::cat({this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions
},1);
torch::Tensor obs = torch::cat({
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions
},1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return clamped_obs;
}
*/

/* You may need to override this Forward() function
torch::Tensor RL::Forward()
torch::Tensor RL_XXX::Forward()
{
torch::autograd::GradMode::set_enabled(false);
torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions = this->model.forward({clamped_obs}).toTensor();
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
return clamped_actions;
}
*/
Expand Down
2 changes: 2 additions & 0 deletions src/rl_sar/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
<exec_depend>robot_state_publisher</exec_depend>
<exec_depend>roscpp</exec_depend>
<exec_depend>std_msgs</exec_depend>
<build_depend>rospy</build_depend>
<exec_depend>rospy</exec_depend>
<depend>robot_msgs</depend>
<depend>robot_joint_controller</depend>

Expand Down
37 changes: 37 additions & 0 deletions src/rl_sar/scripts/observation_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

class ObservationBuffer:
def __init__(self, num_envs, num_obs, include_history_steps):

self.num_envs = num_envs
self.num_obs = num_obs
self.include_history_steps = include_history_steps

self.num_obs_total = num_obs * include_history_steps

self.obs_buf = torch.zeros(self.num_envs, self.num_obs_total, dtype=torch.float)

def reset(self, reset_idxs, new_obs):
self.obs_buf[reset_idxs] = new_obs.repeat(1, self.include_history_steps)

def insert(self, new_obs):
# Shift observations back.
self.obs_buf[:, : self.num_obs * (self.include_history_steps - 1)] = self.obs_buf[:,self.num_obs : self.num_obs * self.include_history_steps].clone()

# Add new observation.
self.obs_buf[:, -self.num_obs:] = new_obs

def get_obs_vec(self, obs_ids):
"""Gets history of observations indexed by obs_ids.
Arguments:
obs_ids: An array of integers with which to index the desired
observations, where 0 is the latest observation and
include_history_steps - 1 is the oldest observation.
"""

obs = []
for obs_id in reversed(sorted(obs_ids)):
slice_idx = self.include_history_steps - obs_id - 1
obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs])
return torch.cat(obs, dim=-1)
Loading

0 comments on commit c153746

Please sign in to comment.