Skip to content

Commit

Permalink
Updated DMABuffer to reverse incoming data and Controller to reverse …
Browse files Browse the repository at this point in the history
…blocks. Updated accelerator test to reflect changes
  • Loading branch information
TsaiAnson committed Apr 14, 2021
1 parent 9784b66 commit 82a0307
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 36 deletions.
6 changes: 3 additions & 3 deletions src/main/scala/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class AESController(addrBits: Int, beatBytes: Int)(implicit p: Parameters) exten
when (mState === MemState.sWriteIntoMem) {
// Read AES result out to DMA
io.aesCoreIO.cs := 1.U
io.aesCoreIO.address := AESAddr.RESULT + 3.U - counter_reg
io.aesCoreIO.address := AESAddr.RESULT + counter_reg
cStateWire := CtrlState.sDataWrite
} .elsewhen (data_wr_done) {
when (blks_remain_reg > 0.U) {
Expand Down Expand Up @@ -344,10 +344,10 @@ class AESController(addrBits: Int, beatBytes: Int)(implicit p: Parameters) exten
}
}
when (cState === CtrlState.sKeySetup) {
io.aesCoreIO.address := AESAddr.KEY + mem_target_reg - 1.U - counter_reg
io.aesCoreIO.address := AESAddr.KEY + counter_reg
io.testAESWriteData.bits := ((AESAddr.KEY + mem_target_reg - 1.U - counter_reg) << 32) + dequeue.io.dataOut.bits
} .elsewhen (cState === CtrlState.sDataSetup) {
io.aesCoreIO.address := AESAddr.TEXT + 3.U - counter_reg
io.aesCoreIO.address := AESAddr.TEXT + counter_reg
io.testAESWriteData.bits := ((AESAddr.TEXT + 3.U - counter_reg) << 32) + dequeue.io.dataOut.bits
}
}
Expand Down
16 changes: 14 additions & 2 deletions src/main/scala/DMABuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class DMAInputBuffer (addrBits: Int = 32, beatBytes: Int) extends Module {
val startWrite = RegInit(false.B)
// Delay done by a cycle to account for request to propagate to DMA
val doneReg = RegInit(false.B)
// Data to-be-reversed
val toReverse = Wire(UInt(32.W))
// Data with bytes reversed
val reverse = Wire(UInt(32.W))

// Start writing when we have an entire block of data (128b)
when (bitsFilled === 128.U) {
Expand All @@ -36,7 +40,7 @@ class DMAInputBuffer (addrBits: Int = 32, beatBytes: Int) extends Module {
}
// NOTE: The two statements below will NEVER concurrently fire (fire conditions prevent)
when (dataQueue.io.deq.fire()) {
wideData := wideData | (dataQueue.io.deq.bits << bitsFilled).asUInt()
wideData := wideData | (reverse << bitsFilled).asUInt()
bitsFilled := bitsFilled + 32.U
}
when (io.dmaOutput.fire()) {
Expand All @@ -59,6 +63,10 @@ class DMAInputBuffer (addrBits: Int = 32, beatBytes: Int) extends Module {
}
doneReg := bitsFilled === 0.U
io.done := doneReg

// Reversing bytes
toReverse := dataQueue.io.deq.bits
reverse := (toReverse(7,0) << 24).asUInt() | (toReverse(15,8) << 16).asUInt() | (toReverse(23,16) << 8).asUInt() | toReverse(31,24).asUInt()
}

// Outputs data from DMA in 32bit chunks (for AES core)
Expand All @@ -73,6 +81,7 @@ class DMAOutputBuffer (beatBytes: Int) extends Module {
val bitsFilled = RegInit(0.U(log2Ceil(256 + 1).W))
val wideData = RegInit(0.U(256.W)) // Max data we will ever read at a time is 256b
val dataQueue = Module(new Queue(UInt(32.W), 8))
val reverse = Wire(UInt(32.W)) // Used to carry data with bytes reversed

// NOTE: These two statements will NEVER concurrently fire (fire conditions prevent)
when (dataQueue.io.enq.fire()) {
Expand All @@ -88,5 +97,8 @@ class DMAOutputBuffer (beatBytes: Int) extends Module {
io.dataOut <> dataQueue.io.deq

dataQueue.io.enq.valid := bitsFilled >= 32.U
dataQueue.io.enq.bits := wideData(31,0)
dataQueue.io.enq.bits := reverse

// Reversing bytes
reverse := (wideData(7,0) << 24).asUInt() | (wideData(15,8) << 16).asUInt() | (wideData(23,16) << 8).asUInt() | wideData(31,24).asUInt()
}
86 changes: 70 additions & 16 deletions src/test/scala/AccelTopTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.scalatest.flatspec.AnyFlatSpec
import freechips.rocketchip.config.Parameters
import freechips.rocketchip.diplomacy.LazyModule
import freechips.rocketchip.tile.{RoCCCommand, RoCCResponse}
import verif.TLMemoryModel.WordAddr
import verif.TLMemoryModel.{State, WordAddr, read}
import verif._


Expand All @@ -20,7 +20,8 @@ class AccelTopTest extends AnyFlatSpec with ChiselScalatestTester {
val r = new scala.util.Random

// Temporary storage for keys in key-reusal case
var prev_key:BigInt = 0
var prev_key: BigInt = 0
var prev_key_addr: BigInt = 0

def testAccelerator(dut: AESAccelStandaloneBlock, clock: Clock, keySize: Int, encdec: Int, interrupt: Int, rounds: Int, reuse_key_en: Boolean): Boolean = {
assert(keySize == 128 || keySize == 256 || keySize == -1, s"KeySize must be 128, 256, or -1 (random). Given: $keySize")
Expand Down Expand Up @@ -55,17 +56,22 @@ class AccelTopTest extends AnyFlatSpec with ChiselScalatestTester {
else encrypt = encdec == 1
if (interrupt == -1) interrupt_en = r.nextInt(2)

val stim = genAESStim(actualKeySize, r.nextInt(10) + 1, destructive = destructive, beatBytes, r)
val stim = genAESStim(actualKeySize, r.nextInt(10) + 1, destructive = destructive, if (reuse_key) prev_key_addr else BigInt(-1), beatBytes, r)
slaveModel.state = stim._6

// Randomize reuse_key (cannot reuse on first enc/dec, key expansion required)
if (i != 0 && reuse_key_en) reuse_key = r.nextBoolean()
if (!reuse_key) prev_key = stim._2
if (!reuse_key) {
prev_key_addr = stim._1
prev_key = stim._2
}

// // Debug Printing
// println(s"Debug key size: $actualKeySize")
// println(s"Debug encdec: $encrypt")
// println(s"Debug: key size: $actualKeySize")
// println(s"Debug: encdec: $encrypt")
// println(s"Debug: $stim")
// println(s"Debug: Reuse key $reuse_key")
// println(s"Debug: Destructive $destructive")

var inputCmd = Seq[DecoupledTX[RoCCCommand]]()
// Key load instruction
Expand Down Expand Up @@ -135,14 +141,14 @@ class AccelTopTest extends AnyFlatSpec with ChiselScalatestTester {
allPass
}

// Basic sanity test: elaborate to see if anything structure-wise broke
it should "elaborate the accelerator" in {
val dut = LazyModule(new AESAccelStandaloneBlock(beatBytes))
// Requires verilator backend! (For verilog blackbox files)
test(dut.module).withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { c =>
assert(true)
}
}
// // Elaborate to see if anything structure-wise broke
// it should "elaborate the accelerator" in {
// val dut = LazyModule(new AESAccelStandaloneBlock(beatBytes))
// // Requires verilator backend! (For verilog blackbox files)
// test(dut.module).withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { c =>
// assert(true)
// }
// }

it should "Test 128b AES Encryption" in {
val dut = LazyModule(new AESAccelStandaloneBlock(beatBytes))
Expand All @@ -155,7 +161,7 @@ class AccelTopTest extends AnyFlatSpec with ChiselScalatestTester {
it should "Test 128b AES Decryption" in {
val dut = LazyModule(new AESAccelStandaloneBlock(beatBytes))
test(dut.module).withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { c =>
val result = testAccelerator(dut, c.clock, keySize = 128, encdec = -1, interrupt = 0, rounds = 20, reuse_key_en = true)
val result = testAccelerator(dut, c.clock, keySize = 128, encdec = 0, interrupt = 0, rounds = 20, reuse_key_en = true)
assert(result)
}
}
Expand All @@ -171,7 +177,7 @@ class AccelTopTest extends AnyFlatSpec with ChiselScalatestTester {
it should "Test 256b AES Decryption" in {
val dut = LazyModule(new AESAccelStandaloneBlock(beatBytes))
test(dut.module).withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { c =>
val result = testAccelerator(dut, c.clock, keySize = 256, encdec = -1, interrupt = 0, rounds = 20, reuse_key_en = true)
val result = testAccelerator(dut, c.clock, keySize = 256, encdec = 0, interrupt = 0, rounds = 20, reuse_key_en = true)
assert(result)
}
}
Expand All @@ -184,4 +190,52 @@ class AccelTopTest extends AnyFlatSpec with ChiselScalatestTester {
assert(result)
}
}


// // Debug sanity test
// // 128B Key: 256'h2b7e151628aed2a6abf7158809cf4f3c00000000000000000000000000000000
// // plaintext: 128'h6bc1bee22e409f96e93d7e117393172a
// // ciphertext: 128'h3ad77bb40d7a3660a89ecaf32466ef97
// it should "debug AES sanity check" in {
// val dut = LazyModule(new AESAccelStandaloneBlock(beatBytes))
// test(dut.module).withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { c =>
// // RoCCCommand driver + RoCCResponse receiver
// val driver = new DecoupledDriverMaster[RoCCCommand](c.clock, dut.module.io.cmd)
// val txProto = new DecoupledTX(new RoCCCommand())
// val monitor = new DecoupledMonitor[RoCCResponse](c.clock, dut.module.io.resp)
// val receiver = new DecoupledDriverSlave[RoCCResponse](c.clock, dut.module.io.resp, 0)
//
// // Mock Memory
// val slaveFn = new TLMemoryModel(dut.to_mem.params)
// val slaveModel = new TLDriverSlave(c.clock, dut.to_mem, slaveFn, TLMemoryModel.State.empty())
// val slaveMonitor = new TLMonitor(c.clock, dut.to_mem)
//
// val keyData = BigInt("2b7e151628aed2a6abf7158809cf4f3c00000000000000000000000000000000", 16)
// var initMem: Map[WordAddr, BigInt] = Map ()
// val keyDataRev = BigInt(keyData.toByteArray.reverse)
// initMem = initMem + (0.toLong -> (keyDataRev & BigInt("1" * 128, 2)))
// initMem = initMem + (1.toLong -> (keyDataRev >> 128))
// val textData = BigInt("6bc1bee22e409f96e93d7e117393172a", 16)
// initMem = initMem + (2.toLong -> BigInt(textData.toByteArray.reverse))
// slaveModel.state = State.init(initMem, beatBytes)
//
// val input = Seq(
// txProto.tx(keyLoad128(0x0)),
// txProto.tx(addrLoad(0x20, 0x30)),
// txProto.tx(encBlock(1, 0))
// )
// driver.push(input)
// c.clock.step(50)
//
// while(!finishedWriting(slaveModel.state, 0x30, 1, 0, beatBytes)) {
// c.clock.step()
// }
// c.clock.step(5) // Few cycles delay for AccessAck to propagate back and busy to be de-asserted
// assert(!dut.module.io.busy.peek().litToBoolean, "Accelerator is still busy when data was written back.")
//
// val actual = read(slaveModel.state.mem, 3.toLong, beatBytes, -1)
// println(s"RESULT: ${actual.toString(16)}")
// assert(actual == BigInt(Array(0.toByte) ++ BigInt("3ad77bb40d7a3660a89ecaf32466ef97", 16).toByteArray.takeRight(16).reverse))
// }
// }
}
34 changes: 19 additions & 15 deletions src/test/scala/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ package object AESTestUtils {
// Generates random (but legal) addresses for key, source text, dest text
// Legal key address: anything beatByte aligned
// Legal data address: anything beatByte aligned and != key addr, and if src/dest overlap it must directly overlap (same address, destructive)
// If prevKeyAddr != -1, then use that address
// (keyAddr: BigInt, srcAddr: BigInt, destAddr: BigInt)
def getRandomKeyTextAddr(textBlocks: Int, destructive: Boolean, beatBytes: Int, r: Random): (BigInt, BigInt, BigInt) = {
def getRandomKeyTextAddr(textBlocks: Int, destructive: Boolean, prevKeyAddr: BigInt, beatBytes: Int, r: Random): (BigInt, BigInt, BigInt) = {
assert(textBlocks > 0, s"# of text block must be at least 1. Given: $textBlocks")
assert(beatBytes == 16, "TEMP: beatBytes must be 16 (see accelerator test for more details)")

val keyAddr = BigInt(32, r) & ~(beatBytes - 1)
val keyAddr = if (prevKeyAddr != -1) prevKeyAddr else BigInt(32, r) & ~(beatBytes - 1)
var srcAddr = BigInt(32, r) & ~(beatBytes - 1)
while (keyAddr <= srcAddr && srcAddr <= (keyAddr + 32)) {
srcAddr = BigInt(32, r) & ~(beatBytes - 1)
Expand All @@ -54,30 +55,33 @@ package object AESTestUtils {
// Helper method for creating TLMemoryModel state with initial AES Data (key, text)
// Assumes all addresses are legal (beatBytes aligned)
// Assumes keyData is 256 bits, and textData is 128 bits per block
// NOTE: Reverses data to mimic how it will be stored in memory
def getTLMemModelState(keyAddr: BigInt, keyData: BigInt, textAddr: BigInt, textData: Seq[BigInt], beatBytes: Int): State = {
assert(beatBytes == 16, "TEMP: beatBytes must be 16 (see accelerator test for more details)")

var initMem: Map[WordAddr, BigInt] = Map ()
val keyWordAddr = keyAddr / beatBytes
initMem = initMem + (keyWordAddr.toLong -> (keyData & BigInt("1" * 128, 2)))
initMem = initMem + ((keyWordAddr + 1).toLong -> (keyData >> 128))
val keyDataRev = BigInt(keyData.toByteArray.takeRight(32).reverse.padTo(32, 0.toByte))
initMem = initMem + (keyWordAddr.toLong -> (keyDataRev & BigInt("1" * 128, 2)))
initMem = initMem + ((keyWordAddr + 1).toLong -> (keyDataRev >> 128))

val txtWordAddr = textAddr / beatBytes
for (i <- textData.indices) {
// Blocks are not in little-endian order, but doesn't matter as each block is independent
initMem = initMem + ((txtWordAddr + i).toLong -> textData(i))
// NOTE: Prepend 0 byte s.t. it is interpreted as a positive (bug where if the last byte is FF it will be removed due to negative interpretation)
initMem = initMem + ((txtWordAddr + i).toLong -> BigInt(Array(0.toByte) ++ textData(i).toByteArray.takeRight(16).reverse.padTo(16, 0.toByte)))
}

State.init(initMem, beatBytes)
}

// Generates random stimulus for AES Accelerator
// (keyAddr: BigInt, keyData (256b post-padded): BigInt, srcAddr: BigInt, textData: Seq[BigInt], destAddr: BigInt, memState: TLMemModel.State)
def genAESStim(keySize: Int, textBlocks: Int, destructive: Boolean, beatBytes: Int, r: Random): (BigInt, BigInt, BigInt, Seq[BigInt], BigInt, State) = {
def genAESStim(keySize: Int, textBlocks: Int, destructive: Boolean, prevKeyAddr: BigInt, beatBytes: Int, r: Random):
(BigInt, BigInt, BigInt, Seq[BigInt], BigInt, State) = {
assert(beatBytes == 16, "TEMP: beatBytes must be 16 (see accelerator test for more details)")

// Generate Addresses
val addresses = getRandomKeyTextAddr(textBlocks, destructive, beatBytes, r)
val addresses = getRandomKeyTextAddr(textBlocks, destructive, prevKeyAddr, beatBytes, r)
val keyAddr = addresses._1
val srcAddr = addresses._2
val dstAddr = addresses._3
Expand All @@ -97,27 +101,27 @@ package object AESTestUtils {
// Generate State
val state = getTLMemModelState(keyAddr, keyData, srcAddr, srcData, beatBytes)

if (keySize == 128) {
(keyAddr + 16, keyData, srcAddr, srcData, dstAddr, state)
} else {
(keyAddr, keyData, srcAddr, srcData, dstAddr, state)
}
(keyAddr, keyData, srcAddr, srcData, dstAddr, state)
}

// Conditional if last block of result data has been written
// InitData should be 0 unless destructive
def finishedWriting(state: State, destAddr: BigInt, txtBlocks: Int, initData: BigInt, beatBytes: Int): Boolean = {
assert(beatBytes == 16, "TEMP: beatBytes must be 16 (see accelerator test for more details)")

read(state.mem, (destAddr/beatBytes + txtBlocks*(16/beatBytes) - 1).toLong, beatBytes, -1) != initData
// Reversing since data is stored in reverse
BigInt(Array(0.toByte) ++ read(state.mem, (destAddr/beatBytes + txtBlocks*(16/beatBytes) - 1).toLong, beatBytes, -1)
.toByteArray.takeRight(16).reverse) != initData
}

// Checks if output matches standard AES library (ECB)
def checkResult(keySize: Int, key: BigInt, srcData: Seq[BigInt], dstAddr: BigInt, encrypt: Boolean, state: State, beatBytes: Int): Boolean = {
assert(beatBytes == 16, "TEMP: beatBytes must be 16 (see accelerator test for more details)")

val cipher = AESECBCipher(keySize, key, encrypt)
val results = srcData.map(x => BigInt(Array(0.toByte) ++ cipher.doFinal(x.toByteArray.reverse.padTo(16, 0.toByte).reverse.takeRight(16))))
// NOTE: Additional reverse to match how data will be stored in memory
// NOTE: Prepending a 0 byte in front so that results are interpreted as positives
val results = srcData.map(x => BigInt(Array(0.toByte) ++ cipher.doFinal(x.toByteArray.reverse.padTo(16, 0.toByte).reverse.takeRight(16)).reverse))

var matched = true
for (i <- results.indices) {
Expand Down

0 comments on commit 82a0307

Please sign in to comment.