forked from minitorch/Module-0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
45 lines (35 loc) · 779 Bytes
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from dataclasses import dataclass
import random
def make_pts(N):
X = []
for i in range(N):
x_1 = random.random()
x_2 = random.random()
X.append((x_1, x_2))
return X
@dataclass
class Graph:
N: int
X: list
y: list
def simple(N):
X = make_pts(N)
y = []
for x_1, x_2 in X:
y1 = 1 if x_1 < 0.5 else 0
y.append(y1)
return Graph(N, X, y)
def split(N):
X = make_pts(N)
y = []
for x_1, x_2 in X:
y1 = 1 if x_1 < 0.2 or x_1 > 0.8 else 0
y.append(y1)
return Graph(N, X, y)
def xor(N):
X = make_pts(N)
y = []
for x_1, x_2 in X:
y1 = 1 if ((x_1 < 0.5 and x_2 > 0.5) or (x_1 > 0.5 and x_2 < 0.5)) else 0
y.append(y1)
return Graph(N, X, y)