@@ -573,9 +573,9 @@ def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
573
573
pred_len , prefixes ):
574
574
"""Train an RNN model and predict the next item in the sequence."""
575
575
if is_random_iter :
576
- data_iter_fn = gb . data_iter_random
576
+ data_iter_fn = data_iter_random
577
577
else :
578
- data_iter_fn = gb . data_iter_consecutive
578
+ data_iter_fn = data_iter_consecutive
579
579
params = get_params ()
580
580
loss = gloss .SoftmaxCrossEntropyLoss ()
581
581
@@ -598,7 +598,7 @@ def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
598
598
l = loss (outputs , y ).mean ()
599
599
l .backward ()
600
600
grad_clipping (params , clipping_theta , ctx )
601
- gb . sgd (params , lr , 1 )
601
+ sgd (params , lr , 1 )
602
602
l_sum += l .asscalar () * y .size
603
603
n += y .size
604
604
@@ -623,7 +623,7 @@ def train_and_predict_rnn_gluon(model, num_hiddens, vocab_size, ctx,
623
623
624
624
for epoch in range (num_epochs ):
625
625
l_sum , n , start = 0.0 , 0 , time .time ()
626
- data_iter = gb . data_iter_consecutive (
626
+ data_iter = data_iter_consecutive (
627
627
corpus_indices , batch_size , num_steps , ctx )
628
628
state = model .begin_state (batch_size = batch_size , ctx = ctx )
629
629
for X , Y in data_iter :
@@ -635,7 +635,7 @@ def train_and_predict_rnn_gluon(model, num_hiddens, vocab_size, ctx,
635
635
l = loss (output , y ).mean ()
636
636
l .backward ()
637
637
params = [p .data () for p in model .collect_params ().values ()]
638
- gb . grad_clipping (params , clipping_theta , ctx )
638
+ grad_clipping (params , clipping_theta , ctx )
639
639
trainer .step (1 )
640
640
l_sum += l .asscalar () * y .size
641
641
n += y .size
@@ -660,7 +660,7 @@ def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
660
660
l = loss (y_hat , y ).sum ()
661
661
l .backward ()
662
662
if trainer is None :
663
- gb . sgd (params , lr , batch_size )
663
+ sgd (params , lr , batch_size )
664
664
else :
665
665
trainer .step (batch_size )
666
666
y = y .astype ('float32' )
0 commit comments