forked from data61/MP-SPDZ
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tutorial.mpc
133 lines (91 loc) · 2.73 KB
/
tutorial.mpc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# sint: secret integers
# you can assign public numbers to sint
a = sint(1)
b = sint(2)
def test(actual, expected):
# you can reveal a number in order to print it
actual = actual.reveal()
print_ln('expected %s, got %s', expected, actual)
# some arithmetic works as expected
test(a + b, 3)
test(a * b, 2)
test(a - b, -1)
# but division doesn't, don't do the following
# test(b / a, 2)
# comparisons produce 1 for true and 0 for false
test(a < b, 1)
test(a <= b, 1)
test(a >= b, 0)
test(a > b, 0)
test(a == b, 0)
test(a != b, 1)
# if_else() can be used instead of branching
# let's find out the larger number
test((a < b).if_else(b, a), 2)
# arrays and loops work as follows
a = Array(100, sint)
@for_range(100)
def f(i):
a[i] = sint(i) * sint(i - 1)
test(a[99], 99 * 98)
# if you use loops, use Array to store results
# don't do this
# @for_range(100)
# def f(i):
# a = sint(i)
# test(a, 99)
# sfix: fixed-point numbers
# set the precision after the dot and in total
sfix.set_precision(16, 32)
# and the output precision in decimal digits
print_float_precision(4)
# you can do all basic arithmetic with sfix, including division
a = sfix(2)
b = sfix(-0.1)
test(a + b, 1.9)
test(a - b, 2.1)
test(a * b, -0.2)
test(a / b, -20)
test(a < b, 0)
test(a <= b, 0)
test(a >= b, 1)
test(a > b, 1)
test(a == b, 0)
test(a != b, 1)
test((a < b).if_else(a, b), -0.1)
# now let's do a computation with private inputs
# party 0 supplies three number and party 1 supplies three percentages
# we want to compute the weighted mean
print_ln('Party 0: please input three numbers not adding up to zero')
print_ln('Party 1: please input any three numbers')
data = Matrix(3, 2, sfix)
# use Python loops for compile-time optimization
for i in range(3):
for j in range(2):
data[i][j] = sfix.from_sint(sint.get_input_from(j))
# compute weighted average
weight_total = sum(point[0] for point in data)
result = sum(point[0] * point[1] for point in data) / weight_total
# the following only works with arithmetic circuits
# @if_e((sum(point[0] for point in data) != 0).reveal())
# def _():
# print_ln('weighted average: %s', result.reveal())
# @else_
# def _():
# print_ln('your inputs made no sense')
# so we output even an invalid result (the weights adding up to zero)
print_ln('weighted average: %s', result.reveal())
# but we warn the user
# note that the we don't reveal the weight sum, only the comparison
print_ln_if((sum(point[0] for point in data) == 0).reveal(), \
'but the inputs were invalid (weights add up to zero)')
# permutation matrix
M = Matrix(2, 2, sfix)
M[0][0] = 0
M[1][0] = 1
M[0][1] = 1
M[1][1] = 0
# matrix multiplication
M = data * M
test(M[0][0], data[0][1].reveal())
test(M[1][1], data[1][0].reveal())