forked from Evovest/EvoTrees.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
importance.jl
36 lines (30 loc) · 899 Bytes
/
importance.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# importance from single tree
# function importance!(gain::AbstractVector, tree::Tree)
# @inbounds for node in tree.nodes
# if node.split
# gain[node.feat] += node.gain
# end
# end
# end
function importance!(gain::AbstractVector, tree::Tree)
@inbounds for n in eachindex(tree.split)
if tree.split[n]
gain[tree.feat[n]] += tree.gain[n]
end
end
end
"""
importance(model::GBTree, vars::AbstractVector)
Sorted normalized feature importance based on loss function gain.
"""
function importance(model::GBTree, vars::AbstractVector)
gain = zeros(length(vars))
# Loop importance over all trees and sort results.
for tree in model.trees
importance!(gain, tree)
end
gain .= gain ./ sum(gain)
pairs = collect(Dict(zip(string.(vars), gain)))
sort!(pairs, by = x -> -x[2])
return pairs
end