forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/trebuchet'
- Loading branch information
Showing
19 changed files
with
2,629 additions
and
800 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
using Flux, Gym, Printf | ||
using Flux.Optimise: Optimiser | ||
using Flux.Tracker: data | ||
using Statistics: mean | ||
using DataStructures: CircularBuffer | ||
using Distributions: sample | ||
using CuArrays | ||
|
||
# Load game environment | ||
env = CartPoleEnv() | ||
reset!(env) | ||
|
||
# ----------------------------- Parameters ------------------------------------- | ||
|
||
STATE_SIZE = length(env.state) # 4 | ||
ACTION_SIZE = length(env.action_space) # 2 | ||
MEM_SIZE = 100_000 | ||
BATCH_SIZE = 64 | ||
γ = 1f0 # discount rate | ||
|
||
# Exploration params | ||
ϵ = 1f0 # Initial exploration rate | ||
ϵ_MIN = 1f-2 # Final exploratin rate | ||
ϵ_DECAY = 995f-3 | ||
|
||
# Optimiser params | ||
η = 1f-2 # Learning rate | ||
η_decay = 1f-3 | ||
|
||
memory = CircularBuffer{Any}(MEM_SIZE) # Used to remember past results | ||
|
||
# ------------------------------ Model Architecture ---------------------------- | ||
|
||
model = Chain(Dense(STATE_SIZE, 24, tanh), | ||
Dense(24, 48, tanh), | ||
Dense(48, ACTION_SIZE)) |> gpu | ||
|
||
loss(x, y) = Flux.mse(model(x |> gpu), y) | ||
|
||
opt = Optimiser(ADAM(η), InvDecay(η_decay)) | ||
|
||
# ----------------------------- Helper Functions ------------------------------- | ||
|
||
get_ϵ(e) = max(ϵ_MIN, min(ϵ, 1f0 - log10(e * ϵ_DECAY))) | ||
|
||
remember(state, action, reward, next_state, done) = | ||
push!(memory, (data(state), action, reward, data(next_state), done)) | ||
|
||
function action(state, train=true) | ||
train && rand() <= get_ϵ(e) && (return rand(-1:1)) | ||
act_values = model(state |> gpu) | ||
a = Flux.onecold(act_values) | ||
return a == 2 ? 1 : -1 | ||
end | ||
|
||
inv_action(a) = a == 1 ? 2 : 1 | ||
|
||
function replay() | ||
global ϵ | ||
batch_size = min(BATCH_SIZE, length(memory)) | ||
minibatch = sample(memory, batch_size, replace = false) | ||
|
||
x = [] | ||
y = [] | ||
for (iter, (state, action, reward, next_state, done)) in enumerate(minibatch) | ||
target = reward | ||
if !done | ||
target += γ * maximum(data(model(next_state |> gpu))) | ||
end | ||
|
||
target_f = data(model(state |> gpu)) | ||
target_f[action] = target | ||
|
||
push!(x, state) | ||
push!(y, target_f) | ||
end | ||
x = hcat(x...) |> gpu | ||
y = hcat(y...) |> gpu | ||
Flux.train!(loss, params(model), [(x, y)], opt) | ||
|
||
ϵ *= ϵ > ϵ_MIN ? ϵ_DECAY : 1.0f0 | ||
end | ||
|
||
function episode!(env, train=true) | ||
done = false | ||
total_reward = 0f0 | ||
frames = 1 | ||
while !done && frames <= 200 | ||
#render(env) | ||
s = env.state | ||
a = action(s, train) | ||
s′, r, done, _ = step!(env, a) | ||
total_reward += r | ||
train && remember(s, inv_action(a), r, s′, done) | ||
frames += 1 | ||
end | ||
|
||
total_reward | ||
end | ||
|
||
# -------------------------------- Testing ------------------------------------- | ||
|
||
function test() | ||
score_mean = 0f0 | ||
for _=1:100 | ||
reset!(env) | ||
total_reward = episode!(env, false) | ||
score_mean += total_reward / 100 | ||
end | ||
|
||
return score_mean | ||
end | ||
|
||
# ------------------------------ Training -------------------------------------- | ||
|
||
e = 1 | ||
|
||
while true | ||
global e | ||
reset!(env) | ||
total_reward = episode!(env) | ||
print("Episode: $e | Score: $total_reward | ") | ||
score_mean = test() | ||
print("Mean score over 100 test episodes: $(@sprintf "%6.2f" score_mean)") | ||
if score_mean > 195 | ||
println("\nCartPole-v0 solved!") | ||
break | ||
end | ||
println() | ||
replay() | ||
e += 1 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# This implementation of DQN on cartpole is to verify the cartpole.jl env | ||
using Flux, Gym, Printf | ||
using Flux.Tracker: track, @grad, data#, gradient | ||
using Flux.Optimise: Optimiser, _update_params!#, update! | ||
using Statistics: mean | ||
using DataStructures: CircularBuffer | ||
using CuArrays | ||
|
||
#Load game environment | ||
env = CartPoleEnv() | ||
reset!(env) | ||
|
||
#ctx = Ctx(env) | ||
|
||
#display(ctx.s) | ||
#using Blink# when not on Juno | ||
#body!(Blink.Window(), ctx.s) | ||
|
||
# ----------------------------- Parameters ------------------------------------- | ||
|
||
STATE_SIZE = length(env.state) | ||
ACTION_SIZE = length(env.action_space) | ||
MAX_TRAIN_REWARD = env.x_threshold*env.θ_threshold_radians | ||
SEQ_LEN = 8 | ||
|
||
# Optimiser params | ||
η = 3f-2 | ||
|
||
# ------------------------------ Model Architecture ---------------------------- | ||
|
||
sign(x::TrackedArray) = track(sign, x) | ||
@grad sign(x) = Base.sign.(data(x)), x̄ -> (x̄,) | ||
|
||
model = Chain(Dense(STATE_SIZE, 24, relu), | ||
Dense(24, 48, relu), | ||
Dense(48, 1, tanh), x->sign(x)) |> gpu | ||
|
||
opt = ADAM(η) | ||
|
||
action(state) = model(state) | ||
|
||
function loss(rewards) | ||
ep_len = size(rewards, 1) | ||
max_rewards = ones(Float32, ep_len) * MAX_TRAIN_REWARD |> gpu | ||
Flux.mse(rewards, max_rewards) | ||
end | ||
|
||
# ----------------------------- Helper Functions ------------------------------- | ||
|
||
function train_reward() | ||
state = env.state | ||
x, ẋ, θ, θ̇ = state[1:1], state[2:2], state[3:3], state[4:4] | ||
# Custom reward for training | ||
# Product of Triangular function over x-axis and θ-axis | ||
# Min reward = 0, Max reward = env.x_threshold * env.θ_threshold_radians | ||
x_upper = env.x_threshold .- x | ||
x_lower = env.x_threshold .+ x | ||
|
||
r_x = max.(0f0, min.(x_upper, x_lower)) | ||
|
||
θ_upper = env.θ_threshold_radians .- θ | ||
θ_lower = env.θ_threshold_radians .+ θ | ||
|
||
r_θ = max.(0f0, min.(θ_upper, θ_lower)) | ||
|
||
return r_x .* r_θ | ||
end | ||
|
||
|
||
function replay(rewards) | ||
#grads = gradient(() -> loss(rewards), params(model)) | ||
#for p in params(model) | ||
# update!(opt, p, grads[p]) | ||
#end | ||
|
||
Flux.back!(loss(rewards)) | ||
_update_params!(opt, params(model)) | ||
end | ||
|
||
#@which _update_params!(opt, params(model)) | ||
|
||
function episode!(env, train=true) | ||
done = false | ||
total_reward = 0f0 | ||
rewards = [] | ||
frames = 1 | ||
while !done && frames <= 200 | ||
#render(env, ctx) | ||
#sleep(0.01) | ||
|
||
a = action(env.state) | ||
s′, r, done, _ = step!(env, a) | ||
total_reward += r | ||
|
||
train && push!(rewards, train_reward()) | ||
|
||
if train && (frames % SEQ_LEN == 0 || done) | ||
rewards = vcat(rewards...) | ||
replay(rewards) | ||
rewards = [] | ||
env.state = param(env.state.data) | ||
end | ||
|
||
frames += 1 | ||
end | ||
total_reward | ||
end | ||
|
||
# -------------------------------- Testing ------------------------------------- | ||
|
||
function test() | ||
score_mean = 0f0 | ||
for _=1:100 | ||
reset!(env) | ||
total_reward = episode!(env, false) | ||
score_mean += total_reward / 100 | ||
end | ||
return score_mean | ||
end | ||
|
||
# ------------------------------ Training -------------------------------------- | ||
|
||
e = 1 | ||
|
||
while true | ||
global e | ||
reset!(env) | ||
total_reward = episode!(env) | ||
print("Episode: $e | Score: $total_reward | ") | ||
|
||
score_mean = test() | ||
score_mean_str = @sprintf "%6.2f" score_mean | ||
print("Mean score over 100 test episodes: " * score_mean_str) | ||
|
||
println() | ||
|
||
if score_mean > 195 | ||
println("CartPole-v0 solved!") | ||
break | ||
end | ||
e += 1 | ||
end |
Oops, something went wrong.