Skip to content

Commit

Permalink
created trajectory tracking reward based environment, added assertion…
Browse files Browse the repository at this point in the history
…s for nan cases in reward function, unsure of cause still
  • Loading branch information
AvidEslami committed Sep 4, 2024
1 parent f760a4d commit 1de85dd
Show file tree
Hide file tree
Showing 15 changed files with 28,041 additions and 6 deletions.
2 changes: 1 addition & 1 deletion flightlib/include/flightlib/envs/env_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class EnvBase {

// control time step
Scalar sim_dt_{0.02};
Scalar max_t_{5.0};
Scalar max_t_{2.0};

// random variable generator
std::normal_distribution<Scalar> norm_dist_{0.0, 1.0};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#pragma once

// std lib
#include <stdlib.h>
#include <cmath>
#include <iostream>

// yaml cpp
#include <yaml-cpp/yaml.h>

// flightlib

#include "quadenv_ctl.hpp"


namespace flightlib {

// namespace quadenv {

// enum Ctl : int {
// // observations
// kObs = 0,
// //
// kPos = 0,
// kNPos = 3,
// kOri = 3,
// kNOri = 3,
// kLinVel = 6,
// kNLinVel = 3,
// kAngVel = 9,
// kNAngVel = 3,
// kNObs = 12,
// // control actions
// kAct = 0,
// kNAct = 4,
// };
// };

class QuadrotorEnvByDataTraj final : public EnvBase {
public:
EIGEN_MAKE_ALIGNED_OPERATOR_NEW

QuadrotorEnvByDataTraj();
QuadrotorEnvByDataTraj(const std::string &cfg_path);
~QuadrotorEnvByDataTraj();

// - public OpenAI-gym-style functions
bool reset(Ref<Vector<>> obs, const bool random = true) override;
bool resetRange(Ref<Vector<>> obs, int lower_zbound, int upper_zbound, int lower_xybound, int upper_xybound, const bool random = true) override;
Scalar step(const Ref<Vector<>> act, Ref<Vector<>> obs) override;

// - public set functions
bool loadParam(const YAML::Node &cfg);

// - public get functions
bool getObs(Ref<Vector<>> obs) override;
bool getAct(Ref<Vector<>> act) const;
bool getAct(Command *const cmd) const;

// - auxiliar functions
bool isTerminalState(Scalar &reward) override;
void addObjectsToUnity(std::shared_ptr<UnityBridge> bridge);

friend std::ostream &operator<<(std::ostream &os,
const QuadrotorEnvByDataTraj &quad_env);

private:
// quadrotor
std::shared_ptr<Quadrotor> quadrotor_ptr_;
QuadState quad_state_;
Command cmd_;
Logger logger_{"QuadrotorEnvByDataTraj"};

// Define reward for training
Scalar pos_coeff_, ori_coeff_, lin_vel_coeff_, ang_vel_coeff_, act_coeff_;

// observations and actions (for RL)
Vector<quadenv::kNObs> quad_obs_;
Vector<quadenv::kNAct> quad_act_;
// Store a trajectory as a list of states (pos, ori, lin_vel -> 10 elements)
std::vector<Vector<10>> traj_;
int mid_train_step_;


// reward function design (for model-free reinforcement learning)
Vector<quadenv::kNObs> goal_state_;

// action and observation normalization (for learning)
Vector<quadenv::kNAct> act_mean_;
Vector<quadenv::kNAct> act_std_;
Vector<quadenv::kNObs> obs_mean_ = Vector<quadenv::kNObs>::Zero();
Vector<quadenv::kNObs> obs_std_ = Vector<quadenv::kNObs>::Ones();

YAML::Node cfg_;
Matrix<3, 2> world_box_;
};

} // namespace flightlib
1 change: 1 addition & 0 deletions flightlib/include/flightlib/envs/vec_env.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "flightlib/envs/quadrotor_env/quadrotor_env.hpp"
#include "flightlib/envs/quadrotor_env/quadrotor_hover_env.hpp"
#include "flightlib/envs/quadrotor_env/quadrotor_continuous_env.hpp"
#include "flightlib/envs/quadrotor_env/quadrotor_env_bydata_traj.hpp"
#include "flightlib/envs/quadrotor_env/quadrotor_env_bydata.hpp"

namespace flightlib {
Expand Down
7 changes: 4 additions & 3 deletions flightlib/src/envs/quadrotor_env/quadrotor_continuous_env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ int flightpath = 3;
// 1: 9 Meter
// 2: 15 Meter

int line_counter = 8;
int eff_line_counter = 15;
int line_counter = 7;
int eff_line_counter = 7;

bool toced_continuous = false;

bool completed_lap = false;

bool assist = true;
bool assist = false;

QuadrotorContinuousEnv::QuadrotorContinuousEnv()
: QuadrotorContinuousEnv(getenv("FLIGHTMARE_PATH") +
Expand Down Expand Up @@ -509,6 +509,7 @@ bool QuadrotorContinuousEnv::isTerminalState(Scalar &reward) {
// }
// }
std::string csv_path = "/home/avidavid/Downloads/CPC16_Z1.csv";
// std::string csv_path = "/home/avidavid/Downloads/0.016.csv";
std::vector<std::string> track_data;
std::vector<double> coordinates;
loadCSV(track_data, csv_path);
Expand Down
Loading

0 comments on commit 1de85dd

Please sign in to comment.