Skip to content

Commit

Permalink
Avoid race condition depending of the order of client connections.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Nov 25, 2019
1 parent 470b075 commit 9a83cfe
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 130 deletions.
28 changes: 13 additions & 15 deletions ExternalIO/bankers-bonus-client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
* the bankers_bonus.mpc program.
*
* Each connecting client:
* - sends a unique id to identify the client
* - sends an integer input (bonus value to compare)
* - sends an increasing id to identify the client, starting with 0
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
* - sends an integer input (bonus value to compare)
*
* The result is returned authenticated with a share of a random value:
* - share of winning unique id [y]
Expand All @@ -22,13 +22,13 @@
* To run with 2 parties / SPDZ engines:
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./compile.py bankers_bonus
* ./Scripts/run-online bankers_bonus to run the engines.
* ./Scripts/run-online.sh bankers_bonus to run the engines.
*
* ./bankers-bonus-client.x 123 2 100 0
* ./bankers-bonus-client.x 456 2 200 0
* ./bankers-bonus-client.x 789 2 50 1
* ./bankers-bonus-client.x 0 2 100 0
* ./bankers-bonus-client.x 1 2 200 0
* ./bankers-bonus-client.x 2 2 50 1
*
* Expect winner to be second client with id 456.
* Expect winner to be second client with id 1.
*/

#include "Math/gfp.h"
Expand All @@ -46,7 +46,7 @@
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(vector<gfp>& values, vector<int>& sockets, int nparties)
void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int nparties)
{
int num_inputs = values.size();
octetStream os;
Expand Down Expand Up @@ -172,17 +172,15 @@ int main(int argc, char** argv)
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
send(sockets[i], (octet*) &my_client_id, sizeof(int));
octetStream os;
os.store(finish);
os.Send(sockets[i]);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;

// Map inputs into gfp
vector<gfp> input_values_gfp(3);
input_values_gfp[0].assign(my_client_id);
input_values_gfp[1].assign(salary_value);
input_values_gfp[2].assign(finish);

// Run the commputation
send_private_inputs(input_values_gfp, sockets, nparties);
send_private_inputs({salary_value}, sockets, nparties);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;

// Get the result back (client_id of winning client)
Expand Down
23 changes: 11 additions & 12 deletions ExternalIO/bankers-bonus-commsec-client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
* the bankers_bonus.mpc program.
*
* Each connecting client:
* - sends an increasing id to identify the client, starting with 0
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
* - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security.
* - sends a unique id to identify the client
* - sends an integer input (bonus value to compare)
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
*
* The result is returned authenticated with a share of a random value:
* - share of winning unique id [y]
Expand All @@ -24,7 +24,7 @@
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients.
* ./compile.py bankers_bonus_commsec
* ./Scripts/run-online bankers_bonus_commsec to run the engines.
* ./Scripts/run-online.sh bankers_bonus_commsec to run the engines.
*
* ./bankers-bonus-commsec-client.x 0 2 100 0
* ./bankers-bonus-commsec-client.x 1 2 200 0
Expand Down Expand Up @@ -139,7 +139,7 @@ pair< vector<octet>, vector<octet> > sts_initiator_role(sign_key_container_t key
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(vector<gfp>& values, vector<int>& sockets, int nparties,
void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int nparties,
commsec_t commsec, vector<octet*>& keys)
{
int num_inputs = values.size();
Expand Down Expand Up @@ -380,20 +380,19 @@ int main(int argc, char** argv)
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
send(sockets[i], (octet*) &my_client_id, sizeof(int));
octetStream os;
os.store(finish);
os.Send(sockets[i]);

send_public_key(sts_key.client_publickey_ints, sockets[i]);
send_public_key(client_public_key_ints, sockets[i]);
commseckey[i] = sts_initiator_role(sts_key, sockets, i);
commseckey[i] = sts_initiator_role(sts_key, sockets, i);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;

// Map inputs into gfp
vector<gfp> input_values_gfp(3);
input_values_gfp[0].assign(my_client_id);
input_values_gfp[1].assign(salary_value);
input_values_gfp[2].assign(finish);

// Send the inputs to the SPDZ Engines
send_private_inputs(input_values_gfp, sockets, nparties, commseckey, session_keys);
send_private_inputs({salary_value}, sockets, nparties, commseckey, session_keys);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;

// Get the result back
Expand Down
28 changes: 8 additions & 20 deletions Networking/ServerSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ void ServerSocket::accept_clients()
}

data_signal.lock();
process_client(client_id);
clients[client_id] = consocket;
data_signal.broadcast();
data_signal.unlock();
Expand Down Expand Up @@ -157,8 +158,6 @@ void* anonymous_accept_thread(void* server_socket)
return 0;
}

int AnonymousServerSocket::global_client_socket_count = 0;

void AnonymousServerSocket::init()
{
pthread_create(&thread, 0, anonymous_accept_thread, this);
Expand All @@ -169,22 +168,12 @@ int AnonymousServerSocket::get_connection_count()
return num_accepted_clients;
}

void AnonymousServerSocket::accept_clients()
void AnonymousServerSocket::process_client(int client_id)
{
while (true)
{
struct sockaddr dest;
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
int socksize = sizeof(dest);
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
if (consocket<0) { error("set_up_socket:accept"); }

data_signal.lock();
client_connection_queue.push(consocket);
num_accepted_clients++;
data_signal.broadcast();
data_signal.unlock();
}
if (clients.find(client_id) != clients.end())
close_client_socket(clients[client_id]);
num_accepted_clients++;
client_connection_queue.push(client_id);
}

int AnonymousServerSocket::get_connection_socket(int& client_id)
Expand All @@ -195,10 +184,9 @@ int AnonymousServerSocket::get_connection_socket(int& client_id)
while (client_connection_queue.empty())
data_signal.wait();

client_id = global_client_socket_count;
global_client_socket_count++;
int client_socket = client_connection_queue.front();
client_id = client_connection_queue.front();
client_connection_queue.pop();
int client_socket = clients[client_id];
data_signal.unlock();
return client_socket;
}
12 changes: 4 additions & 8 deletions Networking/ServerSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class ServerSocket
// disable copying
ServerSocket(const ServerSocket& other);

