Skip to content

Commit

Permalink
Various improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Aug 24, 2020
1 parent cf1719b commit ad583af
Show file tree
Hide file tree
Showing 149 changed files with 2,879 additions and 851 deletions.
2 changes: 1 addition & 1 deletion BMR/CommonParty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ CommonFakeParty::~CommonFakeParty()

CommonParty::~CommonParty()
{
cerr << "Total time: " << timer.elapsed() << endl;
#ifdef VERBOSE
cerr << "Wire storage: " << 1e-9 * wires.capacity() << " GB" << endl;
cerr << "CPU time: " << cpu_timer.elapsed() << endl;
cerr << "First phase time: " << timers[0].elapsed() << endl;
cerr << "Second phase time: " << timers[1].elapsed() << endl;
cerr << "Number of gates: " << gate_counter << endl;
#endif
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
}

void CommonParty::check(int n_parties)
Expand Down
2 changes: 2 additions & 0 deletions BMR/ProgramParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "Party.h"

#include "GC/ShareSecret.hpp"

template<class T>
ProgramPartySpec<T>* ProgramPartySpec<T>::singleton = 0;

Expand Down
26 changes: 26 additions & 0 deletions BMR/RealGarbleWire.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,32 @@
#include "Register.h"

template<class T> class RealProgramParty;
template<class T> class RealGarbleWire;

template<class T>
class GarbleInputter
{
public:
RealProgramParty<T>& party;

Bundle<octetStream> oss;
PointerVector<pair<RealGarbleWire<T>*, int>> tuples;

GarbleInputter();
void exchange();
};

