Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/trebuchet'
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Mar 5, 2019
2 parents 227bd29 + 10bd26c commit cdda5ca
Show file tree
Hide file tree
Showing 19 changed files with 2,629 additions and 800 deletions.
132 changes: 132 additions & 0 deletions games/differentiable-programming/cartpole/DQN.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
using Flux, Gym, Printf
using Flux.Optimise: Optimiser
using Flux.Tracker: data
using Statistics: mean
using DataStructures: CircularBuffer
using Distributions: sample
using CuArrays

# Load game environment
env = CartPoleEnv()
reset!(env)

# ----------------------------- Parameters -------------------------------------

STATE_SIZE = length(env.state) # 4
ACTION_SIZE = length(env.action_space) # 2
MEM_SIZE = 100_000
BATCH_SIZE = 64
γ = 1f0 # discount rate

# Exploration params
ϵ = 1f0 # Initial exploration rate
ϵ_MIN = 1f-2 # Final exploratin rate
ϵ_DECAY = 995f-3

# Optimiser params
η = 1f-2 # Learning rate
η_decay = 1f-3

memory = CircularBuffer{Any}(MEM_SIZE) # Used to remember past results

# ------------------------------ Model Architecture ----------------------------

model = Chain(Dense(STATE_SIZE, 24, tanh),
Dense(24, 48, tanh),
Dense(48, ACTION_SIZE)) |> gpu

loss(x, y) = Flux.mse(model(x |> gpu), y)

opt = Optimiser(ADAM(η), InvDecay(η_decay))

# ----------------------------- Helper Functions -------------------------------

get_ϵ(e) = max(ϵ_MIN, min(ϵ, 1f0 - log10(e * ϵ_DECAY)))

remember(state, action, reward, next_state, done) =
push!(memory, (data(state), action, reward, data(next_state), done))

function action(state, train=true)
train && rand() <= get_ϵ(e) && (return rand(-1:1))
act_values = model(state |> gpu)
a = Flux.onecold(act_values)
return a == 2 ? 1 : -1
end

inv_action(a) = a == 1 ? 2 : 1

function replay()
global ϵ
batch_size = min(BATCH_SIZE, length(memory))
minibatch = sample(memory, batch_size, replace = false)

x = []
y = []
for (iter, (state, action, reward, next_state, done)) in enumerate(minibatch)
target = reward
if !done
target += γ * maximum(data(model(next_state |> gpu)))
end

target_f = data(model(state |> gpu))
target_f[action] = target

push!(x, state)
push!(y, target_f)
end
x = hcat(x...) |> gpu
y = hcat(y...) |> gpu
Flux.train!(loss, params(model), [(x, y)], opt)

ϵ *= ϵ > ϵ_MIN ? ϵ_DECAY : 1.0f0
end

function episode!(env, train=true)
done = false
total_reward = 0f0
frames = 1
while !done && frames <= 200
#render(env)
s = env.state
a = action(s, train)
s′, r, done, _ = step!(env, a)
total_reward += r
train && remember(s, inv_action(a), r, s′, done)
frames += 1
end

total_reward
end

# -------------------------------- Testing -------------------------------------

function test()
score_mean = 0f0
for _=1:100
reset!(env)
total_reward = episode!(env, false)
score_mean += total_reward / 100
end

return score_mean
end

# ------------------------------ Training --------------------------------------

e = 1

while true
global e
reset!(env)
total_reward = episode!(env)
print("Episode: $e | Score: $total_reward | ")
score_mean = test()
print("Mean score over 100 test episodes: $(@sprintf "%6.2f" score_mean)")
if score_mean > 195
println("\nCartPole-v0 solved!")
break
end
println()
replay()
e += 1
end
142 changes: 142 additions & 0 deletions games/differentiable-programming/cartpole/DiffRL.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# This implementation of DQN on cartpole is to verify the cartpole.jl env
using Flux, Gym, Printf
using Flux.Tracker: track, @grad, data#, gradient
using Flux.Optimise: Optimiser, _update_params!#, update!
using Statistics: mean
using DataStructures: CircularBuffer
using CuArrays

#Load game environment
env = CartPoleEnv()
reset!(env)

#ctx = Ctx(env)

#display(ctx.s)
#using Blink# when not on Juno
#body!(Blink.Window(), ctx.s)

# ----------------------------- Parameters -------------------------------------

STATE_SIZE = length(env.state)
ACTION_SIZE = length(env.action_space)
MAX_TRAIN_REWARD = env.x_threshold*env.θ_threshold_radians
SEQ_LEN = 8

# Optimiser params
η = 3f-2

# ------------------------------ Model Architecture ----------------------------

sign(x::TrackedArray) = track(sign, x)
@grad sign(x) = Base.sign.(data(x)), x̄ -> (x̄,)

model = Chain(Dense(STATE_SIZE, 24, relu),
Dense(24, 48, relu),
Dense(48, 1, tanh), x->sign(x)) |> gpu

opt = ADAM(η)

action(state) = model(state)

function loss(rewards)
ep_len = size(rewards, 1)
max_rewards = ones(Float32, ep_len) * MAX_TRAIN_REWARD |> gpu
Flux.mse(rewards, max_rewards)
end

# ----------------------------- Helper Functions -------------------------------

function train_reward()
state = env.state
x, ẋ, θ, θ̇ = state[1:1], state[2:2], state[3:3], state[4:4]
# Custom reward for training
# Product of Triangular function over x-axis and θ-axis
# Min reward = 0, Max reward = env.x_threshold * env.θ_threshold_radians
x_upper = env.x_threshold .- x
x_lower = env.x_threshold .+ x

r_x = max.(0f0, min.(x_upper, x_lower))

θ_upper = env.θ_threshold_radians .- θ
θ_lower = env.θ_threshold_radians .+ θ

r_θ = max.(0f0, min.(θ_upper, θ_lower))

return r_x .* r_θ
end


function replay(rewards)
#grads = gradient(() -> loss(rewards), params(model))
#for p in params(model)
# update!(opt, p, grads[p])
#end

Flux.back!(loss(rewards))
_update_params!(opt, params(model))
end

#@which _update_params!(opt, params(model))

function episode!(env, train=true)
done = false
total_reward = 0f0
rewards = []
frames = 1
while !done && frames <= 200
#render(env, ctx)
#sleep(0.01)

a = action(env.state)
s′, r, done, _ = step!(env, a)
total_reward += r

train && push!(rewards, train_reward())

if train && (frames % SEQ_LEN == 0 || done)
rewards = vcat(rewards...)
replay(rewards)
rewards = []
env.state = param(env.state.data)
end

frames += 1
end
total_reward
end

# -------------------------------- Testing -------------------------------------

function test()
score_mean = 0f0
for _=1:100
reset!(env)
total_reward = episode!(env, false)
score_mean += total_reward / 100
end
return score_mean
end

# ------------------------------ Training --------------------------------------

e = 1

while true
global e
reset!(env)
total_reward = episode!(env)
print("Episode: $e | Score: $total_reward | ")

score_mean = test()
score_mean_str = @sprintf "%6.2f" score_mean
print("Mean score over 100 test episodes: " * score_mean_str)

println()

if score_mean > 195
println("CartPole-v0 solved!")
break
end
e += 1
end
Loading

0 comments on commit cdda5ca

Please sign in to comment.