Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote #7

Merged
merged 3 commits into from
Nov 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
name = "JLBoost"
uuid = "13d6d4a1-5e7f-472c-9ebc-8123a4fbb95f"
authors = ["evalparse <[email protected]>"]
authors = ["Dai ZJ <[email protected]>"]
version = "0.1.0"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JDF = "babc3d20-cd49-4f60-a736-a8f9c08892d3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SortingLab = "562c1548-17b8-5b69-83cf-d8aebec229f5"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
DataFrames = "0.19"
ForwardDiff = "0.10"
JDF = "0.2"
NNlib = "0.6"
SortingLab = "0.2"
Zygote = "0.4"
julia = "1"

[extras]
Expand Down
112 changes: 70 additions & 42 deletions src/JLBoost.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ module JLBoost
using DataFrames
using SortingLab
#using StatsBase
#using Zygote:gradient, hessian
using ForwardDiff:gradient, hessian
using Zygote: gradient, hessian
# using ForwardDiff: gradient, hessian
using Base.Iterators: drop
#using RCall

export JLBoostTreeNode, JLBoostTree, showlah
export xgboost, best_split, _best_split
export xgboost, best_split, _best_split, scoretree

include("JLBoostTree.jl")

Expand All @@ -30,7 +30,7 @@ logloss(w, y) = -(y*log(softmax(w)) + (1-y)*log(1-softmax(w)))
# logloss = logitbinarycrossentropy

g(loss_fn, y, prev_w) = begin
gres = gradient(x->loss_fn(x[1], y), [prev_w])
gres = gradient(x->loss_fn(x, y), prev_w)
gres[1]
end

Expand All @@ -39,6 +39,16 @@ h(loss_fn, y, prev_w) = begin
hres[1]
end

# g_forwarddiff(loss_fn, y, prev_w) = begin
# gres = ForwardDiff.gradient(x->loss_fn(x[1], y), [prev_w])
# gres[1]
# end

# h_forwarddiff(loss_fn, y, prev_w) = begin
# hres = ForwardDiff.hessian(x->loss_fn(x[1], y), [prev_w])
# hres[1]
# end

# update the weight once so that it starts at a better point
function update_weight(loss_fn, df, target, prev_w, lambda)
target_vec = df[target];
Expand All @@ -47,15 +57,22 @@ function update_weight(loss_fn, df, target, prev_w, lambda)
-sum(g.(loss_fn, target_vec, prev_w_vec))/(sum(h.(loss_fn, target_vec, prev_w_vec)) + lambda)
end

function apply_split(df, feature, bsplit, lweight, rweight)
df[df[feature] .<= bsplit,:prev_w] = df[df[feature] .<= bsplit,:prev_w] .+ lweight
df[df[feature] .> bsplit,:prev_w] = df[df[feature] .> bsplit,:prev_w] .+ rweight
"""
apply_split(df::AbstractDataFrame, feature, split_at, lweight, rweight)

Apply split to the dataframe df
"""
function apply_split(df::AbstractDataFrame, feature, split_at, lweight, rweight)
df[df[feature] .<= split_at,:prev_w] = df[df[feature] .<= split_at,:prev_w] .+ lweight
df[df[feature] .> split_at,:prev_w] = df[df[feature] .> split_at,:prev_w] .+ rweight
df
end
# w = update_weight(logloss, df, target, prev_w, lambda)
# df[prev_w] .+= w
# (w, unique(df[prev_w])..., softmax(w))

