Skip to content

Tidy, Type-Safe 'prediction()' Methods

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE
MIT
LICENSE.md
Notifications You must be signed in to change notification settings

dfrankow/prediction

Repository files navigation

Tidy, Type-Safe 'prediction()' Methods

prediction is a package focused on one function - prediction() - that provides type-safe methods for generating predictions from fitted regression models. prediction() is an S3 generic, which always return a "data.frame" class object rather than the mix of vectors, lists, etc. that are returned by the predict() method for many model types. It provides a key piece of infrastructure for the margins package.

In addition to prediction(), the package provides a number of utility functions for generating useful predictions:

  • find_data(), an S3 generic with methods that find the data frame used to estimate a regression model
  • mean_or_mode() and median_or_mode(), which provide a convenient way to compute the data needed for predicted values at means (or at medians), respecting the differences between factor and numeric variables.
  • seq_range(), which generates a vector of n values based upon the range of values in a variable
  • build_datalist(), which generates a list of data frames from an input data frame and a specified set of replacement at values (mimicking the atlist option of Stata's margins command)

Simple code examples

A major downside of the predict() methods for common modelling classes is that the result is not typesafe. Consider the following simple example:

library("stats")
library("datasets")
x <- lm(mpg ~ cyl * hp + wt, data = mtcars)
class(predict(x))
## [1] "numeric"
class(predict(x, se.fit = TRUE))
## [1] "list"

prediction solves this issue by providing a wrapper around predict(), called prediction(), that always returns a tidy data frame with a very simple print() method:

library("prediction")
(p <- prediction(x))
## Average prediction for 32 observations: 20.0906
class(p)
## [1] "prediction" "data.frame"
head(p)
##    mpg cyl disp  hp drat    wt  qsec vs am gear carb   fitted se.fitted
## 1 21.0   6  160 110 3.90 2.620 16.46  0  1    4    4 21.90488 0.6927034
## 2 21.0   6  160 110 3.90 2.875 17.02  0  1    4    4 21.10933 0.6266557
## 3 22.8   4  108  93 3.85 2.320 18.61  1  1    4    1 25.64753 0.6652076
## 4 21.4   6  258 110 3.08 3.215 19.44  1  0    3    1 20.04859 0.6041400
## 5 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2 17.25445 0.7436172
## 6 18.1   6  225 105 2.76 3.460 20.22  1  0    3    1 19.53360 0.6436862

The output always contains the original data (i.e., either data found using the find_data() function or passed to the data argument to prediction()). This makes it much simpler to pass predictions to, e.g., further summary or plotting functions.

Additionally the vast majority of methods all the passing of an at argument, which can be used to obtain predicted values using modified version of data held to specific values:

prediction(x, at = list(hp = seq_range(mtcars$hp, 5)))
## Average predictions for 32 observations:
##  at(hp)  value
##    52.0 22.605
##   122.8 19.328
##   193.5 16.051
##   264.2 12.774
##   335.0  9.497

Supported model classes

The currently supported model classes are:

  • "lm" from stats::lm()
  • "glm" from stats::glm(), MASS::glm.nb(), glmx::glmx(), glmx::hetglm()
  • "ar" from stats::ar()
  • "Arima" from stats::arima()
  • "arima0" from stats::arima0()
  • "betareg" from betareg::betareg()
  • "clm" from ordinal::clm()
  • "coxph" from survival::coxph()
  • "crch" from crch::crch()
  • "gam" from gam::gam()
  • "gls" from nlme::gls()
  • "hxlr" from crch::hxlr()
  • "ivreg" from AER::ivreg()
  • "lda" from MASS:lda()
  • "loess" from stats::loess()
  • "naiveBayes" from e1071::naiveBayes()
  • "nls" from stats::nls()
  • "nnet" from nnet::nnet(), nnet::multinom()
  • "polr" from MASS::polr()
  • "ppr" from stats::ppr()
  • "princomp" from stats::princomp()
  • "qda" from MASS:qda()
  • "rlm" from MASS::rlm()
  • "rq" from quantreg::rq()
  • "selection" from sampleSelection::selection()
  • "survreg" from survival::survreg()
  • "svm" from e1071::svm()
  • "svyglm" from survey::svyglm()

Requirements and Installation

CRAN Build Status Build status codecov.io Project Status: Active - The project has reached a stable, usable state and is being actively developed.

The development version of this package can be installed directly from GitHub using ghit:

if (!require("ghit")) {
    install.packages("ghit")
    library("ghit")
}
install_github("leeper/prediction")

About

Tidy, Type-Safe 'prediction()' Methods

Resources

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE
MIT
LICENSE.md

Stars

Watchers

Forks

Packages

No packages published

Languages

  • R 99.6%
  • Makefile 0.4%