// receive id from client
int assign_client_id(int consocket);
virtual void process_client(int) {}

public:
ServerSocket(int Portnum);
Expand All @@ -55,23 +54,20 @@ class ServerSocket
class AnonymousServerSocket : public ServerSocket
{
private:
// Global no. of client sockets that have been returned - used to create identifiers
static int global_client_socket_count;
// No. of accepted connections in this instance
int num_accepted_clients;
queue<int> client_connection_queue;

void process_client(int client_id);

public:
AnonymousServerSocket(int Portnum) :
ServerSocket(Portnum), num_accepted_clients(0) { };
// override so clients do not send id
void accept_clients();
void init();

virtual int get_connection_count();

// Get socket for the last client who connected
// Writes a unique client identifier (i.e. a counter) to client_id
// Get socket and id for the last client who connected
int get_connection_socket(int& client_id);
};

Expand Down
10 changes: 10 additions & 0 deletions Processor/ExternalClients.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ int ExternalClients::get_client_connection(int portnum_base)
int client_id, socket;
socket = client_connection_servers[portnum_base]->get_connection_socket(client_id);
external_client_sockets[client_id] = socket;
if (symmetric_client_keys[client_id] != 0)
delete symmetric_client_keys[client_id];
symmetric_client_commsec_send_keys.erase(client_id);
symmetric_client_commsec_recv_keys.erase(client_id);
cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl;
return client_id;
}
Expand Down Expand Up @@ -175,3 +179,9 @@ int ExternalClients::get_party_num()
return party_num;
}

int ExternalClients::get_socket(int id)
{
if (external_client_sockets.find(id) == external_client_sockets.end())
throw runtime_error("external connection not found for id " + to_string(id));
return external_client_sockets[id];
}
5 changes: 3 additions & 2 deletions Processor/ExternalClients.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ class ExternalClients
bool server_keys_loaded = false;
bool ed25519_keys_loaded = false;

// Maps holding per client values (indexed by unique 32-bit id)
std::map<int,int> external_client_sockets;

public:

unsigned char server_publickey_ed25519[crypto_sign_ed25519_PUBLICKEYBYTES];
unsigned char server_secretkey_ed25519[crypto_sign_ed25519_SECRETKEYBYTES];

