Skip to content

Commit

Permalink
fix default for sample fraction without replacement in pure cpp version
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed Aug 22, 2019
1 parent ce29a66 commit 46c2f4c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cpp_version/src/utility/ArgumentHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ R package "ranger" under GPL3 license.
namespace ranger {

ArgumentHandler::ArgumentHandler(int argc, char **argv) :
caseweights(""), depvarname(""), fraction(1), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict(
caseweights(""), depvarname(""), fraction(0), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict(
""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), nthreads(
DEFAULT_NUM_THREADS), predall(false), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), maxdepth(
DEFAULT_MAXDEPTH), file(""), impmeasure(DEFAULT_IMPORTANCE_MODE), targetpartitionsize(0), mtry(0), outprefix(
Expand Down
9 changes: 8 additions & 1 deletion src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode
prediction_mode = true;
}

// Sample fraction to vector
// Sample fraction default and convert to vector
if (sample_fraction == 0) {
if (sample_with_replacement) {
sample_fraction = DEFAULT_SAMPLE_FRACTION_REPLACE;
} else {
sample_fraction = DEFAULT_SAMPLE_FRACTION_NOREPLACE;
}
}
std::vector<double> sample_fraction_vector = { sample_fraction };

// Call other init function
Expand Down
3 changes: 2 additions & 1 deletion src/globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ const uint DEFAULT_MAXDEPTH = 0;
const PredictionType DEFAULT_PREDICTIONTYPE = RESPONSE;
const uint DEFAULT_NUM_RANDOM_SPLITS = 1;

//const std::vector<double> DEFAULT_SAMPLE_FRACTION = std::vector<double>({1});
const double DEFAULT_SAMPLE_FRACTION_REPLACE = 1;
const double DEFAULT_SAMPLE_FRACTION_NOREPLACE = 0.632;

// Interval to print progress in seconds
const double STATUS_INTERVAL = 30.0;
Expand Down

0 comments on commit 46c2f4c

Please sign in to comment.