forked from ddbourgin/numpy-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtesting.py
143 lines (112 loc) · 4.21 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Utilities for writing unit tests"""
import numbers
import numpy as np
#######################################################################
# Assertions #
#######################################################################
def is_symmetric(X):
"""Check that an array `X` is symmetric along its main diagonal"""
return np.allclose(X, X.T)
def is_symmetric_positive_definite(X):
"""Check that a matrix `X` is a symmetric and positive-definite."""
if is_symmetric(X):
try:
# if matrix is symmetric, check whether the Cholesky decomposition
# (defined only for symmetric/Hermitian positive definite matrices)
# exists
np.linalg.cholesky(X)
return True
except np.linalg.LinAlgError:
return False
return False
def is_stochastic(X):
"""True if `X` contains probabilities that sum to 1 along the columns"""
msg = "Array should be stochastic along the columns"
assert len(X[X < 0]) == len(X[X > 1]) == 0, msg
assert np.allclose(np.sum(X, axis=1), np.ones(X.shape[0])), msg
return True
def is_number(a):
"""Check that a value `a` is numeric"""
return isinstance(a, numbers.Number)
def is_one_hot(x):
"""Return True if array `x` is a binary array with a single 1"""
msg = "Matrix should be one-hot binary"
assert np.array_equal(x, x.astype(bool)), msg
assert np.allclose(np.sum(x, axis=1), np.ones(x.shape[0])), msg
return True
def is_binary(x):
"""Return True if array `x` consists only of binary values"""
msg = "Matrix must be binary"
assert np.array_equal(x, x.astype(bool)), msg
return True
#######################################################################
# Data Generators #
#######################################################################
def random_one_hot_matrix(n_examples, n_classes):
"""Create a random one-hot matrix of shape (`n_examples`, `n_classes`)"""
X = np.eye(n_classes)
X = X[np.random.choice(n_classes, n_examples)]
return X
def random_stochastic_matrix(n_examples, n_classes):
"""Create a random stochastic matrix of shape (`n_examples`, `n_classes`)"""
X = np.random.rand(n_examples, n_classes)
X /= X.sum(axis=1, keepdims=True)
return X
def random_tensor(shape, standardize=False):
"""
Create a random real-valued tensor of shape `shape`. If `standardize` is
True, ensure each column has mean 0 and std 1.
"""
offset = np.random.randint(-300, 300, shape)
X = np.random.rand(*shape) + offset
if standardize:
eps = np.finfo(float).eps
X = (X - X.mean(axis=0)) / (X.std(axis=0) + eps)
return X
def random_binary_tensor(shape, sparsity=0.5):
"""
Create a random binary tensor of shape `shape`. `sparsity` is a value
between 0 and 1 controlling the ratio of 0s to 1s in the output tensor.
"""
return (np.random.rand(*shape) >= (1 - sparsity)).astype(float)
def random_paragraph(n_words, vocab=None):
"""
Generate a random paragraph consisting of `n_words` words. If `vocab` is
not None, words will be drawn at random from this list. Otherwise, words
will be sampled uniformly from a collection of 26 Latin words.
"""
if vocab is None:
vocab = [
"at",
"stet",
"accusam",
"aliquyam",
"clita",
"lorem",
"ipsum",
"dolor",
"dolore",
"dolores",
"sit",
"amet",
"consetetur",
"sadipscing",
"elitr",
"sed",
"diam",
"nonumy",
"eirmod",
"duo",
"ea",
"eos",
"erat",
"est",
"et",
"gubergren",
]
return [np.random.choice(vocab) for _ in range(n_words)]
#######################################################################
# Custom Warnings #
#######################################################################
class DependencyWarning(RuntimeWarning):
pass