// Maps holding per client values (indexed by unique 32-bit id)
std::map<int,int> external_client_sockets;
std::map<int,octet*> symmetric_client_keys;
std::map<int,pair<vector<octet>,uint64_t>> symmetric_client_commsec_send_keys;
std::map<int,pair<vector<octet>,uint64_t>> symmetric_client_commsec_recv_keys;
Expand Down
52 changes: 10 additions & 42 deletions Processor/Processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ template<class sint, class sgf2n>
void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs,
int socket_id, int message_type, const vector<int>& registers)
{
if (socket_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << socket_id << endl;
return;
}
int m = registers.size();
socket_stream.reset_write_head();

Expand Down Expand Up @@ -144,7 +139,7 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyT
// Apply STS commsec encryption if session keys have been created.
try {
maybe_encrypt_sequence(socket_id);
socket_stream.Send(external_clients.external_client_sockets[socket_id]);
socket_stream.Send(external_clients.get_socket(socket_id));
}
catch (bad_value& e) {
cerr << "Send error thrown when writing " << m << " values of type " << reg_type << " to socket id "
Expand All @@ -157,15 +152,9 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyT
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_ints(int client_id, const vector<int>& registers)
{
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
return;
}

int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.external_client_sockets[client_id]);
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);
for (int i = 0; i < m; i++)
{
Expand All @@ -179,15 +168,9 @@ void Processor<sint, sgf2n>::read_socket_ints(int client_id, const vector<int>&
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_vector(int client_id, const vector<int>& registers)
{
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
return;
}

int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.external_client_sockets[client_id]);
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);
for (int i = 0; i < m; i++)
{
Expand All @@ -199,14 +182,9 @@ void Processor<sint, sgf2n>::read_socket_vector(int client_id, const vector<int>
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_private(int client_id, const vector<int>& registers, bool read_macs)
{
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
return;
}
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.external_client_sockets[client_id]);
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);

map<int,octet*>::iterator it = external_clients.symmetric_client_keys.find(client_id);
Expand Down Expand Up @@ -251,11 +229,6 @@ void Processor<sint, sgf2n>::init_secure_socket_internal(int client_id, const ve
if(registers.size() != 8) {
throw "Invalid call to init_secure_socket.";
}
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
throw "No socket connection exists for client";
}

// Extract client long term public key into bytes
vector<int> client_public_key (registers.size(), 0);
Expand All @@ -269,15 +242,15 @@ void Processor<sint, sgf2n>::init_secure_socket_internal(int client_id, const ve
m1 = ke.send_msg1();
socket_stream.reset_write_head();
socket_stream.append(m1.bytes, sizeof m1.bytes);
socket_stream.Send(external_clients.external_client_sockets[client_id]);
socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id],
socket_stream.Send(external_clients.get_socket(client_id));
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
96);
socket_stream.consume(m2.pubkey, sizeof m2.pubkey);
socket_stream.consume(m2.sig, sizeof m2.sig);
m3 = ke.recv_msg2(m2);
socket_stream.reset_write_head();
socket_stream.append(m3.bytes, sizeof m3.bytes);
socket_stream.Send(external_clients.external_client_sockets[client_id]);
socket_stream.Send(external_clients.get_socket(client_id));

// Use results of STS to generate send and receive keys.
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
Expand Down Expand Up @@ -323,11 +296,6 @@ void Processor<sint, sgf2n>::resp_secure_socket_internal(int client_id, const ve
if(registers.size() != 8) {
throw "Invalid call to init_secure_socket.";
}
if (client_id >= (int)external_clients.external_client_sockets.size())
{
cerr << "No socket connection exists for client id " << client_id << endl;
throw "No socket connection exists for client";
}
vector<int> client_public_key (registers.size(), 0);
for(unsigned int i = 0; i < registers.size(); i++) {
client_public_key[i] = (int&)get_Ci_ref(registers[i]);
Expand All @@ -337,16 +305,16 @@ void Processor<sint, sgf2n>::resp_secure_socket_internal(int client_id, const ve
// Start Station to Station Protocol for the responder
STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519);
socket_stream.reset_read_head();
socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id],
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
32);
socket_stream.consume(m1.bytes, sizeof m1.bytes);
m2 = ke.recv_msg1(m1);
socket_stream.reset_write_head();
socket_stream.append(m2.pubkey, sizeof m2.pubkey);
socket_stream.append(m2.sig, sizeof m2.sig);
socket_stream.Send(external_clients.external_client_sockets[client_id]);
socket_stream.Send(external_clients.get_socket(client_id));

socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id],
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
64);
socket_stream.consume(m3.bytes, sizeof m3.bytes);
ke.recv_msg3(m3);
Expand Down
Loading

0 comments on commit 9a83cfe

Please sign in to comment.