From 0ca0445c134eafc659c33a7e4aea548719b53362 Mon Sep 17 00:00:00 2001 From: EmilHvitfeldt Date: Tue, 29 Jan 2019 13:09:05 -0800 Subject: [PATCH] added classprob prediction type --- R/nullmodel_data.R | 18 ++++++++++++++++-- .../{test-nullmodel.R => test_nullmodel.R} | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) rename tests/testthat/{test-nullmodel.R => test_nullmodel.R} (99%) diff --git a/R/nullmodel_data.R b/R/nullmodel_data.R index 3de382220..80380a98a 100644 --- a/R/nullmodel_data.R +++ b/R/nullmodel_data.R @@ -23,14 +23,28 @@ null_model_parsnip_data <- defaults = list() ), class = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" + ) + ), + classprob = list( pre = NULL, - post = NULL, + post = function(x, object) { + str(as_tibble(x)) + as_tibble(x) + }, func = c(fun = "predict"), args = list( object = quote(object$fit), new_data = quote(new_data), - type = "class" + type = "prob" ) ), numeric = list( diff --git a/tests/testthat/test-nullmodel.R b/tests/testthat/test_nullmodel.R similarity index 99% rename from tests/testthat/test-nullmodel.R rename to tests/testthat/test_nullmodel.R index 2d82f59c6..45a6a9932 100644 --- a/tests/testthat/test-nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -134,3 +134,4 @@ test_that('classification', { expect_true(!is.null(null_model$fit)) }) +