Skip to content

Commit

Permalink
fixed unity connection issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
yun-long committed Sep 7, 2020
1 parent d7b47d1 commit 4580185
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 92 deletions.
2 changes: 1 addition & 1 deletion flightlib/src/envs/vec_env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void VecEnv<EnvBase>::perAgentStep(int agent_id, Ref<MatrixRowMajor<>> act,
template<typename EnvBase>
bool VecEnv<EnvBase>::setUnity(bool render) {
unity_render_ = render;
if (unity_render_ && unity_bridge_ptr_ != nullptr) {
if (unity_render_ && unity_bridge_ptr_ == nullptr) {
// create unity bridge
unity_bridge_ptr_ = UnityBridge::getInstance();
// add objects to Unity
Expand Down
5 changes: 1 addition & 4 deletions flightrl/examples/run_drone_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def configure_random_seed(seed, env=None):

def parser():
parser = argparse.ArgumentParser()
parser.add_argument('--quad_env_cfg', type=str, default=os.path.abspath("../../configs/env.yaml"),
help='configuration file of the quad environment')
parser.add_argument('--train', type=int, default=1,
help="To train new model or simply test pre-trained model")
parser.add_argument('--render', type=int, default=0,
Expand All @@ -58,7 +56,6 @@ def main():
else:
cfg["env"]["render"] = "no"

print("render: ", args.render)
env = wrapper.FlightEnvVec(QuadrotorEnv_v1(
dump(cfg, Dumper=RoundTripDumper), False))

Expand Down Expand Up @@ -109,7 +106,7 @@ def main():
# # Testing mode with a trained weight
else:
model = PPO2.load(args.weight)
test_model(env, model, render=True)
test_model(env, model, render=args.render)


if __name__ == "__main__":
Expand Down
167 changes: 84 additions & 83 deletions flightrl/rpg_baselines/envs/env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,89 +3,90 @@
import numpy as np
import time


class EnvWrapper(gym.Env):
def __init__(self, env):
self.env = env
self.env.init()
self.num_obs = env.getObsDim()
self.num_act = env.getActDim()

self._observation_space = gym.spaces.Box(
np.ones(self.num_obs) * -np.Inf,
np.ones(self.num_obs) * np.Inf,
dtype=np.float32)
# the actions are eventually constrained by the action space.
self._action_space = gym.spaces.Box(
low=np.ones(self.num_act) * -1.,
high=np.ones(self.num_act) * 1.,
dtype=np.float32)
self.observation = np.zeros(self.num_obs, dtype=np.float32)
self.reward = np.float32(0.0)
self.done = False

gym.Env.__init__(self)
#
self._max_episode_steps = 300
def seed(self, seed=None):
self.env.setSeed(seed)
def step(self, action):
self.reward = self.env.step(action, self.observation)
terminal_reward = 0.0
self.done = self.env.isTerminalState(terminal_reward)
return self.observation.copy(), self.reward, \
self.done, [dict(reward_run=self.reward, reward_ctrl=0.0)]

def reset(self):
self.reward = 0.0
self.env.reset(self.observation)
return self.observation.copy()

def reset_and_update_info(self):
return self.reset(),

def obs(self):
self.env.getObs(self.observation)
return self.observation

def close(self,):
return True
def getQuadState(self,):
quad_state = np.zeros(10, dtype=np.float32)
self.env.getQuadState(quad_state)
quad_correct = np.zeros(10, dtype=np.float32)
quad_correct[0:3] = quad_state[0:3]
quad_correct[3] = quad_state[9]
quad_correct[4] = quad_state[6]
quad_correct[5] = quad_state[7]
quad_correct[6] = quad_state[8]
quad_correct[7:10] = quad_state[3:6]
return quad_correct
def getGateState(self,):
gate_state = np.zeros(9, dtype=np.float32)
self.env.getGateState(gate_state)
return gate_state

def connectFlightmare(self):
self.env.connectFlightmare()

def disconnectFlightmare(self):
self.env.disconnectFlightmare()
@property
def observation_space(self):
return self._observation_space

@property
def action_space(self):
return self._action_space
@property
def max_episode_steps(self):
return self._max_episode_steps
def __init__(self, env):
self.env = env
self.env.init()
self.num_obs = env.getObsDim()
self.num_act = env.getActDim()

self._observation_space = gym.spaces.Box(
np.ones(self.num_obs) * -np.Inf,
np.ones(self.num_obs) * np.Inf,
dtype=np.float32)
# the actions are eventually constrained by the action space.
self._action_space = gym.spaces.Box(
low=np.ones(self.num_act) * -1.,
high=np.ones(self.num_act) * 1.,
dtype=np.float32)
self.observation = np.zeros(self.num_obs, dtype=np.float32)
self.reward = np.float32(0.0)
self.done = False

gym.Env.__init__(self)
#
self._max_episode_steps = 300

def seed(self, seed=None):
self.env.setSeed(seed)

def step(self, action):
self.reward = self.env.step(action, self.observation)
terminal_reward = 0.0
self.done = self.env.isTerminalState(terminal_reward)
return self.observation.copy(), self.reward, \
self.done, [dict(reward_run=self.reward, reward_ctrl=0.0)]

def reset(self):
self.reward = 0.0
self.env.reset(self.observation)
return self.observation.copy()

def reset_and_update_info(self):
return self.reset(),

def obs(self):
self.env.getObs(self.observation)
return self.observation

def close(self,):
return True

def getQuadState(self,):
quad_state = np.zeros(10, dtype=np.float32)
self.env.getQuadState(quad_state)
quad_correct = np.zeros(10, dtype=np.float32)
quad_correct[0:3] = quad_state[0:3]
quad_correct[3] = quad_state[9]
quad_correct[4] = quad_state[6]
quad_correct[5] = quad_state[7]
quad_correct[6] = quad_state[8]
quad_correct[7:10] = quad_state[3:6]
return quad_correct

def getGateState(self,):
gate_state = np.zeros(9, dtype=np.float32)
self.env.getGateState(gate_state)
return gate_state

def connectUnity(self):
self.env.connectUnity()

def disconnectUnity(self):
self.env.disconnectUnity()

@property
def observation_space(self):
return self._observation_space

@property
def action_space(self):
return self._action_space

@property
def max_episode_steps(self):
return self._max_episode_steps

# def main():
# import os
Expand All @@ -97,7 +98,7 @@ def max_episode_steps(self):
# env = EnvWrapper(env)

# obs = env.reset()

# obs = env.obs()
# print(obs)

Expand Down
8 changes: 4 additions & 4 deletions flightrl/rpg_baselines/envs/vec_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def render(self, mode='human'):
def close(self):
self.wrapper.close()

def connectFlightmare(self):
self.wrapper.connectFlightmare()
def connectUnity(self):
self.wrapper.connectUnity()

def disconnectFlightmare(self):
self.wrapper.disconnectFlightmare()
def disconnectUnity(self):
self.wrapper.disconnectUnity()

@property
def num_envs(self):
Expand Down
4 changes: 4 additions & 0 deletions flightrl/rpg_baselines/ppo/ppo2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_model(env, model, render=False):

max_ep_length = env.max_episode_steps
num_rollouts = 5
if render:
env.connectUnity()
for n_roll in range(num_rollouts):
pos, euler, dpos, deuler = [], [], [], []
actions = []
Expand Down Expand Up @@ -91,6 +93,8 @@ def test_model(env, model, render=False):
ax_action3.step(t, actions[:, 3], color="C{0}".format(
n_roll), label="act [0, 1, 2, 3] -- trail: {0}".format(n_roll))
#
if render:
env.disconnectUnity()
ax_z.legend()
ax_dz.legend()
ax_euler_z.legend()
Expand Down

0 comments on commit 4580185

Please sign in to comment.