Skip to content

Commit

Permalink
added special arrays for secret fixed points and fixed an old multipl…
Browse files Browse the repository at this point in the history
…ication bug
  • Loading branch information
rdragos committed Nov 16, 2017
1 parent 9b03d49 commit a4ce145
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 54 deletions.
34 changes: 16 additions & 18 deletions Compiler/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def block():
print_str('-')
cint((positive_left - right + 1) >> val.f).print_reg_plain()
x = 0
max_dec_base = 6 # max 32-bit precision
max_dec_base = 8 # max 32-bit precision
last_nonzero = 0
for i,b in enumerate(reversed(right.bit_decompose(val.f))):
x += b * int(10**max_dec_base / 2**(i + 1))
Expand Down Expand Up @@ -830,7 +830,7 @@ def map_reduce_single(n_parallel, n_loops, initializer, reducer, mem_state=None)
n_parallel = n_parallel or 1
if mem_state is None:
# default to list of MemValues to allow varying types
mem_state = [MemValue(x) for x in initializer()]
mem_state = [type(x).MemValue(x) for x in initializer()]
use_array = False
else:
# use Arrays for multithread version
Expand Down Expand Up @@ -1146,7 +1146,7 @@ def stop_timer(timer_id=0):
# Fixed point ops

from math import ceil, log
from floatingpoint import PreOR, TruncPr, two_power
from floatingpoint import PreOR, TruncPr, two_power, shift_two

def approximate_reciprocal(divisor, k, f, theta):
"""
Expand Down Expand Up @@ -1189,17 +1189,15 @@ def block():
q = MemValue(two_power(k))
e = MemValue(twos_complement(normalized_divisor.read()))

@for_range(theta)
def block(i):
qread = q.read()
eread = e.read()
qread += (qread * eread) >> k
eread = (eread * eread) >> k
qr = q.read()
er = e.read()

q.write(qread)
e.write(eread)
for i in range(theta):
qr = qr + shift_two(qr * er, k)
er = shift_two(er * er, k)

res = q >> (2*k - 2*f - cnt_leading_zeros)
q = qr
res = shift_two(q, (2*k - 2*f - cnt_leading_zeros))

return res

Expand All @@ -1221,19 +1219,18 @@ def cint_cint_division(a, b, k, f):
absolute_b = b * sign_b
absolute_a = a * sign_a
w0 = approximate_reciprocal(absolute_b, k, f, theta)

A = Array(theta, cint)
B = Array(theta, cint)
W = Array(theta, cint)

A[0] = absolute_a
B[0] = absolute_b
W[0] = w0
@for_range(1, theta)
def block(i):
A[i] = (A[i - 1] * W[i - 1]) >> f
B[i] = (B[i - 1] * W[i - 1]) >> f
for i in range(1, theta):
A[i] = shift_two(A[i - 1] * W[i - 1], f)
B[i] = shift_two(B[i - 1] * W[i - 1], f)
W[i] = two - B[i]

return (sign_a * sign_b) * A[theta - 1]

from Compiler.program import Program
Expand All @@ -1257,10 +1254,11 @@ def sint_cint_division(a, b, k, f, kappa):
B[0] = absolute_b
W[0] = w0


@for_range(1, theta)
def block(i):
A[i] = TruncPr(A[i - 1] * W[i - 1], 2*k, f, kappa)
temp = (B[i - 1] * W[i - 1]) >> f
temp = shift_two(B[i - 1] * W[i - 1], f)
# no reading and writing to the same variable in a for loop.
W[i] = two - temp
B[i] = temp
Expand Down
111 changes: 77 additions & 34 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,9 @@ def set_precision(cls, f, k = None):
else:
cls.k = k

def conv(self):
return self.v

@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
res = []
Expand Down Expand Up @@ -1635,16 +1638,11 @@ def __init__(self, v=None, size=None):
self.v = v.v
elif isinstance(v, MemValue):
self.v = v
elif v == None:
self.v = 0

@vectorize
def load_int(self, v):
self.v = cint(v) * (2 ** self.f)

def conv(self):
return self

def store_in_mem(self, address):
self.v.store_in_mem(address)

Expand Down Expand Up @@ -1779,6 +1777,9 @@ def set_precision(cls, f, k = None):
else:
cls.k = k

def conv(self):
return self.v

@classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" Securely obtain shares of n values input by a client.
Expand All @@ -1792,6 +1793,12 @@ def load_mem(cls, address, mem_type=None):
res.append(sint.load_mem(address))
return sfix(*res)

@classmethod
def load_sint(cls, v):
res = cls()
res.load_int(v)
return res

@vectorize_init
def __init__(self, _v=None, size=None):
self.size = get_global_vector_size()
Expand All @@ -1810,19 +1817,16 @@ def __init__(self, _v=None, size=None):
self.v = (1-2*_v.s)*a
elif isinstance(_v, sfix):
self.v = _v.v
elif isinstance(_v, MemValue):
elif isinstance(_v, MemFix):
#this is a memvalue object
self.v = _v
elif _v == None:
self.v = sint(0)
self.v = _v.v
# elif _v == None:
# self.v = sint(0)
self.kappa = sfix.kappa

@vectorize
def load_int(self, v):
self.v = sint(v) << self.f

def conv(self):
return self
self.v = sint(v) * (2**self.f)

def store_in_mem(self, address):
self.v.store_in_mem(address)
Expand All @@ -1844,12 +1848,9 @@ def add(self, other):
@vectorize
def mul(self, other):
other = parse_type(other)
if isinstance(other, sfix):
if isinstance(other, (sfix, cfix)):
val = floatingpoint.TruncPr(self.v * other.v, self.k * 2, self.f, self.kappa)
return sfix(val)
elif isinstance(other, cfix):
res = sfix((self.v * other.v) >> sfix.f)
return res
elif isinstance(other, cfix.scalars):
scalar_fix = cfix(other)
return self * scalar_fix
Expand Down Expand Up @@ -1940,8 +1941,11 @@ def reveal(self):
# (precision n1) 41 + (precision n2) 41 + (stat_sec) 40 = 82 + 40 = 122 <= 128
# with statistical security of 40

sfix.set_precision(20, 41)
cfix.set_precision(20, 41)
fixed_lower = 20
fixed_upper = 40

sfix.set_precision(fixed_lower, fixed_upper)
cfix.set_precision(fixed_lower, fixed_upper)

class sfloat(_number):
""" Shared floating point data type, representing (1 - 2s)*(1 - z)*v*2^p.
Expand Down Expand Up @@ -2356,6 +2360,9 @@ def f(i):
self[i].assign_all(value)
return self

