From 4be7a28e586d2d1f0c50a660c09de4abe17a5046 Mon Sep 17 00:00:00 2001 From: Dan Frankowski Date: Wed, 18 Mar 2020 09:45:18 -0500 Subject: [PATCH] Add prediction.ridgeLinear without se.fit --- DESCRIPTION | 8 +++--- NAMESPACE | 1 + NEWS.md | 4 +++ R/prediction_ridgeLinear.R | 48 ++++++++++++++++++++++++++++++++++ tests/testthat/tests-methods.R | 16 ++++++++++++ 5 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 R/prediction_ridgeLinear.R diff --git a/DESCRIPTION b/DESCRIPTION index cde7b57..09ca2f2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -3,8 +3,8 @@ Type: Package Title: Tidy, Type-Safe 'prediction()' Methods Description: A one-function package containing 'prediction()', a type-safe alternative to 'predict()' that always returns a data frame. The 'summary()' method provides a data frame with average predictions, possibly over counterfactual versions of the data (a la the 'margins' command in 'Stata'). Marginal effect estimation is provided by the related package, 'margins' . The package currently supports common model types (e.g., "lm", "glm") from the 'stats' package, as well as numerous other model classes from other add-on packages. See the README or main package documentation page for a complete listing. License: MIT + file LICENSE -Version: 0.3.15 -Date: 2019-12-24 +Version: 0.3.16 +Date: 2020-03-18 Authors@R: c(person("Thomas J.", "Leeper", role = c("aut", "cre"), email = "thosjleeper@gmail.com", @@ -13,7 +13,8 @@ Authors@R: c(person("Thomas J.", "Leeper", email = "carlganz@ucla.edu"), person("Vincent", "Arel-Bundock", role = "ctb", email = "vincent.arel-bundock@umontreal.ca", - comment = c(ORCID = "0000-0003-2042-7063")) + comment = c(ORCID = "0000-0003-2042-7063")), + person("Dan", "Frankowski", role = "ctb") ) URL: https://github.com/leeper/prediction BugReports: https://github.com/leeper/prediction/issues @@ -55,6 +56,7 @@ Enhances: plm, pscl, quantreg, + ridge (>= 2.5), rpart, sampleSelection, speedglm, diff --git a/NAMESPACE b/NAMESPACE index 3daf9dc..24f6276 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -48,6 +48,7 @@ S3method(prediction,ivreg) S3method(prediction,knnreg) S3method(prediction,kqr) S3method(prediction,ksvm) +S3method(prediction,linearRidge) S3method(prediction,lm) S3method(prediction,lme) S3method(prediction,loess) diff --git a/NEWS.md b/NEWS.md index 5c186bc..78db439 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +## prediction 0.3.16 + +* Add `prediction.ridgeLinear` + ## prediction 0.3.15 * `build_datalist()` now works correctly with data.table datasets. (#34, #35, h/t Dan Schrage) diff --git a/R/prediction_ridgeLinear.R b/R/prediction_ridgeLinear.R new file mode 100644 index 0000000..0f2196a --- /dev/null +++ b/R/prediction_ridgeLinear.R @@ -0,0 +1,48 @@ +## Like prediction.default, but without calculate_se + +#' @rdname prediction +#' @export +prediction.ridgeLinear <- + function(model, + data = find_data(model, parent.frame()), + at = NULL, + type = "response", + vcov = stats::vcov(model), + calculate_se = FALSE, + ...) { + if (calculate_se) { + stop(paste0("calculate_se not implemented")) + } + # extract predicted values + data <- data + if (missing(data) || is.null(data)) { + pred <- predict(model, type = type, ...) + pred <- make_data_frame(fitted = pred, se.fitted = rep(NA_real_, length(pred))) + } else { + # setup data + if (!is.null(at)) { + data <- build_datalist(data, at = at, as.data.frame = TRUE) + at_specification <- attr(data, "at_specification") + } + # calculate predictions + tmp <- predict(model, newdata = data, type = type, ...) + # cbind back together + pred <- make_data_frame(data, fitted = tmp, se.fitted = rep(NA_real_, nrow(data))) + } + + # variance(s) of average predictions + J <- NULL + vc <- NA_real_ + + # output + structure(pred, + class = c("prediction", "data.frame"), + at = if (is.null(at)) at else at_specification, + type = type, + call = if ("call" %in% names(model)) model[["call"]] else NULL, + model_class = class(model), + row.names = seq_len(nrow(pred)), + vcov = vc, + jacobian = J, + weighted = FALSE) + } diff --git a/tests/testthat/tests-methods.R b/tests/testthat/tests-methods.R index 49b7542..299018e 100644 --- a/tests/testthat/tests-methods.R +++ b/tests/testthat/tests-methods.R @@ -248,6 +248,22 @@ if (require("kernlab", quietly = TRUE)) { }) } +if (require("ridge", quietly = TRUE)) { + test_that("Test prediction() for 'ridge'", { + data("mtcars", package = "datasets") + + model1 <- lm(mpg ~ wt + cyl, data = mtcars) + model2 <- linearRidge(mpg ~ wt + cyl, data = mtcars, lambda = 0) + + preds1 <- prediction(model1) + preds2 <- prediction(model2) + expect_equal(preds1$fitted, preds2$fitted, tolerance = 0.0001, label = "predictions") + expect_true(inherits(preds2, "prediction"), label = "'prediction' class is correct") + # expect_true(all(c("fitted", "se.fitted") %in% names(p)), label = "'fitted' and 'se.fitted' columns returned") + expect_true(all(c("fitted") %in% names(preds2)), label = "'fitted' column returned") + }) +} + if (require("lme4", quietly = TRUE)) { test_that("Test prediction() for 'merMod'", { data("cbpp", package = "lme4")