Skip to content

Commit adeebb5

Browse files
wendycwongmaurever
wendycwong
andauthored
h2oaiGH-16149: fix gam rebalance bug (h2oai#16153)
* add R test that reproduce the error. * add more to test. * h2oaiGH-16149: fixed key passed to rebalance dataset to avoid collision. * Adopt adam code review comments. Co-authored-by: Veronika Maurerová <[email protected]>
1 parent 0c60e3e commit adeebb5

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

h2o-algos/src/main/java/hex/gam/GAM.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ public void computeImpl() {
927927
if (error_count() > 0) // if something goes wrong, let's throw a fit
928928
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(GAM.this);
929929
// add gamified columns to training frame
930-
Frame newTFrame = new Frame(rebalance(adaptTrain(), false, _result+".temporary.train"));
930+
Frame newTFrame = new Frame(rebalance(adaptTrain(), false, Key.make()+".temporary.train"));
931931
verifyGamTransformedFrame(newTFrame);
932932

933933
if (error_count() > 0) // if something goes wrong during gam transformation, let's throw a fit again!
@@ -937,7 +937,7 @@ public void computeImpl() {
937937
int[] singleGamColsCount = new int[]{_cubicSplineNum, _iSplineNum, _mSplineNum};
938938
_valid = rebalance(adaptValidFrame(_parms.valid(), _valid, _parms, _gamColNamesCenter, _binvD,
939939
_zTranspose, _knots, _zTransposeCS, _allPolyBasisList, _gamColMeansRaw, _oneOGamColStd, singleGamColsCount),
940-
false, _result + ".temporary.valid");
940+
false, Key.make() + ".temporary.valid");
941941
}
942942
DKV.put(newTFrame); // This one will cause deleted vectors if add to Scope.track
943943
Frame newValidFrame = _valid == null ? null : new Frame(_valid);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
setwd(normalizePath(dirname(R.utils::commandArgs(asValues=TRUE)$"f")))
2+
source("../../../scripts/h2o-r-test-setup.R")
3+
4+
library(data.table)
5+
6+
# This test was provided by a customer. No exit condition is needed as
7+
# the test before my fix always failed. As long as this test completes
8+
# successfully, it should be good enough.
9+
test.gam.dataset.error <- function(n) {
10+
sum_insured <- seq(1, 200000, length.out = n)
11+
d2 <-
12+
data.table(
13+
sum_insured = sum_insured,
14+
sqrt = sqrt(sum_insured),
15+
sine = sin(2 * pi * sum_insured / 40000)
16+
)
17+
d2[, sine := 0.3 * sqrt * sine , ]
18+
d2[, y := pmax(0, sqrt + sine) , ]
19+
20+
d2[, x := sum_insured]
21+
d2[, x2 := rev(x) , ] # flip axis
22+
23+
# import the dataset
24+
h2o_data2 <- as.h2o(d2)
25+
26+
model2 <-
27+
h2o.gam(
28+
y = "y",
29+
gam_columns = c("x2"),
30+
bs = c(2),
31+
spline_orders = c(3),
32+
splines_non_negative = c(F),
33+
training_frame = h2o_data2,
34+
family = "tweedie",
35+
tweedie_variance_power = 1.1,
36+
scale = c(0),
37+
lambda = 0,
38+
alpha = 0,
39+
keep_gam_cols = T,
40+
non_negative = TRUE,
41+
num_knots = c(10)
42+
)
43+
print("model building completed.")
44+
}
45+
46+
test.model.gam.dataset.error <- function() {
47+
# test for n=1005
48+
test.gam.dataset.error(1005)
49+
# test for n=1001;
50+
test.gam.dataset.error(1001)
51+
}
52+
53+
doTest("General Additive Model dataset size 1001 and 1005 error", test.model.gam.dataset.error)

0 commit comments

Comments
 (0)