Skip to content

Commit

Permalink
resolving merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
anupambhatnagar committed Feb 14, 2020
2 parents 803e62f + 53c5fda commit 23df766
Show file tree
Hide file tree
Showing 82 changed files with 549 additions and 398 deletions.
10 changes: 5 additions & 5 deletions Project/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(gameObject.transform.rotation.z);
AddVectorObs(gameObject.transform.rotation.x);
AddVectorObs(ball.transform.position - gameObject.transform.position);
AddVectorObs(m_BallRb.velocity);
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
sensor.AddObservation(ball.transform.position - gameObject.transform.position);
sensor.AddObservation(m_BallRb.velocity);
}

public override void AgentAction(float[] vectorAction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(gameObject.transform.rotation.z);
AddVectorObs(gameObject.transform.rotation.x);
AddVectorObs((ball.transform.position - gameObject.transform.position));
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
sensor.AddObservation((ball.transform.position - gameObject.transform.position));
}

public override void AgentAction(float[] vectorAction)
Expand Down
4 changes: 2 additions & 2 deletions Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ public override void InitializeAgent()
{
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(m_Position, 20);
sensor.AddOneHotObservation(m_Position, 20);
}

public override void AgentAction(float[] vectorAction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(gameObject.transform.localPosition);
AddVectorObs(target.transform.localPosition);
sensor.AddObservation(gameObject.transform.localPosition);
sensor.AddObservation(target.transform.localPosition);
}

public override void AgentAction(float[] vectorAction)
Expand Down
30 changes: 15 additions & 15 deletions Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,29 +72,29 @@ public override void InitializeAgent()
/// <summary>
/// Add relevant information on each body part to observations.
/// </summary>
public void CollectObservationBodyPart(BodyPart bp)
public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
{
var rb = bp.rb;
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground
sensor.AddObservation(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground

var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.velocity);
AddVectorObs(velocityRelativeToLookRotationToTarget);
sensor.AddObservation(velocityRelativeToLookRotationToTarget);

var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
AddVectorObs(angularVelocityRelativeToLookRotationToTarget);
sensor.AddObservation(angularVelocityRelativeToLookRotationToTarget);

if (bp.rb.transform != body)
{
var localPosRelToBody = body.InverseTransformPoint(rb.position);
AddVectorObs(localPosRelToBody);
AddVectorObs(bp.currentXNormalizedRot); // Current x rot
AddVectorObs(bp.currentYNormalizedRot); // Current y rot
AddVectorObs(bp.currentZNormalizedRot); // Current z rot
AddVectorObs(bp.currentStrength / m_JdController.maxJointForceLimit);
sensor.AddObservation(localPosRelToBody);
sensor.AddObservation(bp.currentXNormalizedRot); // Current x rot
sensor.AddObservation(bp.currentYNormalizedRot); // Current y rot
sensor.AddObservation(bp.currentZNormalizedRot); // Current z rot
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
m_JdController.GetCurrentJointForces();

Expand All @@ -106,21 +106,21 @@ public override void CollectObservations()
RaycastHit hit;
if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f))
{
AddVectorObs(hit.distance);
sensor.AddObservation(hit.distance);
}
else
AddVectorObs(10.0f);
sensor.AddObservation(10.0f);

// Forward & up to help with orientation
var bodyForwardRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.forward);
AddVectorObs(bodyForwardRelativeToLookRotationToTarget);
sensor.AddObservation(bodyForwardRelativeToLookRotationToTarget);

var bodyUpRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(body.up);
AddVectorObs(bodyUpRelativeToLookRotationToTarget);
sensor.AddObservation(bodyUpRelativeToLookRotationToTarget);

foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
{
CollectObservationBodyPart(bodyPart);
CollectObservationBodyPart(bodyPart, sensor);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,21 @@ public override void InitializeAgent()
{
base.InitializeAgent();
m_AgentRb = GetComponent<Rigidbody>();
Monitor.verticalOffset = 1f;
m_MyArea = area.GetComponent<FoodCollectorArea>();
m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>();

SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity);
AddVectorObs(localVelocity.x);
AddVectorObs(localVelocity.z);
AddVectorObs(System.Convert.ToInt32(m_Frozen));
AddVectorObs(System.Convert.ToInt32(m_Shoot));
sensor.AddObservation(localVelocity.x);
sensor.AddObservation(localVelocity.z);
sensor.AddObservation(System.Convert.ToInt32(m_Frozen));
sensor.AddObservation(System.Convert.ToInt32(m_Shoot));
}
}

Expand Down
14 changes: 7 additions & 7 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ public override void InitializeAgent()
{
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
{
// There are no numeric observations to collect as this environment uses visual
// observations.

// Mask the necessary actions if selected by the user.
if (maskActions)
{
SetMask();
SetMask(actionMasker);
}
}

/// <summary>
/// Applies the mask for the agents action to disallow unnecessary actions.
/// </summary>
void SetMask()
void SetMask(ActionMasker actionMasker)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
Expand All @@ -55,22 +55,22 @@ void SetMask()

if (positionX == 0)
{
SetActionMask(k_Left);
actionMasker.SetActionMask(k_Left);
}

if (positionX == maxPosition)
{
SetActionMask(k_Right);
actionMasker.SetActionMask(k_Right);
}

if (positionZ == 0)
{
SetActionMask(k_Down);
actionMasker.SetActionMask(k_Down);
}

if (positionZ == maxPosition)
{
SetActionMask(k_Up);
actionMasker.SetActionMask(k_Up);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ public override void InitializeAgent()
m_GroundMaterial = m_GroundRenderer.material;
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
AddVectorObs(GetStepCount() / (float)maxStep);
sensor.AddObservation(GetStepCount() / (float)maxStep);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ public override void InitializeAgent()
m_SwitchLogic = areaSwitch.GetComponent<PyramidSwitch>();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
if (useVectorObs)
{
AddVectorObs(m_SwitchLogic.GetState());
AddVectorObs(transform.InverseTransformDirection(m_AgentRb.velocity));
sensor.AddObservation(m_SwitchLogic.GetState());
sensor.AddObservation(transform.InverseTransformDirection(m_AgentRb.velocity));
}
}

Expand Down
24 changes: 12 additions & 12 deletions Project/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,22 @@ public override void InitializeAgent()
/// We collect the normalized rotations, angularal velocities, and velocities of both
/// limbs of the reacher as well as the relative position of the target and hand.
/// </summary>
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(pendulumA.transform.localPosition);
AddVectorObs(pendulumA.transform.rotation);
AddVectorObs(m_RbA.angularVelocity);
AddVectorObs(m_RbA.velocity);
sensor.AddObservation(pendulumA.transform.localPosition);
sensor.AddObservation(pendulumA.transform.rotation);
sensor.AddObservation(m_RbA.angularVelocity);
sensor.AddObservation(m_RbA.velocity);

AddVectorObs(pendulumB.transform.localPosition);
AddVectorObs(pendulumB.transform.rotation);
AddVectorObs(m_RbB.angularVelocity);
AddVectorObs(m_RbB.velocity);
sensor.AddObservation(pendulumB.transform.localPosition);
sensor.AddObservation(pendulumB.transform.rotation);
sensor.AddObservation(m_RbB.angularVelocity);
sensor.AddObservation(m_RbB.velocity);

AddVectorObs(goal.transform.localPosition);
AddVectorObs(hand.transform.localPosition);
sensor.AddObservation(goal.transform.localPosition);
sensor.AddObservation(hand.transform.localPosition);

AddVectorObs(m_GoalSpeed);
sensor.AddObservation(m_GoalSpeed);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Collections.Generic;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;

Expand Down Expand Up @@ -561,4 +561,4 @@ void Initialize()
s_RedStyle = s_ColorStyle[5];
}
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace MLAgentsExamples
public class ProjectSettingsOverrides : MonoBehaviour
{
// Original values
float m_OriginalMonitorVerticalOffset;

Vector3 m_OriginalGravity;
float m_OriginalFixedDeltaTime;
float m_OriginalMaximumDeltaTime;
Expand All @@ -16,9 +16,6 @@ public class ProjectSettingsOverrides : MonoBehaviour
[Tooltip("Increase or decrease the scene gravity. Use ~3x to make things less floaty")]
public float gravityMultiplier = 1.0f;

[Header("Display Settings")]
public float monitorVerticalOffset;

[Header("Advanced physics settings")]
[Tooltip("The interval in seconds at which physics and other fixed frame rate updates (like MonoBehaviour's FixedUpdate) are performed.")]
public float fixedDeltaTime = .02f;
Expand All @@ -32,15 +29,13 @@ public class ProjectSettingsOverrides : MonoBehaviour
public void Awake()
{
// Save the original values
m_OriginalMonitorVerticalOffset = Monitor.verticalOffset;
m_OriginalGravity = Physics.gravity;
m_OriginalFixedDeltaTime = Time.fixedDeltaTime;
m_OriginalMaximumDeltaTime = Time.maximumDeltaTime;
m_OriginalSolverIterations = Physics.defaultSolverIterations;
m_OriginalSolverVelocityIterations = Physics.defaultSolverVelocityIterations;

// Override
Monitor.verticalOffset = monitorVerticalOffset;
Physics.gravity *= gravityMultiplier;
Time.fixedDeltaTime = fixedDeltaTime;
Time.maximumDeltaTime = maximumDeltaTime;
Expand All @@ -52,7 +47,6 @@ public void Awake()

public void OnDestroy()
{
Monitor.verticalOffset = m_OriginalMonitorVerticalOffset;
Physics.gravity = m_OriginalGravity;
Time.fixedDeltaTime = m_OriginalFixedDeltaTime;
Time.maximumDeltaTime = m_OriginalMaximumDeltaTime;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

public class TemplateAgent : Agent
{
public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
}

Expand Down
20 changes: 10 additions & 10 deletions Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ public override void InitializeAgent()
SetResetParameters();
}

public override void CollectObservations()
public override void CollectObservations(VectorSensor sensor)
{
AddVectorObs(m_InvertMult * (transform.position.x - myArea.transform.position.x));
AddVectorObs(transform.position.y - myArea.transform.position.y);
AddVectorObs(m_InvertMult * m_AgentRb.velocity.x);
AddVectorObs(m_AgentRb.velocity.y);
sensor.AddObservation(m_InvertMult * (transform.position.x - myArea.transform.position.x));
sensor.AddObservation(transform.position.y - myArea.transform.position.y);
sensor.AddObservation(m_InvertMult * m_AgentRb.velocity.x);
sensor.AddObservation(m_AgentRb.velocity.y);

AddVectorObs(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x));
AddVectorObs(ball.transform.position.y - myArea.transform.position.y);
AddVectorObs(m_InvertMult * m_BallRb.velocity.x);
AddVectorObs(m_BallRb.velocity.y);
sensor.AddObservation(m_InvertMult * (ball.transform.position.x - myArea.transform.position.x));
sensor.AddObservation(ball.transform.position.y - myArea.transform.position.y);
sensor.AddObservation(m_InvertMult * m_BallRb.velocity.x);
sensor.AddObservation(m_BallRb.velocity.y);

AddVectorObs(m_InvertMult * gameObject.transform.rotation.z);
sensor.AddObservation(m_InvertMult * gameObject.transform.rotation.z);
}

public override void AgentAction(float[] vectorAction)
Expand Down
Loading

0 comments on commit 23df766

Please sign in to comment.