Skip to content

Commit

Permalink
Update Util
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Apr 10, 2017
1 parent 118f198 commit 4244458
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions Util/Util.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,35 +113,29 @@ def gen_spin(size=50, n=7, n_class=7, scale=4, one_hot=True):
ys[ix] = i % n_class
if not one_hot:
return xs, ys
z = []
for yy in ys:
z.append([0 if i != yy else 1 for i in range(n_class)])
return xs, np.array(z)
return xs, np.array(ys[..., None] == np.arange(n_class), dtype=np.int8)

@staticmethod
def gen_random(size=100, n_dim=2, n_class=2, one_hot=True):
xs = np.random.randn(size, n_dim)
xs = np.random.randn(size, n_dim).astype(np.float32)
ys = np.random.randint(n_class, size=size).astype(np.int8)
if not one_hot:
return xs, ys
z = []
for yy in ys:
z.append([0 if i != yy else 1 for i in range(n_class)])
return xs, np.array(z, dtype=np.int8)
return xs, np.array(ys == np.arange(n_class), dtype=np.int8)

@staticmethod
def gen_two_clusters(size=100, n_dim=2, center=0, dis=2, scale=1, one_hot=True):
center1 = (np.random.random(n_dim) + center - 0.5) * scale + dis
center2 = (np.random.random(n_dim) + center - 0.5) * scale - dis
cluster1 = (np.random.randn(size, n_dim) + center1) * scale
cluster2 = (np.random.randn(size, n_dim) + center2) * scale
data = np.vstack((cluster1, cluster2))
data = np.vstack((cluster1, cluster2)).astype(np.float32)
labels = np.array([1] * size + [0] * size)
_indices = np.random.permutation(size * 2)
data, labels = data[_indices], labels[_indices]
if not one_hot:
return data, labels
labels = np.array([[0, 1] if label == 1 else [1, 0] for label in labels])
labels = np.array([[0, 1] if label == 1 else [1, 0] for label in labels], dtype=np.int8)
return data, labels

@staticmethod
Expand Down

0 comments on commit 4244458

Please sign in to comment.