Skip to content

Commit

Permalink
fix unitest of test_RecurrentGradientMachine, and some tiny doc update
Browse files Browse the repository at this point in the history
Change-Id: I028e402c964ca4f4431cbf8153bea4379dd4df70
  • Loading branch information
luotao1 authored and reyoung committed Sep 8, 2016
1 parent d6d9122 commit dbaabc9
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 24 deletions.
2 changes: 1 addition & 1 deletion doc/demo/imagenet_model/resnet_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ We provide both C++ and Python interfaces to extract features. The following exa

### C++ Interface

First, specify image data list in `define_py_data_sources` in the config, see example `demo/model_zoo/resnet/resnet.py`.
First, specify image data list in `define_py_data_sources2` in the config, see example `demo/model_zoo/resnet/resnet.py`.

```
train_list = 'train.list' if not is_test else None
Expand Down
2 changes: 1 addition & 1 deletion doc/demo/rec/ml_regression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ In these network, we use several api in `trainer_config_helpers
* Text Convolution Pooling Layer, `text_conv_pool
<../../ui/api/trainer_config_helpers/networks.html
#trainer_config_helpers.networks.text_conv_pool>`_
* Declare Python Data Sources, `define_py_data_sources
* Declare Python Data Sources, `define_py_data_sources2
<../../ui/api/trainer_config_helpers/data_sources.html>`_

Data Provider
Expand Down
30 changes: 18 additions & 12 deletions paddle/gserver/tests/sequenceGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,33 @@
import os
import sys

from paddle.trainer.PyDataProviderWrapper import *
from paddle.trainer.PyDataProvider2 import *

@init_hook_wrapper
def hook(obj, dict_file, **kwargs):
obj.word_dict = dict_file
obj.slots = [IndexSlot(len(obj.word_dict)), IndexSlot(3)]
obj.logger.info('dict len : %d' % (len(obj.word_dict)))
def hook(settings, dict_file, **kwargs):
settings.word_dict = dict_file
settings.input_types = [integer_value_sequence(len(settings.word_dict)),
integer_value_sequence(3)]
settings.logger.info('dict len : %d' % (len(settings.word_dict)))

@provider(use_seq=True, init_hook=hook)
def process(obj, file_name):
@provider(init_hook=hook)
def process(settings, file_name):
with open(file_name, 'r') as fdata:
for line in fdata:
label, comment = line.strip().split('\t')
label = int(''.join(label.split()))
words = comment.split()
word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict]
word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict]
yield word_slot, [label]

## for hierarchical sequence network
@provider(use_seq=True, init_hook=hook)
def process2(obj, file_name):
def hook2(settings, dict_file, **kwargs):
settings.word_dict = dict_file
settings.input_types = [integer_value_sub_sequence(len(settings.word_dict)),
integer_value_sub_sequence(3)]
settings.logger.info('dict len : %d' % (len(settings.word_dict)))

@provider(init_hook=hook2)
def process2(settings, file_name):
with open(file_name) as fdata:
label_list = []
word_slot_list = []
Expand All @@ -47,7 +53,7 @@ def process2(obj, file_name):
label,comment = line.strip().split('\t')
label = int(''.join(label.split()))
words = comment.split()
word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict]
word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict]
label_list.append([label])
word_slot_list.append(word_slot)
else:
Expand Down
10 changes: 5 additions & 5 deletions paddle/gserver/tests/sequence_layer_group.conf
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ dict_file = dict()
for line_count, line in enumerate(open(dict_path, "r")):
dict_file[line.strip()] = line_count

define_py_data_sources(train_list='gserver/tests/Sequence/train.list',
test_list=None,
module='sequenceGen',
obj='process',
args={"dict_file":dict_file})
define_py_data_sources2(train_list='gserver/tests/Sequence/train.list',
test_list=None,
module='sequenceGen',
obj='process',
args={"dict_file":dict_file})

settings(batch_size=5)
######################## network configure ################################
Expand Down
10 changes: 5 additions & 5 deletions paddle/gserver/tests/sequence_nest_layer_group.conf
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ dict_file = dict()
for line_count, line in enumerate(open(dict_path, "r")):
dict_file[line.strip()] = line_count

define_py_data_sources(train_list='gserver/tests/Sequence/train.list.nest',
test_list=None,
module='sequenceGen',
obj='process2',
args={"dict_file":dict_file})
define_py_data_sources2(train_list='gserver/tests/Sequence/train.list.nest',
test_list=None,
module='sequenceGen',
obj='process2',
args={"dict_file":dict_file})

settings(batch_size=2)
######################## network configure ################################
Expand Down

0 comments on commit dbaabc9

Please sign in to comment.