Skip to content

Commit

Permalink
src|python-bindings: Use std::vector<uint8_t> for id and memo (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
xdustinface authored Aug 10, 2021
1 parent cd8e0bb commit 97aa105
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 29 deletions.
14 changes: 4 additions & 10 deletions python-bindings/chiapos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,14 @@ PYBIND11_MODULE(chiapos, m)
.def(
"get_memo",
[](DiskProver &dp) {
uint8_t *memo = new uint8_t[dp.GetMemoSize()];
dp.GetMemo(memo);
py::bytes ret = py::bytes(reinterpret_cast<char *>(memo), dp.GetMemoSize());
delete[] memo;
return ret;
const std::vector<uint8_t>& memo = dp.GetMemo();
return py::bytes(reinterpret_cast<const char*>(memo.data()), memo.size());
})
.def(
"get_id",
[](DiskProver &dp) {
uint8_t *id = new uint8_t[kIdLen];
dp.GetId(id);
py::bytes ret = py::bytes(reinterpret_cast<char *>(id), kIdLen);
delete[] id;
return ret;
const std::vector<uint8_t>& id = dp.GetId();
return py::bytes(reinterpret_cast<const char*>(id.data()), id.size());
})
.def("get_size", [](DiskProver &dp) { return dp.GetSize(); })
.def("get_filename", [](DiskProver &dp) { return dp.GetFilename(); })
Expand Down
7 changes: 3 additions & 4 deletions src/cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,12 @@ int main(int argc, char *argv[]) try {
Verifier verifier = Verifier();

uint32_t success = 0;
uint8_t id_bytes[32];
prover.GetId(id_bytes);
std::vector<uint8_t> id_bytes = prover.GetId();
k = prover.GetSize();

for (uint32_t num = 0; num < iterations; num++) {
vector<unsigned char> hash_input = intToBytes(num, 4);
hash_input.insert(hash_input.end(), &id_bytes[0], &id_bytes[32]);
hash_input.insert(hash_input.end(), id_bytes.begin(), id_bytes.end());

vector<unsigned char> hash(picosha2::k_digest_size);
picosha2::hash256(hash_input.begin(), hash_input.end(), hash.begin(), hash.end());
Expand All @@ -269,7 +268,7 @@ int main(int argc, char *argv[]) try {
cout << "challenge: 0x" << Util::HexStr(hash.data(), 256 / 8) << endl;
cout << "proof: 0x" << Util::HexStr(proof_data, k * 8) << endl;
LargeBits quality =
verifier.ValidateProof(id_bytes, k, hash.data(), proof_data, k * 8);
verifier.ValidateProof(id_bytes.data(), k, hash.data(), proof_data, k * 8);
if (quality.GetSize() == 256 && quality == qualities[i]) {
cout << "quality: " << quality << endl;
cout << "Proof verification succeeded. k = " << static_cast<int>(k) << endl;
Expand Down
24 changes: 9 additions & 15 deletions src/prover_disk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DiskProver {
public:
// The constructor opens the file, and reads the contents of the file header. The table pointers
// will be used to find and seek to all seven tables, at the time of proving.
explicit DiskProver(const std::string& filename)
explicit DiskProver(const std::string& filename) : id(kIdLen)
{
struct plot_header header{};
this->filename = filename;
Expand Down Expand Up @@ -80,16 +80,14 @@ class DiskProver {
} else {
throw std::invalid_argument("Invalid plot file format");
}

memcpy(this->id, header.id, sizeof(header.id));
memcpy(id.data(), header.id, sizeof(header.id));
this->k = header.k;
SafeSeek(disk_file, offsetof(struct plot_header, fmt_desc) + fmt_desc_len);

uint8_t size_buf[2];
SafeRead(disk_file, size_buf, 2);
this->memo_size = Util::TwoBytesToInt(size_buf);
this->memo = new uint8_t[this->memo_size];
SafeRead(disk_file, this->memo, this->memo_size);
memo.resize(Util::TwoBytesToInt(size_buf));
SafeRead(disk_file, memo.data(), memo.size());

this->table_begin_pointers = std::vector<uint64_t>(11, 0);
this->C2 = std::vector<uint64_t>();
Expand Down Expand Up @@ -122,18 +120,15 @@ class DiskProver {
~DiskProver()
{
std::lock_guard<std::mutex> l(_mtx);
delete[] this->memo;
for (int i = 0; i < 6; i++) {
Encoding::ANSFree(kRValues[i]);
}
Encoding::ANSFree(kC3R);
}

void GetMemo(uint8_t* buffer) { memcpy(buffer, memo, this->memo_size); }

uint32_t GetMemoSize() const noexcept { return this->memo_size; }
const std::vector<uint8_t>& GetMemo() { return memo; }

void GetId(uint8_t* buffer) { memcpy(buffer, id, kIdLen); }
const std::vector<uint8_t>& GetId() { return id; }

std::string GetFilename() const noexcept { return filename; }

Expand Down Expand Up @@ -241,9 +236,8 @@ class DiskProver {
private:
mutable std::mutex _mtx;
std::string filename;
uint32_t memo_size;
uint8_t* memo;
uint8_t id[kIdLen]{}; // Unique plot id
std::vector<uint8_t> memo;
std::vector<uint8_t> id; // Unique plot id
uint8_t k;
std::vector<uint64_t> table_begin_pointers;
std::vector<uint64_t> C2;
Expand Down Expand Up @@ -577,7 +571,7 @@ class DiskProver {
// Where a < b is defined as: max(b) > max(a) where a and b are lists of k bit elements
std::vector<LargeBits> ReorderProof(const std::vector<Bits>& xs_input) const
{
F1Calculator f1(k, id);
F1Calculator f1(k, id.data());
std::vector<std::pair<Bits, Bits> > results;
LargeBits xs;

Expand Down

0 comments on commit 97aa105

Please sign in to comment.