def get_address(self):
return self.address


class SubMultiArray(object):
def __init__(self, sizes, value_type, address, index):
Expand Down Expand Up @@ -2425,6 +2432,37 @@ def __init__(self, rows, columns):
def __getitem__(self, index):
return sfloatArray(self.columns, self.multi_array[index].address)

class sfixArray(Array):
def __init__(self, length, address=None):
self.array = Array(length, sint, address)
self.length = length
self.value_type = sfix

def __getitem__(self, index):
if isinstance(index, slice):
return Array.__getitem__(self, index)
return sfix(*self.array[index])

def __setitem__(self, index, value):
if isinstance(index, slice):
return Array.__setitem__(self, index, value)
self.array[index] = value.v

def get_address(self, index):
return self.array.get_address(index)

class sfixMatrix(Matrix):
def __init__(self, rows, columns, address=None):
self.rows = rows
self.columns = columns
self.multi_array = Matrix(rows, columns, sint, address)

def __getitem__(self, index):
return sfixArray(self.columns, self.multi_array[index].address)

def get_address(self):
return self.multi_array.get_address()

class _mem(_number):
__add__ = lambda self,other: self.read() + other
__sub__ = lambda self,other: self.read() - other
Expand Down Expand Up @@ -2578,28 +2616,22 @@ def read(self):

class MemFix(_mem):
def __init__(self, *args):
arg_type = type(*args)
if arg_type == sfix:
value = sfix(*args)
elif arg_type == cfix:
value = cfix(*args)
else:
raise CompilerError('MemFix init argument error')
self.reg_type = value.reg_type
value = sfix(*args)
self.v = MemValue(value.v)

def write(self, *args):
if self.reg_type == 's':
value = sfix(*args)
else:
value = cfix(*args)
value = sfix(*args)
self.v.write(value.v)

def reveal(self):
return cfix(self.v.reveal())

def read(self):
if self.reg_type == 's':
return sfix(self.v)
val = self.v.read()
if isinstance(val, sint):
return sfix(val)
else:
return cfix(self.v)
return cfix(val)

def getNamedTupleType(*names):
class NamedTuple(object):
Expand Down Expand Up @@ -2641,4 +2673,15 @@ def reveal(self):
sfloat.Matrix = sfloatMatrix
sfloat.MemValue = MemFloat

sfix.Array = sfixArray
sfix.Matrix = sfixMatrix

sfix.MemValue = MemFix

cint.MemValue = MemValue
sint.MemValue = MemValue

sgf2n.MemValue = MemValue
cgf2n.MemValue = MemValue

import library
4 changes: 2 additions & 2 deletions Programs/Source/fixed_point_tutorial.mpc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ n = 10
m = 5

# array of fixed points
A = Array(n, sfix)
A = sfixArray(n)

for i in range(n):
A[i] = sfix(i)
Expand All @@ -18,7 +18,7 @@ for i in range(n):
print_ln('%s', A[i].reveal())

# matrix of fixed points
M = Matrix(n, m, sfix)
M = sfixMatrix(n, m)

for i in range(n):
for j in range(m):
Expand Down

0 comments on commit a4ce145

Please sign in to comment.