"""
best_split(loss_fn, df::AbstractDataFrame, feature, target, prev_w, lambda, gamma; verbose = false)

Determine the best split of a given variable
"""
function best_split(loss_fn, df::DataFrame, feature, target, prev_w, lambda, gamma; verbose = false)
if verbose
println("Choosing a split on", feature)
Expand All @@ -79,6 +96,8 @@ Find the best (binary) split point by loss_fn(feature, target) given a sorted it
of feature
"""
function best_split(loss_fn, feature, target, prev_w, lambda::Number, gamma::Number, verbose = false)
@assert length(feature) == length(target)
@assert length(feature) == length(prev_w)
if issorted(feature)
res = _best_split(loss_fn, feature, target, prev_w, lambda, gamma, verbose)
else
Expand Down Expand Up @@ -123,31 +142,44 @@ function _best_split(loss_fn, feature, target, prev_w, lambda::Number, gamma::Nu
ch = cumsum(h.(loss_fn, target, prev_w))

max_cg = cg[end]
max_ch = ch[end]
max_ch = ch[end]

last_feature = feature[1]
cutpt = zero(Int)
lweight = 0.0
rweight = 0.0
best_gain = typemin(Float64)

for (i, (f, cg, ch)) in enumerate(zip(drop(feature,1) , @view(cg[2:end]), @view(ch[2:end])))
if f != last_feature
left_split = cg^2 /(ch + lambda)
right_split = (max_cg-cg)^(2) / ((max_ch - ch) + lambda)
no_split = max_cg^2 /(max_ch + lambda)
gain = left_split + right_split - no_split - gamma
if gain > best_gain
println(i)
println(best_gain)
best_gain = gain
cutpt = i
end
last_feature = f
end
end
cutpt::Int = zero(Int)
lweight::Float64 = 0.0
rweight::Float64 = 0.0
best_gain::Float64 = typemin(Float64)

if length(feature) == 1
no_split = max_cg^2 /(max_ch + lambda)
gain = no_split - gamma
cutpt = 1
lweight = -cg[cutpt]/(ch[cutpt]+lambda)
rweight = -(max_cg - cg[cutpt])/(max_ch - ch[cutpt] + lambda)
else
for (i, (f, cg, ch)) in enumerate(zip(drop(feature,1) , @view(cg[1:end-1]), @view(ch[1:end-1])))
if f != last_feature
left_split = cg^2 /(ch + lambda)
right_split = (max_cg-cg)^(2) / ((max_ch - ch) + lambda)
no_split = max_cg^2 /(max_ch + lambda)
gain = left_split + right_split - no_split - gamma
if gain > best_gain
best_gain = gain
cutpt = i
lweight = -cg/(ch+lambda)
rweight = -(max_cg - cg)/(max_ch - ch + lambda)
end
last_feature = f
end
end
end

(split_at = feature[cutpt], cutpt = cutpt, gain = best_gain, lweight = lweight, rweight = rweight)
split_at = typemin(eltype(feature))
if cutpt >= 1
split_at = feature[cutpt]
end

(split_at = split_at, cutpt = cutpt, gain = best_gain, lweight = lweight, rweight = rweight)
end

# The main XGBoost function
Expand All @@ -158,8 +190,6 @@ function xgboost(df, target, features; prev_w = :prev_w, eta = 0.3, lambda = 0,
end

function xgboost(df, target, features, jlt::JLBoostTreeNode; prev_w = :prev_w, eta = 0.3, lambda = 0, gamma = 0, maxdepth = 6, subsample = 1)
#println(maxdepth)

# initialise the weights to 0 if the column doesn't exist yet
if !(prev_w in names(df))
df[!, prev_w] .= 0.0
Expand All @@ -170,24 +200,22 @@ function xgboost(df, target, features, jlt::JLBoostTreeNode; prev_w = :prev_w, e
#df[prev_w] = df[prev_w] .+ jlt.weight

# compute the gain for all splits for all features
all_splits = [best_split(logloss, df, feature, target, prev_w, lambda, gamma) for feature in features]
return all_splits
all_splits = [best_split(logloss, df, feature, target, prev_w, lambda, gamma) for feature in features]
split_with_best_gain = all_splits[findmax(map(x->x.gain, all_splits))[2]]


# there needs to be positive gain then apply split to the tree
if split_with_best_gain.gain > 0
# set the parent tree node
jlt.split = split_with_best_gain.best_split
jlt.splitfeature = split_with_best_gain.feature
jlt.split = split_with_best_gain.split_at
jlt.splitfeature = split_with_best_gain.feature

left_treenode = JLBoostTreeNode(split_with_best_gain.lweight)
left_treenode = JLBoostTreeNode(split_with_best_gain.lweight)
right_treenode = JLBoostTreeNode(split_with_best_gain.rweight)

if maxdepth > 1
# now recursively apply the weights to left branch and right branch
df_left = df[df[!, split_with_best_gain.feature] .<= split_with_best_gain.best_split,:]
df_right = df[df[!, split_with_best_gain.feature] .> split_with_best_gain.best_split,:]
df_left = df[df[!, split_with_best_gain.feature] .<= split_with_best_gain.split_at,:]
df_right = df[df[!, split_with_best_gain.feature] .> split_with_best_gain.split_at,:]

left_treenode = xgboost(df_left, target, features, left_treenode; prev_w = prev_w, eta = eta, lambda = lambda, gamma = gamma, maxdepth = maxdepth - 1, subsample = subsample)
right_treenode = xgboost(df_right, target, features, right_treenode; prev_w = prev_w, eta = eta, lambda = lambda, gamma = gamma, maxdepth = maxdepth - 1, subsample = subsample)
Expand Down