-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
document support/handling of missings + tests
- Loading branch information
Showing
13 changed files
with
206 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "EvoTrees" | ||
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" | ||
authors = ["jeremiedb <[email protected]>"] | ||
version = "0.16.1" | ||
version = "0.16.2" | ||
|
||
[deps] | ||
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
using Statistics | ||
using StatsBase: sample | ||
using EvoTrees: sigmoid, logit | ||
using EvoTrees: check_args, check_parameter | ||
using CategoricalArrays | ||
using DataFrames | ||
using Random: seed! | ||
|
||
# prepare a dataset | ||
seed!(123) | ||
nobs = 1_000 | ||
x_num = rand(nobs) .* 5 | ||
lvls = ["a", "b", "c"] | ||
x_cat = categorical(rand(lvls, nobs), levels=lvls, ordered=false) | ||
x_bool = rand(Bool, nobs) | ||
|
||
x_num_m1 = Vector{Union{Missing,Float64}}(copy(x_num)) | ||
x_num_m2 = Vector{Any}(copy(x_num)) | ||
lvls_m1 = ["a", "b", "c", missing] | ||
x_cat_m1 = categorical(rand(lvls_m1, nobs), levels=lvls) | ||
x_bool_m1 = Vector{Union{Missing,Bool}}(copy(x_bool)) | ||
|
||
# train-eval split | ||
is = collect(1:nobs) | ||
i_sample = sample(is, nobs, replace=false) | ||
train_size = 0.8 | ||
i_train = i_sample[1:floor(Int, train_size * nobs)] | ||
i_eval = i_sample[floor(Int, train_size * nobs)+1:end] | ||
|
||
# target var | ||
y_tot = sin.(x_num) .* 0.5 .+ 0.5 | ||
y_tot = logit(y_tot) + randn(nobs) | ||
y_tot = sigmoid(y_tot) | ||
target_name = "y" | ||
y_tot = sigmoid(y_tot) | ||
y_tot_m1 = allowmissing(y_tot) | ||
y_tot_m1[1] = missing | ||
|
||
config = EvoTreeRegressor( | ||
loss=:linear, | ||
nrounds=100, | ||
nbins=16, | ||
lambda=0.5, | ||
gamma=0.1, | ||
eta=0.05, | ||
max_depth=3, | ||
min_weight=1.0, | ||
rowsample=0.5, | ||
colsample=1.0, | ||
rng=123, | ||
) | ||
|
||
@testset "DataFrames - missing features" begin | ||
|
||
df_tot = DataFrame(x_num=x_num, x_bool=x_bool, x_cat=x_cat, y=y_tot) | ||
dtrain, deval = df_tot[i_train, :], df_tot[i_eval, :] | ||
|
||
model = fit_evotree( | ||
config, | ||
dtrain; | ||
target_name) | ||
|
||
@test model.info[:fnames] == [:x_num, :x_bool, :x_cat] | ||
|
||
# keep only fnames <= Real or Categorical | ||
df_tot = DataFrame(x_num=x_num, x_num_m1=x_num_m1, x_num_m2=x_num_m2, | ||
x_cat_m1=x_cat_m1, x_bool_m1=x_bool_m1, y=y_tot) | ||
dtrain, deval = df_tot[i_train, :], df_tot[i_eval, :] | ||
|
||
model = fit_evotree( | ||
config, | ||
dtrain; | ||
target_name, | ||
deval) | ||
|
||
@test model.info[:fnames] == [:x_num] | ||
|
||
model = fit_evotree( | ||
config, | ||
dtrain; | ||
target_name, | ||
fnames=[:x_num]) | ||
|
||
@test model.info[:fnames] == [:x_num] | ||
|
||
# specifyin features with missings should error | ||
@test_throws AssertionError fit_evotree( | ||
config, | ||
dtrain; | ||
deval, | ||
fnames=[:x_num, :x_num_m1, :x_num_m2, :x_cat_m1, :x_bool_m1], | ||
target_name) | ||
|
||
end | ||
|
||
@testset "DataFrames - missing in target errors" begin | ||
|
||
df_tot = DataFrame(x_num=x_num, x_bool=x_bool, x_cat=x_cat, y=y_tot_m1) | ||
dtrain, deval = df_tot[i_train, :], df_tot[i_eval, :] | ||
|
||
@test_throws AssertionError fit_evotree( | ||
config, | ||
dtrain; | ||
target_name) | ||
|
||
end | ||
|
||
@testset "Matrix - missing features" begin | ||
|
||
x_tot = allowmissing(hcat(x_num_m1)) | ||
@test_throws AssertionError fit_evotree( | ||
config; | ||
x_train=x_tot, | ||
y_train=y_tot) | ||
|
||
x_tot = Matrix{Any}(hcat(x_num_m2)) | ||
@test_throws AssertionError fit_evotree( | ||
config; | ||
x_train=x_tot, | ||
y_train=y_tot) | ||
|
||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters