Skip to content

Commit

Permalink
Kryo register FileInputStream and addDefaultSerializer for InputStream (
Browse files Browse the repository at this point in the history
corda#471)

HashCheckingStream CordaSerializable and FileInputStream Kryo register
  • Loading branch information
Konstantinos Chalkias authored Apr 6, 2017
1 parent d6403ce commit f7dd273
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import net.i2p.crypto.eddsa.EdDSAPublicKey
import org.objenesis.strategy.StdInstantiatorStrategy
import org.slf4j.Logger
import java.io.BufferedInputStream
import java.io.FileInputStream
import java.io.InputStream
import java.util.*

object DefaultKryoCustomizer {
Expand Down Expand Up @@ -53,6 +55,7 @@ object DefaultKryoCustomizer {
ImmutableMapSerializer.registerSerializers(this)
ImmutableMultimapSerializer.registerSerializers(this)

// InputStream subclasses whitelisting, required for attachments.
register(BufferedInputStream::class.java, InputStreamSerializer)
register(Class.forName("sun.net.www.protocol.jar.JarURLConnection\$JarURLInputStream"), InputStreamSerializer)

Expand Down Expand Up @@ -81,6 +84,11 @@ object DefaultKryoCustomizer {

addDefaultSerializer(Logger::class.java, LoggerSerializer)

register(FileInputStream::class.java, InputStreamSerializer)
// Required for HashCheckingStream (de)serialization.
// Note that return type should be specifically set to InputStream, otherwise it may not work, i.e. val aStream : InputStream = HashCheckingStream(...).
addDefaultSerializer(InputStream::class.java, InputStreamSerializer)

val customization = KryoSerializationCustomization(this)
pluginRegistries.forEach { it.customizeSerialization(customization) }
}
Expand Down
14 changes: 12 additions & 2 deletions core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.pqc.jcajce.provider.BouncyCastlePQCProvider
import org.junit.After
import org.junit.Before
import org.junit.Ignore
import org.junit.Test
import org.slf4j.LoggerFactory
import java.io.InputStream
Expand All @@ -19,6 +17,8 @@ import java.time.Instant
import java.util.*
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import net.corda.node.services.persistence.NodeAttachmentService
import java.io.ByteArrayInputStream

class KryoTests {

Expand Down Expand Up @@ -132,6 +132,16 @@ class KryoTests {
assertTrue(logger === logger2)
}

@Test
fun `HashCheckingStream (de)serialize`() {
val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() })
val readRubbishStream : InputStream = NodeAttachmentService.HashCheckingStream(SecureHash.sha256(rubbish), rubbish.size, ByteArrayInputStream(rubbish)).serialize(kryo).deserialize(kryo)
for (i in 0 .. 12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
assertEquals(-1, readRubbishStream.read())
}

@CordaSerializable
private data class Person(val name: String, val birthday: Instant?)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ class NodeAttachmentService(override var storePath: Path, dataSourceProperties:
* inside it, we haven't read the whole file, so we can't check the hash. But when copying it over the network
* this will provide an additional safety check against user error.
*/
private class HashCheckingStream(val expected: SecureHash.SHA256,
val expectedSize: Int,
input: InputStream,
private val counter: CountingInputStream = CountingInputStream(input),
private val stream: HashingInputStream = HashingInputStream(Hashing.sha256(), counter)) : FilterInputStream(stream) {
@VisibleForTesting @CordaSerializable
class HashCheckingStream(val expected: SecureHash.SHA256,
val expectedSize: Int,
input: InputStream,
private val counter: CountingInputStream = CountingInputStream(input),
private val stream: HashingInputStream = HashingInputStream(Hashing.sha256(), counter)) : FilterInputStream(stream) {
override fun close() {
super.close()

Expand All @@ -86,7 +87,7 @@ class NodeAttachmentService(override var storePath: Path, dataSourceProperties:
private val checkOnLoad: Boolean) : Attachment {
override fun open(): InputStream {

var stream = ByteArrayInputStream(attachment)
val stream = ByteArrayInputStream(attachment)

// This is just an optional safety check. If it slows things down too much it can be disabled.
if (id is SecureHash.SHA256 && checkOnLoad)
Expand Down

0 comments on commit f7dd273

Please sign in to comment.