Skip to content

Commit

Permalink
Allow user to override preset from the command line, also fixing an i…
Browse files Browse the repository at this point in the history
…ssue with new Gym api
  • Loading branch information
yonkshi committed Nov 12, 2021
1 parent 0dba889 commit c40cf13
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 21 deletions.
4 changes: 2 additions & 2 deletions dedo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
versions.append(obj_name)
# Register v0 as random material textures and the rest of task versions.
cls_nm = 'DeformRobotEnv' if task.startswith('FoodPacking') else 'DeformEnv'
register(id=task+'-v'+str(0), entry_point='dedo.envs:'+cls_nm)
register(id=task+'-v'+str(0), entry_point='dedo.envs:'+cls_nm, order_enforce=False)
for version_id, obj_name in enumerate(versions):
register(id=task+'-v'+str(version_id+1),
entry_point='dedo.envs:'+cls_nm)
entry_point='dedo.envs:'+cls_nm, order_enforce=False)

# Register dual-arm robot tasks.
register(id='HangGarmentRobot-v1', entry_point='dedo.envs:DeformRobotEnv')
19 changes: 18 additions & 1 deletion dedo/demo_preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dedo.utils.bullet_manipulator import convert_all

# TODO(yonkshi): remove this before release
from dedo.internal.waypoint_utils import create_traj
from dedo.internal.waypoint_utils import create_traj, create_traj_savgol


def play(env, num_episodes, args):
Expand Down Expand Up @@ -174,8 +174,25 @@ def build_traj(env, preset_wp, left_or_right, anchor_idx, ctrl_freq, robot):
# exit(1)
# WARNING: old code below.
traj_pos_vel = create_traj(init_anc_pos, wp[:, :3], steps, ctrl_freq)

pos_traj = traj_pos_vel[:, :3]
vel_traj = traj_pos_vel[:, 3:]
# traj_pos_vel = create_traj_savgol(init_anc_pos, wp[:, :3], steps, ctrl_freq)
#
# # TODO Debug viz
# import matplotlib
# matplotlib.use('TkAgg')
# from mpl_toolkits import mplot3d
# import matplotlib.pyplot as plt
# fig = plt.figure()
# ax = plt.axes(projection='3d')
#
# ax.plot3D(pos_traj[:, 0], pos_traj[:, 1], pos_traj[:, 2], label='default', linestyle="",marker=".")
# ax.plot3D(traj_pos_vel[:, 0], traj_pos_vel[:, 1], traj_pos_vel[:, 2], label='savgol', linestyle="", marker=".")
# ax.plot3D(wp[:, 0], wp[:, 1], wp[:, 2], label='WP', linestyle="", marker="o")
# plt.legend()
# plt.show()
# print('debug end')
# plot_traj(pos_traj)
from scipy.interpolate import interp1d
# xi = interp1d(ids, waypoints[:, 0], kind='cubic')(interp_i)
Expand Down
13 changes: 5 additions & 8 deletions dedo/envs/deform_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TOTE_MAJOR_VERSIONS, TOTE_VARS_PER_VERSION)
from ..utils.procedural_utils import (
gen_procedural_hang_cloth, gen_procedural_button_cloth)
from ..utils.args import preset_override_util


class DeformEnv(gym.Env):
Expand Down Expand Up @@ -139,15 +140,13 @@ def load_objects(self, sim, args, debug):
args.num_holes = args.version
deform_obj = gen_procedural_hang_cloth(
self.args, 'procedural_hang_cloth', DEFORM_INFO)
for arg_nm, arg_val in DEFORM_INFO[deform_obj].items():
setattr(args, arg_nm, arg_val)
preset_override_util(args, DEFORM_INFO[deform_obj])
elif self.args.task == 'ButtonProc': # procedural gen. for buttoning
args.num_holes = 2
args.node_density = 15
deform_obj, hole_centers = gen_procedural_button_cloth(
self.args, 'proc_button_cloth', DEFORM_INFO)
for arg_nm, arg_val in DEFORM_INFO[deform_obj].items():
setattr(args, arg_nm, arg_val)
preset_override_util(args, DEFORM_INFO[deform_obj])
# Move buttons to match hole position.
h1, h2 = hole_centers
h1 = (-h1[1], 0, h1[2]+2)
Expand Down Expand Up @@ -186,11 +185,9 @@ def load_objects(self, sim, args, debug):
else:
deform_obj = TASK_INFO[args.task][args.version - 1]

