From 46c2f4cea95cf25cd6177b4cccd1ca6aac5cd84b Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 22 Aug 2019 09:56:03 +0200 Subject: [PATCH] fix default for sample fraction without replacement in pure cpp version --- cpp_version/src/utility/ArgumentHandler.cpp | 2 +- src/Forest.cpp | 9 ++++++++- src/globals.h | 3 ++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/cpp_version/src/utility/ArgumentHandler.cpp b/cpp_version/src/utility/ArgumentHandler.cpp index afc4b6492..cc2afddc0 100644 --- a/cpp_version/src/utility/ArgumentHandler.cpp +++ b/cpp_version/src/utility/ArgumentHandler.cpp @@ -20,7 +20,7 @@ R package "ranger" under GPL3 license. namespace ranger { ArgumentHandler::ArgumentHandler(int argc, char **argv) : - caseweights(""), depvarname(""), fraction(1), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict( + caseweights(""), depvarname(""), fraction(0), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict( ""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), nthreads( DEFAULT_NUM_THREADS), predall(false), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), maxdepth( DEFAULT_MAXDEPTH), file(""), impmeasure(DEFAULT_IMPORTANCE_MODE), targetpartitionsize(0), mtry(0), outprefix( diff --git a/src/Forest.cpp b/src/Forest.cpp index efabc480a..12b13a7ad 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -80,7 +80,14 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode prediction_mode = true; } - // Sample fraction to vector + // Sample fraction default and convert to vector + if (sample_fraction == 0) { + if (sample_with_replacement) { + sample_fraction = DEFAULT_SAMPLE_FRACTION_REPLACE; + } else { + sample_fraction = DEFAULT_SAMPLE_FRACTION_NOREPLACE; + } + } std::vector sample_fraction_vector = { sample_fraction }; // Call other init function diff --git a/src/globals.h b/src/globals.h index b2974423c..1b391296c 100644 --- a/src/globals.h +++ b/src/globals.h @@ -94,7 +94,8 @@ const uint DEFAULT_MAXDEPTH = 0; const PredictionType DEFAULT_PREDICTIONTYPE = RESPONSE; const uint DEFAULT_NUM_RANDOM_SPLITS = 1; -//const std::vector DEFAULT_SAMPLE_FRACTION = std::vector({1}); +const double DEFAULT_SAMPLE_FRACTION_REPLACE = 1; +const double DEFAULT_SAMPLE_FRACTION_NOREPLACE = 0.632; // Interval to print progress in seconds const double STATUS_INTERVAL = 30.0;