template<class T>
class RealGarbleWire : public PRFRegister
{
friend class RealProgramParty<T>;
friend class GarbleInputter<T>;

T mask;

public:
typedef GarbleInputter<T> Input;

static void store(NoMemory& dest,
const vector<GC::WriteAccess<GC::Secret<RealGarbleWire>>>& accesses);
static void load(vector<GC::ReadAccess<GC::Secret<RealGarbleWire>>>& accesses,
Expand All @@ -26,6 +43,11 @@ class RealGarbleWire : public PRFRegister
static void convcbit(Integer& dest, const GC::Clear& source,
GC::Processor<GC::Secret<RealGarbleWire>>& processor);

static void inputb(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
const vector<int>& args);
static void inputbvec(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
ProcessorBase& input_processor, const vector<int>& args);

RealGarbleWire(const Register& reg) : PRFRegister(reg) {}

void garble(PRFOutputs& prf_output, const RealGarbleWire<T>& left,
Expand All @@ -37,6 +59,10 @@ class RealGarbleWire : public PRFRegister
void public_input(bool value);
void random();
void output();

void my_input(Input& Inputter, bool value, int n_bits);
void other_input(Input& Inputter, int from);
void finalize_input(Input&, int from, int);
};

template<class T>
Expand Down
111 changes: 88 additions & 23 deletions BMR/RealGarbleWire.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,53 +94,118 @@ void RealGarbleWire<T>::XOR(const RealGarbleWire<T>& left, const RealGarbleWire<
}

template<class T>
void RealGarbleWire<T>::input(party_id_t from, char input)
void RealGarbleWire<T>::inputb(
GC::Processor<GC::Secret<RealGarbleWire>>& processor,
const vector<int>& args)
{
GarbleInputter<T> inputter;
processor.inputb(inputter, processor, args,
inputter.party.P->my_num());
}

template<class T>
void RealGarbleWire<T>::inputbvec(
GC::Processor<GC::Secret<RealGarbleWire>>& processor,
ProcessorBase& input_processor, const vector<int>& args)
{
GarbleInputter<T> inputter;
processor.inputbvec(inputter, input_processor, args,
inputter.party.P->my_num());
}

template<class T>
GarbleInputter<T>::GarbleInputter() :
party(RealProgramParty<T>::s()), oss(*party.P)
{
}

template<class T>
void RealGarbleWire<T>::my_input(Input& inputter, bool, int n_bits)
{
assert(n_bits == 1);
inputter.tuples.push_back({this, inputter.party.P->my_num()});
}

template<class T>
void RealGarbleWire<T>::other_input(Input& inputter, int from)
{
inputter.tuples.push_back({this, from});
}

template<class T>
void GarbleInputter<T>::exchange()
{
PRFRegister::input(from, input);
auto& party = RealProgramParty<T>::s();
assert(party.shared_proc != 0);
auto& inputter = party.shared_proc->input;
inputter.reset(from - 1);
if (from == party.get_id())
inputter.reset_all(*party.P);
for (auto& tuple : tuples)
{
char my_mask;
my_mask = party.prng.get_bit();
party.input_masks.serialize(my_mask);
inputter.add_mine(my_mask);
inputter.send_mine();
mask = inputter.finalize_mine();
int from = tuple.second;
party_id_t from_id = from + 1;
tuple.first->PRFRegister::input(from_id, -1);
if (from_id == party.get_id())
{
char my_mask;
my_mask = party.prng.get_bit();
party.garble_input_masks.serialize(my_mask);
inputter.add_mine(my_mask);
#ifdef DEBUG_MASK
cout << "my mask: " << (int)my_mask << endl;
cout << "my mask: " << (int)my_mask << endl;
#endif
}
else
{
inputter.add_other(from);
}
}
else
{
inputter.add_other(from - 1);
octetStream os;
party.P->receive_player(from - 1, os, true);
inputter.finalize_other(from - 1, mask, os);
}

inputter.exchange();

for (auto& tuple : tuples)
tuple.first->mask = (inputter.finalize(tuple.second));

// important to make sure that mask is a bit
try
{
mask.force_to_bit();
for (auto& tuple : tuples)
tuple.first->mask.force_to_bit();
}
catch (not_implemented& e)
{
assert(party.P != 0);
assert(party.MC != 0);
auto& protocol = party.shared_proc->protocol;
protocol.init_mul(party.shared_proc);
protocol.prepare_mul(mask, T::constant(1, party.P->my_num(), party.mac_key) - mask);
for (auto& tuple : tuples)
protocol.prepare_mul(tuple.first->mask,
T::constant(1, party.P->my_num(), party.mac_key)
- tuple.first->mask);
protocol.exchange();
if (party.MC->open(protocol.finalize_mul(), *party.P) != 0)
vector<T> to_check;
to_check.reserve(tuples.size());
for (size_t i = 0; i < tuples.size(); i++)
{
to_check.push_back(protocol.finalize_mul());
}
try
{
party.MC->CheckFor(0, to_check, *party.P);
}
catch (mac_fail&)
{
throw runtime_error("input mask not a bit");
}
}
#ifdef DEBUG_MASK
cout << "shared mask: " << party.MC->POpen(mask, *party.P) << endl;
#endif
}

template<class T>
void RealGarbleWire<T>::finalize_input(GarbleInputter<T>&, int, int)
{
}

template<class T>
void RealGarbleWire<T>::public_input(bool value)
{
Expand Down Expand Up @@ -169,7 +234,7 @@ void RealGarbleWire<T>::output()
assert(party.MC != 0);
assert(party.P != 0);
auto m = party.MC->open(mask, *party.P);
party.output_masks.push_back(m.get_bit(0));
party.garble_output_masks.push_back(m.get_bit(0));
party.taint();
#ifdef DEBUG_MASK
cout << "output mask: " << m << endl;
Expand Down
7 changes: 7 additions & 0 deletions BMR/RealProgramParty.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class RealProgramParty : public ProgramPartySpec<T>
typedef typename T::Input Inputter;

friend class RealGarbleWire<T>;
friend class GarbleInputter<T>;
friend class GarbleJob<T>;

static RealProgramParty* singleton;
Expand All @@ -40,9 +41,15 @@ class RealProgramParty : public ProgramPartySpec<T>

GC::BreakType next;

bool one_shot;

size_t data_sent;

public:
static RealProgramParty& s();

LocalBuffer garble_input_masks, garble_output_masks;

RealProgramParty(int argc, const char** argv);
~RealProgramParty();

Expand Down
36 changes: 29 additions & 7 deletions BMR/RealProgramParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,20 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
"-N", // Flag token.
"--nparties" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Evaluate only after garbling.", // Help description.
"-O", // Flag token.
"--oneshot" // Flag token.
);
opt.parse(argc, argv);
int nparties;
opt.get("-N")->getInt(nparties);
this->check(nparties);
one_shot = opt.isSet("-O");

NetworkOptions network_opts(opt, argc, argv);
OnlineOptions& online_opts = OnlineOptions::singleton;
Expand Down Expand Up @@ -90,7 +100,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
mac_key.randomize(prng);
if (T::needs_ot)
BaseMachine::s().ot_setups.push_back({*P, true});
prep = Preprocessing<T>::get_live_prep(0, usage);
prep = new typename T::TriplePrep(0, usage);
}
else
{
Expand Down Expand Up @@ -122,14 +132,22 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
for (int i = 0; i < SPDZ_OP_N; i++)
this->spdz_wires[i].push_back({});

this->timer.reset();
do
{
next = GC::TIME_BREAK;
garble();
try
{
this->online_timer.start();
this->start_online_round();
if (one_shot)
this->start_online_round();
else
{
this->load_garbled_circuit();
next = this->second_phase(program, this->processor,
this->machine, this->dynamic_memory);
}
this->online_timer.stop();
}
catch (needs_cleaning& e)
Expand All @@ -139,6 +157,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
while (next != GC::DONE_BREAK);

MC->Check(*P);
data_sent = P->comm_stats.total_data() + prep->data_sent();

if (server)
delete server;
Expand All @@ -152,7 +171,7 @@ void RealProgramParty<T>::garble()
auto& program = this->program;
auto& MC = this->MC;

while (next == GC::TIME_BREAK)
do
{
garble_jobs.clear();
garble_inputter->reset_all(*P);
Expand All @@ -178,13 +197,15 @@ void RealProgramParty<T>::garble()
vector<typename T::clear> opened;
MC->POpen(opened, wires, *P);

LocalBuffer garbled_circuit;
for (auto& x : opened)
this->garbled_circuit.serialize(x);
garbled_circuit.serialize(x);

this->garbled_circuits.push_and_clear(this->garbled_circuit);
this->input_masks_store.push_and_clear(this->input_masks);
this->output_masks_store.push_and_clear(this->output_masks);
this->garbled_circuits.push_and_clear(garbled_circuit);
this->input_masks_store.push_and_clear(garble_input_masks);
this->output_masks_store.push_and_clear(garble_output_masks);
}
while (one_shot and next == GC::TIME_BREAK);
}

template<class T>
Expand All @@ -194,6 +215,7 @@ RealProgramParty<T>::~RealProgramParty()
delete prep;
delete garble_inputter;
delete garble_protocol;
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
}

template<class T>
Expand Down
Loading

0 comments on commit ad583af

Please sign in to comment.