Skip to content

Commit

Permalink
[BC] Add base_pose and tcp_pose in MS1 envs' observations
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayuan-Gu committed May 26, 2023
1 parent 38d9fdc commit 538ab6c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mani_skill2/envs/ms1/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_actor_state,
get_articulation_padded_state,
parse_urdf_config,
vectorize_pose
)
from mani_skill2.sensors.camera import CameraConfig

Expand Down Expand Up @@ -179,6 +180,12 @@ def check_actor_static(self, actor: sapien.Actor, max_v=None, max_ang_v=None):
# ---------------------------------------------------------------------------- #
# Observation
# ---------------------------------------------------------------------------- #
def _get_obs_agent(self):
obs = super()._get_obs_agent()
if self._obs_mode not in ["state", "state_dict"]:
obs["base_pose"] = vectorize_pose(self.agent.base_pose)
return obs

def _get_obs_extra(self) -> OrderedDict:
obs = OrderedDict()
if self._obs_mode in ["state", "state_dict"]:
Expand Down
12 changes: 12 additions & 0 deletions mani_skill2/envs/ms1/move_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
transform_points,
)
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import get_entity_by_name, vectorize_pose
from mani_skill2.utils.trimesh_utils import get_actor_visual_mesh

from .base_env import MS1BaseEnv
Expand Down Expand Up @@ -106,6 +107,10 @@ def _load_agent(self):
self._scene, self._control_freq, self._control_mode, config=self._agent_cfg
)

links = self.agent.robot.get_links()
self.left_tcp: sapien.Link = get_entity_by_name(links, "left_panda_hand_tcp")
self.right_tcp: sapien.Link = get_entity_by_name(links, "right_panda_hand_tcp")

# -------------------------------------------------------------------------- #
# Reset
# -------------------------------------------------------------------------- #
Expand Down Expand Up @@ -408,3 +413,10 @@ def _get_task_articulations(self):
def set_state(self, state: np.ndarray):
super().set_state(state)
self._prev_actor_pose = self.bucket.pose

def _get_obs_extra(self):
obs = super()._get_obs_extra()
if self._obs_mode not in ["state", "state_dict"]:
obs["left_tcp_pose"] = vectorize_pose(self.left_tcp.pose)
obs["right_tcp_pose"] = vectorize_pose(self.right_tcp.pose)
return obs
5 changes: 5 additions & 0 deletions mani_skill2/envs/ms1/open_cabinet_door_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mani_skill2.utils.common import np_random, random_choice
from mani_skill2.utils.geometry import angle_distance, transform_points
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import get_entity_by_name, vectorize_pose
from mani_skill2.utils.trimesh_utils import (
get_articulation_meshes,
get_visual_body_meshes,
Expand Down Expand Up @@ -152,6 +153,9 @@ def _load_agent(self):
self._scene, self._control_freq, self._control_mode, config=self._agent_cfg
)

links = self.agent.robot.get_links()
self.tcp: sapien.Link = get_entity_by_name(links, "right_panda_hand_tcp")

# -------------------------------------------------------------------------- #
# Reset
# -------------------------------------------------------------------------- #
Expand Down Expand Up @@ -392,6 +396,7 @@ def _get_obs_extra(self) -> OrderedDict:
target_joint_axis=self.target_joint_axis,
target_link_pos=self.target_link_pos,
)
obs["tcp_pose"] = vectorize_pose(self.tcp.pose)
return obs

def _get_obs_priviledged(self):
Expand Down
12 changes: 12 additions & 0 deletions mani_skill2/envs/ms1/push_chair.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mani_skill2.utils.common import np_random
from mani_skill2.utils.geometry import transform_points
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.sapien_utils import get_entity_by_name, vectorize_pose
from mani_skill2.utils.trimesh_utils import get_actor_visual_mesh

from .base_env import MS1BaseEnv
Expand Down Expand Up @@ -119,6 +120,10 @@ def _load_agent(self):
self._scene, self._control_freq, self._control_mode, config=self._agent_cfg
)

links = self.agent.robot.get_links()
self.left_tcp: sapien.Link = get_entity_by_name(links, "left_panda_hand_tcp")
self.right_tcp: sapien.Link = get_entity_by_name(links, "right_panda_hand_tcp")

def _set_chair_links_mesh(self):
self.links_info = {}
for link in self.chair.get_links():
Expand Down Expand Up @@ -356,3 +361,10 @@ def _get_task_articulations(self):
def set_state(self, state: np.ndarray):
super().set_state(state)
self._prev_actor_pose = self.root_link.pose

def _get_obs_extra(self):
obs = super()._get_obs_extra()
if self._obs_mode not in ["state", "state_dict"]:
obs["left_tcp_pose"] = vectorize_pose(self.left_tcp.pose)
obs["right_tcp_pose"] = vectorize_pose(self.right_tcp.pose)
return obs

0 comments on commit 538ab6c

Please sign in to comment.