Skip to content

Commit

Permalink
fit_a_line v2 api
Browse files Browse the repository at this point in the history
  • Loading branch information
qingqing01 committed Mar 2, 2017
1 parent a2a5f4a commit 465878a
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 1 deletion.
59 changes: 59 additions & 0 deletions demo/introduction/api_train_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing


def main():
# init
paddle.init(use_gpu=False, trainer_count=1)

# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'),
size=1,
act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b'))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.regression_cost(input=y_predict, label=y)

# create parameters
parameters = paddle.parameters.create(cost)

# create optimizer
optimizer = paddle.optimizer.Momentum(momentum=0)

trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)

# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)

if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.reader.batched(
uci_housing.test(), batch_size=2),
reader_dict={'x': 0,
'y': 1})
if event.pass_id % 10 == 0:
print "Test %d, Cost %f, %s" % (event.pass_id, event.cost,
result.metrics)

# training
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(
uci_housing.train(), buf_size=500),
batch_size=2),
reader_dict={'x': 0,
'y': 1},
event_handler=event_handler,
num_passes=30)


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion python/paddle/v2/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mnist
import uci_housing

__all__ = ['mnist']
__all__ = ['mnist', 'uci_housing']
86 changes: 86 additions & 0 deletions python/paddle/v2/dataset/uci_housing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
from common import download

__all__ = ['train', 'test']

URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data'
MD5 = 'd4accdce7a25600298819f8e28e8d593'
feature_names = [
'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',
'PTRATIO', 'B', 'LSTAT'
]

UCI_TRAIN_DATA = None
UCI_TEST_DATA = None


def feature_range(maximums, minimums):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
feature_num = len(maximums)
ax.bar(range(feature_num), maximums - minimums, color='r', align='center')
ax.set_title('feature scale')
plt.xticks(range(feature_num), feature_names)
plt.xlim([-1, feature_num])
fig.set_figheight(6)
fig.set_figwidth(10)
if not os.path.exists('./image'):
os.makedirs('./image')
fig.savefig('image/ranges.png', dpi=48)
plt.close(fig)


def load_data(filename, feature_num=14, ratio=0.8):
global UCI_TRAIN_DATA, UCI_TEST_DATA
if UCI_TRAIN_DATA is not None and UCI_TEST_DATA is not None:
return

data = np.fromfile(filename, sep=' ')
data = data.reshape(data.shape[0] / feature_num, feature_num)
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum(
axis=0) / data.shape[0]
feature_range(maximums[:-1], minimums[:-1])
for i in xrange(feature_num - 1):
data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i])
offset = int(data.shape[0] * ratio)
UCI_TRAIN_DATA = data[:offset]
UCI_TEST_DATA = data[offset:]


def train():
global UCI_TRAIN_DATA
load_data(download(URL, 'uci_housing', MD5))

def reader():
for d in UCI_TRAIN_DATA:
yield d[:-1], d[-1:]

return reader


def test():
global UCI_TEST_DATA
load_data(download(URL, 'uci_housing', MD5))

def reader():
for d in UCI_TEST_DATA:
yield d[:-1], d[-1:]

return reader

0 comments on commit 465878a

Please sign in to comment.