forked from szcf-weiya/ESL-CN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
69a56fb
commit bb9e627
Showing
16 changed files
with
602 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
## ############################################################################################################################### | ||
## R Code for the simulation of Fig. 3.18 | ||
## | ||
## Refer to | ||
## 1. https://esl.hohoweiya.xyz/03-Linear-Methods-for-Regression/3.6-A-Comparison-of-the-Selection-and-Shrinkage-Methods/index.html | ||
## 2. https://esl.hohoweiya.xyz//notes/linear-reg/sim-3-18/index.html | ||
## for more details | ||
## | ||
## ############################################################################################################################### | ||
|
||
## ############################################################################################################### | ||
## generate simulated data | ||
## ############################################################################################################### | ||
genXY <- function(rho = 0.5, # correlation | ||
N = 100, # number of sample | ||
beta = c(4, 2)) # true coefficient | ||
{ | ||
# covariance matrix | ||
Sigma = matrix(c(1, rho, | ||
rho, 1), 2, 2) | ||
library(MASS) | ||
X = mvrnorm(N, c(0, 0), Sigma) | ||
Y = X[, 1] * beta[1] + X[, 2] * beta[2] | ||
return(list(X=X, Y=Y)) | ||
} | ||
|
||
## ############################################################################################################### | ||
## main function | ||
## return the beta calculated by 6 methods (ols, ridge, lasso, pcr (plus mypcr which from scratch), pls, subset) | ||
## | ||
## ############################################################################################################### | ||
|
||
select.vs.shrink <- function(X, Y) | ||
{ | ||
## ############################################ | ||
## least square regressions | ||
## ############################################ | ||
ols.fit = lm(Y ~ 0 + X) | ||
ols.beta = coef(ols.fit) | ||
ols.beta = as.matrix(t(ols.beta)) | ||
## ########################################### | ||
## setting | ||
## ########################################### | ||
|
||
## create grid to fit lasso/ridge path | ||
grid = 10^seq(10, -2, length = 100) | ||
|
||
## ############################################ | ||
## lasso | ||
## ############################################ | ||
library(glmnet) | ||
## use cross-validation to choose the best model | ||
## lasso.fit = cv.glmnet(X, Y, alpha = 1) | ||
|
||
lasso.fit = glmnet(X, Y, alpha = 1, lambda = grid) | ||
#plot(lasso.fit) | ||
## extract beta | ||
lasso.beta = as.matrix(lasso.fit$beta) # convert dsCMatrix to regular matrix | ||
#plot(lasso.beta[1,], lasso.beta[2,]) | ||
lasso.beta = t(lasso.beta) | ||
attr(lasso.beta, "dimnames") = list(NULL, | ||
c("X1","X2")) | ||
|
||
## ############################################ | ||
## ridge regression | ||
## ############################################ | ||
ridge.fit = glmnet(X, Y, alpha = 0, lambda = grid) | ||
ridge.beta = as.matrix(ridge.fit$beta) # convert dsCMatrix to regular matrix | ||
ridge.beta = t(ridge.beta) | ||
attr(ridge.beta, "dimnames") = list(NULL, | ||
c("X1", "X2")) | ||
## ############################################ | ||
## principal component regression (PCR) | ||
## ############################################ | ||
library(pls) | ||
pcr.fit = pcr(Y ~ X, scale = FALSE) | ||
pcr.beta = pcr.fit$coefficients | ||
pcr.beta = rbind(c(0, 0), pcr.beta[,,1], pcr.beta[,,2]) # c(0, 0) for zero PC | ||
## for plot | ||
## or write from scratch | ||
## get PCs | ||
pc = prcomp(X, scale = FALSE) | ||
pc.m = pc$rotation | ||
## scores | ||
pc.z = pc$x | ||
## use one pc | ||
mypcr.fit.1 = lm(Y ~ 0+pc.z[,1]) | ||
## use two pc | ||
mypcr.fit.2 = lm(Y ~ 0+pc.z) | ||
## original beta | ||
mypcr.beta.1 = coef(mypcr.fit.1) * pc.m[, 1] | ||
mypcr.beta.2 = t(pc.m %*% coef(mypcr.fit.2)) | ||
mypcr.beta = rbind(c(0, 0), mypcr.beta.1, mypcr.beta.2) | ||
attr(mypcr.beta, "dimnames") = list(NULL, | ||
c("X1", "X2")) | ||
## ############################################ | ||
## Partial Least Squares (PLS) | ||
## ############################################ | ||
pls.fit = plsr(Y ~ X, scale = FALSE) | ||
pls.beta = pls.fit$coefficients | ||
pls.beta = rbind(c(0, 0), pls.beta[,,1], pls.beta[,,2]) | ||
## ############################################ | ||
## Best Subset | ||
## ############################################ | ||
library(leaps) | ||
bs.fit = regsubsets(x = X, y = Y, intercept = FALSE) | ||
if (summary(bs.fit)$which[1, 1]) | ||
{ | ||
bs.beta = c(coef(bs.fit, 1), 0) | ||
} else { | ||
bs.beta = c(0, coef(bs.fit, 1)) | ||
} | ||
bs.beta = rbind(c(0, 0), bs.beta, coef(bs.fit, 2)) | ||
attr(bs.beta, "dimnames") = list(NULL, | ||
c("X1","X2")) | ||
res = list(ols = ols.beta, | ||
ridge = ridge.beta, | ||
lasso = lasso.beta, | ||
pcr = pcr.beta, | ||
mypcr = mypcr.beta, | ||
pls = pls.beta, | ||
subset = bs.beta) | ||
class(res) = "selectORshrink" | ||
return(res) | ||
} | ||
## ####################################################################### | ||
## plot function | ||
## ####################################################################### | ||
plot.selectORshrink <- function(obj, rho = 0.5) | ||
{ | ||
plot(0, 0, | ||
type = "n", | ||
xlab = expression(beta[1]), | ||
ylab = expression(beta[2]), | ||
main = substitute(paste(rho,"=",r), list(r=rho)), | ||
xlim = c(0, 6), | ||
ylim = c(-1, 3)) | ||
par(lwd = 3, cex = 1) | ||
lines(obj$ridge, col = "red") | ||
lines(obj$lasso, col = "green") | ||
lines(obj$pcr, col = "purple") | ||
lines(obj$pls, col = "orange") | ||
lines(obj$subset, col = "blue") | ||
points(obj$ols, col = "black", pch = 16) | ||
abline(h=0, lty = 2) | ||
abline(v=0, lty = 2) | ||
legend(4.8, 3, | ||
c("Ridge", "Lasso", "PCR", "PLS", "Best Subset", "Least Squares"), | ||
col = c("red", "green", "purple", "orange", "blue", "black"), | ||
lty = c(1,1,1,1,1,NA), | ||
pch =c(NA,NA,NA,NA,NA, 16), | ||
box.col = "white", | ||
box.lwd = 0, | ||
bg = "transparent") | ||
} | ||
|
||
## ################################################################################### | ||
## results | ||
## ################################################################################### | ||
|
||
## case 1 | ||
set.seed(1234) | ||
data = genXY() | ||
X = data$X | ||
Y = data$Y | ||
res1 = select.vs.shrink(X, Y) | ||
png("res_rho_05.png", width = 640, height = 480) | ||
plot(res1, rho = 0.5) | ||
dev.off() | ||
## case 2 | ||
set.seed(1234) | ||
data2 = genXY(rho = -0.5) | ||
X2 = data2$X | ||
Y2 = data2$Y | ||
res2 = select.vs.shrink(X2, Y2) | ||
png("res_rho_-05.png", width = 640, height = 480) | ||
plot(res2, rho = -0.5) | ||
dev.off() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
## ############################################################################################## | ||
## file LDA/sim-fig-4-2.R | ||
## copyright (C) 2018 Lijun Wang <[email protected]> | ||
## | ||
## This program is to simulate figure 4.2 in the book called The Elements of Statistical Learning. | ||
## You also can find the translated documents from https://esl.hohoweiya.xyz | ||
## | ||
## ############################################################################################# | ||
|
||
## generate data and reproduce figure 4.2 | ||
mu = c(0.25, 0.5, 0.75) | ||
sigma = 0.005*matrix(c(1, 0, | ||
0, 1), 2, 2) | ||
library(MASS) | ||
set.seed(1650) | ||
N = 100 | ||
X1 = mvrnorm(n = N, c(mu[1], mu[1]), Sigma = sigma) | ||
X2 = mvrnorm(n = N, c(mu[2], mu[2]), Sigma = sigma) | ||
X3 = mvrnorm(n = N, c(mu[3], mu[3]), Sigma = sigma) | ||
X = rbind(X1, X2, X3) | ||
|
||
png("reproduce-fig-4-2.png") | ||
plot(X1[,1],X1[,2],col="orange", xlim = c(0,1),ylim = c(0,1), pch="1", | ||
xlab = expression(X[1]), ylab = expression(X[2])) | ||
points(X2[,1],X2[,2],col="blue", pch="2") | ||
points(X3[,1],X3[,2],col="green", pch="3") | ||
dev.off() | ||
|
||
## project X onto the line joining the three centroids | ||
X.proj = rowMeans(X) # if necessary, multiply sqrt 2 | ||
## fit as in figure 4.3 | ||
## consider orange | ||
Y1 = c(rep(1, N), rep(0, N*2)) | ||
## blue | ||
Y2 = c(rep(0, N), rep(1, N), rep(0, N)) | ||
## green | ||
Y3 = c(rep(0, N), rep(0, N), rep(1, N)) | ||
## regression | ||
m1 = lm(Y1~X.proj) | ||
pred1 = as.numeric(fitted(m1)[order(X.proj)]) | ||
m2 = lm(Y2~X.proj) | ||
pred2 = as.numeric(fitted(m2)[order(X.proj)]) | ||
m3 = lm(Y3~X.proj) | ||
pred3 = as.numeric(fitted(m3)[order(X.proj)]) | ||
c1 = which(pred1 <= pred2)[1] | ||
c2 = min(which(pred3 > pred2)) | ||
# class 1: 1 ~ c1 | ||
# class 2: c1+1 ~ c2 | ||
# class 3: c2+1 ~ end | ||
# actually, c1 = c2 | ||
err1 = (abs(c2 - 2*N) + abs(c1 - N))/(3*N) | ||
|
||
## reproduce figure 4.3 left | ||
png("reproduce-fig-4-3l.png") | ||
plot(0, 0, type = "n", | ||
xlim = c(0, 1), ylim = c(0,1), xlab = "", ylab = "", | ||
main = paste0("Degree = 1; Error = ", round(err1, digits = 4))) | ||
abline(coef(m1), col = "orange") | ||
abline(coef(m2), col = "blue") | ||
abline(coef(m3), col = "green") | ||
points(X.proj, fitted(m1), pch="1", col="orange") | ||
points(X.proj, fitted(m2), pch = "2", col = "blue") | ||
points(X.proj, fitted(m3), pch = "3", col = "green") | ||
rug(X.proj[1:N], col = "orange") | ||
rug(X.proj[(N+1):(2*N)], col = "blue") | ||
rug(X.proj[(2*N+1):(3*N)], col = "green") | ||
abline(h=c(0.0, 0.5, 1.0), lty=5, lwd = 0.4) | ||
abline(v=c(sort(X.proj)[N], sort(X.proj)[N*2]), lwd = 0.4) | ||
dev.off() | ||
|
||
## polynomial regression | ||
pm1 = lm(Y1~X.proj+I(X.proj^2)) | ||
pm2 = lm(Y2~X.proj+I(X.proj^2)) | ||
pm3 = lm(Y3~X.proj+I(X.proj^2)) | ||
## error rate for figure 4.3 right | ||
pred21 = as.numeric(fitted(pm1)[order(X.proj)]) | ||
pred22 = as.numeric(fitted(pm2)[order(X.proj)]) | ||
pred23 = as.numeric(fitted(pm3)[order(X.proj)]) | ||
c1 = which(pred21 <= pred22)[1] - 1 | ||
c2 = max(which(pred23 <= pred22)) | ||
# class 1: 1 ~ c1 | ||
# class 2: c1+1 ~ c2 | ||
# class 3: c2+1 ~ end | ||
err2 = (abs(c2 - 2*N) + abs(c1 - N))/(3*N) | ||
|
||
## reproduce figure 4.3 right | ||
png("reproduce-fig-4-3r.png") | ||
plot(0, 0, type = "n", | ||
xlim = c(0, 1), ylim = c(-1,2), xlab = "", ylab = "", | ||
main = paste0("Degree = 2; Error = ", round(err2, digits = 4))) | ||
lines(sort(X.proj), fitted(pm1)[order(X.proj)], col="orange", type = "o", pch = "1") | ||
lines(sort(X.proj), fitted(pm2)[order(X.proj)], col="blue", type = "o", pch = "2") | ||
lines(sort(X.proj), fitted(pm3)[order(X.proj)], col="green", type = "o", pch = "3") | ||
abline(h=c(0.0, 0.5, 1.0), lty=5, lwd = 0.4) | ||
## add rug | ||
rug(X.proj[1:N], col = "orange") | ||
rug(X.proj[(N+1):(2*N)], col = "blue") | ||
rug(X.proj[(2*N+1):(3*N)], col = "green") | ||
abline(v=c(sort(X.proj)[N], sort(X.proj)[N*2]), lwd = 0.4) | ||
dev.off() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
sigma = diag(1, 2, 2) | ||
mu1 = c(-1, 0) | ||
mu2 = c(1, 0) | ||
mu3 = c(0, 1.7) | ||
N = 30 |
Oops, something went wrong.