Skip to content

Commit

Permalink
inline version of initial f1(x) writes
Browse files Browse the repository at this point in the history
  • Loading branch information
AWice committed Sep 11, 2020
1 parent d7630c0 commit aa59247
Showing 1 changed file with 57 additions and 16 deletions.
73 changes: 57 additions & 16 deletions src/plotter_disk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "../lib/include/filesystem.hh"
namespace fs = ghc::filesystem;
#include "calculate_bucket.hpp"
#include "chacha8.h"
#include "encoding.hpp"
#include "pos_constants.hpp"
#include "sort_on_disk.hpp"
Expand Down Expand Up @@ -460,28 +461,68 @@ class DiskPlotter {

// Instead of computing f1(1), f1(2), etc, for each x, we compute them in batches
// to increase CPU efficency.
for (uint64_t lp = 0; lp <= (((uint64_t)1) << (k - kBatchSizes)); lp++) {
// For each pair x, y in the batch
for (auto kv : f1.CalculateBuckets(Bits(x, k), 2 << (kBatchSizes - 1))) {
// TODO(mariano): fix inefficient memory alloc here
(std::get<0>(kv) + std::get<1>(kv)).ToBytes(buf);

// We write the x, y pair
tmp_1_disks[1].Write(plot_file, (buf), entry_size_bytes);
plot_file += entry_size_bytes;
{ // Inline version of F1Calc.CalculateBuckets for performance
struct chacha8_ctx enc_ctx;
uint8_t enc_key[32];
enc_key[0] = 1;
memcpy(enc_key + 1, id, 31);
chacha8_keysetup(&enc_ctx, enc_key, 256, NULL);
uint64_t num_eval = 2 << (kBatchSizes - 1);
std::vector<Bits> blocks((num_eval + 1) * k / kF1BlockSizeBits + 2);
uint8_t ciphertext_bytes[kF1BlockSizeBits / 8];
int t = 0;

for (uint64_t lp = 0; lp < (((uint64_t)1) << (k - kBatchSizes)); lp++) {
uint64_t count0 = (x * (uint128_t)k) / kF1BlockSizeBits;
uint64_t count1 = (x + (uint128_t)num_eval + 1) * k / kF1BlockSizeBits;
uint16_t start_bit = (x * (uint128_t)k) % kF1BlockSizeBits;
uint64_t x2 = x + num_eval;
t = 0;
while (count0 <= count1) {
chacha8_get_keystream(&enc_ctx, count0++, 1, ciphertext_bytes);
Bits ciphertext(ciphertext_bytes, kF1BlockSizeBits / 8, kF1BlockSizeBits);
blocks[t++] = std::move(ciphertext);
}

bucket_sizes[SortOnDiskUtils::ExtractNum(
buf, entry_size_bytes, 0, kLogNumSortBuckets)] += 1;
t = 0;
for (; x < x2; x++) {
Bits L_bits = Bits(x, k);
// Takes the first kExtraBits bits from the input, and adds zeroes if it's not
// enough
Bits extra_data = L_bits.Slice(0, kExtraBits);
if (extra_data.GetSize() < kExtraBits) {
extra_data = extra_data + Bits(0, kExtraBits - extra_data.GetSize());
}

if (x + 1 > max_value) {
break;
if (start_bit + k < kF1BlockSizeBits) {
// Everything can be sliced from the current block
(blocks[t].Slice(start_bit, start_bit + k) + extra_data + L_bits)
.ToBytes(buf);
tmp_1_disks[1].Write(plot_file, (buf), entry_size_bytes);
plot_file += entry_size_bytes;
bucket_sizes[SortOnDiskUtils::ExtractNum(
buf, entry_size_bytes, 0, kLogNumSortBuckets)] += 1;
} else {
// Must move forward one block
((blocks[t].Slice(start_bit) +
blocks[t + 1].Slice(0, k - kF1BlockSizeBits + start_bit)) +
(extra_data + L_bits))
.ToBytes(buf);
tmp_1_disks[1].Write(plot_file, (buf), entry_size_bytes);
plot_file += entry_size_bytes;
bucket_sizes[SortOnDiskUtils::ExtractNum(
buf, entry_size_bytes, 0, kLogNumSortBuckets)] += 1;
t++;
}

// Start bit of the output slice in the current block
start_bit += k;
start_bit %= kF1BlockSizeBits;
}
++x;
}
if (x + 1 > max_value) {
break;
}
}

// A zero entry is the end of table symbol.
memset(buf, 0x00, entry_size_bytes);
tmp_1_disks[1].Write(plot_file, buf, entry_size_bytes);
Expand Down

0 comments on commit aa59247

Please sign in to comment.