Skip to content

Commit

Permalink
Add prediction.ridgeLinear without se.fit
Browse files Browse the repository at this point in the history
  • Loading branch information
dfrankow committed Mar 18, 2020
1 parent 095788b commit 4be7a28
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 3 deletions.
8 changes: 5 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ Type: Package
Title: Tidy, Type-Safe 'prediction()' Methods
Description: A one-function package containing 'prediction()', a type-safe alternative to 'predict()' that always returns a data frame. The 'summary()' method provides a data frame with average predictions, possibly over counterfactual versions of the data (a la the 'margins' command in 'Stata'). Marginal effect estimation is provided by the related package, 'margins' <https://cran.r-project.org/package=margins>. The package currently supports common model types (e.g., "lm", "glm") from the 'stats' package, as well as numerous other model classes from other add-on packages. See the README or main package documentation page for a complete listing.
License: MIT + file LICENSE
Version: 0.3.15
Date: 2019-12-24
Version: 0.3.16
Date: 2020-03-18
Authors@R: c(person("Thomas J.", "Leeper",
role = c("aut", "cre"),
email = "[email protected]",
Expand All @@ -13,7 +13,8 @@ Authors@R: c(person("Thomas J.", "Leeper",
email = "[email protected]"),
person("Vincent", "Arel-Bundock", role = "ctb",
email = "[email protected]",
comment = c(ORCID = "0000-0003-2042-7063"))
comment = c(ORCID = "0000-0003-2042-7063")),
person("Dan", "Frankowski", role = "ctb")
)
URL: https://github.com/leeper/prediction
BugReports: https://github.com/leeper/prediction/issues
Expand Down Expand Up @@ -55,6 +56,7 @@ Enhances:
plm,
pscl,
quantreg,
ridge (>= 2.5),
rpart,
sampleSelection,
speedglm,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ S3method(prediction,ivreg)
S3method(prediction,knnreg)
S3method(prediction,kqr)
S3method(prediction,ksvm)
S3method(prediction,linearRidge)
S3method(prediction,lm)
S3method(prediction,lme)
S3method(prediction,loess)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## prediction 0.3.16

* Add `prediction.ridgeLinear`

## prediction 0.3.15

* `build_datalist()` now works correctly with data.table datasets. (#34, #35, h/t Dan Schrage)
Expand Down
48 changes: 48 additions & 0 deletions R/prediction_ridgeLinear.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
## Like prediction.default, but without calculate_se

#' @rdname prediction
#' @export
prediction.ridgeLinear <-
function(model,
data = find_data(model, parent.frame()),
at = NULL,
type = "response",
vcov = stats::vcov(model),
calculate_se = FALSE,
...) {
if (calculate_se) {
stop(paste0("calculate_se not implemented"))
}
# extract predicted values
data <- data
if (missing(data) || is.null(data)) {
pred <- predict(model, type = type, ...)
pred <- make_data_frame(fitted = pred, se.fitted = rep(NA_real_, length(pred)))
} else {
# setup data
if (!is.null(at)) {
data <- build_datalist(data, at = at, as.data.frame = TRUE)
at_specification <- attr(data, "at_specification")
}
# calculate predictions
tmp <- predict(model, newdata = data, type = type, ...)
# cbind back together
pred <- make_data_frame(data, fitted = tmp, se.fitted = rep(NA_real_, nrow(data)))
}

# variance(s) of average predictions
J <- NULL
vc <- NA_real_

# output
structure(pred,
class = c("prediction", "data.frame"),
at = if (is.null(at)) at else at_specification,
type = type,
call = if ("call" %in% names(model)) model[["call"]] else NULL,
model_class = class(model),
row.names = seq_len(nrow(pred)),
vcov = vc,
jacobian = J,
weighted = FALSE)
}
16 changes: 16 additions & 0 deletions tests/testthat/tests-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,22 @@ if (require("kernlab", quietly = TRUE)) {
})
}

if (require("ridge", quietly = TRUE)) {
test_that("Test prediction() for 'ridge'", {
data("mtcars", package = "datasets")

model1 <- lm(mpg ~ wt + cyl, data = mtcars)
model2 <- linearRidge(mpg ~ wt + cyl, data = mtcars, lambda = 0)

preds1 <- prediction(model1)
preds2 <- prediction(model2)
expect_equal(preds1$fitted, preds2$fitted, tolerance = 0.0001, label = "predictions")
expect_true(inherits(preds2, "prediction"), label = "'prediction' class is correct")
# expect_true(all(c("fitted", "se.fitted") %in% names(p)), label = "'fitted' and 'se.fitted' columns returned")
expect_true(all(c("fitted") %in% names(preds2)), label = "'fitted' column returned")
})
}

if (require("lme4", quietly = TRUE)) {
test_that("Test prediction() for 'merMod'", {
data("cbpp", package = "lme4")
Expand Down

0 comments on commit 4be7a28

Please sign in to comment.