forked from leeper/prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathREADME.Rmd
134 lines (113 loc) · 6.49 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
---
title: "Tidy, Type-Safe 'prediction()' Methods"
output: github_document
---
<img src="man/figures/logo.png" align="right" />
The **prediction** and **margins** packages are a combined effort to port the functionality of Stata's (closed source) [`margins`](http://www.stata.com/help.cgi?margins) command to (open source) R. **prediction** is focused on one function - `prediction()` - that provides type-safe methods for generating predictions from fitted regression models. `prediction()` is an S3 generic, which always return a `"data.frame"` class object rather than the mix of vectors, lists, etc. that are returned by the `predict()` methods for various model types. It provides a key piece of underlying infrastructure for the **margins** package. Users interested in generating marginal (partial) effects, like those generated by Stata's `margins, dydx(*)` command, should consider using `margins()` from the sibling project, [**margins**](https://cran.r-project.org/package=margins).
In addition to `prediction()`, this package provides a number of utility functions for generating useful predictions:
- `find_data()`, an S3 generic with methods that find the data frame used to estimate a regression model. This is a wrapper around `get_all_vars()` that attempts to locate data as well as modify it according to `subset` and `na.action` arguments used in the original modelling call.
- `mean_or_mode()` and `median_or_mode()`, which provide a convenient way to compute the data needed for predicted values *at means* (or *at medians*), respecting the differences between factor and numeric variables.
- `seq_range()`, which generates a vector of *n* values based upon the range of values in a variable
- `build_datalist()`, which generates a list of data frames from an input data frame and a specified set of replacement `at` values (mimicking the `atlist` option of Stata's `margins` command)
## Simple code examples
```{r opts, echo = FALSE}
library("knitr")
options(width = 100)
opts_knit$set(upload.fun = imgur_upload, base.url = NULL)
opts_chunk$set(fig.width=7, fig.height=4)
```
A major downside of the `predict()` methods for common modelling classes is that the result is not type-safe. Consider the following simple example:
```{r predict}
library("stats")
library("datasets")
x <- lm(mpg ~ cyl * hp + wt, data = mtcars)
class(predict(x))
class(predict(x, se.fit = TRUE))
```
**prediction** solves this issue by providing a wrapper around `predict()`, called `prediction()`, that always returns a tidy data frame with a very simple `print()` method:
```{r prediction}
library("prediction")
(p <- prediction(x))
class(p)
head(p)
```
The output always contains the original data (i.e., either data found using the `find_data()` function or passed to the `data` argument to `prediction()`). This makes it much simpler to pass predictions to, e.g., further summary or plotting functions.
Additionally the vast majority of methods allow the passing of an `at` argument, which can be used to obtain predicted values using modified version of `data` held to specific values:
```{r at_arg}
prediction(x, at = list(hp = seq_range(mtcars$hp, 5)))
```
This more or less serves as a direct R port of (the subset of functionality of) Stata's `margins` command that calculates predictive marginal means, etc. For calculation of marginal or partial effects, see the [**margins**](https://cran.r-project.org/package=margins) package.
## Supported model classes
The currently supported model classes are:
- "lm" from `stats::lm()`
- "glm" from `stats::glm()`, `MASS::glm.nb()`, `glmx::glmx()`, `glmx::hetglm()`, `brglm::brglm()`
- "ar" from `stats::ar()`
- "Arima" from `stats::arima()`
- "arima0" from `stats::arima0()`
- "biglm" from `biglm::biglm()` (including `"ffdf"` backed models)
- "betareg" from `betareg::betareg()`
- "bruto" from `mda::bruto()`
- "clm" from `ordinal::clm()`
- "coxph" from `survival::coxph()`
- "crch" from `crch::crch()`
- "earth" from `earth::earth()`
- "fda" from `mda::fda()`
- "Gam" from `gam::gam()`
- "gausspr" from `kernlab::gausspr()`
- "gee" from `gee::gee()`
- "glimML" from `aod::betabin()`, `aod::negbin()`
- "glimQL" from `aod::quasibin()`, `aod::quasipois()`
- "glmnet" from `glmnet::glmnet()`
- "gls" from `nlme::gls()`
- "hurdle" from `pscl::hurdle()`
- "hxlr" from `crch::hxlr()`
- "ivreg" from `AER::ivreg()`
- "knnreg" from `caret::knnreg()`
- "kqr" from `kernlab::kqr()`
- "ksvm" from `kernlab::ksvm()`
- "lda" from `MASS:lda()`
- "lme" from `nlme::lme()`
- "loess" from `stats::loess()`
- "lqs" from `MASS::lqs()`
- "mars" from `mda::mars()`
- "mca" from `MASS::mca()`
- "mclogit" from `mclogit::mclogit()`
- "mda" from `mda::mda()`
- "merMod" from `lme4::lmer()` and `lme4::glmer()`
- "mnp" from `MNP::mnp()`
- "naiveBayes" from `e1071::naiveBayes()`
- "nlme" from `nlme::nlme()`
- "nls" from `stats::nls()`
- "nnet" from `nnet::nnet()`, `nnet::multinom()`
- "plm" from `plm::plm()`
- "polr" from `MASS::polr()`
- "ppr" from `stats::ppr()`
- "princomp" from `stats::princomp()`
- "qda" from `MASS:qda()`
- "rlm" from `MASS::rlm()`
- "rpart" from `rpart::rpart()`
- "rq" from `quantreg::rq()`
- "selection" from `sampleSelection::selection()`
- "speedglm" from `speedglm::speedglm()`
- "speedlm" from `speedglm::speedlm()`
- "survreg" from `survival::survreg()`
- "svm" from `e1071::svm()`
- "svyglm" from `survey::svyglm()`
- "tobit" from `AER::tobit()`
- "train" from `caret::train()`
- "truncreg" from `truncreg::truncreg()`
- "zeroinfl" from `pscl::zeroinfl()`
## Requirements and Installation
[![CRAN](https://www.r-pkg.org/badges/version/prediction)](https://cran.r-project.org/package=prediction)
![Downloads](https://cranlogs.r-pkg.org/badges/prediction)
[![Build Status](https://travis-ci.org/leeper/prediction.svg?branch=master)](https://travis-ci.org/leeper/prediction)
[![Build status](https://ci.appveyor.com/api/projects/status/a4tebeoa98cq07gy/branch/master?svg=true)](https://ci.appveyor.com/project/leeper/prediction/branch/master)
[![codecov.io](https://codecov.io/github/leeper/prediction/coverage.svg?branch=master)](https://codecov.io/github/leeper/prediction?branch=master)
[![Project Status: Active - The project has reached a stable, usable state and is being actively developed.](http://www.repostatus.org/badges/latest/active.svg)](http://www.repostatus.org/#active)
The development version of this package can be installed directly from GitHub using `remotes`:
``` r
if (!require("remotes")) {
install.packages("remotes")
}
remotes::install_github("leeper/prediction")
```