Skip to content

Commit

Permalink
fix: prover server for different circuits
Browse files Browse the repository at this point in the history
  • Loading branch information
daveroga committed Mar 7, 2023
1 parent cbe090f commit 3c75053
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 120 deletions.
194 changes: 95 additions & 99 deletions src/fullprover.cpp
Original file line number Diff line number Diff line change
@@ -1,58 +1,68 @@
#include <nlohmann/json.hpp>
using json = nlohmann::json;

#include <sys/stat.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>

#include "zkey_utils.hpp"

#include "fullprover.hpp"
#include "fr.hpp"

#include "logger.hpp"
#include "wtns_utils.hpp"

using namespace CPlusPlusLogging;

FullProver::FullProver(std::string datFileName, std::string zkeyFileName) {
std::string getfilename(std::string path)
{
path = path.substr(path.find_last_of("/\\") + 1);
size_t dot_i = path.find_last_of('.');
return path.substr(0, dot_i);
}

FullProver::FullProver(std::string zkeyFileNames[], int size) {
pendingInput="";
wtns=NULL;
canceled = false;

circuit = loadCircuit(datFileName);

// open output
calcWit = new Circom_CalcWit(circuit);

zkey = BinFileUtils::openExisting(zkeyFileName, "zkey", 1);
auto zkeyHeader = ZKeyUtils::loadHeader(zkey.get());

prover = Groth16::makeProver<AltBn128::Engine>(
zkeyHeader->nVars,
zkeyHeader->nPublic,
zkeyHeader->domainSize,
zkeyHeader->nCoefs,
zkeyHeader->vk_alpha1,
zkeyHeader->vk_beta1,
zkeyHeader->vk_beta2,
zkeyHeader->vk_delta1,
zkeyHeader->vk_delta2,
zkey->getSectionData(4), // Coefs
zkey->getSectionData(5), // pointsA
zkey->getSectionData(6), // pointsB1
zkey->getSectionData(7), // pointsB2
zkey->getSectionData(8), // pointsC
zkey->getSectionData(9) // pointsH1
);

wtns = new AltBn128::FrElement[circuit->NVars];
mpz_init(altBbn128r);
mpz_set_str(altBbn128r, "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10);

for(int i = 0; i < size; i++) {
std::string circuitHash = getfilename(zkeyFileNames[i]);
zKeys[circuitHash] = BinFileUtils::openExisting(zkeyFileNames[i], "zkey", 1);
zkHeaders[circuitHash] = ZKeyUtils::loadHeader(zKeys[circuitHash].get());

std::string proofStr;
if (mpz_cmp(zkHeaders[circuitHash]->rPrime, altBbn128r) != 0) {
throw std::invalid_argument( "zkey curve not supported" );
}

std::ostringstream ss1;
ss1 << "circuit: " << circuitHash;
LOG_DEBUG(ss1);

provers[circuitHash] = Groth16::makeProver<AltBn128::Engine>(
zkHeaders[circuitHash]->nVars,
zkHeaders[circuitHash]->nPublic,
zkHeaders[circuitHash]->domainSize,
zkHeaders[circuitHash]->nCoefs,
zkHeaders[circuitHash]->vk_alpha1,
zkHeaders[circuitHash]->vk_beta1,
zkHeaders[circuitHash]->vk_beta2,
zkHeaders[circuitHash]->vk_delta1,
zkHeaders[circuitHash]->vk_delta2,
zKeys[circuitHash]->getSectionData(4), // Coefs
zKeys[circuitHash]->getSectionData(5), // pointsA
zKeys[circuitHash]->getSectionData(6), // pointsB1
zKeys[circuitHash]->getSectionData(7), // pointsB2
zKeys[circuitHash]->getSectionData(8), // pointsC
zKeys[circuitHash]->getSectionData(9) // pointsH1
);
}

status = ready;
}

FullProver::~FullProver() {
delete calcWit;
delete wtns;
mpz_clear(altBbn128r);
}

void FullProver::startProve(std::string input) {
Expand All @@ -67,7 +77,6 @@ void FullProver::startProve(std::string input) {
LOG_TRACE("FullProver::startProve end");
}


void FullProver::checkPending() {
LOG_TRACE("FullProver::checkPending begin");
if (status != busy) {
Expand All @@ -88,29 +97,71 @@ void FullProver::checkPending() {

void FullProver::thread_calculateProve() {
LOG_TRACE("FullProver::thread_calculateProve start");

try {
calcWit->calculateProve(wtns, executingInput, [this](){ return /* TODO isCanceled() */ false; });
LOG_TRACE(executingInput);
// Generate witness
json j = json::parse(executingInput);
std::string circuitHash = j["circuitHash"];

std::ofstream file("./build/input_"+ circuitHash +".json");
file << j;
file.close();

std::string witnessFile("./build/" + circuitHash + ".wtns");
std::string command("./build/" + circuitHash + " ./build/input_"+ circuitHash +".json " + witnessFile);
LOG_TRACE(command);
std::array<char, 128> buffer;
std::string result;

// std::cout << "Opening reading pipe" << std::endl;
FILE* pipe = popen(command.c_str(), "r");
if (!pipe)
{
std::cerr << "Couldn't start command." << std::endl;
}
while (fgets(buffer.data(), 128, pipe) != NULL) {
// std::cout << "Reading..." << std::endl;
result += buffer.data();
}
auto returnCode = pclose(pipe);

std::cout << result << std::endl;
std::cout << returnCode << std::endl;

// Load witness
auto wtns = BinFileUtils::openExisting(witnessFile, "wtns", 2);
auto wtnsHeader = WtnsUtils::loadHeader(wtns.get());

if (mpz_cmp(wtnsHeader->prime, altBbn128r) != 0) {
throw std::invalid_argument( "different wtns curve" );
}

AltBn128::FrElement *wtnsData = (AltBn128::FrElement *)wtns->getSectionData(2);

pubData.clear();
LOG_TRACE("FullProver::thread_calculateProve calculating prove");
for (int i=1; i<=circuit->NPublic; i++) {
AltBn128::FrElement aux;
AltBn128::Fr.toMontgomery(aux, wtns[i]);
AltBn128::FrElement aux;
for (int i=1; i<=zkHeaders[circuitHash]->nPublic; i++) {
AltBn128::Fr.toMontgomery(aux, wtnsData[i]);
pubData.push_back(AltBn128::Fr.toString(aux));
}

if (!isCanceled()) {
proof = prover->prove(wtns)->toJson();
proof = provers[circuitHash]->prove(wtnsData)->toJson();
} else {
LOG_TRACE("AVOIDING prove");
proof = {};
}


calcFinished();
} catch (std::runtime_error e) {
if (!isCanceled()) {
errString = e.what();
}
calcFinished();
}
}

LOG_TRACE("FullProver::thread_calculateProve end");
}

Expand Down Expand Up @@ -182,58 +233,3 @@ json FullProver::getStatus() {
LOG_TRACE("FullProver::getStatus end");
return st;
}


#define ADJ_P(a) *((void **)&a) = (void *)(((char *)circuit)+ (uint64_t)(a))

Circom_Circuit *FullProver::loadCircuit(std::string const &datFileName) {
LOG_TRACE("FullProver::loadCircuit start");
Circom_Circuit *circuitF;
Circom_Circuit *circuit;

int fd;
struct stat sb;

fd = open(datFileName.c_str(), O_RDONLY);
if (fd == -1) {
std::ostringstream ss;
ss << ".dat file not found: " << datFileName << "\n";
LOG_ERROR(ss);
throw std::system_error(errno, std::generic_category(), "open");
}

if (fstat(fd, &sb) == -1) { /* To obtain file size */
throw std::system_error(errno, std::generic_category(), "fstat");
}

circuitF = (Circom_Circuit *)mmap(NULL, sb.st_size, PROT_READ , MAP_PRIVATE, fd, 0);
close(fd);

circuit = (Circom_Circuit *)malloc(sb.st_size);
memcpy((void *)circuit, (void *)circuitF, sb.st_size);

munmap(circuitF, sb.st_size);

ADJ_P(circuit->wit2sig);
ADJ_P(circuit->components);
ADJ_P(circuit->mapIsInput);
ADJ_P(circuit->constants);
ADJ_P(circuit->P);
ADJ_P(circuit->componentEntries);

for (int i=0; i<circuit->NComponents; i++) {
ADJ_P(circuit->components[i].hashTable);
ADJ_P(circuit->components[i].entries);
circuit->components[i].fn = _functionTable[ (uint64_t)circuit->components[i].fn];
}

for (int i=0; i<circuit->NComponentEntries; i++) {
ADJ_P(circuit->componentEntries[i].sizes);
}

LOG_TRACE("FullProver::loadCircuit end");
return circuit;
}



20 changes: 9 additions & 11 deletions src/fullprover.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
#include <nlohmann/json.hpp>
using json = nlohmann::json;

#include <mutex>
#include "alt_bn128.hpp"
#include "groth16.hpp"
#include "calcwit.hpp"
#include "binfile_utils.hpp"
#include "zkey_utils.hpp"

class FullProver {
enum Status {aborted = -2, busy = -1, failed = 0, success = 1, unverified =2, uninitialized=3, initializing=5, ready=6 };
Expand All @@ -17,21 +18,18 @@ class FullProver {
std::string pendingInput;
std::string executingInput;

std::map<std::string, std::unique_ptr<Groth16::Prover<AltBn128::Engine>>> provers;
std::map<std::string, std::unique_ptr<ZKeyUtils::Header>> zkHeaders;
std::map<std::string, std::unique_ptr<BinFileUtils::BinFile>> zKeys;

mpz_t altBbn128r;

json proof;
json pubData;
std::string errString;

Circom_Circuit *circuit;
Circom_CalcWit *calcWit;

AltBn128::FrElement *wtns;
bool canceled;

std::unique_ptr<BinFileUtils::BinFile> zkey;

std::unique_ptr<Groth16::Prover<AltBn128::Engine>> prover;

Circom_Circuit *loadCircuit(std::string const &datFileName);
bool isCanceled();
void calcFinished();
void thread_calculateProve();
Expand All @@ -40,7 +38,7 @@ class FullProver {


public:
FullProver(std::string datFileName, std::string zkeyFileName);
FullProver(std::string zkeyFileNames[], int size);
~FullProver();
void startProve(std::string input);
void abort();
Expand Down
19 changes: 15 additions & 4 deletions src/main_proofserver.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <pistache/router.h>
#include <pistache/endpoint.h>

#include "proverapi.hpp"
#include "fullprover.hpp"
#include "logger.hpp"
Expand All @@ -10,14 +9,25 @@ using namespace Pistache;
using namespace Pistache::Rest;

int main(int argc, char **argv) {
if (argc < 3) {
std::cerr << "Invalid number of parameters:\n";
std::cerr << "Usage: proverServer <port> <circuit1.zkey> <circuit2.zkey> ... <circuitN.zkey> \n";
return -1;
}

Logger::getInstance()->enableConsoleLogging();
Logger::getInstance()->updateLogLevel(LOG_LEVEL_DEBUG);
LOG_INFO("Initializing server...");
int port = std::stoi(argv[1]); // parse port
// parse the zkeys
std::string zkeyFileNames[argc - 2];
for (int i = 0; i < argc - 2; i++) {
zkeyFileNames[i] = argv[i + 2];
}

FullProver fullProver(argv[1], argv[2]);
FullProver fullProver(zkeyFileNames, argc - 2);
ProverAPI proverAPI(fullProver);
Address addr(Ipv4::any(), Port(9080));
Address addr(Ipv4::any(), Port(port));

auto opts = Http::Endpoint::options().threads(1).maxRequestSize(128000000);
Http::Endpoint server(addr);
Expand All @@ -29,6 +39,7 @@ int main(int argc, char **argv) {
Routes::Post(router, "/input", Routes::bind(&ProverAPI::postInput, &proverAPI));
Routes::Post(router, "/cancel", Routes::bind(&ProverAPI::postCancel, &proverAPI));
server.setHandler(router.handler());
LOG_INFO("Server ready on port 9080...");
std::string serverReady("Server ready on port " + std::to_string(port) + "...");
LOG_INFO(serverReady);
server.serve();
}
2 changes: 1 addition & 1 deletion src/main_prover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int main(int argc, char **argv) {

if (argc != 5) {
std::cerr << "Invalid number of parameters:\n";
std::cerr << "Usage: prove <circuit.zkey> <witness.wtns> <proof.json> <public.json>\n";
std::cerr << "Usage: prover <circuit.zkey> <witness.wtns> <proof.json> <public.json>\n";
return -1;
}

Expand Down
6 changes: 1 addition & 5 deletions tasksfile.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,19 @@ function buildPistche() {


function buildProverServer() {
sh("cp " + process.argv[3] + " build/circuit.cpp", {cwd: ".", nopipe: true});
sh("g++" +
" -I."+
" -I../src"+
" -I../depends/pistache/include"+
" -I../depends/json/single_include"+
" -I../depends/ffiasm/c"+
" -I../depends/circom_runtime/c"+
" ../src/main_proofserver.cpp"+
" ../src/proverapi.cpp"+
" ../src/fullprover.cpp"+
" ../src/binfile_utils.cpp"+
" ../src/wtns_utils.cpp"+
" ../src/zkey_utils.cpp"+
" ../src/logger.cpp"+
" ../depends/circom_runtime/c/calcwit.cpp"+
" ../depends/circom_runtime/c/utils.cpp"+
" ../depends/ffiasm/c/misc.cpp"+
" ../depends/ffiasm/c/naf.cpp"+
" ../depends/ffiasm/c/splitparstr.cpp"+
Expand All @@ -56,7 +53,6 @@ function buildProverServer() {
" fq.o"+
" fr.cpp"+
" fr.o"+
" circuit.cpp"+
" -L../depends/pistache/build/src -lpistache"+
" -o proverServer"+
" -fmax-errors=5 -pthread -std=c++17 -fopenmp -lgmp -lsodium -g -DSANITY_CHECK", {cwd: "build", nopipe: true}
Expand Down

0 comments on commit 3c75053

Please sign in to comment.