Skip to content

Commit

Permalink
Update algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Jun 20, 2017
1 parent 73cd923 commit f3f274e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
2 changes: 1 addition & 1 deletion NN/TF/Networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ def fit(self,
sub_bar = ProgressBar(max_value=train_repeat * record_period - 1, name="Iteration")
else:
sub_bar = None
self.batch_training(
self._batch_training(
x_train, y_train, batch_size, train_repeat,
self._loss, self._train_step, sub_bar, counter, *args)
self._handle_animation(
Expand Down
33 changes: 16 additions & 17 deletions Util/Bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _get_train_repeat(x, batch_size):
def _batch_work(self, *args):
pass

def batch_training(self, x, y, batch_size, train_repeat, *args):
def _batch_training(self, x, y, batch_size, train_repeat, *args):
pass

def get_metrics(self, metrics):
Expand Down Expand Up @@ -592,24 +592,24 @@ def _extra(_ax, axis0, axis1, _c, _ex0, _ex1):


class GDBase(ClassifierBase):
GDBaseTiming = Timing()

def __init__(self, **kwargs):
super(GDBase, self).__init__(**kwargs)
self._optimizer = self._model_parameters = self._model_grads = None

def _loss(self, y, y_pred, sample_weight):
pass

def _get_grads(self, x_batch, y_batch, y_pred, sample_weight_batch, *args):
pass
return 0

def _update_model_params(self):
for i, (param, grad) in enumerate(zip(self._model_parameters, self._model_grads)):
if grad is not None:
param -= self._optimizer.run(i, grad)

def batch_training(self, x, y, batch_size, train_repeat, *args, **kwargs):
@GDBaseTiming.timeit(level=1, prefix="[Core] ")
def _batch_training(self, x, y, batch_size, train_repeat, *args, **kwargs):
sample_weight, *args = args
epoch_cost = 0
epoch_loss = 0
for i in range(train_repeat):
if train_repeat != 1:
batch = np.random.permutation(len(x))[:batch_size]
Expand All @@ -618,12 +618,10 @@ def batch_training(self, x, y, batch_size, train_repeat, *args, **kwargs):
else:
x_batch, y_batch, sample_weight_batch = x, y, sample_weight
y_pred = self.predict(x_batch, get_raw_results=True, **kwargs)
local_loss = self._loss(y_batch, y_pred, sample_weight_batch)
epoch_cost += local_loss
self._get_grads(x_batch, y_batch, y_pred, sample_weight_batch, *args)
epoch_loss += self._get_grads(x_batch, y_batch, y_pred, sample_weight_batch, *args)
self._update_model_params()
self._batch_work(i, *args)
return epoch_cost / train_repeat
return epoch_loss / train_repeat


class TFClassifierBase(ClassifierBase):
Expand Down Expand Up @@ -656,7 +654,7 @@ def f1_score(y, y_pred_arg):
return 2 * tp / (2 * tp + fn + fp)

@clf_timing.timeit(level=2, prefix="[Core] ")
def batch_training(self, x, y, batch_size, train_repeat, *args):
def _batch_training(self, x, y, batch_size, train_repeat, *args):
loss, train_step, *args = args
epoch_cost = 0
for i in range(train_repeat):
Expand Down Expand Up @@ -902,14 +900,15 @@ class GDKernelBase(KernelBase, GDBase):

def __init__(self, **kwargs):
super(GDKernelBase, self).__init__(**kwargs)
self._fit_args, self._fit_args_names = [1e-3], ["tol"]
self._batch_size = kwargs.get("batch_size", 128)
self._optimizer = kwargs.get("optimizer", "Adam")
self._train_repeat = 0

def _prepare(self, sample_weight, **kwargs):
lr = kwargs.get("lr", self._params["lr"])
self._alpha = np.zeros(len(self._x), dtype=np.float32)
self._b = np.zeros(1, dtype=np.float32)
self._alpha = np.random.random(len(self._x)).astype(np.float32)
self._b = np.random.random(1).astype(np.float32)
self._model_parameters = [self._alpha, self._b]
self._optimizer = OptFactory().get_optimizer_by_name(
self._optimizer, self._model_parameters, lr, self._params["epoch"]
Expand All @@ -919,7 +918,7 @@ def _prepare(self, sample_weight, **kwargs):
def _fit(self, sample_weight, tol):
if self._train_repeat == 0:
self._train_repeat = self._get_train_repeat(self._x, self._batch_size)
l = self.batch_training(
l = self._batch_training(
self._gram, self._y, self._batch_size, self._train_repeat,
sample_weight, gram_provided=True
)
Expand All @@ -929,8 +928,8 @@ def _fit(self, sample_weight, tol):
@GDKernelBaseTiming.timeit(level=1, prefix="[API] ")
def predict(self, x, get_raw_results=False, gram_provided=False):
if not gram_provided:
x = self._kernel(np.atleast_2d(x), self._x)
y_pred = (x.dot(self._alpha) + self._b).ravel()
x = self._kernel(self._x, np.atleast_2d(x))
y_pred = (self._alpha.dot(x) + self._b).ravel()
if not get_raw_results:
return np.sign(y_pred)
return y_pred
Expand Down
4 changes: 2 additions & 2 deletions g_CNN/Networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ def fit(self, x, y, lr=None, epoch=None, batch_size=None, train_rate=None,
sub_bar = ProgressBar(max_value=train_repeat * record_period - 1, name="Iteration")
else:
sub_bar = None
self.batch_training(x_train, y_train, batch_size, train_repeat,
self._loss, self._train_step, sub_bar, *args[0])
self._batch_training(x_train, y_train, batch_size, train_repeat,
self._loss, self._train_step, sub_bar, *args[0])
if (counter + 1) % record_period == 0:
self._batch_work(*args[1])
if self.verbose >= NNVerbose.EPOCH:
Expand Down

0 comments on commit f3f274e

Please sign in to comment.