for arg_nm, arg_val in DEFORM_INFO[deform_obj].items():
setattr(args, arg_nm, arg_val)
preset_override_util(args, DEFORM_INFO[deform_obj])
if deform_obj in DEFORM_INFO:
for arg_nm, arg_val in DEFORM_INFO[deform_obj].items():
setattr(args, arg_nm, arg_val)
preset_override_util(args, DEFORM_INFO[deform_obj])

# Load rigid objects.
rigid_ids = []
Expand Down
10 changes: 5 additions & 5 deletions dedo/internal/waypoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@
from scipy.signal import savgol_filter


def create_traj_savgol(waypts, steps_per_waypt):
def create_traj_savgol(init_pos, waypts, steps_per_waypt, frequency):
'''
:param waypts:
:param steps_per_waypt:
:return:
'''
"""A scratch function to smooth trajectory."""
waypts = waypts.reshape(-1, 3)
waypts = np.concatenate([[init_pos], waypts], axis=0)
n_waypts = waypts.shape[0]
dists = []
for i in range(n_waypts-1):
dists.append(np.linalg.norm(waypts[i]-waypts[i+1]))
tot_dist = sum(dists)
t_max = n_waypts*steps_per_waypt
t_max = sum(steps_per_waypt)
dense_waypts = np.zeros((t_max, 3))
t = 0
for i in range(n_waypts-1):
Expand All @@ -33,8 +32,9 @@ def create_traj_savgol(waypts, steps_per_waypt):
if t < t_max:
dense_waypts[t:,:] = dense_waypts[t-1,:] # set rest to last entry
# dense_waypts = np.repeat(waypts, steps_per_waypt, axis=0) # simple repeat
window_len = int(sum(steps_per_waypt) / len(waypts) * 2 + 1)
dense_waypts = savgol_filter(dense_waypts,
window_length=int(steps_per_waypt*2+1),
window_length=window_len,
polyorder=4, axis=0)
print('dense_waypts', dense_waypts.shape)
return dense_waypts
Expand Down
21 changes: 20 additions & 1 deletion dedo/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
"""
import argparse

import sys
from .task_info import TASK_INFO
import re


def get_args_parser():
Expand Down Expand Up @@ -165,3 +166,21 @@ def get_args():
args_postprocess(args)
return args


def preset_override_util(args, deform_obj):
'''
Overrides args object with preset information (deform_obj).
Moreover users can override the override.
'''

# Regex to extract raw arg names from sys.argv
user_raw_args = []
for argv in sys.argv:
m = re.search("(?:--)([a-zA-Z0-9-_]+)(?:=)?", argv)
if m is not None:
user_raw_args.append(m.group(1)) # gets the var name

for arg_nm, arg_val in deform_obj.items():
if arg_nm in user_raw_args: # User overrides the preset arg
continue
setattr(args, arg_nm, arg_val)
11 changes: 7 additions & 4 deletions dedo/utils/procedural_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""
Utilities for generating procedural cloth.
TODO(yonkshi): add brief function-level descriptions, address TODO in
create_cloth_obj and clean up.
Procedural generation works as follows:
1. Generate a mesh, randomly carves out a square hole.
2. If hole > 1, also checks for overlapping of two holes. If overlap, randomly choose a new hole position. Repeat until no overlap found.
3. Saves hollowed mesh into .obj file in the /tmp/ directory.
Note: this code is for research i.e. quick experimentation; it has minimal
Expand Down Expand Up @@ -157,7 +161,7 @@ def gen_random_hole(node_density, dim_constraints):


def try_gen_holes(node_density, num_holes, constraints):
for i in range(1000): # 100 MC
for i in range(1000): # 1000 MC
if num_holes == 2:
holeA = gen_random_hole(node_density, constraints)
holeB = gen_random_hole(node_density, constraints)
Expand Down Expand Up @@ -216,7 +220,6 @@ def validate_and_integerize(hole):
# Check if file already exists
if not os.path.exists(os.path.join(data_path, "generated_cloth")):
os.makedirs(os.path.join(data_path, "generated_cloth"))
# TODO(yonkshi): hole outside loop; please check
fnm = "cloth_" + str(node_density) + "_" + str(hole[0]['x']) + "_" + \
str(hole[0]['y']) + "_" + str(hole[0]['x']) + "_" + \
str(min_point[0]) + "_" + str(min_point[1]) + "_" + \
Expand Down

0 comments on commit c40cf13

Please sign in to comment.