Skip to content

Commit

Permalink
Added disabling of signatures when key is blank
Browse files Browse the repository at this point in the history
  • Loading branch information
Chip Senkbeil committed Jun 26, 2015
1 parent 04ef8ac commit bf234cb
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.ibm.spark.communication.security.{SecurityActorType, SignatureManager
import com.ibm.spark.kernel.protocol.v5.SocketType
import com.ibm.spark.kernel.protocol.v5.client.ActorLoader
import com.ibm.spark.kernel.protocol.v5.client.socket._
import com.ibm.spark.utils.LogLike
import com.typesafe.config.Config

/**
Expand All @@ -48,7 +49,7 @@ trait SystemInitialization {
/**
* Represents the standard implementation of SystemInitialization.
*/
trait StandardSystemInitialization extends SystemInitialization {
trait StandardSystemInitialization extends SystemInitialization with LogLike {
/**
* Initializes the system-related client objects.
*
Expand All @@ -67,6 +68,7 @@ trait StandardSystemInitialization extends SystemInitialization {
val commRegistrar = new CommRegistrar(commStorage)

val (heartbeat, stdin, shell, ioPub) = initializeSystemActors(
config = config,
actorSystem = actorSystem,
actorLoader = actorLoader,
socketFactory = socketFactory,
Expand All @@ -80,27 +82,30 @@ trait StandardSystemInitialization extends SystemInitialization {
}

private def initializeSystemActors(
actorSystem: ActorSystem, actorLoader: ActorLoader,
config: Config, actorSystem: ActorSystem, actorLoader: ActorLoader,
socketFactory: SocketFactory, commRegistrar: CommRegistrar,
commStorage: CommStorage
) = {
val signatureEnabled = config.getString("key").nonEmpty

val heartbeatClient = actorSystem.actorOf(
Props(classOf[HeartbeatClient], socketFactory, actorLoader),
Props(classOf[HeartbeatClient],
socketFactory, actorLoader, signatureEnabled),
name = SocketType.HeartbeatClient.toString
)

val stdinClient = actorSystem.actorOf(
Props(classOf[StdinClient], socketFactory, actorLoader),
Props(classOf[StdinClient], socketFactory, actorLoader, signatureEnabled),
name = SocketType.StdInClient.toString
)

val shellClient = actorSystem.actorOf(
Props(classOf[ShellClient], socketFactory, actorLoader),
Props(classOf[ShellClient], socketFactory, actorLoader, signatureEnabled),
name = SocketType.ShellClient.toString
)

val ioPubClient = actorSystem.actorOf(
Props(classOf[IOPubClient], socketFactory, actorLoader,
Props(classOf[IOPubClient], socketFactory, actorLoader, signatureEnabled,
commRegistrar, commStorage),
name = SocketType.IOPubClient.toString
)
Expand All @@ -111,14 +116,21 @@ trait StandardSystemInitialization extends SystemInitialization {
private def initializeSecurityActors(
config: Config,
actorSystem: ActorSystem
): ActorRef = {
): Option[ActorRef] = {
val key = config.getString("key")
val signatureScheme = config.getString("signature_scheme").replace("-", "")

val signatureManager = actorSystem.actorOf(
Props(classOf[SignatureManagerActor], key, signatureScheme),
name = SecurityActorType.SignatureManager.toString
)
var signatureManager: Option[ActorRef] = None

if (key.nonEmpty) {
logger.debug(s"Initializing client signatures with key '$key'!")
signatureManager = Some(actorSystem.actorOf(
Props(classOf[SignatureManagerActor], key, signatureScheme),
name = SecurityActorType.SignatureManager.toString
))
} else {
logger.debug(s"Signatures disabled for client!")
}

signatureManager
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ object HeartbeatMessage {}
*
* @param socketFactory A factory to create the ZeroMQ socket connection
* @param actorLoader The loader used to retrieve actors
* @param signatureEnabled Whether or not to check and provide signatures
*/
class HeartbeatClient(socketFactory : SocketFactory, actorLoader: ActorLoader)
extends Actor with LogLike
{
class HeartbeatClient(
socketFactory : SocketFactory,
actorLoader: ActorLoader,
signatureEnabled: Boolean
) extends Actor with LogLike {
logger.debug("Created new Heartbeat Client actor")
implicit val timeout = Timeout(1.minute)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import scala.util.Failure
*/
class IOPubClient(
socketFactory: SocketFactory, actorLoader: ActorLoader,
signatureEnabled: Boolean,
commRegistrar: CommRegistrar, commStorage: CommStorage
) extends Actor with LogLike {
private val PARENT_HEADER_NULL_MESSAGE = "Parent Header was null in Kernel Message."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@ import com.ibm.spark.kernel.protocol.v5.client.execution.{DeferredExecution, Def
import com.ibm.spark.kernel.protocol.v5.content.ExecuteReply

import com.ibm.spark.utils.LogLike
import scala.concurrent.Await
import scala.concurrent.duration._
import akka.pattern.ask

/**
* The client endpoint for Shell messages specified in the IPython Kernel Spec
* @param socketFactory A factory to create the ZeroMQ socket connection
* @param actorLoader The loader used to retrieve actors
* @param signatureEnabled Whether or not to check and provide signatures
*/
class ShellClient(socketFactory: SocketFactory, actorLoader: ActorLoader)
extends Actor with LogLike
{
class ShellClient(
socketFactory: SocketFactory,
actorLoader: ActorLoader,
signatureEnabled: Boolean
) extends Actor with LogLike {
logger.debug("Created shell client actor")
implicit val timeout = Timeout(21474835.seconds)

Expand All @@ -59,6 +63,9 @@ class ShellClient(socketFactory: SocketFactory, actorLoader: ActorLoader)
case message: ZMQMessage =>
logger.debug("Received shell kernel message.")
val kernelMessage: KernelMessage = message

// TODO: Validate incoming message signature

logger.trace(s"Kernel message is ${kernelMessage}")
receiveExecuteReply(message.parentHeader.msg_id,kernelMessage)

Expand All @@ -68,14 +75,15 @@ class ShellClient(socketFactory: SocketFactory, actorLoader: ActorLoader)
val signatureManager =
actorLoader.load(SecurityActorType.SignatureManager)

// TODO: Validate incoming message signature
val messageWithSignature = signatureManager ? message

import scala.concurrent.ExecutionContext.Implicits.global
messageWithSignature.map(_.asInstanceOf[KernelMessage]).foreach(kernelMessage => {
val zmqMessage: ZMQMessage = kernelMessage
val messageWithSignature = if (signatureEnabled) {
val signatureMessage = signatureManager ? message
Await.result(signatureMessage, 100.milliseconds)
.asInstanceOf[KernelMessage]
} else message

val zMQMessage: ZMQMessage = messageWithSignature

socket ! zmqMessage
})
socket ! zMQMessage
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ import play.api.libs.json.Json
import StdinClient._
import akka.pattern.ask

import scala.concurrent.duration._
import scala.concurrent.Await

object StdinClient {
type ResponseFunction = (String, Boolean) => String
val EmptyResponseFunction: ResponseFunction = (_, _) => ""
Expand All @@ -38,10 +41,12 @@ object StdinClient {
* The client endpoint for Stdin messages specified in the IPython Kernel Spec
* @param socketFactory A factory to create the ZeroMQ socket connection
* @param actorLoader The loader used to retrieve actors
* @param signatureEnabled Whether or not to check and provide signatures
*/
class StdinClient(
socketFactory: SocketFactory,
actorLoader: ActorLoader
actorLoader: ActorLoader,
signatureEnabled: Boolean
) extends Actor with LogLike {
logger.debug("Created stdin client actor")

Expand Down Expand Up @@ -80,16 +85,18 @@ class StdinClient(
.withContentString(inputReply)
.build

val signatureManager =
actorLoader.load(SecurityActorType.SignatureManager)
val messageWithSignature = signatureManager ? newKernelMessage

import scala.concurrent.ExecutionContext.Implicits.global
messageWithSignature.map(_.asInstanceOf[KernelMessage]).foreach(kernelMessage => {
val responseZmqMessage: ZMQMessage = kernelMessage
val messageWithSignature = if (signatureEnabled) {
val signatureManager =
actorLoader.load(SecurityActorType.SignatureManager)
val signatureMessage = signatureManager ? newKernelMessage
Await.result(signatureMessage, 100.milliseconds)
.asInstanceOf[KernelMessage]
} else newKernelMessage

val zmqMessage: ZMQMessage = messageWithSignature

socket ! responseZmqMessage
})
socket ! zmqMessage
} else {
logger.debug(s"Unknown message of type $messageType")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class HeartbeatClientSpec extends TestKit(ActorSystem("HeartbeatActorSpec"))
when(socketFactory.HeartbeatClient(any(classOf[ActorSystem]), any(classOf[ActorRef]))).thenReturn(probe.ref)

val heartbeatClient = system.actorOf(Props(
classOf[HeartbeatClient], socketFactory, mockActorLoader
classOf[HeartbeatClient], socketFactory, mockActorLoader, true
))

describe("send heartbeat") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class IOPubClientSpec extends TestKit(ActorSystem(
timeout = scaled(Span(200, Milliseconds)),
interval = scaled(Span(5, Milliseconds))
)
private val SignatureEnabled = true

private var clientSocketProbe: TestProbe = _
private var mockClientSocketFactory: SocketFactory = _
Expand Down Expand Up @@ -93,7 +94,7 @@ class IOPubClientSpec extends TestKit(ActorSystem(
// Construct the object we will test against
ioPubClient = system.actorOf(Props(
classOf[IOPubClient], mockClientSocketFactory, mockActorLoader,
mockCommRegistrar, spyCommStorage
SignatureEnabled, mockCommRegistrar, spyCommStorage
))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import play.api.libs.json.Json

class ShellClientSpec extends TestKit(ActorSystem("ShellActorSpec"))
with ImplicitSender with FunSpecLike with Matchers with MockitoSugar {
private val SignatureEnabled = true

describe("ShellClientActor") {
val socketFactory = mock[SocketFactory]
Expand All @@ -47,7 +48,7 @@ class ShellClientSpec extends TestKit(ActorSystem("ShellActorSpec"))
.when(mockActorLoader).load(SecurityActorType.SignatureManager)

val shellClient = system.actorOf(Props(
classOf[ShellClient], socketFactory, mockActorLoader
classOf[ShellClient], socketFactory, mockActorLoader, SignatureEnabled
))

describe("send execute request") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class StdinClientSpec extends TestKit(ActorSystem("StdinActorSpec"))
with ImplicitSender with FunSpecLike with Matchers with MockitoSugar
with BeforeAndAfter
{
private val SignatureEnabled = true
private val TestReplyString = "some value"
private val TestResponseFunc: ResponseFunction = (_, _) => TestReplyString

Expand All @@ -57,7 +58,7 @@ class StdinClientSpec extends TestKit(ActorSystem("StdinActorSpec"))
.StdinClient(any[ActorSystem], any[ActorRef])

stdinClient = system.actorOf(Props(
classOf[StdinClient], mockSocketFactory, mockActorLoader
classOf[StdinClient], mockSocketFactory, mockActorLoader, SignatureEnabled
))

// Set the response function for our client socket
Expand Down
9 changes: 5 additions & 4 deletions client/src/test/scala/test/utils/SparkClientDeployer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ object SparkClientDeployer extends LogLike{
config: Config, actorSystem: ActorSystem, actorLoader: ActorLoader,
socketFactory: SocketFactory):
(ActorRef, ActorRef, ActorRef, ActorRef, CommRegistrar, CommStorage) = {
val signatureEnabled = config.getString("key").nonEmpty
val commStorage = new CommStorage()
val commRegistrar = new CommRegistrar(commStorage)

heartbeatProbe = new TestProbe(actorSystem)
val heartbeatClient = actorSystem.actorOf(
Props(classOf[HeartbeatClient], socketFactory, actorLoader)
Props(classOf[HeartbeatClient], socketFactory, actorLoader, signatureEnabled)
)
val heartbeatInterceptor = actorSystem.actorOf(
Props(new ActorInterceptor(heartbeatProbe, heartbeatClient)),
Expand All @@ -89,7 +90,7 @@ object SparkClientDeployer extends LogLike{

stdinProbe = new TestProbe(actorSystem)
val stdinClient = actorSystem.actorOf(
Props(classOf[StdinClient], socketFactory, actorLoader)
Props(classOf[StdinClient], socketFactory, actorLoader, signatureEnabled)
)
val stdinInterceptor = actorSystem.actorOf(
Props(new ActorInterceptor(stdinProbe, stdinClient)),
Expand All @@ -98,7 +99,7 @@ object SparkClientDeployer extends LogLike{

shellProbe = new TestProbe(actorSystem)
val shellClient = actorSystem.actorOf(
Props(classOf[ShellClient], socketFactory, actorLoader)
Props(classOf[ShellClient], socketFactory, actorLoader, signatureEnabled)
)
val shellInterceptor = actorSystem.actorOf(
Props(new ActorInterceptor(shellProbe, shellClient)),
Expand All @@ -107,7 +108,7 @@ object SparkClientDeployer extends LogLike{

ioPubProbe = new TestProbe(actorSystem)
val ioPubClient = actorSystem.actorOf(
Props(classOf[IOPubClient], socketFactory, actorLoader,
Props(classOf[IOPubClient], socketFactory, actorLoader, signatureEnabled,
commRegistrar, commStorage)
)
val ioPubInterceptor = actorSystem.actorOf(
Expand Down

0 comments on commit bf234cb

Please sign in to comment.