diff --git a/client-common/src/main/java/com/cloudera/livy/client/common/HttpMessages.java b/client-common/src/main/java/com/cloudera/livy/client/common/HttpMessages.java index 42c393ed7..7a6272dd2 100644 --- a/client-common/src/main/java/com/cloudera/livy/client/common/HttpMessages.java +++ b/client-common/src/main/java/com/cloudera/livy/client/common/HttpMessages.java @@ -56,15 +56,19 @@ private CreateClientRequest() { public static class SessionInfo implements ClientMessage { public final int id; + public final String owner; + public final String proxyUser; public final String state; - public SessionInfo(int id, String state) { + public SessionInfo(int id, String owner, String proxyUser, String state) { this.id = id; + this.owner = owner; + this.proxyUser = proxyUser; this.state = state; } private SessionInfo() { - this(-1, null); + this(-1, null, null, null); } } diff --git a/client-http/src/test/scala/com/cloudera/livy/client/http/HttpClientSpec.scala b/client-http/src/test/scala/com/cloudera/livy/client/http/HttpClientSpec.scala index 6808578af..0e8ebbd23 100644 --- a/client-http/src/test/scala/com/cloudera/livy/client/http/HttpClientSpec.scala +++ b/client-http/src/test/scala/com/cloudera/livy/client/http/HttpClientSpec.scala @@ -269,6 +269,7 @@ private class HttpClientTestBootstrap extends LifeCycle { val id = sessionManager.nextId() when(session.id).thenReturn(id) when(session.state).thenReturn(SessionState.Idle()) + when(session.proxyUser).thenReturn(None) when(session.stop()).thenReturn(Future { }) require(HttpClientSpec.session == null, "Session already created?") HttpClientSpec.session = session diff --git a/client-local/src/main/java/com/cloudera/livy/client/local/LocalClient.java b/client-local/src/main/java/com/cloudera/livy/client/local/LocalClient.java index 1bfed37c7..d5f38912c 100644 --- a/client-local/src/main/java/com/cloudera/livy/client/local/LocalClient.java +++ b/client-local/src/main/java/com/cloudera/livy/client/local/LocalClient.java @@ -317,6 +317,11 @@ public void run() { argv.add("--class"); argv.add(RemoteDriver.class.getName()); + if (conf.get(PROXY_USER) != null) { + argv.add("--proxy-user"); + argv.add(conf.get(PROXY_USER)); + } + String jar = "spark-internal"; String livyJars = conf.get(LIVY_JARS); if (livyJars == null) { diff --git a/client-local/src/main/java/com/cloudera/livy/client/local/LocalConf.java b/client-local/src/main/java/com/cloudera/livy/client/local/LocalConf.java index 5c4a48e53..fcf4c1ebc 100644 --- a/client-local/src/main/java/com/cloudera/livy/client/local/LocalConf.java +++ b/client-local/src/main/java/com/cloudera/livy/client/local/LocalConf.java @@ -46,6 +46,8 @@ public static enum Entry implements ConfEntry { LIVY_JARS("jars", null), + PROXY_USER("proxy_user", null), + RPC_SERVER_ADDRESS("rpc.server.address", null), RPC_CLIENT_HANDSHAKE_TIMEOUT("server.connect.timeout", "90000ms"), RPC_CLIENT_CONNECT_TIMEOUT("client.connect.timeout", "10000ms"), diff --git a/client-local/src/test/java/com/cloudera/livy/client/local/TestSparkClient.java b/client-local/src/test/java/com/cloudera/livy/client/local/TestSparkClient.java index b396e1f1c..64a2373f9 100644 --- a/client-local/src/test/java/com/cloudera/livy/client/local/TestSparkClient.java +++ b/client-local/src/test/java/com/cloudera/livy/client/local/TestSparkClient.java @@ -35,6 +35,7 @@ import com.google.common.base.Objects; import com.google.common.base.Strings; import com.google.common.io.ByteStreams; +import org.apache.hadoop.security.UserGroupInformation; import org.apache.spark.SparkFiles; import org.apache.spark.api.java.JavaFutureAction; import org.apache.spark.api.java.JavaRDD; @@ -247,7 +248,7 @@ void call(LivyClient client) throws Exception { } @Test - public void testStreamingContext() throws Exception{ + public void testStreamingContext() throws Exception { runTest(true, new TestFunction() { @Override void call(LivyClient client) throws Exception { @@ -258,6 +259,25 @@ void call(LivyClient client) throws Exception { }); } + @Test + public void testImpersonation() throws Exception { + final String PROXY = "__proxy__"; + + runTest(false, new TestFunction() { + @Override + void config(Properties conf) { + conf.put(LocalConf.Entry.PROXY_USER.key(), PROXY); + } + + @Override + void call(LivyClient client) throws Exception { + JobHandle handle = client.submit(new GetCurrentUserJob()); + String userName = handle.get(TIMEOUT, TimeUnit.SECONDS); + assertEquals(PROXY, userName); + } + }); + } + @Test public void testBypass() throws Exception { runBypassTest(false); @@ -518,6 +538,15 @@ public String call(JobContext jc) { } + private static class GetCurrentUserJob implements Job { + + @Override + public String call(JobContext jc) throws Exception { + return UserGroupInformation.getCurrentUser().getUserName(); + } + + } + private abstract static class TestFunction { abstract void call(LivyClient client) throws Exception; void config(Properties conf) { } diff --git a/core/src/main/scala/com/cloudera/livy/LivyConf.scala b/core/src/main/scala/com/cloudera/livy/LivyConf.scala index 70bd991ce..e1ab221ee 100644 --- a/core/src/main/scala/com/cloudera/livy/LivyConf.scala +++ b/core/src/main/scala/com/cloudera/livy/LivyConf.scala @@ -43,6 +43,7 @@ object LivyConf { val IMPERSONATION_ENABLED = Entry("livy.impersonation.enabled", false) val LIVY_HOME = Entry("livy.home", null) val FILE_UPLOAD_MAX_SIZE = Entry("livy.file.upload.max.size", 100L * 1024 * 1024) + val SUPERUSERS = Entry("livy.superusers", null) lazy val TEST_LIVY_HOME = Files.createTempDirectory("livyTemp").toUri.toString } @@ -55,6 +56,8 @@ class LivyConf(loadDefaults: Boolean) extends ClientConf[LivyConf](null) { import LivyConf._ + private lazy val _superusers = Option(get(SUPERUSERS)).map(_.split("[, ]+").toSeq).getOrElse(Nil) + /** * Create a LivyConf that loads defaults from the system properties and the classpath. * @return @@ -92,6 +95,9 @@ class LivyConf(loadDefaults: Boolean) extends ClientConf[LivyConf](null) { .getOrElse("spark-submit") } + /** Return the list of superusers. */ + def superusers(): Seq[String] = _superusers + private def loadFromMap(map: Iterable[(String, String)]): Unit = { map.foreach { case (k, v) => if (k.startsWith("livy.")) { diff --git a/server/src/main/scala/com/cloudera/livy/server/JsonServlet.scala b/server/src/main/scala/com/cloudera/livy/server/JsonServlet.scala index cf67f77b7..4b3de6225 100644 --- a/server/src/main/scala/com/cloudera/livy/server/JsonServlet.scala +++ b/server/src/main/scala/com/cloudera/livy/server/JsonServlet.scala @@ -77,18 +77,20 @@ abstract class JsonServlet extends ScalatraServlet with ApiFormats with FutureSu } } - override protected def renderResponse(actionResult: Any): Unit = { + override protected def renderResponseBody(actionResult: Any): Unit = { val result = actionResult match { - case async: AsyncResult => - async case ActionResult(status, body, headers) if format == "json" => ActionResult(status, toJson(body), headers) + case str: String if format == "json" => + // This should be changed when we implement LIVY-54. For now, just create a dummy + // JSON object when a raw string is being returned. + toJson(Map("msg" -> str)) case other if format == "json" => - Ok(toJson(other)) + toJson(other) case other => other } - super.renderResponse(result) + super.renderResponseBody(result) } protected def bodyAs[T: ClassTag](req: HttpServletRequest) diff --git a/server/src/main/scala/com/cloudera/livy/server/Main.scala b/server/src/main/scala/com/cloudera/livy/server/Main.scala index 98fc940ad..96aef796d 100644 --- a/server/src/main/scala/com/cloudera/livy/server/Main.scala +++ b/server/src/main/scala/com/cloudera/livy/server/Main.scala @@ -98,6 +98,10 @@ object Main extends Logging { livyConf.get(KERBEROS_NAME_RULES)) server.context.addFilter(holder, "/*", EnumSet.allOf(classOf[DispatcherType])) info(s"SPNEGO auth enabled (principal = $principal)") + if (!livyConf.getBoolean(LivyConf.IMPERSONATION_ENABLED)) { + info(s"Enabling impersonation since auth type is $authType.") + livyConf.set(LivyConf.IMPERSONATION_ENABLED, true) + } case null => // Nothing to do. diff --git a/server/src/main/scala/com/cloudera/livy/server/SessionServlet.scala b/server/src/main/scala/com/cloudera/livy/server/SessionServlet.scala index 6afc4b662..114236863 100644 --- a/server/src/main/scala/com/cloudera/livy/server/SessionServlet.scala +++ b/server/src/main/scala/com/cloudera/livy/server/SessionServlet.scala @@ -35,7 +35,6 @@ object SessionServlet extends Logging * * Type parameters: * S: the session type - * R: the type representing the session create parameters. */ abstract class SessionServlet[S <: Session](livyConf: LivyConf) extends JsonServlet @@ -159,6 +158,36 @@ abstract class SessionServlet[S <: Session](livyConf: LivyConf) */ protected def remoteUser(req: HttpServletRequest): String = req.getRemoteUser() + /** + * Checks that the request's user can impersonate the target user. + * + * If the user does not have permission to impersonate, then halt execution. + * + * @return The user that should be impersonated. That can be the target user if defined, the + * request's user - which may not be defined - otherwise, or `None` if impersonation is + * disabled. + */ + protected def checkImpersonation( + target: Option[String], + req: HttpServletRequest): Option[String] = { + if (livyConf.getBoolean(LivyConf.IMPERSONATION_ENABLED)) { + if (!target.map(hasAccess(_, req)).getOrElse(true)) { + halt(Forbidden(s"User '${remoteUser(req)}' not allowed to impersonate '$target'.")) + } + target.orElse(Option(remoteUser(req))) + } else { + None + } + } + + /** + * Check that the request's user has access to resources owned by the given target user. + */ + protected def hasAccess(target: String, req: HttpServletRequest): Boolean = { + val user = remoteUser(req) + user == null || user == target || livyConf.superusers().contains(user) + } + /** * Performs an operation on the session, without checking for ownership. Operations executed * via this method must not modify the session in any way, or return potentially sensitive @@ -176,7 +205,7 @@ abstract class SessionServlet[S <: Session](livyConf: LivyConf) val sessionId = params("id").toInt sessionManager.get(sessionId) match { case Some(session) => - if (allowAll || isOwner(session, request)) { + if (allowAll || hasAccess(session.owner, request)) { fn(session) } else { Forbidden() @@ -186,13 +215,6 @@ abstract class SessionServlet[S <: Session](livyConf: LivyConf) } } - /** - * Returns whether the current request's user is the owner of the given session. - */ - protected def isOwner(session: Session, req: HttpServletRequest): Boolean = { - session.owner == remoteUser(req) - } - private def serializeLogs(session: S, fromOpt: Option[Int], sizeOpt: Option[Int]) = { val lines = session.logLines() diff --git a/server/src/main/scala/com/cloudera/livy/server/batch/BatchSession.scala b/server/src/main/scala/com/cloudera/livy/server/batch/BatchSession.scala index ae0b3ebc1..139c59fee 100644 --- a/server/src/main/scala/com/cloudera/livy/server/batch/BatchSession.scala +++ b/server/src/main/scala/com/cloudera/livy/server/batch/BatchSession.scala @@ -26,7 +26,12 @@ import com.cloudera.livy.LivyConf import com.cloudera.livy.sessions.{Session, SessionState} import com.cloudera.livy.utils.SparkProcessBuilder -class BatchSession(id: Int, owner: String, livyConf: LivyConf, request: CreateBatchRequest) +class BatchSession( + id: Int, + owner: String, + proxyUser: Option[String], + livyConf: LivyConf, + request: CreateBatchRequest) extends Session(id, owner) { private val process = { @@ -34,7 +39,7 @@ class BatchSession(id: Int, owner: String, livyConf: LivyConf, request: CreateBa val builder = new SparkProcessBuilder(livyConf) builder.conf(request.conf) - request.proxyUser.foreach(builder.proxyUser) + proxyUser.foreach(builder.proxyUser) request.className.foreach(builder.className) request.jars.foreach(builder.jar) request.pyFiles.foreach(builder.pyFile) diff --git a/server/src/main/scala/com/cloudera/livy/server/batch/BatchSessionServlet.scala b/server/src/main/scala/com/cloudera/livy/server/batch/BatchSessionServlet.scala index 1fe364467..2a7c37afb 100644 --- a/server/src/main/scala/com/cloudera/livy/server/batch/BatchSessionServlet.scala +++ b/server/src/main/scala/com/cloudera/livy/server/batch/BatchSessionServlet.scala @@ -31,12 +31,13 @@ class BatchSessionServlet(livyConf: LivyConf) override protected def createSession(req: HttpServletRequest): BatchSession = { val createRequest = bodyAs[CreateBatchRequest](req) - new BatchSession(sessionManager.nextId(), remoteUser(req), livyConf, createRequest) + val proxyUser = checkImpersonation(createRequest.proxyUser, req) + new BatchSession(sessionManager.nextId(), remoteUser(req), proxyUser, livyConf, createRequest) } override protected def clientSessionView(session: BatchSession, req: HttpServletRequest): Any = { val logs = - if (isOwner(session, req)) { + if (hasAccess(session.owner, req)) { val lines = session.logLines() val size = 10 diff --git a/server/src/main/scala/com/cloudera/livy/server/client/ClientSession.scala b/server/src/main/scala/com/cloudera/livy/server/client/ClientSession.scala index 5b48830c6..0d956abb1 100644 --- a/server/src/main/scala/com/cloudera/livy/server/client/ClientSession.scala +++ b/server/src/main/scala/com/cloudera/livy/server/client/ClientSession.scala @@ -34,10 +34,15 @@ import org.apache.hadoop.fs.{FileSystem, Path} import com.cloudera.livy.{LivyClientBuilder, Logging} import com.cloudera.livy.client.common.HttpMessages._ -import com.cloudera.livy.client.local.LocalClient +import com.cloudera.livy.client.local.{LocalClient, LocalConf} import com.cloudera.livy.sessions.{Session, SessionState} -class ClientSession(id: Int, owner: String, createRequest: CreateClientRequest, livyHome: String) +class ClientSession( + id: Int, + owner: String, + val proxyUser: Option[String], + createRequest: CreateClientRequest, + livyHome: String) extends Session(id, owner) with Logging { implicit val executionContext = ExecutionContext.global @@ -47,13 +52,15 @@ class ClientSession(id: Int, owner: String, createRequest: CreateClientRequest, private val client = { info(s"Creating LivyClient for sessionId: $id") - new LivyClientBuilder() + val builder = new LivyClientBuilder() .setConf("spark.app.name", s"livy-session-$id") .setConf("spark.master", "yarn-cluster") .setAll(Option(createRequest.conf).getOrElse(new JHashMap())) .setURI(new URI("local:spark")) .setConf("livy.client.sessionId", id.toString) - .build() + + proxyUser.foreach(builder.setConf(LocalConf.Entry.PROXY_USER.key(), _)) + builder.build() }.asInstanceOf[LocalClient] private val fs = FileSystem.get(new Configuration()) diff --git a/server/src/main/scala/com/cloudera/livy/server/client/ClientSessionServlet.scala b/server/src/main/scala/com/cloudera/livy/server/client/ClientSessionServlet.scala index 5c550779f..afbd8ff51 100644 --- a/server/src/main/scala/com/cloudera/livy/server/client/ClientSessionServlet.scala +++ b/server/src/main/scala/com/cloudera/livy/server/client/ClientSessionServlet.scala @@ -26,6 +26,7 @@ import org.scalatra.servlet.{FileUploadSupport, MultipartConfig} import com.cloudera.livy.{JobHandle, LivyConf} import com.cloudera.livy.client.common.HttpMessages._ +import com.cloudera.livy.client.local.LocalConf import com.cloudera.livy.server.SessionServlet import com.cloudera.livy.sessions.SessionManager @@ -39,7 +40,15 @@ class ClientSessionServlet(livyConf: LivyConf) override protected def createSession(req: HttpServletRequest): ClientSession = { val id = sessionManager.nextId() val createRequest = bodyAs[CreateClientRequest](req) - new ClientSession(id, remoteUser(req), createRequest, livyConf.livyHome) + val user = remoteUser(req) + val requestedProxy = + if (createRequest.conf != null) { + Option(createRequest.conf.get(LocalConf.Entry.PROXY_USER.key())) + } else { + None + } + val proxyUser = checkImpersonation(requestedProxy, req) + new ClientSession(id, user, proxyUser, createRequest, livyConf.livyHome) } jpost[SerializedJob]("/:id/submit-job") { req => @@ -123,7 +132,8 @@ class ClientSessionServlet(livyConf: LivyConf) } override protected def clientSessionView(session: ClientSession, req: HttpServletRequest): Any = { - new SessionInfo(session.id, session.state.toString) + new SessionInfo(session.id, session.owner, session.proxyUser.getOrElse(null), + session.state.toString) } } diff --git a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala index 14817d6ca..4bc801737 100644 --- a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala +++ b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala @@ -52,6 +52,7 @@ object InteractiveSession { class InteractiveSession( id: Int, owner: String, + _proxyUser: Option[String], livyConf: LivyConf, request: CreateInteractiveRequest) extends Session(id, owner) { @@ -83,7 +84,7 @@ class InteractiveSession( val jars = request.jars ++ livyJars(livyConf) jars.foreach(builder.jar) - request.proxyUser.foreach(builder.proxyUser) + _proxyUser.foreach(builder.proxyUser) request.queue.foreach(builder.queue) request.name.foreach(builder.name) @@ -220,7 +221,7 @@ class InteractiveSession( def kind: Kind = request.kind - def proxyUser: Option[String] = request.proxyUser + def proxyUser: Option[String] = _proxyUser def url: Option[URL] = _url diff --git a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSessionServlet.scala b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSessionServlet.scala index f2391d88c..822f9e936 100644 --- a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSessionServlet.scala +++ b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSessionServlet.scala @@ -44,14 +44,16 @@ class InteractiveSessionServlet(livyConf: LivyConf) override protected def createSession(req: HttpServletRequest): InteractiveSession = { val createRequest = bodyAs[CreateInteractiveRequest](req) - new InteractiveSession(sessionManager.nextId(), remoteUser(req), livyConf, createRequest) + val proxyUser = checkImpersonation(createRequest.proxyUser, req) + new InteractiveSession(sessionManager.nextId(), remoteUser(req), proxyUser, livyConf, + createRequest) } override protected def clientSessionView( session: InteractiveSession, req: HttpServletRequest): Any = { val logs = - if (isOwner(session, req)) { + if (hasAccess(session.owner, req)) { val lines = session.logLines() val size = 10 diff --git a/server/src/test/scala/com/cloudera/livy/server/BaseJsonServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/BaseJsonServletSpec.scala index 58bc41195..127ff17d5 100644 --- a/server/src/test/scala/com/cloudera/livy/server/BaseJsonServletSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/BaseJsonServletSpec.scala @@ -99,18 +99,21 @@ abstract class BaseJsonServletSpec extends ScalatraSuite with FunSpecLike { private def doTest[R: ClassTag](expectedStatus: Int, fn: R => Unit) (implicit klass: ClassTag[R]): Unit = { status should be (expectedStatus) - val result = - if (header("Content-Type").startsWith("application/json")) { - if (header("Content-Length").toInt > 0) { - mapper.readValue(response.inputStream, klass.runtimeClass) + // Only try to parse the body if response is in the "OK" range (20x). + if ((status / 100) * 100 == SC_OK) { + val result = + if (header("Content-Type").startsWith("application/json")) { + if (header("Content-Length").toInt > 0) { + mapper.readValue(response.inputStream, klass.runtimeClass) + } else { + null + } } else { - null + assert(klass.runtimeClass == classOf[Unit]) + () } - } else { - assert(klass.runtimeClass == classOf[Unit]) - () - } - fn(result.asInstanceOf[R]) + fn(result.asInstanceOf[R]) + } } private def toJson(obj: Any): Array[Byte] = mapper.writeValueAsBytes(obj) diff --git a/server/src/test/scala/com/cloudera/livy/server/BaseSessionServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/BaseSessionServletSpec.scala index a40be6eaf..9471dbe5f 100644 --- a/server/src/test/scala/com/cloudera/livy/server/BaseSessionServletSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/BaseSessionServletSpec.scala @@ -18,14 +18,41 @@ package com.cloudera.livy.server +import javax.servlet.http.HttpServletRequest + import org.scalatest.BeforeAndAfterAll +import com.cloudera.livy.LivyConf import com.cloudera.livy.sessions.Session +object BaseSessionServletSpec { + + /** Header used to override the user remote user in tests. */ + val REMOTE_USER_HEADER = "X-Livy-SessionServlet-User" + +} + abstract class BaseSessionServletSpec[S <: Session] extends BaseJsonServletSpec with BeforeAndAfterAll { + /** Name of the admin user. */ + protected val ADMIN = "__admin__" + + /** Create headers that identify a specific user in tests. */ + protected def makeUserHeaders(user: String): Map[String, String] = { + defaultHeaders ++ Map(BaseSessionServletSpec.REMOTE_USER_HEADER -> user) + } + + protected val adminHeaders = makeUserHeaders(ADMIN) + + /** Create a LivyConf with impersonation enabled and a superuser. */ + protected def createConf(): LivyConf = { + new LivyConf() + .set(LivyConf.IMPERSONATION_ENABLED, true) + .set(LivyConf.SUPERUSERS, ADMIN) + } + override def afterAll(): Unit = { super.afterAll() servlet.shutdown() @@ -40,3 +67,12 @@ abstract class BaseSessionServletSpec[S <: Session] protected def toJson(msg: AnyRef): Array[Byte] = mapper.writeValueAsBytes(msg) } + +trait RemoteUserOverride { + this: SessionServlet[_] => + + override protected def remoteUser(req: HttpServletRequest): String = { + req.getHeader(BaseSessionServletSpec.REMOTE_USER_HEADER) + } + +} diff --git a/server/src/test/scala/com/cloudera/livy/server/SessionServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/SessionServletSpec.scala index cb018ccb5..f0bb94362 100644 --- a/server/src/test/scala/com/cloudera/livy/server/SessionServletSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/SessionServletSpec.scala @@ -29,7 +29,7 @@ import com.cloudera.livy.sessions.{Session, SessionState} object SessionServletSpec { - val REMOTE_USER_HEADER = "X-Livy-SessionServlet-User" + val PROXY_USER = "proxyUser" class MockSession(id: Int, owner: String) extends Session(id, owner) { @@ -51,24 +51,24 @@ class SessionServletSpec import SessionServletSpec._ override def createServlet(): SessionServlet[Session] = { - new SessionServlet[Session](new LivyConf()) { + new SessionServlet[Session](createConf()) with RemoteUserOverride { override protected def createSession(req: HttpServletRequest): Session = { + val params = bodyAs[Map[String, String]](req) + checkImpersonation(params.get(PROXY_USER), req) new MockSession(sessionManager.nextId(), remoteUser(req)) } - override protected def clientSessionView(session: Session, req: HttpServletRequest): Any = { - val logs = if (isOwner(session, req)) session.logLines() else Nil + override protected def clientSessionView( + session: Session, + req: HttpServletRequest): Any = { + val logs = if (hasAccess(session.owner, req)) session.logLines() else Nil MockSessionView(session.id, session.owner, logs) } - - override protected def remoteUser(req: HttpServletRequest): String = { - req.getHeader(REMOTE_USER_HEADER) - } } } - private val aliceHeaders = defaultHeaders ++ Map(REMOTE_USER_HEADER -> "alice") - private val bobHeaders = defaultHeaders ++ Map(REMOTE_USER_HEADER -> "bob") + private val aliceHeaders = makeUserHeaders("alice") + private val bobHeaders = makeUserHeaders("bob") private def delete(id: Int, headers: Map[String, String], expectedStatus: Int): Unit = { jdelete[Map[String, Any]](s"/$id", headers = headers, expectedStatus = expectedStatus) { _ => @@ -102,6 +102,28 @@ class SessionServletSpec } } + it("should allow admins to access all sessions") { + jpost[MockSessionView]("/", Map(), headers = aliceHeaders) { res => + jget[MockSessionView](s"/${res.id}", headers = adminHeaders) { res => + assert(res.owner === "alice") + assert(res.logs === IndexedSeq("log")) + } + delete(res.id, adminHeaders, SC_OK) + } + } + + it("should not allow regular users to impersonate others") { + jpost[MockSessionView]("/", Map(PROXY_USER -> "bob"), headers = aliceHeaders, + expectedStatus = SC_FORBIDDEN) { _ => } + } + + it("should allow admins to impersonate anyone") { + jpost[MockSessionView]("/", Map(PROXY_USER -> "bob"), headers = adminHeaders) { res => + delete(res.id, bobHeaders, SC_FORBIDDEN) + delete(res.id, adminHeaders, SC_OK) + } + } + } } diff --git a/server/src/test/scala/com/cloudera/livy/server/batch/BatchSessionSpec.scala b/server/src/test/scala/com/cloudera/livy/server/batch/BatchSessionSpec.scala index ced7dd84f..12dab1fe8 100644 --- a/server/src/test/scala/com/cloudera/livy/server/batch/BatchSessionSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/batch/BatchSessionSpec.scala @@ -55,7 +55,7 @@ class BatchSessionSpec req.file = script.toString req.conf = Map("spark.driver.extraClassPath" -> sys.props("java.class.path")) - val batch = new BatchSession(0, null, new LivyConf(), req) + val batch = new BatchSession(0, null, None, new LivyConf(), req) Utils.waitUntil({ () => !batch.state.isActive }, Duration(10, TimeUnit.SECONDS)) (batch.state match { diff --git a/server/src/test/scala/com/cloudera/livy/server/client/ClientServletSpec.scala b/server/src/test/scala/com/cloudera/livy/server/client/ClientServletSpec.scala index c26d67cd4..ace6a2907 100644 --- a/server/src/test/scala/com/cloudera/livy/server/client/ClientServletSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/client/ClientServletSpec.scala @@ -31,6 +31,7 @@ import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps +import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.api.java.function.VoidFunction import org.scalatest.concurrent.Eventually._ @@ -38,12 +39,16 @@ import com.cloudera.livy.{Job, JobContext, JobHandle, LivyConf} import com.cloudera.livy.client.common.{BufferUtils, Serializer} import com.cloudera.livy.client.common.HttpMessages._ import com.cloudera.livy.client.local.LocalConf -import com.cloudera.livy.server.BaseSessionServletSpec +import com.cloudera.livy.server.{BaseSessionServletSpec, RemoteUserOverride} class ClientServletSpec extends BaseSessionServletSpec[ClientSession] { - override def createServlet(): ClientSessionServlet = new ClientSessionServlet(new LivyConf()) + private val PROXY = "__proxy__" + + override def createServlet(): ClientSessionServlet = { + new ClientSessionServlet(createConf()) with RemoteUserOverride + } private var sessionId: Int = -1 @@ -57,15 +62,7 @@ class ClientServletSpec describe("Client Servlet") { it("should create client sessions") { - val classpath = sys.props("java.class.path") - val conf = new HashMap[String, String] - conf.put("spark.master", "local") - conf.put("livy.local.jars", "") - conf.put("spark.driver.extraClassPath", classpath) - conf.put("spark.executor.extraClassPath", classpath) - conf.put(LocalConf.Entry.CLIENT_IN_PROCESS.key(), "true") - - jpost[SessionInfo]("/", new CreateClientRequest(10000L, conf)) { data => + jpost[SessionInfo]("/", createRequest()) { data => header("Location") should equal("/0") data.id should equal (0) sessionId = data.id @@ -130,6 +127,52 @@ class ClientServletSpec } } } + + it("should support user impersonation") { + val headers = makeUserHeaders(PROXY) + jpost[SessionInfo]("/", createRequest(inProcess = false), headers = headers) { data => + try { + data.owner should be (PROXY) + data.proxyUser should be (PROXY) + val user = runJob(data.id, new GetUserJob(), headers = headers) + user should be (PROXY) + } finally { + deleteSession(data.id) + } + } + } + + it("should honor impersonation requests") { + val request = createRequest(inProcess = false) + request.conf.put(LocalConf.Entry.PROXY_USER.key(), PROXY) + jpost[SessionInfo]("/", request, headers = adminHeaders) { data => + try { + data.owner should be (ADMIN) + data.proxyUser should be (PROXY) + val user = runJob(data.id, new GetUserJob(), headers = adminHeaders) + user should be (PROXY) + } finally { + deleteSession(data.id) + } + } + } + } + + private def deleteSession(id: Int): Unit = { + jdelete[Map[String, Any]](s"/$id", headers = adminHeaders) { _ => } + } + + private def createRequest(inProcess: Boolean = true): CreateClientRequest = { + val classpath = sys.props("java.class.path") + val conf = new HashMap[String, String] + conf.put("spark.master", "local") + conf.put("livy.local.jars", "") + conf.put("spark.driver.extraClassPath", classpath) + conf.put("spark.executor.extraClassPath", classpath) + if (inProcess) { + conf.put(LocalConf.Entry.CLIENT_IN_PROCESS.key(), "true") + } + new CreateClientRequest(10000L, conf) } private def testResourceUpload(cmd: String, sessionId: Int): Unit = { @@ -145,21 +188,34 @@ class ClientServletSpec Source.fromFile(resultFile).mkString should be("Test data") } } + private def testJobSubmission(sid: Int, sync: Boolean): Unit = { + val result = runJob(sid, new TestJob(), sync = sync) + result should be (42) + } + + private def runJob[T]( + sid: Int, + job: Job[T], + sync: Boolean = false, + headers: Map[String, String] = defaultHeaders): T = { val ser = new Serializer() - val job = BufferUtils.toByteArray(ser.serialize(new TestJob())) + val jobData = BufferUtils.toByteArray(ser.serialize(job)) val route = if (sync) s"/$sid/submit-job" else s"/$sid/run-job" var jobId: Long = -1L - jpost[JobStatus](route, new SerializedJob(job)) { data => + jpost[JobStatus](route, new SerializedJob(jobData), headers = headers) { data => jobId = data.id } + + var result: Option[T] = None eventually(timeout(1 minute), interval(100 millis)) { jget[JobStatus](s"/$sid/jobs/$jobId") { status => status.id should be (jobId) - val result = ser.deserialize(ByteBuffer.wrap(status.result)) - result should be (42) + status.state should be (JobHandle.State.SUCCEEDED) + result = Some(ser.deserialize(ByteBuffer.wrap(status.result)).asInstanceOf[T]) } } + result.getOrElse(throw new IllegalStateException()) } } @@ -183,3 +239,9 @@ class AsyncTestJob extends Job[Int] { } } + +class GetUserJob extends Job[String] { + + override def call(jc: JobContext): String = UserGroupInformation.getCurrentUser().getUserName() + +} diff --git a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala index a90a64926..c2fdb3415 100644 --- a/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala +++ b/server/src/test/scala/com/cloudera/livy/server/interactive/InteractiveSessionSpec.scala @@ -44,7 +44,7 @@ class InteractiveSessionSpec extends FunSpec with Matchers with BeforeAndAfterAl val req = new CreateInteractiveRequest() req.kind = PySpark() - new InteractiveSession(0, null, livyConf, req) + new InteractiveSession(0, null, None, livyConf, req) } override def afterAll(): Unit = {