prediction is a package focused on one function - prediction()
- that provides type-safe methods for generating predictions from fitted regression models. prediction()
is an S3 generic, which always return a "data.frame"
class object rather than the mix of vectors, lists, etc. that are returned by the predict()
method for many model types. It provides a key piece of infrastructure for the margins package.
In addition to prediction()
, the package provides a number of utility functions for generating useful predictions:
find_data()
, an S3 generic with methods that find the data frame used to estimate a regression modelmean_or_mode()
andmedian_or_mode()
, which provide a convenient way to compute the data needed for predicted values at means (or at medians), respecting the differences between factor and numeric variables.seq_range()
, which generates a vector of n values based upon the range of values in a variablebuild_datalist()
, which generates a list of data frames from an input data frame and a specified set of replacementat
values (mimicking theatlist
option of Stata'smargins
command)
A major downside of the predict()
methods for common modelling classes is that the result is not typesafe. Consider the following simple example:
library("stats")
library("datasets")
x <- lm(mpg ~ cyl * hp + wt, data = mtcars)
class(predict(x))
## [1] "numeric"
class(predict(x, se.fit = TRUE))
## [1] "list"
prediction solves this issue by providing a wrapper around predict()
, called prediction()
, that always returns a tidy data frame with a very simple print()
method:
library("prediction")
(p <- prediction(x))
## Average prediction for 32 observations: 20.0906
class(p)
## [1] "prediction" "data.frame"
head(p)
## mpg cyl disp hp drat wt qsec vs am gear carb fitted se.fitted
## 1 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4 21.90488 0.6927034
## 2 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4 21.10933 0.6266557
## 3 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1 25.64753 0.6652076
## 4 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1 20.04859 0.6041400
## 5 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2 17.25445 0.7436172
## 6 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1 19.53360 0.6436862
The output always contains the original data (i.e., either data found using the find_data()
function or passed to the data
argument to prediction()
). This makes it much simpler to pass predictions to, e.g., further summary or plotting functions.
Additionally the vast majority of methods all the passing of an at
argument, which can be used to obtain predicted values using modified version of data
held to specific values:
prediction(x, at = list(hp = seq_range(mtcars$hp, 5)))
## Average predictions for 32 observations:
## at(hp) value
## 52.0 22.605
## 122.8 19.328
## 193.5 16.051
## 264.2 12.774
## 335.0 9.497
The currently supported model classes are:
- "lm" from
stats::lm()
- "glm" from
stats::glm()
,MASS::glm.nb()
,glmx::glmx()
,glmx::hetglm()
- "ar" from
stats::ar()
- "Arima" from
stats::arima()
- "arima0" from
stats::arima0()
- "betareg" from
betareg::betareg()
- "clm" from
ordinal::clm()
- "coxph" from
survival::coxph()
- "crch" from
crch::crch()
- "gam" from
gam::gam()
- "gls" from
nlme::gls()
- "hxlr" from
crch::hxlr()
- "ivreg" from
AER::ivreg()
- "lda" from
MASS:lda()
- "loess" from
stats::loess()
- "naiveBayes" from
e1071::naiveBayes()
- "nls" from
stats::nls()
- "nnet" from
nnet::nnet()
,nnet::multinom()
- "polr" from
MASS::polr()
- "ppr" from
stats::ppr()
- "princomp" from
stats::princomp()
- "qda" from
MASS:qda()
- "rlm" from
MASS::rlm()
- "rq" from
quantreg::rq()
- "selection" from
sampleSelection::selection()
- "survreg" from
survival::survreg()
- "svm" from
e1071::svm()
- "svyglm" from
survey::svyglm()
The development version of this package can be installed directly from GitHub using ghit
:
if (!require("ghit")) {
install.packages("ghit")
library("ghit")
}
install_github("leeper/prediction")