Skip to content

Commit

Permalink
Half-gate garbling, native 2D convolution, TensorFlow inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Jun 15, 2020
1 parent 31e43a6 commit 3f9f3be
Show file tree
Hide file tree
Showing 138 changed files with 2,672 additions and 1,199 deletions.
31 changes: 0 additions & 31 deletions BMR/AuthValue.cpp

This file was deleted.

6 changes: 3 additions & 3 deletions BMR/CommonParty.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ class CommonParty
LocalBuffer wires;
ReceivedMsgStore wire_storage;

template<class T, class U>
GC::BreakType first_phase(GC::Program<U>& program, GC::Processor<T>& processor,
template<class T>
GC::BreakType first_phase(GC::Program& program, GC::Processor<T>& processor,
GC::Machine<T>& machine);
template<class T, class U>
GC::BreakType second_phase(GC::Program<T>& program, GC::Processor<T>& processor,
GC::BreakType second_phase(GC::Program& program, GC::Processor<T>& processor,
GC::Machine<T>& machine, U& dynamic_memory);

public:
Expand Down
8 changes: 4 additions & 4 deletions BMR/CommonParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#include "CommonParty.h"

template <class T, class U>
GC::BreakType CommonParty::first_phase(GC::Program<U>& program,
template <class T>
GC::BreakType CommonParty::first_phase(GC::Program& program,
GC::Processor<T>& processor, GC::Machine<T>& machine)
{
(void)machine;
Expand All @@ -20,7 +20,7 @@ GC::BreakType CommonParty::first_phase(GC::Program<U>& program,
GC::BreakType next;
try
{
next = (reinterpret_cast<GC::Program<T>*>(&program))->execute(processor, dynamic_memory);
next = program.execute(processor, dynamic_memory);
}
catch (needs_cleaning& e)
{
Expand All @@ -44,7 +44,7 @@ GC::BreakType CommonParty::first_phase(GC::Program<U>& program,
}

template<class T, class U>
GC::BreakType CommonParty::second_phase(GC::Program<T>& program,
GC::BreakType CommonParty::second_phase(GC::Program& program,
GC::Processor<T>& processor, GC::Machine<T>& machine,
U& dynamic_memory)
{
Expand Down
2 changes: 1 addition & 1 deletion BMR/Key.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Key {
void serialize(SendBuffer& output) const { output.serialize(r); }
void serialize_no_allocate(SendBuffer& output) const { output.serialize_no_allocate(r); }

bool get_signal() const { return _mm_cvtsi128_si64(r) & 1; }
bool get_signal() const { return _mm_cvtsi128_si32(r) & 1; }
void set_signal(bool signal);

Key doubling(int i) const;
Expand Down
103 changes: 0 additions & 103 deletions BMR/Machine.cpp

This file was deleted.

7 changes: 0 additions & 7 deletions BMR/Party.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,6 @@ void FakeProgramParty::receive_spdz_wires(ReceivedMsg& msg)
void ProgramParty::store_wire(const Register& reg)
{
wires.serialize(reg.key(get_id(), 0));
#ifndef FREE_XOR
wires.serialize(reg.key(get_id(), 1));
#endif
#ifdef DEBUG
cout << "storing wire" << endl;
reg.print();
Expand All @@ -394,11 +391,7 @@ void ProgramParty::store_wire(const Register& reg)
void ProgramParty::load_wire(Register& reg)
{
wires.unserialize(reg.key(get_id(), 0));
#ifdef FREE_XOR
reg.key(get_id(), 1) = reg.key(get_id(), 0) ^ get_delta();
#else
wires.unserialize(reg.key(get_id(), 1));
#endif
#ifdef DEBUG
cout << "loading wire" << endl;
reg.print();
Expand Down
6 changes: 1 addition & 5 deletions BMR/Party.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ProgramParty : virtual public CommonParty, virtual public PartyProperties,

GC::Machine< GC::Secret<EvalRegister> > machine;
GC::Processor<GC::Secret<EvalRegister> > processor;
GC::Program<GC::Secret<EvalRegister> > program;
GC::Program program;

GC::Machine< GC::Secret<PRFRegister> > prf_machine;
GC::Processor<GC::Secret<PRFRegister> > prf_processor;
Expand Down Expand Up @@ -170,11 +170,7 @@ class ProgramPartySpec : public ProgramParty
void get_spdz_wire(SpdzOp op, DualWire<T>& spdz_wire);
};

#ifdef SPDZ_AUTH
typedef ProgramPartySpec<Share<gf2n_long>> FakeProgramPartySuper;
#else
typedef ProgramPartySpec<GC::Memory<AuthValue>> FakeProgramPartySuper;
#endif

class FakeProgramParty : virtual public BaseParty, virtual public FakeProgramPartySuper
{
Expand Down
4 changes: 2 additions & 2 deletions BMR/ProgramParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ void ProgramPartySpec<T>::load(string progname)
program.parse(progname + "-0");
machine.reset(program, dynamic_memory);
processor.reset(program);
prf_machine.reset(*reinterpret_cast<GC::Program<GC::Secret<PRFRegister> >* >(&program));
prf_processor.reset(*reinterpret_cast<GC::Program<GC::Secret<PRFRegister> >* >(&program));
prf_machine.reset(program);
prf_processor.reset(program);
}

template<class T>
Expand Down
7 changes: 0 additions & 7 deletions BMR/RealProgramParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
P = new CryptoPlayer(N, 0);

delta = prng.get_doubleword();
#ifdef KEY_SIGNAL
delta.set_signal(1);
#endif
#ifdef VERBOSE
cerr << "delta: " << delta << endl;
#endif
Expand Down Expand Up @@ -201,16 +199,11 @@ RealProgramParty<T>::~RealProgramParty()
template<class T>
void RealProgramParty<T>::receive_keys(Register& reg)
{
#ifndef FREE_XOR
#error not implemented
#endif
auto& _id = this->_id;
auto& _N = this->_N;
reg.init(_N);
reg.keys[0][_id - 1] = this->prng.get_doubleword();
#ifdef KEY_SIGNAL
reg.keys[0][_id - 1].set_signal(0);
#endif
reg.keys[1][_id - 1] = reg.keys[0][_id - 1] ^ this->get_delta();
}

Expand Down
61 changes: 0 additions & 61 deletions BMR/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@ void Register::init(int rfd, int n_parties) {
mask = mask>0 ? 1 : 0;
keys.init(n_parties);
keys.randomize();
#ifdef KEY_SIGNAL
for (int i = 0; i < 2; i++)
for (size_t j = 0; j < keys[i].size(); j++)
if (keys[i][j].get_signal() != i)
keys[i][j] ^= Key(1);
#endif
}

void Register::set_eval_keys()
Expand Down Expand Up @@ -284,21 +282,7 @@ void Register::eval(const Register& left, const Register& right, GarbledGate& ga
// }
// std::cout << std::endl;

#ifdef KEY_SIGNAL
external = garbled_entry[my_id - 1].get_signal();
#else
if(garbled_entry[my_id-1] == key(my_id, 0)) {
external = 0;
} else if (garbled_entry[my_id-1] == key(my_id, 1)) {
external = 1;
} else {
printf("\nERROR!!!\n");
cout << "got key: " << garbled_entry[my_id - 1] << endl;
cout << "possibilities: " << key(my_id, 0) << " " << key(my_id, 1) << endl;
throw std::invalid_argument("result key doesn't fit any of my keys");
// return NO_SIGNAL;
}
#endif

#ifdef DEBUG_MASK
cout << "output signal: " << (int)external << endl;
Expand Down Expand Up @@ -680,9 +664,7 @@ void RandomRegister::randomize()
party.random_timer.start();
init(party.randomfd, party._N);
party.random_timer.stop();
#ifdef FREE_XOR
keys[1] = keys[0] ^ party.get_deltas();
#endif
party.add_keys(*this);
}

Expand Down Expand Up @@ -764,16 +746,13 @@ void EvalRegister::output()
ProgramParty& party = ProgramParty::s();
party.load_wire(*this);
set_mask(party.output_masks.pop_front());
#ifdef KEY_SIGNAL
#ifdef DEBUG_REGS
cout << "check " << get_id() << endl;
#endif
check_signal_key(party.get_id(), garbled_entry);
#endif
party.taint();
}

#ifdef FREE_XOR
void RandomRegister::XOR(const Register& left, const Register& right)
{
mask = left.get_mask() ^ right.get_mask();
Expand Down Expand Up @@ -824,46 +803,6 @@ void EvalRegister::XOR(const Register& left, const Register& right)
<< " ^ " << right.get_garbled_entry()[i] << endl;
#endif
}
#endif

void EvalRegister::check(const int128& value, word share, int128 mac)
{
#ifdef DEBUG_DYNAMIC
cout << "check result " << value << endl;
#endif
if (value != 0)
{
cout << "MAC check: " << value << " " << share<< " " << mac << endl;
throw runtime_error("MAC check failed");
}
}

void EvalRegister::get_dyn_mask(GC::Mask& mask, int length, int mac_length)
{
mask.share = CommonParty::s().prng.get_word() & ((1ULL << length) - 1);
mask.mac = int128(CommonParty::s().prng.get_doubleword())
& int128::ones(mac_length);
#ifdef DEBUG_DYNAMIC
cout << "mask " << hex << mask.share << " " << mask.mac << " ";
cout << ((1ULL << length) - 1) << " " << int128::ones(mac_length) << endl;
#endif
}

void EvalRegister::unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share,
word masked, int128 masked_mac)
{
dest.share = mask_share;
dest.mac = mac_mask_share;
if (ProgramParty::s()._id == 1)
{
dest.share ^= masked;
dest.mac ^= masked_mac;
}
#ifdef DEBUG_DYNAMIC
cout << dest.share << " ?= " << mask_share << " ^ " << masked << endl;
cout << dest.mac << " ?= " << mac_mask_share << " ^ " << masked_mac << endl;
#endif
}

template <>
void RandomRegister::store(NoMemory& mem,
Expand Down
Loading

0 comments on commit 3f9f3be

Please sign in to comment.