Skip to content

Commit

Permalink
fix hist cumul
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedb committed Dec 25, 2022
1 parent 2d7f99c commit 0e30320
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 85 deletions.
2 changes: 1 addition & 1 deletion experiments/debug-softmax-split.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Xtrain, ytrain = MLJBase.reformat(model, selectrows(X, train), selectrows(y, tra
rng = StableRNG(6)
params_evo = EvoTreeClassifier(;
T = Float32,
nrounds = 5,
nrounds = 200,
lambda = 0.0,
gamma = 0.0,
eta = 0.1,
Expand Down
6 changes: 3 additions & 3 deletions experiments/readme_plots_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ y_train, y_eval = Y[i_train], Y[i_eval]
# linear
params1 = EvoTreeRegressor(T=Float32,
loss=:linear,
nrounds=1, nbins=16,
nrounds=500, nbins=64,
lambda=0.1, gamma=0.1, eta=0.1,
max_depth=3, min_weight=1.0,
max_depth=6, min_weight=1.0,
rowsample=0.5, colsample=1.0,
device="gpu")

Expand Down Expand Up @@ -140,4 +140,4 @@ plot!(x_train[:, 1][x_perm], pred_train_gaussian[x_perm, 1], color="navy", linew
plot!(x_train[:, 1][x_perm], pred_train_gaussian[x_perm, 2], color="darkred", linewidth=1.5, label="sigma")
plot!(x_train[:, 1][x_perm], pred_q20[x_perm, 1], color="green", linewidth=1.5, label="q20")
plot!(x_train[:, 1][x_perm], pred_q80[x_perm, 1], color="green", linewidth=1.5, label="q80")
savefig("figures/gaussian-sinus-gpu.png")
savefig("figures/gaussian-sinus-gpu.png")
Binary file modified figures/gaussian-sinus-gpu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/regression_sinus_gpu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
69 changes: 2 additions & 67 deletions src/find_split.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,85 +237,20 @@ end
Generic fallback
"""
# function update_gains!(
# node::TrainNode,
# js::Vector,
# params::EvoTypes{L,T},
# K,
# monotone_constraints,
# ) where {L,T}

# h = node.h
# hL = node.hL
# hR = node.hR
# gains = node.gains

# cumsum!(hL, h, dims = 2)
# hR .= view(hL, :, params.nbins:params.nbins, :) .- hL
# @info "minimum(hR[3,:,:])" minimum(hR[3,:,:])

# @inbounds for j in js
# monotone_constraint = monotone_constraints[j]
# @inbounds for bin = 1:params.nbins
# if hL[end, bin, j] > params.min_weight && hR[end, bin, j] > params.min_weight
# if monotone_constraint != 0
# predL = pred_scalar(view(hL, :, bin, j), params)
# predR = pred_scalar(view(hR, :, bin, j), params)
# end
# if (monotone_constraint == 0) ||
# (monotone_constraint == -1 && predL > predR) ||
# (monotone_constraint == 1 && predL < predR)

# gains[bin, j] =
# get_gain(params, view(hL, :, bin, j)) +
# get_gain(params, view(hR, :, bin, j))
# end
# end
# end
# end
# return nothing
# end

function update_gains!(
node::TrainNode,
js::Vector,
params::EvoTypes{L,T},
K,
monotone_constraints,
) where {L,T}

KK = 2 * K + 1
h = node.h
hL = node.hL
hR = node.hR
gains = node.gains
= node.∑

@inbounds for j in js
@inbounds for k = 1:KK
val = h[k, 1, j]
hL[k, 1, j] = val
hR[k, 1, j] = ∑[k] - val
end
@inbounds for bin = 2:params.nbins
@inbounds for k = 1:KK
val = h[k, bin, j]
hL[k, bin, j] = hL[k, bin-1, j] + val
hR[k, bin, j] = hR[k, bin-1, j] - val
end
end
end

# hL2 = copy(hL) .* 0
# hR2 = copy(hR) .* 0
# cumsum!(hL2, h, dims = 2)
# hR2 .= view(hL2, :, params.nbins:params.nbins, :) .- hL2
# @info "max abs diff hL" maximum(abs.(hL[3, :, :] .- hL2[3, :, :]))
# @info "max abs diff hR" maximum(abs.(hR[3, :, :] .- hR2[3, :, :]))
@info "minimum(hR[3,:,:])" minimum(hR[3, :, :])
# @info "minimum(hR2[3,:,:])" minimum(hR2[3, :, :])
# @info "node.∑" node.∑
# @info "sum(node.h, dims=2)" sum(node.h, dims=2)
cumsum!(hL, h, dims = 2)
hR .= view(hL, :, params.nbins:params.nbins, :) .- hL

@inbounds for j in js
monotone_constraint = monotone_constraints[j]
Expand Down
16 changes: 2 additions & 14 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ function grow_tree!(
# grow while there are remaining active nodes
while length(n_current) > 0 && depth <= params.max_depth
offset = 0 # identifies breakpoint for each node set within a depth

if depth < params.max_depth
for n_id in eachindex(n_current)
n = n_current[n_id]
Expand All @@ -212,20 +212,8 @@ function grow_tree!(
if depth == params.max_depth || nodes[n].∑[end] <= params.min_weight
pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is)
else
update_gains!(nodes[n], js, params, K, monotone_constraints)
update_gains!(nodes[n], js, params, monotone_constraints)
best = findmax(nodes[n].gains)
# if depth in [2]
# @info "minimum(nodes[n].h[3,:,:])" minimum(nodes[n].h[3,:,:])
# @info "minimum(nodes[n].h[4,:,:])" minimum(nodes[n].h[4,:,:])
# @info "nodes[n].hL" nodes[n].hL[:, best[2][1], best[2][2]]
# @info "nodes[n].hR" nodes[n].hR[:, best[2][1], best[2][2]]
# end
if depth in [2,3,4]
@info "depth" depth
# @info "best" best
# @info "nodes[n].gain" nodes[n].gain
# @info "nodes[n].∑" nodes[n].∑
end
if best[2][1] != params.nbins && best[1] > nodes[n].gain + params.gamma
tree.gain[n] = best[1] - nodes[n].gain
tree.cond_bin[n] = best[2][1]
Expand Down

0 comments on commit 0e30320

Please sign in to comment.