Skip to content

Commit

Permalink
pep8 + working tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rflamary committed May 30, 2018
1 parent fde3d59 commit 06eabe7
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions ot/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import scipy as sp
from .utils import check_random_state


def get_1D_gauss(n, m, s):
Expand Down Expand Up @@ -60,7 +61,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
n samples drawn from N(m,sigma)
"""

generator = check_random_state(random_state)
if np.isscalar(sigma):
sigma = np.array([sigma, ])
Expand Down Expand Up @@ -98,9 +99,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
labels of the samples
"""

generator = check_random_state(random_state)

if dataset.lower() == '3gauss':
y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
x = np.zeros((n, 2))
Expand Down Expand Up @@ -140,8 +141,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
n2 = np.sum(y == 2)
x = np.zeros((n, 2))

x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz,random_state=generator)
x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz,random_state=generator)
x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)

x = x.dot(rot)

Expand Down

0 comments on commit 06eabe7

Please sign in to comment.