forked from minitorch/Module-0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtesting.py
146 lines (113 loc) · 2.76 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import minitorch.operators as operators
class MathTest:
@staticmethod
def neg(a):
return -a
@staticmethod
def addConstant(a):
return 5 + a
@staticmethod
def subConstant(a):
return a - 5
@staticmethod
def mult(a):
return 5 * a
@staticmethod
def div(a):
return a / 5
@staticmethod
def inv(a):
return operators.inv(a + 3.5)
@staticmethod
def sig(a):
return operators.sigmoid(a)
@staticmethod
def log(a):
return operators.log(a + 100000)
@staticmethod
def relu(a):
return operators.relu(a + 5.5)
@staticmethod
def exp(a):
return operators.exp(a - 200)
@staticmethod
def add2(a, b):
return a + b
@staticmethod
def mul2(a, b):
return a * b
@staticmethod
def div2(a, b):
return a / (b + 5.5)
@staticmethod
def gt2(a, b):
return operators.lt(b, a + 1.2)
@staticmethod
def lt2(a, b):
return operators.lt(a + 1.2, b)
@staticmethod
def eq2(a, b):
return operators.eq(a, (b + 5.5))
@staticmethod
def sum_red(a):
return operators.sum(a)
@staticmethod
def mean_red(a):
return operators.sum(a) / float(len(a))
@staticmethod
def mean_full_red(a):
return operators.sum(a) / float(len(a))
@classmethod
def _tests(cls):
"""
Returns a list of all the math tests.
"""
one_arg = []
two_arg = []
red_arg = []
for k in dir(MathTest):
if callable(getattr(MathTest, k)) and not k.startswith("_"):
base_fn = getattr(MathTest, k)
scalar_fn = getattr(cls, k)
tup = (k, base_fn, scalar_fn)
if k.endswith("2"):
two_arg.append(tup)
elif k.endswith("red"):
red_arg.append(tup)
else:
one_arg.append(tup)
return one_arg, two_arg, red_arg
class MathTestVariable(MathTest):
@staticmethod
def inv(a):
return 1.0 / (a + 3.5)
@staticmethod
def sig(x):
return x.sigmoid()
@staticmethod
def log(x):
return (x + 100000).log()
@staticmethod
def relu(x):
return (x + 5.5).relu()
@staticmethod
def exp(a):
return (a - 200).exp()
@staticmethod
def sum_red(a):
return a.sum(0)
@staticmethod
def mean_red(a):
return a.mean(0)
@staticmethod
def mean_full_red(a):
return a.mean()
@staticmethod
def eq2(a, b):
return a == (b + 5.5)
@staticmethod
def gt2(a, b):
return a + 1.2 > b
@staticmethod
def lt2(a, b):
return a + 1.2 < b