Skip to content

Commit

Permalink
add intercept and standardize functions
Browse files Browse the repository at this point in the history
  • Loading branch information
stephens999 committed May 10, 2018
1 parent 74cb314 commit 338453a
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 9 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: susieR
Type: Package
Title: Fit Sum of Single Effects linear regression model
Version: 0.1.4
Version: 0.1.5
Author: Matthew Stephens
Maintainer: <[email protected]>
Description: Fits a sparse regression model with up to $L$ effects, where $L$ is user-specified.
Expand All @@ -10,3 +10,4 @@ License: MIT
Encoding: UTF-8
LazyData: true
RoxygenNote: 6.0.1.9000
Suggests: testthat
8 changes: 4 additions & 4 deletions R/predict.susie.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#' @title extract regression coefficients from susie fit
#' @param s a susie fit
#' @return a p vector of estimated regression coefficients
#' @return a p+1 vector, the first element being an intercept, and the remaining p elements being estimated regression coefficients
#' @method coef susie
#' @export
coef.susie = function(s){
colSums(s$alpha*s$mu)
c(s$intercept,colSums(s$alpha*s$mu)/s$X_column_scale_factors )
}

#' @title predict future observations or extract coefficients from susie fit
Expand All @@ -22,7 +22,7 @@ predict.susie = function(s,newx = NULL,type=c("response","coefficients")){
return(coef(s))
}

if(missing(newx)){return(s$Xr)}
if(missing(newx)){return(s$fitted)}

return(newx %*% coef(s))
return(s$intercept + newx %*% coef(s)[-1])
}
29 changes: 28 additions & 1 deletion R/susie.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#' @param L maximum number of non-zero effects
#' @param prior_variance the scaled prior variance (vector of length L, or scalar. In latter case gets repeated L times )
#' @param residual_variance the residual variance (defaults to variance of Y)
#' @param standardize logical flag for whether to standardize X to unit variance prior to fitting.
#' @param intercept Should intercept be fitted (default=TRUE) or set to zero (FALSE)
#' @param max_iter maximum number of iterations to perform
#' @param tol convergence tolerance
#' @param estimate_residual_variance indicates whether to estimate residual variance
Expand Down Expand Up @@ -39,12 +41,26 @@
#' coef(res)
#' plot(y,predict(res))
#' @export
susie = function(X,Y,L=10,prior_variance=1,residual_variance=NULL,max_iter=100,tol=1e-2,estimate_residual_variance=TRUE,estimate_prior_variance = FALSE, s_init = NULL, verbose=FALSE){
susie = function(X,Y,L=10,prior_variance=1,residual_variance=NULL,standardize=TRUE,intercept=TRUE,max_iter=100,tol=1e-2,estimate_residual_variance=TRUE,estimate_prior_variance = FALSE, s_init = NULL, verbose=FALSE){
# Check input X.
if (!is.double(X) || !is.matrix(X))
stop("Input X must be a double-precision matrix")
p = ncol(X)
n = nrow(X)
mean_y = mean(Y)

if(intercept){ # center Y and X
Y = Y-mean_y
X = scale(X,center=TRUE, scale = FALSE)
} else {
attr(X,"scaled:center")=rep(0,p)
}

if(standardize){
X = scale(X,center=FALSE, scale=TRUE)
} else {
attr(X,"scaled:scale")=rep(1,p)
}

# initialize susie fit
if(!is.null(s_init)){
Expand Down Expand Up @@ -101,5 +117,16 @@ susie = function(X,Y,L=10,prior_variance=1,residual_variance=NULL,max_iter=100,t
}
elbo = elbo[1:(i+1)] #remove trailing NAs
s$elbo <- elbo

if(intercept){
s$intercept = mean_y - sum(attr(X,"scaled:center")* (colSums(s$alpha*s$mu)/attr(X,"scaled:scale")))# estimate intercept (unshrunk)
s$fitted = s$Xr + mean_y
} else {
s$intercept = 0
s$fitted = s$Xr
}

s$X_column_scale_factors = attr(X,"scaled:scale")

return(s)
}
2 changes: 1 addition & 1 deletion man/coef.susie.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions man/susie.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions tests/testthat.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
library(testthat)
library(susieR)

test_check("susieR")
19 changes: 19 additions & 0 deletions tests/testthat/test_intercept_standardize.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
test_that("scaling and intercept works as expected",{
set.seed(1)
x = matrix(rnorm(2000,3,4),ncol=10)
b = rnorm(10)
y = x %*% b + rnorm(200,0,0.1)
s1 = susie(x,y,intercept= TRUE, standardize=TRUE)
s2 = susie(x,y,intercept = FALSE, standardize = FALSE)
s3 = susie(x,y,intercept =TRUE, standardize = FALSE)
s4 = susie(x,y,intercept = FALSE,standardize = TRUE)

expect_equal(predict(s2),predict(s2,x))
expect_equal(predict(s4),predict(s4,x))
expect_equal(predict(s1),predict(s1,x))
expect_equal(predict(s3),predict(s3,x))

expect_equal(s2$intercept, 0)
expect_equal(s4$intercept, 0)

})

0 comments on commit 338453a

Please sign in to comment.