diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000000000..6bbd46d0ff9565 --- /dev/null +++ b/.clang-format @@ -0,0 +1,29 @@ +# This file is used by clang-format to autoformat paddle source code +# +# The clang-format is part of llvm toolchain. +# It need to install llvm and clang to format source code style. +# +# The basic usage is, +# clang-format -i -style=file PATH/TO/SOURCE/CODE +# +# The -style=file implicit use ".clang-format" file located in one of +# parent directory. +# The -i means inplace change. +# +# The document of clang-format is +# http://clang.llvm.org/docs/ClangFormat.html +# http://clang.llvm.org/docs/ClangFormatStyleOptions.html +# +# TODO(yuyang18): Add python and other language code style +--- +Language: Cpp +BasedOnStyle: Google +IndentWidth: 2 +TabWidth: 2 +ContinuationIndentWidth: 4 +AccessModifierOffset: -2 # The private/protected/public has no indent in class +PointerAlignment: Left # int* p/int& p, not int *p/int &p +Standard: Cpp11 +AllowAllParametersOfDeclarationOnNextLine: true +... + diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000000000..00368ede67d3d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.DS_Store +build/ diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000000000..cb991cc9cfccf5 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,156 @@ +cmake_minimum_required(VERSION 2.8) + +project(paddle CXX C) +set(PADDLE_MAJOR_VERSION 0) +set(PADDLE_MINOR_VERSION 8) +set(PADDLE_PATCH_VERSION 0b) +set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION}) + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake") +set(PROJ_ROOT ${CMAKE_SOURCE_DIR}) +include(package) +include(swig) +find_package(CUDA QUIET) +find_package(Protobuf REQUIRED) +find_package(PythonLibs 2.7 REQUIRED) +find_package(PythonInterp 2.7 REQUIRED) +find_package(NumPy) +find_package(Threads REQUIRED) +find_package(Glog) +find_package(Gflags QUIET) +find_package(GTest) +find_package(Sphinx) +find_package(Doxygen) +include(cblas) +find_program(M4_EXECUTABLE m4) +###################### Configurations ########################### +option(WITH_DSO "Compile PaddlePaddle with dynamic linked libraries" ON) +option(WITH_GPU "Compile PaddlePaddle with gpu" ${CUDA_FOUND}) +option(WITH_DOUBLE "Compile PaddlePaddle with double precision, otherwise use single precision" OFF) +option(WITH_AVX "Compile PaddlePaddle with avx instructs" ON) # TODO(yuyang18): Check AVX is supported or not as default value +option(WITH_PYTHON "Compile PaddlePaddle with python interpretor" ON) +option(WITH_STYLE_CHECK "Style Check for PaddlePaddle" ${PYTHONINTERP_FOUND}) +option(WITH_RDMA "Compile PaddlePaddle with rdma support" OFF) +option(WITH_GLOG "Compile PaddlePaddle use glog, otherwise use a log implement internally" ${LIBGLOG_FOUND}) +option(WITH_GFLAGS "Compile PaddlePaddle use gflags, otherwise use a flag implement internally" ${GFLAGS_FOUND}) +option(WITH_TIMER "Compile PaddlePaddle use timer" OFF) +option(WITH_TESTING "Compile and run unittest for PaddlePaddle" ${GTEST_FOUND}) +option(WITH_DOC "Compile PaddlePaddle with documentation" OFF) +option(WITH_DOC_CN "Compile PaddlePaddle with Chinese documentation" OFF) +option(WITH_SWIG_PY "Compile PaddlePaddle with py PaddlePaddle predict api" ${SWIG_FOUND}) +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING + "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" + FORCE) +endif() + +include(enableCXX11) +include(cpplint) +include(ccache) +include(util) +include(flags) +include(cudnn) +include(FindPythonModule) +include(check_packages) + +# add PaddlePaddle version +if(DEFINED ENV{PADDLE_VERSION}) + add_definitions(-DPADDLE_VERSION=\"$ENV{PADDLE_VERSION}\") +else() + if(EXISTS ${PROJ_ROOT}/.svn/) + find_package(Subversion REQUIRED) + if(SUBVERSION_FOUND) + Subversion_WC_INFO(${PROJ_ROOT} Project) + add_definitions(-DPADDLE_VERSION=${Project_WC_REVISION}) + endif() + endif() +endif() + + +if(NOT WITH_GPU) + add_definitions(-DPADDLE_ONLY_CPU) + add_definitions(-DHPPL_STUB_FUNC) + list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) +else() + # TODO(yuyang18): Change it to remove std=c++11 in cuda compile. + set(CUDA_PROPAGATE_HOST_FLAGS OFF) + if(NOT CUDNN_FOUND) + message(FATAL_ERROR "Paddle need cudnn to compile") + endif() + + if(WITH_DSO) + set(CUDA_LIBRARIES "") + add_definitions(-DPADDLE_USE_DSO) + endif(WITH_DSO) + + # Include cuda and cudnn + include_directories(${CUDNN_INCLUDE_DIR}) + include_directories(${CUDA_TOOLKIT_INCLUDE}) +endif(NOT WITH_GPU) + +if(WITH_DOUBLE) + add_definitions(-DPADDLE_TYPE_DOUBLE -DHPPL_TYPE_DOUBLE) + set(ACCURACY double) +else(WITH_DOUBLE) + set(ACCURACY float) +endif(WITH_DOUBLE) + +if(NOT WITH_TIMER) + add_definitions(-DPADDLE_DISABLE_TIMER) +endif(NOT WITH_TIMER) + +if(WITH_AVX) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx") +else(WITH_AVX) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse3") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse3") +endif(WITH_AVX) + +if(WITH_PYTHON) + include_directories(${PYTHON_INCLUDE_DIR}) + include_directories(${PYTHON_NUMPY_INCLUDE_DIR}) +else(WITH_PYTHON) + add_definitions(-DPADDLE_NO_PYTHON) +endif(WITH_PYTHON) + +if(NOT WITH_RDMA) + add_definitions(-DPADDLE_DISABLE_RDMA) +endif() + +if(WITH_GLOG) + add_definitions(-DPADDLE_USE_GLOG) +endif() + +if(WITH_GFLAGS) + add_definitions(-DPADDLE_USE_GFLAGS) + add_definitions(-DGFLAGS_NS=${GFLAGS_NAMESPACE}) + include_directories(${GFLAGS_INCLUDE_DIRS}) +endif() + +if(WITH_TESTING) + enable_testing() + include_directories(${GTEST_INCLUDE_DIRS}) +endif() + +include_directories("${CBLAS_INC_DIR}") +include_directories("${PROJ_ROOT}") +include_directories("${PROJ_ROOT}/paddle/cuda/include") +include_directories(${PROTOBUF_INCLUDE_DIRS}) +include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") +if(EXISTS "${PROJ_ROOT}/paddle/internals/CMakeLists.txt") + set(PADDLE_WITH_INTERNAL ON) + include(paddle/internals/CMakeLists.txt) +else() + set(PADDLE_WITH_INTERNAL OFF) + set(INTERNAL_PROTO_PATH "") +endif() +add_subdirectory(proto) +add_subdirectory(paddle) +add_subdirectory(python) +if(WITH_DOC) + add_subdirectory(doc) +endif() +if(WITH_DOC_CN) + add_subdirectory(doc_cn) +endif() diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000000000..2ff3140db0d702 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ +Copyright (c) 2016 Baidu, Inc. All Rights Reserved + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 00000000000000..e7d337866c4ba9 --- /dev/null +++ b/README.md @@ -0,0 +1,84 @@ +# PaddlePaddle + +[![Documentation Status](https://readthedocs.org/projects/ctcspeechrecognition/badge/?version=latest)](http://ctcspeechrecognition.readthedocs.io/en/latest/?badge=latest) + +PaddlePaddle (PArallel Distributed Deep LEarning) is an easy-to-use, +efficient, flexible and scalable deep learning platform, which is originally +developed by Baidu scientists and engineers for the purpose of applying deep +learning to many products at Baidu. + +## Features + +- **Flexibility** + + PaddlePaddle supports a wide range of neural network architectures and + optimization algorithms. It is easy to configure complex models such as + neural machine translation model with attention mechanism or complex memory + connection. + +- **Efficiency** + + In order to unleash the power of heterogeneous computing resource, + optimization occurs at different levels of PaddlePaddle, including + computing, memory, architecture and communication. The following are some + examples: + 1. Optimized math operations through SSE/AVX intrinsics, BLAS libraries + (e.g. MKL, ATLAS, cuBLAS) or customized CPU/GPU kernels. + 2. Highly optimized recurrent networks which can handle **variable-length** + sequence without padding. + 3. Optimized local and distributed training for models with high dimensional + sparse data. + +- **Scalability** + + With PaddlePaddle, it is easy to use many CPUs/GPUs and machines to speed + up your training. PaddlePaddle can achieve high throughput and performance + via optimized communication. + +- **Connected to Products** + + In addition, PaddlePaddle is also designed to be easily deployable. At Baidu, + PaddlePaddle has been deployed into products or service with a vast number + of users, including ad click-through rate (CTR) prediction, large-scale image + classification, optical character recognition(OCR), search ranking, computer + virus detection, recommendation, etc. It is widely utilized in products at + Baidu and it has achieved a significant impact. We hope you can also exploit + the capability of PaddlePaddle to make a huge impact for your product. + +## Installation +See [installation guide]() to build and install from the source code or install +the Docker Image. + +## Documentation +- [Quick Start]()
+ You can follow the quick start tutorial to learn how use PaddlePaddle + step-by-step. + +- [Example and Demo]()
+ We provide five demos, including: image classification, sentiment analysis, + sequence to sequence model, recommendation, semantic role labelling. + +- [Distributed Training]()
+ This system supports training deep learning models on multiple machines + with data parallelism. + +- [Python API]()
+ PaddlePaddle supports using either Python interface or C++ to build your + system. We also use SWIG to wrap C++ source code to create a user friendly + interface for Python. You can also use SWIG to create interface for your + favorite programming language. + +- [How to Contribute]()
+ We sincerely appreciate your interest and contributions. If you’d like to + contribute, please read the contribution guide. + +- [Source Code Documents]()
+ +## Ask Questions + +If you want to ask questions and discuss about methods and models, welcome +to send email to paddle-dev@baidu.com. Framework development discussions and +bug reports are collected on [Issues](https://github.com/paddle/paddle/issues). + +## Copyright and License +PaddlePaddle is provided under the [Apache-2.0 license](LICENSE). diff --git a/authors b/authors new file mode 100644 index 00000000000000..ab4d3118ff1f7e --- /dev/null +++ b/authors @@ -0,0 +1,53 @@ +Cao, Ying +Cheng, Yujuan +Dang, Qingqing +Dong, Tengfei +Du, Dalong +Feng, Shouqiang +Gao, Haoyuan +Han, Baochang +Han, Jinchen +Hao, Nanyu +He, Daoyuan +He, Zhengyan +Hou, Jue +Huang, Chang +Huang, Zhiheng +Hu, Na +Kong, Qi +Liao, Gang +Li, Bo +Li, Jiajie +Li, Jing +Li, Lei +Li, Peng +Liu, Sheng +Liu, Yuan +Li, Yuze +Luo, Heng +Luo, Tao +Lyu, Qin +Mao, Hongyue +Qian, Xiaojun +Qi, Jun +Qin, Duohao +Shen, Guolong +Shi, Guangchuan +Song, Xiang +Wang, Jiang +Wang, Yanfei +Wang, Yong +Weng, Renliang +Xu, Tianbing +Xu, Wei +Xu, Xingyu +Yan, Chong +Yan, Chunwei +Yang, Yi +Yu, Yang +Yu, Yinan +Zhang, Jian +Zhang, Ruiqing +Zhang, Weide +Zhao, Liang +Zhou, Jie diff --git a/cmake/FindGflags.cmake b/cmake/FindGflags.cmake new file mode 100644 index 00000000000000..6587089ba382dc --- /dev/null +++ b/cmake/FindGflags.cmake @@ -0,0 +1,582 @@ +# Ceres Solver - A fast non-linear least squares minimizer +# Copyright 2015 Google Inc. All rights reserved. +# http://ceres-solver.org/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of Google Inc. nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: alexs.mac@gmail.com (Alex Stewart) +# + +# FindGflags.cmake - Find Google gflags logging library. +# +# This module will attempt to find gflags, either via an exported CMake +# configuration (generated by gflags >= 2.1 which are built with CMake), or +# by performing a standard search for all gflags components. The order of +# precedence for these two methods of finding gflags is controlled by: +# GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION. +# +# This module defines the following variables: +# +# GFLAGS_FOUND: TRUE iff gflags is found. +# GFLAGS_INCLUDE_DIRS: Include directories for gflags. +# GFLAGS_LIBRARIES: Libraries required to link gflags. +# GFLAGS_NAMESPACE: The namespace in which gflags is defined. In versions of +# gflags < 2.1, this was google, for versions >= 2.1 it is +# by default gflags, although can be configured when building +# gflags to be something else (i.e. google for legacy +# compatibility). +# +# The following variables control the behaviour of this module when an exported +# gflags CMake configuration is not found. +# +# GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION: TRUE/FALSE, iff TRUE then +# then prefer using an exported CMake configuration +# generated by gflags >= 2.1 over searching for the +# gflags components manually. Otherwise (FALSE) +# ignore any exported gflags CMake configurations and +# always perform a manual search for the components. +# Default: TRUE iff user does not define this variable +# before we are called, and does NOT specify either +# GFLAGS_INCLUDE_DIR_HINTS or GFLAGS_LIBRARY_DIR_HINTS +# otherwise FALSE. +# GFLAGS_INCLUDE_DIR_HINTS: List of additional directories in which to +# search for gflags includes, e.g: /timbuktu/include. +# GFLAGS_LIBRARY_DIR_HINTS: List of additional directories in which to +# search for gflags libraries, e.g: /timbuktu/lib. +# +# The following variables are also defined by this module, but in line with +# CMake recommended FindPackage() module style should NOT be referenced directly +# by callers (use the plural variables detailed above instead). These variables +# do however affect the behaviour of the module via FIND_[PATH/LIBRARY]() which +# are NOT re-called (i.e. search for library is not repeated) if these variables +# are set with valid values _in the CMake cache_. This means that if these +# variables are set directly in the cache, either by the user in the CMake GUI, +# or by the user passing -DVAR=VALUE directives to CMake when called (which +# explicitly defines a cache variable), then they will be used verbatim, +# bypassing the HINTS variables and other hard-coded search locations. +# +# GFLAGS_INCLUDE_DIR: Include directory for gflags, not including the +# include directory of any dependencies. +# GFLAGS_LIBRARY: gflags library, not including the libraries of any +# dependencies. + +# Reset CALLERS_CMAKE_FIND_LIBRARY_PREFIXES to its value when FindGflags was +# invoked, necessary for MSVC. +macro(GFLAGS_RESET_FIND_LIBRARY_PREFIX) + if (MSVC) + set(CMAKE_FIND_LIBRARY_PREFIXES "${CALLERS_CMAKE_FIND_LIBRARY_PREFIXES}") + endif (MSVC) +endmacro(GFLAGS_RESET_FIND_LIBRARY_PREFIX) + +# Called if we failed to find gflags or any of it's required dependencies, +# unsets all public (designed to be used externally) variables and reports +# error message at priority depending upon [REQUIRED/QUIET/] argument. +macro(GFLAGS_REPORT_NOT_FOUND REASON_MSG) + unset(GFLAGS_FOUND) + unset(GFLAGS_INCLUDE_DIRS) + unset(GFLAGS_LIBRARIES) + # Do not use unset, as we want to keep GFLAGS_NAMESPACE in the cache, + # but simply clear its value. + set(GFLAGS_NAMESPACE "" CACHE STRING + "gflags namespace (google or gflags)" FORCE) + + # Make results of search visible in the CMake GUI if gflags has not + # been found so that user does not have to toggle to advanced view. + mark_as_advanced(CLEAR GFLAGS_INCLUDE_DIR + GFLAGS_LIBRARY + GFLAGS_NAMESPACE) + + gflags_reset_find_library_prefix() + + # Note _FIND_[REQUIRED/QUIETLY] variables defined by FindPackage() + # use the camelcase library name, not uppercase. + if (Gflags_FIND_QUIETLY) + message(STATUS "Failed to find gflags - " ${REASON_MSG} ${ARGN}) + elseif (Gflags_FIND_REQUIRED) + message(FATAL_ERROR "Failed to find gflags - " ${REASON_MSG} ${ARGN}) + else() + # Neither QUIETLY nor REQUIRED, use no priority which emits a message + # but continues configuration and allows generation. + message("-- Failed to find gflags - " ${REASON_MSG} ${ARGN}) + endif () + return() +endmacro(GFLAGS_REPORT_NOT_FOUND) + +# Verify that all variable names passed as arguments are defined (can be empty +# but must be defined) or raise a fatal error. +macro(GFLAGS_CHECK_VARS_DEFINED) + foreach(CHECK_VAR ${ARGN}) + if (NOT DEFINED ${CHECK_VAR}) + message(FATAL_ERROR "Ceres Bug: ${CHECK_VAR} is not defined.") + endif() + endforeach() +endmacro(GFLAGS_CHECK_VARS_DEFINED) + +# Use check_cxx_source_compiles() to compile trivial test programs to determine +# the gflags namespace. This works on all OSs except Windows. If using Visual +# Studio, it fails because msbuild forces check_cxx_source_compiles() to use +# CMAKE_BUILD_TYPE=Debug for the test project, which usually breaks detection +# because MSVC requires that the test project use the same build type as gflags, +# which would normally be built in Release. +# +# Defines: GFLAGS_NAMESPACE in the caller's scope with the detected namespace, +# which is blank (empty string, will test FALSE is CMake conditionals) +# if detection failed. +function(GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_TRY_COMPILE) + # Verify that all required variables are defined. + gflags_check_vars_defined( + GFLAGS_INCLUDE_DIR GFLAGS_LIBRARY) + # Ensure that GFLAGS_NAMESPACE is always unset on completion unless + # we explicitly set if after having the correct namespace. + set(GFLAGS_NAMESPACE "" PARENT_SCOPE) + + include(CheckCXXSourceCompiles) + # Setup include path & link library for gflags for CHECK_CXX_SOURCE_COMPILES. + set(CMAKE_REQUIRED_INCLUDES ${GFLAGS_INCLUDE_DIR}) + set(CMAKE_REQUIRED_LIBRARIES ${GFLAGS_LIBRARY} ${GFLAGS_LINK_LIBRARIES}) + # First try the (older) google namespace. Note that the output variable + # MUST be unique to the build type as otherwise the test is not repeated as + # it is assumed to have already been performed. + check_cxx_source_compiles( + "#include + int main(int argc, char * argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + return 0; + }" + GFLAGS_IN_GOOGLE_NAMESPACE) + if (GFLAGS_IN_GOOGLE_NAMESPACE) + set(GFLAGS_NAMESPACE google PARENT_SCOPE) + return() + endif() + + # Try (newer) gflags namespace instead. Note that the output variable + # MUST be unique to the build type as otherwise the test is not repeated as + # it is assumed to have already been performed. + set(CMAKE_REQUIRED_INCLUDES ${GFLAGS_INCLUDE_DIR}) + set(CMAKE_REQUIRED_LIBRARIES ${GFLAGS_LIBRARY} ${GFLAGS_LINK_LIBRARIES}) + check_cxx_source_compiles( + "#include + int main(int argc, char * argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + return 0; + }" + GFLAGS_IN_GFLAGS_NAMESPACE) + if (GFLAGS_IN_GFLAGS_NAMESPACE) + set(GFLAGS_NAMESPACE gflags PARENT_SCOPE) + return() + endif (GFLAGS_IN_GFLAGS_NAMESPACE) +endfunction(GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_TRY_COMPILE) + +# Use regex on the gflags headers to attempt to determine the gflags namespace. +# Checks both gflags.h (contained namespace on versions < 2.1.2) and +# gflags_declare.h, which contains the namespace on versions >= 2.1.2. +# In general, this method should only be used when +# GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_TRY_COMPILE() cannot be used, or has +# failed. +# +# Defines: GFLAGS_NAMESPACE in the caller's scope with the detected namespace, +# which is blank (empty string, will test FALSE is CMake conditionals) +# if detection failed. +function(GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_REGEX) + # Verify that all required variables are defined. + gflags_check_vars_defined(GFLAGS_INCLUDE_DIR) + # Ensure that GFLAGS_NAMESPACE is always undefined on completion unless + # we explicitly set if after having the correct namespace. + set(GFLAGS_NAMESPACE "" PARENT_SCOPE) + + # Scan gflags.h to identify what namespace gflags was built with. On + # versions of gflags < 2.1.2, gflags.h was configured with the namespace + # directly, on >= 2.1.2, gflags.h uses the GFLAGS_NAMESPACE #define which + # is defined in gflags_declare.h, we try each location in turn. + set(GFLAGS_HEADER_FILE ${GFLAGS_INCLUDE_DIR}/gflags/gflags.h) + if (NOT EXISTS ${GFLAGS_HEADER_FILE}) + gflags_report_not_found( + "Could not find file: ${GFLAGS_HEADER_FILE} " + "containing namespace information in gflags install located at: " + "${GFLAGS_INCLUDE_DIR}.") + endif() + file(READ ${GFLAGS_HEADER_FILE} GFLAGS_HEADER_FILE_CONTENTS) + + string(REGEX MATCH "namespace [A-Za-z]+" + GFLAGS_NAMESPACE "${GFLAGS_HEADER_FILE_CONTENTS}") + string(REGEX REPLACE "namespace ([A-Za-z]+)" "\\1" + GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}") + + if (NOT GFLAGS_NAMESPACE) + gflags_report_not_found( + "Failed to extract gflags namespace from header file: " + "${GFLAGS_HEADER_FILE}.") + endif (NOT GFLAGS_NAMESPACE) + + if (GFLAGS_NAMESPACE STREQUAL "google" OR + GFLAGS_NAMESPACE STREQUAL "gflags") + # Found valid gflags namespace from gflags.h. + set(GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}" PARENT_SCOPE) + return() + endif() + + # Failed to find gflags namespace from gflags.h, gflags is likely a new + # version, check gflags_declare.h, which in newer versions (>= 2.1.2) contains + # the GFLAGS_NAMESPACE #define, which is then referenced in gflags.h. + set(GFLAGS_DECLARE_FILE ${GFLAGS_INCLUDE_DIR}/gflags/gflags_declare.h) + if (NOT EXISTS ${GFLAGS_DECLARE_FILE}) + gflags_report_not_found( + "Could not find file: ${GFLAGS_DECLARE_FILE} " + "containing namespace information in gflags install located at: " + "${GFLAGS_INCLUDE_DIR}.") + endif() + file(READ ${GFLAGS_DECLARE_FILE} GFLAGS_DECLARE_FILE_CONTENTS) + + string(REGEX MATCH "#define GFLAGS_NAMESPACE [A-Za-z]+" + GFLAGS_NAMESPACE "${GFLAGS_DECLARE_FILE_CONTENTS}") + string(REGEX REPLACE "#define GFLAGS_NAMESPACE ([A-Za-z]+)" "\\1" + GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}") + + if (NOT GFLAGS_NAMESPACE) + gflags_report_not_found( + "Failed to extract gflags namespace from declare file: " + "${GFLAGS_DECLARE_FILE}.") + endif (NOT GFLAGS_NAMESPACE) + + if (GFLAGS_NAMESPACE STREQUAL "google" OR + GFLAGS_NAMESPACE STREQUAL "gflags") + # Found valid gflags namespace from gflags.h. + set(GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}" PARENT_SCOPE) + return() + endif() +endfunction(GFLAGS_CHECK_GFLAGS_NAMESPACE_USING_REGEX) + +# ----------------------------------------------------------------- +# By default, if the user has expressed no preference for using an exported +# gflags CMake configuration over performing a search for the installed +# components, and has not specified any hints for the search locations, then +# prefer a gflags exported configuration if available. +if (NOT DEFINED GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION + AND NOT GFLAGS_INCLUDE_DIR_HINTS + AND NOT GFLAGS_LIBRARY_DIR_HINTS) + message(STATUS "No preference for use of exported gflags CMake configuration " + "set, and no hints for include/library directories provided. " + "Defaulting to preferring an installed/exported gflags CMake configuration " + "if available.") + set(GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION TRUE) +endif() + +if (GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION) + # Try to find an exported CMake configuration for gflags, as generated by + # gflags versions >= 2.1. + # + # We search twice, s/t we can invert the ordering of precedence used by + # find_package() for exported package build directories, and installed + # packages (found via CMAKE_SYSTEM_PREFIX_PATH), listed as items 6) and 7) + # respectively in [1]. + # + # By default, exported build directories are (in theory) detected first, and + # this is usually the case on Windows. However, on OS X & Linux, the install + # path (/usr/local) is typically present in the PATH environment variable + # which is checked in item 4) in [1] (i.e. before both of the above, unless + # NO_SYSTEM_ENVIRONMENT_PATH is passed). As such on those OSs installed + # packages are usually detected in preference to exported package build + # directories. + # + # To ensure a more consistent response across all OSs, and as users usually + # want to prefer an installed version of a package over a locally built one + # where both exist (esp. as the exported build directory might be removed + # after installation), we first search with NO_CMAKE_PACKAGE_REGISTRY which + # means any build directories exported by the user are ignored, and thus + # installed directories are preferred. If this fails to find the package + # we then research again, but without NO_CMAKE_PACKAGE_REGISTRY, so any + # exported build directories will now be detected. + # + # To prevent confusion on Windows, we also pass NO_CMAKE_BUILDS_PATH (which + # is item 5) in [1]), to not preferentially use projects that were built + # recently with the CMake GUI to ensure that we always prefer an installed + # version if available. + # + # [1] http://www.cmake.org/cmake/help/v2.8.11/cmake.html#command:find_package + find_package(gflags QUIET + NO_MODULE + NO_CMAKE_PACKAGE_REGISTRY + NO_CMAKE_BUILDS_PATH) + if (gflags_FOUND) + message(STATUS "Found installed version of gflags: ${gflags_DIR}") + else(gflags_FOUND) + # Failed to find an installed version of gflags, repeat search allowing + # exported build directories. + message(STATUS "Failed to find installed gflags CMake configuration, " + "searching for gflags build directories exported with CMake.") + # Again pass NO_CMAKE_BUILDS_PATH, as we know that gflags is exported and + # do not want to treat projects built with the CMake GUI preferentially. + find_package(gflags QUIET + NO_MODULE + NO_CMAKE_BUILDS_PATH) + if (gflags_FOUND) + message(STATUS "Found exported gflags build directory: ${gflags_DIR}") + endif(gflags_FOUND) + endif(gflags_FOUND) + + set(FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION ${gflags_FOUND}) + + # gflags v2.1 - 2.1.2 shipped with a bug in their gflags-config.cmake [1] + # whereby gflags_LIBRARIES = "gflags", but there was no imported target + # called "gflags", they were called: gflags[_nothreads]-[static/shared]. + # As this causes linker errors when gflags is not installed in a location + # on the current library paths, detect if this problem is present and + # fix it. + # + # [1] https://github.com/gflags/gflags/issues/110 + if (gflags_FOUND) + # NOTE: This is not written as additional conditions in the outer + # if (gflags_FOUND) as the NOT TARGET "${gflags_LIBRARIES}" + # condition causes problems if gflags is not found. + if (${gflags_VERSION} VERSION_LESS 2.1.3 AND + NOT TARGET "${gflags_LIBRARIES}") + message(STATUS "Detected broken gflags install in: ${gflags_DIR}, " + "version: ${gflags_VERSION} <= 2.1.2 which defines gflags_LIBRARIES = " + "${gflags_LIBRARIES} which is not an imported CMake target, see: " + "https://github.com/gflags/gflags/issues/110. Attempting to fix by " + "detecting correct gflags target.") + # Ordering here expresses preference for detection, specifically we do not + # want to use the _nothreads variants if the full library is available. + list(APPEND CHECK_GFLAGS_IMPORTED_TARGET_NAMES + gflags-shared gflags-static + gflags_nothreads-shared gflags_nothreads-static) + foreach(CHECK_GFLAGS_TARGET ${CHECK_GFLAGS_IMPORTED_TARGET_NAMES}) + if (TARGET ${CHECK_GFLAGS_TARGET}) + message(STATUS "Found valid gflags target: ${CHECK_GFLAGS_TARGET}, " + "updating gflags_LIBRARIES.") + set(gflags_LIBRARIES ${CHECK_GFLAGS_TARGET}) + break() + endif() + endforeach() + if (NOT TARGET ${gflags_LIBRARIES}) + message(STATUS "Failed to fix detected broken gflags install in: " + "${gflags_DIR}, version: ${gflags_VERSION} <= 2.1.2, none of the " + "imported targets for gflags: ${CHECK_GFLAGS_IMPORTED_TARGET_NAMES} " + "are defined. Will continue with a manual search for gflags " + "components. We recommend you build/install a version of gflags > " + "2.1.2 (or master).") + set(FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION FALSE) + endif() + endif() + endif() + + if (FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION) + message(STATUS "Detected gflags version: ${gflags_VERSION}") + set(GFLAGS_FOUND ${gflags_FOUND}) + set(GFLAGS_INCLUDE_DIR ${gflags_INCLUDE_DIR}) + set(GFLAGS_LIBRARY ${gflags_LIBRARIES}) + + # gflags does not export the namespace in their CMake configuration, so + # use our function to determine what it should be, as it can be either + # gflags or google dependent upon version & configuration. + # + # NOTE: We use the regex method to determine the namespace here, as + # check_cxx_source_compiles() will not use imported targets, which + # is what gflags will be in this case. + gflags_check_gflags_namespace_using_regex() + + if (NOT GFLAGS_NAMESPACE) + gflags_report_not_found( + "Failed to determine gflags namespace using regex for gflags " + "version: ${gflags_VERSION} exported here: ${gflags_DIR} using CMake.") + endif (NOT GFLAGS_NAMESPACE) + else (FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION) + message(STATUS "Failed to find an installed/exported CMake configuration " + "for gflags, will perform search for installed gflags components.") + endif (FOUND_INSTALLED_GFLAGS_CMAKE_CONFIGURATION) +endif(GFLAGS_PREFER_EXPORTED_GFLAGS_CMAKE_CONFIGURATION) + +if (NOT GFLAGS_FOUND) + # Either failed to find an exported gflags CMake configuration, or user + # told us not to use one. Perform a manual search for all gflags components. + + # Handle possible presence of lib prefix for libraries on MSVC, see + # also GFLAGS_RESET_FIND_LIBRARY_PREFIX(). + if (MSVC) + # Preserve the caller's original values for CMAKE_FIND_LIBRARY_PREFIXES + # s/t we can set it back before returning. + set(CALLERS_CMAKE_FIND_LIBRARY_PREFIXES "${CMAKE_FIND_LIBRARY_PREFIXES}") + # The empty string in this list is important, it represents the case when + # the libraries have no prefix (shared libraries / DLLs). + set(CMAKE_FIND_LIBRARY_PREFIXES "lib" "" "${CMAKE_FIND_LIBRARY_PREFIXES}") + endif (MSVC) + + # Search user-installed locations first, so that we prefer user installs + # to system installs where both exist. + list(APPEND GFLAGS_CHECK_INCLUDE_DIRS + /usr/local/include + /usr/local/homebrew/include # Mac OS X + /opt/local/var/macports/software # Mac OS X. + /opt/local/include + /usr/include) + list(APPEND GFLAGS_CHECK_PATH_SUFFIXES + gflags/include # Windows (for C:/Program Files prefix). + gflags/Include ) # Windows (for C:/Program Files prefix). + + list(APPEND GFLAGS_CHECK_LIBRARY_DIRS + /usr/local/lib + /usr/local/homebrew/lib # Mac OS X. + /opt/local/lib + /usr/lib) + list(APPEND GFLAGS_CHECK_LIBRARY_SUFFIXES + gflags/lib # Windows (for C:/Program Files prefix). + gflags/Lib ) # Windows (for C:/Program Files prefix). + + # Search supplied hint directories first if supplied. + find_path(GFLAGS_INCLUDE_DIR + NAMES gflags/gflags.h + PATHS ${GFLAGS_INCLUDE_DIR_HINTS} + ${GFLAGS_CHECK_INCLUDE_DIRS} + PATH_SUFFIXES ${GFLAGS_CHECK_PATH_SUFFIXES}) + if (NOT GFLAGS_INCLUDE_DIR OR + NOT EXISTS ${GFLAGS_INCLUDE_DIR}) + gflags_report_not_found( + "Could not find gflags include directory, set GFLAGS_INCLUDE_DIR " + "to directory containing gflags/gflags.h") + endif (NOT GFLAGS_INCLUDE_DIR OR + NOT EXISTS ${GFLAGS_INCLUDE_DIR}) + + find_library(GFLAGS_LIBRARY NAMES gflags + PATHS ${GFLAGS_LIBRARY_DIR_HINTS} + ${GFLAGS_CHECK_LIBRARY_DIRS} + PATH_SUFFIXES ${GFLAGS_CHECK_LIBRARY_SUFFIXES}) + if (NOT GFLAGS_LIBRARY OR + NOT EXISTS ${GFLAGS_LIBRARY}) + gflags_report_not_found( + "Could not find gflags library, set GFLAGS_LIBRARY " + "to full path to libgflags.") + endif (NOT GFLAGS_LIBRARY OR + NOT EXISTS ${GFLAGS_LIBRARY}) + + # gflags typically requires a threading library (which is OS dependent), note + # that this defines the CMAKE_THREAD_LIBS_INIT variable. If we are able to + # detect threads, we assume that gflags requires it. + find_package(Threads QUIET) + set(GFLAGS_LINK_LIBRARIES ${CMAKE_THREAD_LIBS_INIT}) + # On Windows (including MinGW), the Shlwapi library is used by gflags if + # available. + if (WIN32) + include(CheckIncludeFileCXX) + check_include_file_cxx("shlwapi.h" HAVE_SHLWAPI) + if (HAVE_SHLWAPI) + list(APPEND GFLAGS_LINK_LIBRARIES shlwapi.lib) + endif(HAVE_SHLWAPI) + endif (WIN32) + + # Mark internally as found, then verify. GFLAGS_REPORT_NOT_FOUND() unsets + # if called. + set(GFLAGS_FOUND TRUE) + + # Identify what namespace gflags was built with. + if (GFLAGS_INCLUDE_DIR AND NOT GFLAGS_NAMESPACE) + # To handle Windows peculiarities / CMake bugs on MSVC we try two approaches + # to detect the gflags namespace: + # + # 1) Try to use check_cxx_source_compiles() to compile a trivial program + # with the two choices for the gflags namespace. + # + # 2) [In the event 1) fails] Use regex on the gflags headers to try to + # determine the gflags namespace. Whilst this is less robust than 1), + # it does avoid any interaction with msbuild. + gflags_check_gflags_namespace_using_try_compile() + + if (NOT GFLAGS_NAMESPACE) + # Failed to determine gflags namespace using check_cxx_source_compiles() + # method, try and obtain it using regex on the gflags headers instead. + message(STATUS "Failed to find gflags namespace using using " + "check_cxx_source_compiles(), trying namespace regex instead, " + "this is expected on Windows.") + gflags_check_gflags_namespace_using_regex() + + if (NOT GFLAGS_NAMESPACE) + gflags_report_not_found( + "Failed to determine gflags namespace either by " + "check_cxx_source_compiles(), or namespace regex.") + endif (NOT GFLAGS_NAMESPACE) + endif (NOT GFLAGS_NAMESPACE) + endif (GFLAGS_INCLUDE_DIR AND NOT GFLAGS_NAMESPACE) + + # Make the GFLAGS_NAMESPACE a cache variable s/t the user can view it, and could + # overwrite it in the CMake GUI. + set(GFLAGS_NAMESPACE "${GFLAGS_NAMESPACE}" CACHE STRING + "gflags namespace (google or gflags)" FORCE) + + # gflags does not seem to provide any record of the version in its + # source tree, thus cannot extract version. + + # Catch case when caller has set GFLAGS_NAMESPACE in the cache / GUI + # with an invalid value. + if (GFLAGS_NAMESPACE AND + NOT GFLAGS_NAMESPACE STREQUAL "google" AND + NOT GFLAGS_NAMESPACE STREQUAL "gflags") + gflags_report_not_found( + "Caller defined GFLAGS_NAMESPACE:" + " ${GFLAGS_NAMESPACE} is not valid, not google or gflags.") + endif () + # Catch case when caller has set GFLAGS_INCLUDE_DIR in the cache / GUI and + # thus FIND_[PATH/LIBRARY] are not called, but specified locations are + # invalid, otherwise we would report the library as found. + if (GFLAGS_INCLUDE_DIR AND + NOT EXISTS ${GFLAGS_INCLUDE_DIR}/gflags/gflags.h) + gflags_report_not_found( + "Caller defined GFLAGS_INCLUDE_DIR:" + " ${GFLAGS_INCLUDE_DIR} does not contain gflags/gflags.h header.") + endif (GFLAGS_INCLUDE_DIR AND + NOT EXISTS ${GFLAGS_INCLUDE_DIR}/gflags/gflags.h) + # TODO: This regex for gflags library is pretty primitive, we use lowercase + # for comparison to handle Windows using CamelCase library names, could + # this check be better? + string(TOLOWER "${GFLAGS_LIBRARY}" LOWERCASE_GFLAGS_LIBRARY) + if (GFLAGS_LIBRARY AND + NOT "${LOWERCASE_GFLAGS_LIBRARY}" MATCHES ".*gflags[^/]*") + gflags_report_not_found( + "Caller defined GFLAGS_LIBRARY: " + "${GFLAGS_LIBRARY} does not match gflags.") + endif (GFLAGS_LIBRARY AND + NOT "${LOWERCASE_GFLAGS_LIBRARY}" MATCHES ".*gflags[^/]*") + + gflags_reset_find_library_prefix() + +endif(NOT GFLAGS_FOUND) + +# Set standard CMake FindPackage variables if found. +if (GFLAGS_FOUND) + set(GFLAGS_INCLUDE_DIRS ${GFLAGS_INCLUDE_DIR}) + set(GFLAGS_LIBRARIES ${GFLAGS_LIBRARY} ${GFLAGS_LINK_LIBRARIES}) +endif (GFLAGS_FOUND) + +# Handle REQUIRED / QUIET optional arguments. +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Gflags DEFAULT_MSG + GFLAGS_INCLUDE_DIRS GFLAGS_LIBRARIES GFLAGS_NAMESPACE) + +# Only mark internal variables as advanced if we found gflags, otherwise +# leave them visible in the standard GUI for the user to set manually. +if (GFLAGS_FOUND) + mark_as_advanced(FORCE GFLAGS_INCLUDE_DIR + GFLAGS_LIBRARY + GFLAGS_NAMESPACE + gflags_DIR) # Autogenerated by find_package(gflags) +endif (GFLAGS_FOUND) diff --git a/cmake/FindGlog.cmake b/cmake/FindGlog.cmake new file mode 100644 index 00000000000000..142e2ca96ba76d --- /dev/null +++ b/cmake/FindGlog.cmake @@ -0,0 +1,24 @@ +# +# Find libglog +# +# LIBGLOG_INCLUDE_DIR - where to find glog/logging.h, etc. +# LIBGLOG_LIBRARY - List of libraries when using libglog. +# LIBGLOG_FOUND - True if libglog found. +# +# from https://github.com/facebook/hhvm/blob/master/CMake/FindGlog.cmake + +IF (LIBGLOG_INCLUDE_DIR) + # Already in cache, be silent + SET(LIBGLOG_FIND_QUIETLY TRUE) +ENDIF () + +FIND_PATH(LIBGLOG_INCLUDE_DIR glog/logging.h) + +FIND_LIBRARY(LIBGLOG_LIBRARY glog) + +# handle the QUIETLY and REQUIRED arguments and set LIBGLOG_FOUND to TRUE if +# all listed variables are TRUE +INCLUDE(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(LIBGLOG DEFAULT_MSG LIBGLOG_LIBRARY LIBGLOG_INCLUDE_DIR) + +MARK_AS_ADVANCED(LIBGLOG_LIBRARY LIBGLOG_INCLUDE_DIR) \ No newline at end of file diff --git a/cmake/FindNumPy.cmake b/cmake/FindNumPy.cmake new file mode 100644 index 00000000000000..8cdd642ac01315 --- /dev/null +++ b/cmake/FindNumPy.cmake @@ -0,0 +1,38 @@ +# Find the Python NumPy package +# PYTHON_NUMPY_INCLUDE_DIR +# NUMPY_FOUND +# will be set by this script + +cmake_minimum_required(VERSION 2.6) + +if(NOT PYTHON_EXECUTABLE) + if(NumPy_FIND_QUIETLY) + find_package(PythonInterp QUIET) + else() + find_package(PythonInterp) + set(_numpy_out 1) + endif() +endif() + +if (PYTHON_EXECUTABLE) + # write a python script that finds the numpy path + file(WRITE ${PROJECT_BINARY_DIR}/FindNumpyPath.py + "try: import numpy; print(numpy.get_include())\nexcept:pass\n") + + # execute the find script + exec_program("${PYTHON_EXECUTABLE}" ${PROJECT_BINARY_DIR} + ARGS "FindNumpyPath.py" + OUTPUT_VARIABLE NUMPY_PATH) +elseif(_numpy_out) + message(STATUS "Python executable not found.") +endif(PYTHON_EXECUTABLE) + +find_path(PYTHON_NUMPY_INCLUDE_DIR numpy/arrayobject.h + HINTS "${NUMPY_PATH}" "${PYTHON_INCLUDE_PATH}") + +if(PYTHON_NUMPY_INCLUDE_DIR) + set(PYTHON_NUMPY_FOUND 1 CACHE INTERNAL "Python numpy found") +endif(PYTHON_NUMPY_INCLUDE_DIR) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NumPy DEFAULT_MSG PYTHON_NUMPY_INCLUDE_DIR) diff --git a/cmake/FindPythonModule.cmake b/cmake/FindPythonModule.cmake new file mode 100644 index 00000000000000..2eb3441428e829 --- /dev/null +++ b/cmake/FindPythonModule.cmake @@ -0,0 +1,30 @@ +# Find if a Python module is installed +# Found at http://www.cmake.org/pipermail/cmake/2011-January/041666.html +# To use do: find_python_module(PyQt4 REQUIRED) +function(find_python_module module) + string(TOUPPER ${module} module_upper) + if(NOT PY_${module_upper}) + if(ARGC GREATER 1 AND ARGV1 STREQUAL "REQUIRED") + set(${module}_FIND_REQUIRED TRUE) + else() + set(${module}_FIND_REQUIRED FALSE) + endif() + # A module's location is usually a directory, but for binary modules + # it's a .so file. + execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" + "import re, ${module}; print(re.compile('/__init__.py.*').sub('',${module}.__file__))" + RESULT_VARIABLE _${module}_status + OUTPUT_VARIABLE _${module}_location + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT _${module}_status) + set(PY_${module_upper} ${_${module}_location} CACHE STRING + "Location of Python module ${module}") + endif(NOT _${module}_status) + endif(NOT PY_${module_upper}) + find_package_handle_standard_args(PY_${module} DEFAULT_MSG PY_${module_upper}) + if(NOT PY_${module_upper}_FOUND AND ${module}_FIND_REQUIRED) + message(FATAL_ERROR "python module ${module} is not found") + endif() + set(PY_${module_upper}_FOUND ${PY_${module_upper}_FOUND} PARENT_SCOPE) +endfunction(find_python_module) diff --git a/cmake/FindSphinx.cmake b/cmake/FindSphinx.cmake new file mode 100644 index 00000000000000..6702f45a168bf0 --- /dev/null +++ b/cmake/FindSphinx.cmake @@ -0,0 +1,146 @@ +# - This module looks for Sphinx +# Find the Sphinx documentation generator +# +# This modules defines +# SPHINX_EXECUTABLE +# SPHINX_FOUND + +find_program(SPHINX_EXECUTABLE + NAMES sphinx-build + PATHS + /usr/bin + /usr/local/bin + /opt/local/bin + DOC "Sphinx documentation generator" +) + +if( NOT SPHINX_EXECUTABLE ) + set(_Python_VERSIONS + 2.7 2.6 2.5 2.4 2.3 2.2 2.1 2.0 1.6 1.5 + ) + + foreach( _version ${_Python_VERSIONS} ) + set( _sphinx_NAMES sphinx-build-${_version} ) + + find_program( SPHINX_EXECUTABLE + NAMES ${_sphinx_NAMES} + PATHS + /usr/bin + /usr/local/bin + /opt/loca/bin + DOC "Sphinx documentation generator" + ) + endforeach() +endif() + +include(FindPackageHandleStandardArgs) + +find_package_handle_standard_args(Sphinx DEFAULT_MSG + SPHINX_EXECUTABLE +) + + +option( SPHINX_HTML_OUTPUT "Build a single HTML with the whole content." ON ) +option( SPHINX_DIRHTML_OUTPUT "Build HTML pages, but with a single directory per document." OFF ) +option( SPHINX_HTMLHELP_OUTPUT "Build HTML pages with additional information for building a documentation collection in htmlhelp." OFF ) +option( SPHINX_QTHELP_OUTPUT "Build HTML pages with additional information for building a documentation collection in qthelp." OFF ) +option( SPHINX_DEVHELP_OUTPUT "Build HTML pages with additional information for building a documentation collection in devhelp." OFF ) +option( SPHINX_EPUB_OUTPUT "Build HTML pages with additional information for building a documentation collection in epub." OFF ) +option( SPHINX_LATEX_OUTPUT "Build LaTeX sources that can be compiled to a PDF document using pdflatex." OFF ) +option( SPHINX_MAN_OUTPUT "Build manual pages in groff format for UNIX systems." OFF ) +option( SPHINX_TEXT_OUTPUT "Build plain text files." OFF ) + + +mark_as_advanced( + SPHINX_EXECUTABLE + SPHINX_HTML_OUTPUT + SPHINX_DIRHTML_OUTPUT + SPHINX_HTMLHELP_OUTPUT + SPHINX_QTHELP_OUTPUT + SPHINX_DEVHELP_OUTPUT + SPHINX_EPUB_OUTPUT + SPHINX_LATEX_OUTPUT + SPHINX_MAN_OUTPUT + SPHINX_TEXT_OUTPUT +) + +function( Sphinx_add_target target_name builder conf cache source destination ) + add_custom_target( ${target_name} ALL + COMMAND ${SPHINX_EXECUTABLE} -b ${builder} + -d ${cache} + -c ${conf} + ${source} + ${destination} + COMMENT "Generating sphinx documentation: ${builder}" + ) + + set_property( + DIRECTORY APPEND PROPERTY + ADDITIONAL_MAKE_CLEAN_FILES + ${destination} + ) +endfunction() + +# Target dependencies can be optionally listed at the end. +function( Sphinx_add_targets target_base_name conf source base_destination ) + + set( _dependencies ) + + foreach( arg IN LISTS ARGN ) + set( _dependencies ${_dependencies} ${arg} ) + endforeach() + + if( ${SPHINX_HTML_OUTPUT} ) + Sphinx_add_target( ${target_base_name}_html html ${conf} ${source} ${base_destination}/html ) + + add_dependencies( ${target_base_name}_html ${_dependencies} ) + endif() + + if( ${SPHINX_DIRHTML_OUTPUT} ) + Sphinx_add_target( ${target_base_name}_dirhtml dirhtml ${conf} ${source} ${base_destination}/dirhtml ) + + add_dependencies( ${target_base_name}_dirhtml ${_dependencies} ) + endif() + + if( ${SPHINX_QTHELP_OUTPUT} ) + Sphinx_add_target( ${target_base_name}_qthelp qthelp ${conf} ${source} ${base_destination}/qthelp ) + + add_dependencies( ${target_base_name}_qthelp ${_dependencies} ) + endif() + + if( ${SPHINX_DEVHELP_OUTPUT} ) + Sphinx_add_target( ${target_base_name}_devhelp devhelp ${conf} ${source} ${base_destination}/devhelp ) + + add_dependencies( ${target_base_name}_devhelp ${_dependencies} ) + endif() + + if( ${SPHINX_EPUB_OUTPUT} ) + Sphinx_add_target( ${target_base_name}_epub epub ${conf} ${source} ${base_destination}/epub ) + + add_dependencies( ${target_base_name}_epub ${_dependencies} ) + endif() + + if( ${SPHINX_LATEX_OUTPUT} ) + Sphinx_add_target( ${target_base_name}_latex latex ${conf} ${source} ${base_destination}/latex ) + + add_dependencies( ${target_base_name}_latex ${_dependencies} ) + endif() + + if( ${SPHINX_MAN_OUTPUT} ) + Sphinx_add_target( ${target_base_name}_man man ${conf} ${source} ${base_destination}/man ) + + add_dependencies( ${target_base_name}_man ${_dependencies} ) + endif() + + if( ${SPHINX_TEXT_OUTPUT} ) + Sphinx_add_target( ${target_base_name}_text text ${conf} ${source} ${base_destination}/text ) + + add_dependencies( ${target_base_name}_text ${_dependencies} ) + endif() + + if( ${BUILD_TESTING} ) + sphinx_add_target( ${target_base_name}_linkcheck linkcheck ${conf} ${source} ${base_destination}/linkcheck ) + + add_dependencies( ${target_base_name}_linkcheck ${_dependencies} ) + endif() +endfunction() \ No newline at end of file diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake new file mode 100644 index 00000000000000..617bd7ea7162b8 --- /dev/null +++ b/cmake/cblas.cmake @@ -0,0 +1,119 @@ +# Find the CBlas libraries +# +# It will search MKL, atlas, OpenBlas, reference-cblas in order. +# +# If any cblas implementation found, the following variable will be set. +# CBLAS_PROVIDER # one of MKL, ATLAS, OPENBLAS, REFERENCE +# CBLAS_INC_DIR # the include directory for cblas. +# CBLAS_LIBS # a list of libraries should be linked by paddle. +# # Each library should be full path to object file. +# +# User should set one of MKL_ROOT, ATLAS_ROOT, OPENBLAS_ROOT, REFERENCE_CBLAS_ROOT +# during cmake. If none of them set, it will try to find cblas implementation in +# system paths. +# + + +## Find MKL First. +set(MKL_ROOT $ENV{MKL_ROOT} CACHE PATH "Folder contains MKL") + +find_path(MKL_INCLUDE_DIR mkl.h PATHS ${MKL_ROOT}/include) +find_library(MKL_CORE_LIB NAMES mkl_core PATHS ${MKL_ROOT}/lib) +find_library(MKL_SEQUENTIAL_LIB NAMES mkl_sequential PATHS ${MKL_ROOT}/lib) +find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS ${MKL_ROOT}/lib) + + +if(MKL_INCLUDE_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64) + set(CBLAS_PROVIDER MKL) + set(CBLAS_INC_DIR ${MKL_INCLUDE_DIR}) + set(CBLAS_LIBS ${MKL_INTEL_LP64} + ${MKL_SEQUENTIAL_LIB} + ${MKL_CORE_LIB}) + add_definitions(-DPADDLE_USE_MKL) + return() # return file. +endif() + +## Then find atlas. +set(ATLAS_ROOT $ENV{ATLAS_ROOT} CACHE PATH "Folder contains Atlas") +set(ATLAS_INCLUDE_SEARCH_PATHS + ${ATLAS_ROOT}/include + /usr/include + /usr/include/atlas) +set(ATLAS_LIB_SEARCH_PATHS + ${ATLAS_ROOT}/lib + /usr/lib + /usr/lib/blas/atlas + /usr/lib/atlas + /usr/lib/atlas-base # special for ubuntu 14.04. + ) +find_path(ATLAS_INC_DIR NAMES cblas.h + PATHS ${ATLAS_INCLUDE_SEARCH_PATHS}) +find_library(ATLAS_CBLAS_LIB NAMES cblas libcblas.so.3 + PATHS ${ATLAS_LIB_SEARCH_PATHS}) +find_library(ATLAS_LIB NAMES atlas libatlas.so.3 + PATHS ${ATLAS_LIB_SEARCH_PATHS}) + +if(ATLAS_INC_DIR AND ATLAS_CBLAS_LIB AND ATLAS_LIB) + set(CBLAS_PROVIDER ATLAS) + set(CBLAS_INC_DIR ${ATLAS_INC_DIR}) + set(CBLAS_LIBS ${ATLAS_LIB} ${ATLAS_CBLAS_LIB}) + return() +endif() + +## Then find openblas. +set(OPENBLAS_ROOT $ENV{OPENBLAS_ROOT} CACHE PATH "Folder contains Openblas") +set(OPENBLAS_INCLUDE_SEARCH_PATHS + ${OPENBLAS_ROOT}/include + /usr/include + /usr/include/openblas) +set(OPENBLAS_LIB_SEARCH_PATHS + ${OPENBLAS_ROOT}/lib + /usr/lib + /usr/lib/blas/openblas + /usr/lib/openblas) + +find_path(OPENBLAS_INC_DIR NAMES cblas.h + PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) +find_library(OPENBLAS_LIB NAMES openblas + PATHS ${OPENBLAS_LIB_SEARCH_PATHS}) + +if(OPENBLAS_INC_DIR AND OPENBLAS_LIB) + set(CBLAS_PROVIDER OPENBLAS) + set(CBLAS_INC_DIR ${OPENBLAS_INC_DIR}) + set(CBLAS_LIBS ${OPENBLAS_LIB}) + return() +endif() + + +## Then find the reference-cblas. www.netlib.org/blas/ + + +set(REFERENCE_CBLAS_ROOT $ENV{REFERENCE_CBLAS_ROOT} CACHE PATH + "Folder contains reference-cblas") +set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS + ${REFERENCE_CBLAS_ROOT}/include + /usr/include + /usr/include/cblas +) + +set(REFERENCE_CBLAS_LIB_SEARCH_PATHS + ${REFERENCE_CBLAS_ROOT}/lib + /usr/lib + /usr/lib/blas/reference/ + /usr/lib/reference/ +) + +find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS + ${REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS}) +find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS + ${REFERENCE_CBLAS_LIB_SEARCH_PATHS}) + +if (REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY) + set(CBLAS_PROVIDER REFERENCE) + set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR}) + set(CBLAS_LIBS ${REFERENCE_CBLAS_LIBRARY}) + return() +endif() + +message(FATAL_ERROR "CBlas must be set. Paddle support MKL, ATLAS, OpenBlas, reference-cblas." + " Try set MKL_ROOT, ATLAS_ROOT, OPENBLAS_ROOT or REFERENCE_CBLAS_ROOT.") diff --git a/cmake/ccache.cmake b/cmake/ccache.cmake new file mode 100644 index 00000000000000..968d41801d73c4 --- /dev/null +++ b/cmake/ccache.cmake @@ -0,0 +1,9 @@ +# Use ccache if found ccache program + +find_program(CCACHE_FOUND ccache) + +if(CCACHE_FOUND) + message(STATUS "Ccache is founded, use ccache to speed up compile.") + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache) +endif(CCACHE_FOUND) \ No newline at end of file diff --git a/cmake/check_packages.cmake b/cmake/check_packages.cmake new file mode 100644 index 00000000000000..3bc0c1fd18448e --- /dev/null +++ b/cmake/check_packages.cmake @@ -0,0 +1,45 @@ +# Check package for each cmake option + +if(WITH_GPU) + find_package(CUDA REQUIRED) # CUDA is required when use gpu +endif() + +if(WITH_PYTHON) + find_package(PythonLibs 2.6 REQUIRED) + find_package(PythonInterp REQUIRED) + find_package(NumPy REQUIRED) +endif() + +if(WITH_STYLE_CHECK) + find_package(PythonInterp REQUIRED) +endif() + +if(WITH_GLOG) + find_package(Glog REQUIRED) +endif() + +if(WITH_GFLAGS) + find_package(Gflags REQUIRED) +endif() + +if(WITH_TESTING) + find_package(GTest REQUIRED) +endif() + +if(WITH_DOC) + find_package(Sphinx REQUIRED) + find_package(Doxygen REQUIRED) + find_python_module(recommonmark REQUIRED) + find_python_module(breathe REQUIRED) +endif() + +if(WITH_SWIG_PY) + if(NOT SWIG_FOUND) + message(FATAL_ERROR "SWIG is not found. Please install swig or disable WITH_SWIG_PY") + endif() + find_python_module(wheel REQUIRED) # package wheel +endif() + +if(NOT M4_EXECUTABLE) + message(FATAL_ERROR "Paddle need m4 to generate proto file.") +endif() diff --git a/cmake/cpplint.cmake b/cmake/cpplint.cmake new file mode 100644 index 00000000000000..241af9a0835b2f --- /dev/null +++ b/cmake/cpplint.cmake @@ -0,0 +1,62 @@ +# util to check C++ file style +# * it basically use google cpplint.py. +# * It provide "add_style_check_target" for cmake. +# Usage see add_style_check_target's document +# +# TODO(yuyang18): Add python style check. + +set(STYLE_FILTER) + +# diable unwanted filters + +# paddle do not indent public/potected/private in class +set(STYLE_FILTER "${STYLE_FILTER}-whitespace/indent,") +# paddle use mutable reference. BUT IT IS NOT RECOMMANDED +set(STYLE_FILTER "${STYLE_FILTER}-runtime/references,") +# paddle use relative path for include. +set(STYLE_FILTER "${STYLE_FILTER}-build/include,") +# paddle use , , etc. +set(STYLE_FILTER "${STYLE_FILTER}-build/c++11,") +# paddle use c style casting. BUT IT IS NOT RECOMMANDED +set(STYLE_FILTER "${STYLE_FILTER}-readability/casting") + + +# IGNORE SOME FILES +set(IGNORE_PATTERN + .*ImportanceSampler.* + .*cblas\\.h.* + .*LtrDataProvider.* + .*MultiDataProvider.*) + +# add_style_check_target +# +# attach check code style step for target. +# +# first argument: target name to attach +# rest arguments: source list to check code style. +# +# NOTE: If WITH_STYLE_CHECK is OFF, then this macro just do nothing. +macro(add_style_check_target TARGET_NAME) + if(WITH_STYLE_CHECK) + set(SOURCES_LIST ${ARGN}) + list(REMOVE_DUPLICATES SOURCES_LIST) + list(SORT SOURCES_LIST) + + foreach(filename ${SOURCES_LIST}) + set(LINT ON) + foreach(pattern ${IGNORE_PATTERN}) + if(filename MATCHES ${pattern}) + message(STATUS "DROP LINT ${filename}") + set(LINT OFF) + endif() + endforeach() + if(LINT MATCHES ON) + add_custom_command(TARGET ${TARGET_NAME} + PRE_BUILD + COMMAND "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py" + "--filter=${STYLE_FILTER}" ${filename} + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}) + endif() + endforeach() + endif() +endmacro() diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake new file mode 100644 index 00000000000000..e2ff923a229232 --- /dev/null +++ b/cmake/cudnn.cmake @@ -0,0 +1,68 @@ +set(CUDNN_ROOT "" CACHE PATH "CUDNN ROOT") +find_path(CUDNN_INCLUDE_DIR cudnn.h + PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include + $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} + NO_DEFAULT_PATH +) + +get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH) + +list(APPEND CUDNN_CHECK_LIBRARY_DIRS + ${CUDNN_ROOT} + ${CUDNN_ROOT}/lib64 + ${CUDNN_ROOT}/lib + $ENV{CUDNN_ROOT} + $ENV{CUDNN_ROOT}/lib64 + $ENV{CUDNN_ROOT}/lib + /usr/lib) +find_library(CUDNN_LIBRARY NAMES libcudnn.so # libcudnn_static.a + PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist} + NO_DEFAULT_PATH + DOC "Path to cuDNN library.") + + +if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY) + set(CUDNN_FOUND ON) +else() + set(CUDNN_FOUND OFF) +endif() + +if(CUDNN_FOUND) + file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) + + get_filename_component(CUDNN_LIB_PATH ${CUDNN_LIBRARY} DIRECTORY) + + string(REGEX MATCH "define CUDNN_VERSION +([0-9]+)" + CUDNN_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_VERSION +([0-9]+)" "\\1" + CUDNN_VERSION "${CUDNN_VERSION}") + + if("${CUDNN_VERSION}" STREQUAL "2000") + message(STATUS "Current cuDNN version is v2. ") + else() + string(REGEX MATCH "define CUDNN_MAJOR +([0-9]+)" CUDNN_MAJOR_VERSION + "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MAJOR +([0-9]+)" "\\1" + CUDNN_MAJOR_VERSION "${CUDNN_MAJOR_VERSION}") + string(REGEX MATCH "define CUDNN_MINOR +([0-9]+)" CUDNN_MINOR_VERSION + "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MINOR +([0-9]+)" "\\1" + CUDNN_MINOR_VERSION "${CUDNN_MINOR_VERSION}") + string(REGEX MATCH "define CUDNN_PATCHLEVEL +([0-9]+)" + CUDNN_PATCHLEVEL_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_PATCHLEVEL +([0-9]+)" "\\1" + CUDNN_PATCHLEVEL_VERSION "${CUDNN_PATCHLEVEL_VERSION}") + + if(NOT CUDNN_MAJOR_VERSION) + set(CUDNN_VERSION "???") + else() + math(EXPR CUDNN_VERSION + "${CUDNN_MAJOR_VERSION} * 1000 + + ${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}") + endif() + + message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. " + "Current cuDNN version is v${CUDNN_MAJOR_VERSION}. ") + + endif() +endif() diff --git a/cmake/enableCXX11.cmake b/cmake/enableCXX11.cmake new file mode 100644 index 00000000000000..dc8cc3371aa6e5 --- /dev/null +++ b/cmake/enableCXX11.cmake @@ -0,0 +1,13 @@ +# Enable C++ 11 for GCC. +# NOTE: It's only tested for gcc. +include(CheckCXXCompilerFlag) +CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORT_CXX11) +CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORT_CXX0X) + +if(COMPILER_SUPPORT_CXX11) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +elseif(COMPILER_SUPPORT_CXX0X) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x") +else() + message(FATAL_ERROR "Your compiler must support c++11") +endif() \ No newline at end of file diff --git a/cmake/flags.cmake b/cmake/flags.cmake new file mode 100644 index 00000000000000..351af42ee6f6ad --- /dev/null +++ b/cmake/flags.cmake @@ -0,0 +1,86 @@ +# Setting Paddle Compile Flags +include(CheckCXXCompilerFlag) +include(CheckCCompilerFlag) +include(CheckCXXSymbolExists) +# safe_set_flag +# +# Set a compile flag only if compiler is support +# is_c: is C flag or C++ flag, bool type. +# src_list: The list name which the flag name will be append to. +# flag_name: the flag name for compiler, such as '-Werror' '-Wall' etc +# rest arguments: not used. +function(safe_set_flag is_c src_list flag_name) + string(REPLACE "-" "_" safe_name ${flag_name}) + string(REPLACE "=" "_" safe_name ${safe_name}) + if(is_c) + CHECK_C_COMPILER_FLAG(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name}) + else() + CHECK_CXX_COMPILER_FLAG(${flag_name} CXX_COMPILER_SUPPORT_FLAG_${safe_name}) + set(safe_name CXX_COMPILER_SUPPORT_FLAG_${safe_name}) + endif() + if(${safe_name}) + set(${src_list} "${${src_list}} ${flag_name}" PARENT_SCOPE) + if(is_c) + set(CUDA_NVCC_FLAGS + --compiler-options;${flag_name} + ${CUDA_NVCC_FLAGS} + PARENT_SCOPE) + endif() + endif() +endfunction() + +# helper macro to set cflag +macro(safe_set_cflag src_list flag_name) + safe_set_flag(ON ${src_list} ${flag_name}) +endmacro() + +# helper macro to set cxxflag +macro(safe_set_cxxflag src_list flag_name) + safe_set_flag(OFF ${src_list} ${flag_name}) +endmacro() + +CHECK_CXX_SYMBOL_EXISTS(UINT64_MAX "stdint.h" UINT64_MAX_EXISTS) +if(NOT UINT64_MAX_EXISTS) + set(CMAKE_REQUIRED_DEFINITIONS -D__STDC_LIMIT_MACROS) + CHECK_CXX_SYMBOL_EXISTS(UINT64_MAX "stdint.h" UINT64_MAX_EXISTS_HERE) + if(UINT64_MAX_EXISTS_HERE) + set(CMAKE_REQUIRED_DEFINITIONS) + add_definitions(-D__STDC_LIMIT_MACROS) + else() + message(FATAL_ERROR "Cannot find symbol UINT64_MAX") + endif() +endif() + +# Common flags. the compiler flag used for C/C++ sources whenever release or debug +# Do not care if this flag is support for gcc. +set(COMMON_FLAGS + -fPIC + -fno-omit-frame-pointer + -Wall + -Wextra + -Werror + -Wnon-virtual-dtor + -Wdelete-non-virtual-dtor + -Wno-unused-parameter + -Wno-error=literal-suffix + -Wno-error=unused-local-typedefs) + +foreach(flag ${COMMON_FLAGS}) + safe_set_cflag(CMAKE_C_FLAGS ${flag}) + safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag}) +endforeach() + +# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. +# So, don't set these flags here. + +foreach(capability 30 35 50) + list(APPEND __arch_flags "-gencode arch=compute_${capability},code=sm_${capability}") +endforeach() + +if (CUDA_VERSION VERSION_GREATER "7.0") + list(APPEND __arch_flags "-gencode arch=compute_52,code=sm_52") +endif() + +set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS}) + diff --git a/cmake/package.cmake b/cmake/package.cmake new file mode 100644 index 00000000000000..211593f358eb34 --- /dev/null +++ b/cmake/package.cmake @@ -0,0 +1,21 @@ +set(CPACK_PACKAGE_NAME paddle) +set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "") +set(CPACK_PACKAGE_VERSION_MAJOR ${PADDLE_MAJOR_VERSION}) +set(CPACK_PACKAGE_VERSION_MINOR ${PADDLE_MINOR_VERSION}) +set(CPACK_PACKAGE_VERSION_PATCH ${PADDLE_PATCH_VERSION}) +set(CPACK_PACKAGE_VERSION ${PADDLE_VERSION}) +## DEB Settings +set(CPACK_DEBIAN_PACKAGE_NAME paddle) +set(CPACK_DEBIAN_PACKAGE_ARCHITECTURE amd64) +set(CPACK_DEBIAN_PACKAGE_MAINTAINER PaddlePaddle Dev ) +set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "Paddle") +set(CPACK_PACKAGE_DESCRIPTION "") +set(CPACK_DEBIAN_PACKAGE_DEPENDS "libatlas3-base, libgflags2, libgoogle-glog0, libprotobuf8, libpython2.7, libstdc++6, python-numpy, python-pip, python-pip-whl, python-protobuf") +set(CPACK_DEBIAN_PACKAGE_SECTION Devel) +set(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA "${PROJ_ROOT}/paddle/scripts/deb/postinst") +#set(CPACK_GENERATOR "DEB") +# Start cpack +include (CMakePackageConfigHelpers) +include (CPack) + + diff --git a/cmake/swig.cmake b/cmake/swig.cmake new file mode 100644 index 00000000000000..f5c1bcc79b3dc0 --- /dev/null +++ b/cmake/swig.cmake @@ -0,0 +1,36 @@ +find_program( + SWIG_BINARY_PATH + swig) + +if(${SWIG_BINARY_PATH} STREQUAL "SWIG_BINARY_PATH-NOTFOUND") + set(SWIG_FOUND OFF) +else() + set(SWIG_FOUND ON) +endif() + +set(MIN_SWIG_VERSION 2) +if(SWIG_FOUND) + execute_process(COMMAND sh -c "${SWIG_BINARY_PATH} -version | grep Version | cut -f3 -d' '" + OUTPUT_VARIABLE _SWIG_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(${_SWIG_VERSION} VERSION_LESS ${MIN_SWIG_VERSION}) + message("swig version ${MIN_SWIG_VERSION} or greater is needed for generating python api. " + "Only version ${_SWIG_VERSION} is found. Set SWIG_FOUND to FALSE") + set(SWIG_FOUND FALSE) + endif(${_SWIG_VERSION} VERSION_LESS ${MIN_SWIG_VERSION}) +endif(SWIG_FOUND) + +function(generate_python_api target_name) + add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/py_paddle/swig_paddle.py + ${PROJ_ROOT}/paddle/Paddle_wrap.cxx + ${PROJ_ROOT}/paddle/Paddle_wrap.h + COMMAND swig -python -c++ -outcurrentdir -I../ api/Paddle.swig + && mv ${PROJ_ROOT}/paddle/swig_paddle.py ${PROJ_ROOT}/paddle/py_paddle/swig_paddle.py + DEPENDS ${PROJ_ROOT}/paddle/api/Paddle.swig + WORKING_DIRECTORY ${PROJ_ROOT}/paddle + COMMENT "Generate Python API from swig") + add_custom_target(${target_name} ALL DEPENDS + ${PROJ_ROOT}/paddle/Paddle_wrap.cxx + ${PROJ_ROOT}/paddle/Paddle_wrap.h + ${PROJ_ROOT}/paddle/py_paddle/swig_paddle.py) +endfunction(generate_python_api) diff --git a/cmake/util.cmake b/cmake/util.cmake new file mode 100644 index 00000000000000..e0e372fed0b049 --- /dev/null +++ b/cmake/util.cmake @@ -0,0 +1,147 @@ +# Some common routine for paddle compile. + + +# target_circle_link_libraries +# Link libraries to target which has circle dependencies. +# +# First Argument: target name want to be linked with libraries +# Rest Arguments: libraries which link together. +function(target_circle_link_libraries TARGET_NAME) + target_link_libraries(${TARGET_NAME} + -Wl,--start-group + ${ARGN} + -Wl,--end-group) +endfunction() + +# compile_cu_as_cpp +# Make a cu file compiled as C++ +# Arguments: Source files +macro(compile_cu_as_cpp) + foreach(s ${ARGN}) + set_source_files_properties(${s} PROPERTIES LANGUAGE CXX) + set_source_files_properties(${s} PROPERTIES COMPILE_FLAGS "-x c++") + endforeach() +endmacro() + +# link_paddle_exe +# add paddle library for a paddle executable, such as trainer, pserver. +# +# It will handle WITH_PYTHON/WITH_GLOG etc. +function(link_paddle_exe TARGET_NAME) + if(WITH_METRIC) + if(WITH_GPU) + set(METRIC_LIBS paddle_metric_learning paddle_dserver_lib metric metric_cpu) + else() + set(METRIC_LIBS paddle_metric_learning paddle_dserver_lib metric_cpu) + endif() + else() + set(METRIC_LIBS "") + endif() + + if(PADDLE_WITH_INTERNAL) + set(INTERAL_LIBS paddle_internal_gserver paddle_internal_parameter) + target_circle_link_libraries(${TARGET_NAME} + -Wl,--whole-archive + paddle_internal_gserver + paddle_internal_owlqn + -Wl,--no-whole-archive + paddle_internal_parameter) + else() + set(INTERAL_LIBS "") + endif() + + target_circle_link_libraries(${TARGET_NAME} + -Wl,--whole-archive + paddle_gserver + ${METRIC_LIBS} + -Wl,--no-whole-archive + paddle_pserver + paddle_trainer_lib + paddle_network + paddle_math + paddle_utils + paddle_parameter + paddle_proto + paddle_cuda + ${METRIC_LIBS} + ${PROTOBUF_LIBRARY} + ${CMAKE_THREAD_LIBS_INIT} + ${CBLAS_LIBS} + ${CMAKE_DL_LIBS} + ${INTERAL_LIBS} + -lz) + + if(WITH_PYTHON) + target_link_libraries(${TARGET_NAME} + ${PYTHON_LIBRARIES}) + endif() + + if(WITH_GLOG) + target_link_libraries(${TARGET_NAME} + ${LIBGLOG_LIBRARY}) + endif() + + if(WITH_GFLAGS) + target_link_libraries(${TARGET_NAME} + ${GFLAGS_LIBRARIES}) + endif() + + if(WITH_GPU) + if(NOT WITH_DSO OR WITH_METRIC) + target_link_libraries(${TARGET_NAME} + ${CUDNN_LIBRARY} + ${CUDA_curand_LIBRARY}) + CUDA_ADD_CUBLAS_TO_TARGET(${TARGET_NAME}) + endif() + + check_library_exists(rt clock_gettime "time.h" HAVE_CLOCK_GETTIME ) + if(HAVE_CLOCK_GETTIME) + target_link_libraries(${TARGET_NAME} rt) + endif() + endif() +endfunction() + +# link_paddle_test +# Link a paddle unittest for target +# TARGET_NAME: the unittest target name +# Rest Arguemnts: not used. +function(link_paddle_test TARGET_NAME) + link_paddle_exe(${TARGET_NAME}) + target_link_libraries(${TARGET_NAME} ${GTEST_MAIN_LIBRARIES} + ${GTEST_LIBRARIES}) +endfunction() + +# add_unittest_without_exec +# +# create a paddle unittest. not specifically define how to run this unittest. +# TARGET_NAME: the unittest target name, same as executable file name +# Rest Arguments: the source files to compile this unittest. +macro(add_unittest_without_exec TARGET_NAME) + add_executable(${TARGET_NAME} ${ARGN}) + link_paddle_test(${TARGET_NAME}) + add_style_check_target(${TARGET_NAME} ${ARGN}) +endmacro() + +# add_unittest +# create a paddle unittest and just to execute this binary to make unittest. +# +# TARGET_NAME: the unittest target name, same as executable file name +# Rest Arguments: the source files to compile this unittest. +macro(add_unittest TARGET_NAME) + add_unittest_without_exec(${TARGET_NAME} ${ARGN}) + add_test(${TARGET_NAME} ${TARGET_NAME}) +endmacro() + +# add_simple_unittest +# create a paddle unittest with file name. It just compile ${TARGET_NAME}.cpp to +# ${TARGET_NAME} and then execute it. +macro(add_simple_unittest TARGET_NAME) + add_unittest(${TARGET_NAME} ${TARGET_NAME}.cpp) +endmacro() + +macro(add_paddle_culib TARGET_NAME) + set(NVCC_FLAG ${CUDA_NVCC_FLAGS}) + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--use_fast_math) + cuda_add_library(${TARGET_NAME} STATIC ${ARGN}) + set(CUDA_NVCC_FLAGS ${NVCC_FLAG}) +endmacro() diff --git a/demo/image_classification/.gitignore b/demo/image_classification/.gitignore new file mode 100644 index 00000000000000..76961dd1436f85 --- /dev/null +++ b/demo/image_classification/.gitignore @@ -0,0 +1,7 @@ +data/cifar-10-batches-py +data/cifar-out +cifar_vgg_model/* +plot.png +train.log +image_provider_copy_1.py +*pyc diff --git a/demo/image_classification/classify.py b/demo/image_classification/classify.py new file mode 120000 index 00000000000000..fefce7086ae7a6 --- /dev/null +++ b/demo/image_classification/classify.py @@ -0,0 +1 @@ +../model_zoo/resnet/classify.py \ No newline at end of file diff --git a/demo/image_classification/classify.sh b/demo/image_classification/classify.sh new file mode 100755 index 00000000000000..f797631346f5a3 --- /dev/null +++ b/demo/image_classification/classify.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +python classify.py \ + --job=predict \ + --conf=vgg_16_cifar.py \ + --model=./cifar_vgg_model/pass-00299 \ + --multi_crop \ + --data=./example/test.list diff --git a/demo/image_classification/data/download_cifar.sh b/demo/image_classification/data/download_cifar.sh new file mode 100644 index 00000000000000..ca9b0b5c905254 --- /dev/null +++ b/demo/image_classification/data/download_cifar.sh @@ -0,0 +1,20 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +tar zxf cifar-10-python.tar.gz +rm cifar-10-python.tar.gz +rm -rf cifar-out/* +echo Converting CIFAR data to images..... +python process_cifar.py ./cifar-10-batches-py ./cifar-out diff --git a/demo/image_classification/data/process_cifar.py b/demo/image_classification/data/process_cifar.py new file mode 100644 index 00000000000000..b766118eb00737 --- /dev/null +++ b/demo/image_classification/data/process_cifar.py @@ -0,0 +1,77 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import sys +import os +import PIL.Image as Image + +""" + Usage: python process_cifar input_dir output_dir +""" + + +def mkdir_not_exist(path): + """ + Make dir if the path does not exist. + path: the path to be created. + """ + if not os.path.exists(path): + os.mkdir(path) + +def create_dir_structure(output_dir): + """ + Create the directory structure for the directory. + output_dir: the direcotry structure path. + """ + mkdir_not_exist(os.path.join(output_dir)) + mkdir_not_exist(os.path.join(output_dir, "train")) + mkdir_not_exist(os.path.join(output_dir, "test")) + +def convert_batch(batch_path, label_set, label_map, + output_dir, data_split): + """ + Convert CIFAR batch to the structure of Paddle format. + batch_path: the batch to be converted. + label_set: the set of labels. + output_dir: the output path. + data_split: whether it is training or testing data. + """ + data = np.load(batch_path) + for data, label, filename in zip(data['data'], data['labels'], + data['filenames']): + data = data.reshape((3, 32, 32)) + data = np.transpose(data, (1, 2, 0)) + label = label_map[label] + output_dir_this = os.path.join(output_dir, data_split, str(label)) + output_filename = os.path.join(output_dir_this, filename) + if not label in label_set: + label_set[label] = True + mkdir_not_exist(output_dir_this) + Image.fromarray(data).save(output_filename) + + +if __name__ == '__main__': + input_dir = sys.argv[1] + output_dir = sys.argv[2] + num_batch = 5 + create_dir_structure(output_dir) + label_map = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", + 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + labels = {} + for i in range(1, num_batch + 1): + convert_batch(os.path.join(input_dir, "data_batch_%d" % i), labels, + label_map, output_dir, "train") + convert_batch(os.path.join(input_dir, "test_batch"), {}, + label_map, output_dir, "test") \ No newline at end of file diff --git a/demo/image_classification/image_predictor.py b/demo/image_classification/image_predictor.py new file mode 100644 index 00000000000000..002cb412aa3563 --- /dev/null +++ b/demo/image_classification/image_predictor.py @@ -0,0 +1,27 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from optparse import OptionParser + +from py_paddle import swig_paddle, util, DataProviderWrapperConverter +from paddle.trainer.PyDataProviderWrapper import DenseSlot +from paddle.trainer.config_parser import parse_config + + + +""" +Will merge predictor from Qingqing. +""" diff --git a/demo/image_classification/image_provider.py b/demo/image_classification/image_provider.py new file mode 100644 index 00000000000000..9e2f8b8949b39b --- /dev/null +++ b/demo/image_classification/image_provider.py @@ -0,0 +1,81 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import random + +import paddle.utils.image_util as image_util +from paddle.trainer.PyDataProvider2 import * + + +# +# {'img_size': 32, +# 'settings': , +# 'color': True, +# 'mean_img_size': 32, +# 'meta': './data/cifar-out/batches/batches.meta', +# 'num_classes': 10, +# 'file_list': ('./data/cifar-out/batches/train_batch_000',), +# 'use_jpeg': True} +def hook(settings, img_size, mean_img_size, num_classes, color, meta, use_jpeg, + is_train, **kwargs): + settings.mean_img_size = mean_img_size + settings.img_size = img_size + settings.num_classes = num_classes + settings.color = color + settings.is_train = is_train + + if settings.color: + settings.img_raw_size = settings.img_size * settings.img_size * 3 + else: + settings.img_raw_size = settings.img_size * settings.img_size + + settings.meta_path = meta + settings.use_jpeg = use_jpeg + + settings.img_mean = image_util.load_meta(settings.meta_path, + settings.mean_img_size, + settings.img_size, + settings.color) + + settings.logger.info('Image size: %s', settings.img_size) + settings.logger.info('Meta path: %s', settings.meta_path) + settings.input_types = [ + dense_vector(settings.img_raw_size), # image feature + integer_value(settings.num_classes)] # labels + + settings.logger.info('DataProvider Initialization finished') + + +@provider(init_hook=hook) +def processData(settings, file_name): + """ + The main function for loading data. + Load the batch, iterate all the images and labels in this batch. + file_name: the batch file name. + """ + data = cPickle.load(io.open(file_name, 'rb')) + indexes = list(range(len(data['images']))) + if settings.is_train: + random.shuffle(indexes) + for i in indexes: + if settings.use_jpeg == 1: + img = image_util.decode_jpeg(data['images'][i]) + else: + img = data['images'][i] + img_feat = image_util.preprocess_img(img, settings.img_mean, + settings.img_size, settings.is_train, + settings.color) + label = data['labels'][i] + yield img_feat.tolist(), int(label) diff --git a/demo/image_classification/image_util.py b/demo/image_classification/image_util.py new file mode 100644 index 00000000000000..c545d16aafbc74 --- /dev/null +++ b/demo/image_classification/image_util.py @@ -0,0 +1,207 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from PIL import Image +from cStringIO import StringIO + +def resize_image(img, target_size): + """ + Resize an image so that the shorter edge has length target_size. + img: the input image to be resized. + target_size: the target resized image size. + """ + percent = (target_size/float(min(img.size[0], img.size[1]))) + resized_size = int(round(img.size[0] * percent)), int(round(img.size[1] * percent)) + img = img.resize(resized_size, Image.ANTIALIAS) + return img + +def flip(im): + """ + Return the flipped image. + Flip an image along the horizontal direction. + im: input image, (H x W x K) ndarrays + """ + if len(im.shape) == 3: + return im[:, :, ::-1] + else: + return im[:, ::-1] + +def crop_img(im, inner_size, color=True, test=True): + """ + Return cropped image. + The size of the cropped image is inner_size * inner_size. + im: (K x H x W) ndarrays + inner_size: the cropped image size. + color: whether it is color image. + test: whether in test mode. + If False, does random cropping and flipping. + If True, crop the center of images. + """ + if color: + height, width = max(inner_size, im.shape[1]), max(inner_size, im.shape[2]) + padded_im = np.zeros((3, height, width)) + startY = (height - im.shape[1]) / 2 + startX = (width - im.shape[2]) / 2 + endY, endX = startY + im.shape[1], startX + im.shape[2] + padded_im[:, startY: endY, startX: endX] = im + else: + im = im.astype('float32') + height, width = max(inner_size, im.shape[0]), max(inner_size, im.shape[1]) + padded_im = np.zeros((height, width)) + startY = (height - im.shape[0]) / 2 + startX = (width - im.shape[1]) / 2 + endY, endX = startY + im.shape[0], startX + im.shape[1] + padded_im[startY: endY, startX: endX] = im + if test: + startY = (height - inner_size) / 2 + startX = (width - inner_size) / 2 + else: + startY = np.random.randint(0, height - inner_size + 1) + startX = np.random.randint(0, width - inner_size + 1) + endY, endX = startY + inner_size, startX + inner_size + if color: + pic = padded_im[:, startY: endY, startX: endX] + else: + pic = padded_im[startY: endY, startX: endX] + if (not test) and (np.random.randint(2) == 0): + pic = flip(pic) + return pic + +def decode_jpeg(jpeg_string): + np_array = np.array(Image.open(StringIO(jpeg_string))) + if len(np_array.shape) == 3: + np_array = np.transpose(np_array, (2, 0, 1)) + return np_array + +def preprocess_img(im, img_mean, crop_size, is_train, color=True): + """ + Does data augmentation for images. + If is_train is false, cropping the center region from the image. + If is_train is true, randomly crop a region from the image, + and randomy does flipping. + im: (K x H x W) ndarrays + """ + im = im.astype('float32') + test = not is_train + pic = crop_img(im, crop_size, color, test) + pic -= img_mean + return pic.flatten() + +def load_meta(meta_path, mean_img_size, crop_size, color=True): + """ + Return the loaded meta file. + Load the meta image, which is the mean of the images in the dataset. + The mean image is subtracted from every input image so that the expected mean + of each input image is zero. + """ + mean = np.load(meta_path)['data_mean'] + border = (mean_img_size - crop_size) / 2 + if color: + assert(mean_img_size * mean_img_size * 3 == mean.shape[0]) + mean = mean.reshape(3, mean_img_size, mean_img_size) + mean = mean[:, border: border + crop_size, + border: border + crop_size].astype('float32') + else: + assert(mean_img_size * mean_img_size == mean.shape[0]) + mean = mean.reshape(mean_img_size, mean_img_size) + mean = mean[border: border + crop_size, + border: border + crop_size].astype('float32') + return mean + +def load_image(img_path, is_color=True): + """ + Load image and return. + img_path: image path. + is_color: is color image or not. + """ + img = Image.open(img_path) + img.load() + return img + +def oversample(img, crop_dims): + """ + image : iterable of (H x W x K) ndarrays + crop_dims: (height, width) tuple for the crops. + Returned data contains ten crops of input image, namely, + four corner patches and the center patch as well as their + horizontal reflections. + """ + # Dimensions and center. + im_shape = np.array(img[0].shape) + crop_dims = np.array(crop_dims) + im_center = im_shape[:2] / 2.0 + + # Make crop coordinates + h_indices = (0, im_shape[0] - crop_dims[0]) + w_indices = (0, im_shape[1] - crop_dims[1]) + crops_ix = np.empty((5, 4), dtype=int) + curr = 0 + for i in h_indices: + for j in w_indices: + crops_ix[curr] = (i, j, i + crop_dims[0], j + crop_dims[1]) + curr += 1 + crops_ix[4] = np.tile(im_center, (1, 2)) + np.concatenate([ + -crop_dims / 2.0, + crop_dims / 2.0 + ]) + crops_ix = np.tile(crops_ix, (2, 1)) + + # Extract crops + crops = np.empty((10 * len(img), crop_dims[0], crop_dims[1], + im_shape[-1]), dtype=np.float32) + ix = 0 + for im in img: + for crop in crops_ix: + crops[ix] = im[crop[0]:crop[2], crop[1]:crop[3], :] + ix += 1 + crops[ix-5:ix] = crops[ix-5:ix, :, ::-1, :] # flip for mirrors + return crops + +class ImageTransformer: + def __init__(self, transpose = None, + channel_swap = None, mean = None, is_color = True): + self.transpose = transpose + self.channel_swap = None + self.mean = None + self.is_color = is_color + + def set_transpose(self, order): + if self.is_color: + assert 3 == len(order) + self.transpose = order + + def set_channel_swap(self, order): + if self.is_color: + assert 3 == len(order) + self.channel_swap = order + + def set_mean(self, mean): + # mean value, may be one value per channel + if mean.ndim == 1: + mean = mean[:, np.newaxis, np.newaxis] + else: + # elementwise mean + if self.is_color: + assert len(mean.shape) == 3 + self.mean = mean + + def transformer(self, data): + if self.transpose is not None: + data = data.transpose(self.transpose) + if self.channel_swap is not None: + data = data[self.channel_swap, :, :] + if self.mean is not None: + data -= self.mean + return data diff --git a/demo/image_classification/preprocess.py b/demo/image_classification/preprocess.py new file mode 100755 index 00000000000000..0286a5d7e9dc8d --- /dev/null +++ b/demo/image_classification/preprocess.py @@ -0,0 +1,40 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils.preprocess_img import ImageClassificationDatasetCreater +from optparse import OptionParser + + +def option_parser(): + parser = OptionParser(usage="usage: python preprcoess.py "\ + "-i data_dir [options]") + parser.add_option("-i", "--input", action="store", + dest="input", help="Input data directory.") + parser.add_option("-s", "--size", action="store", + dest="size", help="Processed image size.") + parser.add_option("-c", "--color", action="store", + dest="color", help="whether to use color images.") + return parser.parse_args() + +if __name__ == '__main__': + options, args = option_parser() + data_dir = options.input + processed_image_size = int(options.size) + color = options.color == "1" + data_creator = ImageClassificationDatasetCreater(data_dir, + processed_image_size, + color) + data_creator.num_per_batch = 1000 + data_creator.overwrite = True + data_creator.create_batches() diff --git a/demo/image_classification/preprocess.sh b/demo/image_classification/preprocess.sh new file mode 100755 index 00000000000000..fe89c8f4bb9464 --- /dev/null +++ b/demo/image_classification/preprocess.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +export PYTHONPATH=$PYTHONPATH:../../ + +data_dir=./data/cifar-out + +python preprocess.py -i $data_dir -s 32 -c 1 diff --git a/demo/image_classification/train.sh b/demo/image_classification/train.sh new file mode 100755 index 00000000000000..ed9b5220fff6a4 --- /dev/null +++ b/demo/image_classification/train.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +config=vgg_16_cifar.py +output=./cifar_vgg_model +log=train.log + +paddle train \ +--config=$config \ +--dot_period=10 \ +--log_period=100 \ +--test_all_data_in_one_period=1 \ +--use_gpu=1 \ +--trainer_count=1 \ +--num_passes=200 \ +--save_dir=$output \ +2>&1 | tee $log + +python -m paddle.utils.plotcurve -i $log > plot.png diff --git a/demo/image_classification/upload_hadoop.sh b/demo/image_classification/upload_hadoop.sh new file mode 100755 index 00000000000000..34d3a8b7ce00f6 --- /dev/null +++ b/demo/image_classification/upload_hadoop.sh @@ -0,0 +1,18 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +hadoop fs -Dhadoop.job.ugi=paddle_demo,paddle_demo -put data/cifar-out/batches/train_batch_* /app/idl/idl-dl/paddle/demo/image_classification/train/ +hadoop fs -Dhadoop.job.ugi=paddle_demo,paddle_demo -put data/cifar-out/batches/test_batch_* /app/idl/idl-dl/paddle/demo/image_classification/test/ +hadoop fs -Dhadoop.job.ugi=paddle_demo,paddle_demo -put data/cifar-out/batches/batches.meta /app/idl/idl-dl/paddle/demo/image_classification/train_meta +hadoop fs -Dhadoop.job.ugi=paddle_demo,paddle_demo -put data/cifar-out/batches/batches.meta /app/idl/idl-dl/paddle/demo/image_classification/test_meta diff --git a/demo/image_classification/vgg_16_cifar.py b/demo/image_classification/vgg_16_cifar.py new file mode 100644 index 00000000000000..238608c3cbede1 --- /dev/null +++ b/demo/image_classification/vgg_16_cifar.py @@ -0,0 +1,56 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +is_predict = get_config_arg("is_predict", bool, False) + +####################Data Configuration ################## +if not is_predict: + data_dir='data/cifar-out/batches/' + meta_path=data_dir+'batches.meta' + + args = {'meta':meta_path,'mean_img_size': 32, + 'img_size': 32,'num_classes': 10, + 'use_jpeg': 1,'color': "color"} + + define_py_data_sources2(train_list=data_dir+"train.list", + test_list=data_dir+'test.list', + module='image_provider', + obj='processData', + args=args) + +######################Algorithm Configuration ############# +settings( + batch_size = 128, + learning_rate = 0.1 / 128.0, + learning_method = MomentumOptimizer(0.9), + regularization = L2Regularization(0.0005 * 128) +) + +#######################Network Configuration ############# +data_size=3*32*32 +label_size=10 +img = data_layer(name='image', + size=data_size) +# small_vgg is predined in trainer_config_helpers.network +predict = small_vgg(input_image=img, + num_channels=3, + num_classes=label_size) + +if not is_predict: + lbl = data_layer(name="label", size=label_size) + outputs(classification_cost(input=predict, label=lbl)) +else: + outputs(predict) diff --git a/demo/model_zoo/embedding/.gitignore b/demo/model_zoo/embedding/.gitignore new file mode 100644 index 00000000000000..908f5a3fb2f7c3 --- /dev/null +++ b/demo/model_zoo/embedding/.gitignore @@ -0,0 +1,2 @@ +baidu.dict +model_*.emb diff --git a/demo/model_zoo/embedding/extract_para.py b/demo/model_zoo/embedding/extract_para.py new file mode 100755 index 00000000000000..17067792fc38d0 --- /dev/null +++ b/demo/model_zoo/embedding/extract_para.py @@ -0,0 +1,96 @@ +#!/bin/env python +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: + python extract_para.py --preModel PREMODEL --preDict PREDICT \ + --usrModel USRMODEL --usrDict USRDICT -d DIM + +Options: + -h, --help show this help message and exit + --preModel PREMODEL the name of pretrained embedding model + --preDict PREDICT the name of pretrained dictionary + --usrModel usrModel the name of output usr embedding model + --usrDict usrDict the name of user specified dictionary + -d DIM dimension of parameter +""" +from optparse import OptionParser +import struct + +def get_row_index(preDict, usrDict): + """ + Get the row positions for all words in user dictionary from pre-trained dictionary. + return: a list of row positions + Example: preDict='a\nb\nc\n', usrDict='a\nc\n', then return [0,2] + """ + pos = [] + index = dict() + with open(preDict, "r") as f: + for line_index, line in enumerate(f): + word = line.strip().split()[0] + index[word] = line_index + with open(usrDict, "r") as f: + for line in f: + word = line.strip().split()[0] + pos.append(index[word]) + return pos + +def extract_parameters_by_usrDict(preModel, preDict, usrModel, usrDict, paraDim): + """ + Extract desired parameters from a pretrained embedding model based on user dictionary + """ + if paraDim not in [32, 64, 128, 256]: + raise RuntimeError("We only support 32, 64, 128, 256 dimensions now") + + fi = open(preModel, "rb") + fo = open(usrModel, "wb") + + # write filehead + rowIndex = get_row_index(preDict, usrDict) + newHead = struct.pack("iil", 0, 4, len(rowIndex) * paraDim) + fo.write(newHead) + bytes = 4 * paraDim + for i in range(0, len(rowIndex)): + # find the absolute position of input file + fi.seek(rowIndex[i] * bytes + 16, 0) + fo.write(fi.read(bytes)) + + print "extract parameters finish, total", len(rowIndex), "lines" + fi.close() + +def main(): + """ + Main entry for running paraconvert.py + """ + usage = "usage: \n" \ + "python %prog --preModel PREMODEL --preDict PREDICT" \ + " --usrModel USRMODEL --usrDict USRDICT -d DIM" + parser = OptionParser(usage) + parser.add_option("--preModel", action="store", dest="preModel", + help="the name of pretrained embedding model") + parser.add_option("--preDict", action="store", dest="preDict", + help="the name of pretrained dictionary") + parser.add_option("--usrModel", action="store", dest="usrModel", + help="the name of output usr embedding model") + parser.add_option("--usrDict", action="store", dest="usrDict", + help="the name of user specified dictionary") + parser.add_option("-d", action="store", dest="dim", + help="dimension of parameter") + (options, args) = parser.parse_args() + extract_parameters_by_usrDict(options.preModel, options.preDict, + options.usrModel, options.usrDict, int(options.dim)) + +if __name__ == '__main__': + main() diff --git a/demo/model_zoo/embedding/paraconvert.py b/demo/model_zoo/embedding/paraconvert.py new file mode 100755 index 00000000000000..523412303617a3 --- /dev/null +++ b/demo/model_zoo/embedding/paraconvert.py @@ -0,0 +1,151 @@ +#!/bin/env python +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: + python paraconvert.py --b2t -i INPUT -o OUTPUT -d DIM + python paraconvert.py --t2b -i INPUT -o OUTPUT + +Options: + -h, --help show this help message and exit + --b2t convert parameter file of embedding model from binary to text + --t2b convert parameter file of embedding model from text to binary + -i INPUT input parameter file name + -o OUTPUT output parameter file name + -d DIM dimension of parameter +""" +from optparse import OptionParser +import struct + +def binary2text(input, output, paraDim): + """ + Convert a binary parameter file of embedding model to be a text file. + input: the name of input binary parameter file, the format is: + 1) the first 16 bytes is filehead: + version(4 bytes): version of paddle, default = 0 + floatSize(4 bytes): sizeof(float) = 4 + paraCount(8 bytes): total number of parameter + 2) the next (paraCount * 4) bytes is parameters, each has 4 bytes + output: the name of output text parameter file, for example: + 0,4,32156096 + -0.7845433,1.1937413,-0.1704215,... + 0.0000909,0.0009465,-0.0008813,... + ... + the format is: + 1) the first line is filehead: + version=0, floatSize=4, paraCount=32156096 + 2) other lines print the paramters + a) each line prints paraDim paramters splitted by ',' + b) there is paraCount/paraDim lines (embedding words) + paraDim: dimension of parameters + """ + fi = open(input, "rb") + fo = open(output, "w") + """ + """ + version, floatSize, paraCount = struct.unpack("iil", fi.read(16)) + newHead = ','.join([str(version), str(floatSize), str(paraCount)]) + print >> fo, newHead + + bytes = 4 * int(paraDim) + format = "%df" % int(paraDim) + context = fi.read(bytes) + line = 0 + + while context: + numbers = struct.unpack(format, context) + lst = [] + for i in numbers: + lst.append('%8.7f' % i) + print >> fo, ','.join(lst) + context = fi.read(bytes) + line += 1 + fi.close() + fo.close() + print "binary2text finish, total", line, "lines" + +def get_para_count(input): + """ + Compute the total number of embedding parameters in input text file. + input: the name of input text file + """ + numRows = 1 + paraDim = 0 + with open(input) as f: + line = f.readline() + paraDim = len(line.split(",")) + for line in f: + numRows += 1 + return numRows * paraDim + +def text2binary(input, output, paddle_head=True): + """ + Convert a text parameter file of embedding model to be a binary file. + input: the name of input text parameter file, for example: + -0.7845433,1.1937413,-0.1704215,... + 0.0000909,0.0009465,-0.0008813,... + ... + the format is: + 1) it doesn't have filehead + 2) each line stores the same dimension of parameters, + the separator is commas ',' + output: the name of output binary parameter file, the format is: + 1) the first 16 bytes is filehead: + version(4 bytes), floatSize(4 bytes), paraCount(8 bytes) + 2) the next (paraCount * 4) bytes is parameters, each has 4 bytes + """ + fi = open(input, "r") + fo = open(output, "wb") + + newHead = struct.pack("iil", 0, 4, get_para_count(input)) + fo.write(newHead) + + count = 0 + for line in fi: + line = line.strip().split(",") + for i in range(0, len(line)): + binary_data = struct.pack("f", float(line[i])) + fo.write(binary_data) + count += 1 + fi.close() + fo.close() + print "text2binary finish, total", count, "lines" + +def main(): + """ + Main entry for running paraconvert.py + """ + usage = "usage: \n" \ + "python %prog --b2t -i INPUT -o OUTPUT -d DIM \n" \ + "python %prog --t2b -i INPUT -o OUTPUT" + parser = OptionParser(usage) + parser.add_option("--b2t", action="store_true", + help="convert parameter file of embedding model from binary to text") + parser.add_option("--t2b", action="store_true", + help="convert parameter file of embedding model from text to binary") + parser.add_option("-i", action="store", dest="input", + help="input parameter file name") + parser.add_option("-o", action="store", dest="output", + help="output parameter file name") + parser.add_option("-d", action="store", dest="dim", + help="dimension of parameter") + (options, args) = parser.parse_args() + if options.b2t: + binary2text(options.input, options.output, options.dim) + if options.t2b: + text2binary(options.input, options.output) + +if __name__ == '__main__': + main() diff --git a/demo/model_zoo/embedding/pre_DictAndModel.sh b/demo/model_zoo/embedding/pre_DictAndModel.sh new file mode 100755 index 00000000000000..7821850fb25cc5 --- /dev/null +++ b/demo/model_zoo/embedding/pre_DictAndModel.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +set -x + +# download the dictionary and pretrained model +for file in baidu.dict model_32.emb model_64.emb model_128.emb model_256.emb +do + # following is the google drive address + # you can also directly download from https://pan.baidu.com/s/1o8q577s + wget https://www.googledrive.com/host/0B7Q8d52jqeI9ejh6Q1RpMTFQT1k/embedding/$file --no-check-certificate +done diff --git a/demo/model_zoo/resnet/.gitignore b/demo/model_zoo/resnet/.gitignore new file mode 100644 index 00000000000000..7a64209b62340a --- /dev/null +++ b/demo/model_zoo/resnet/.gitignore @@ -0,0 +1,5 @@ +fea_output/ +features/ +model.list +ResNet_50.dot +ResNet_50.png diff --git a/demo/model_zoo/resnet/classify.py b/demo/model_zoo/resnet/classify.py new file mode 100755 index 00000000000000..e818995fa31a92 --- /dev/null +++ b/demo/model_zoo/resnet/classify.py @@ -0,0 +1,274 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import cPickle +import logging +from PIL import Image +import numpy as np +from optparse import OptionParser + +import paddle.utils.image_util as image_util + +from py_paddle import swig_paddle, util +from py_paddle import DataProviderWrapperConverter +from paddle.trainer.PyDataProviderWrapper import DenseSlot +from paddle.trainer.config_parser import parse_config + +logging.basicConfig(format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s') +logging.getLogger().setLevel(logging.INFO) + +class ImageClassifier(): + def __init__(self, train_conf, model_dir=None, + resize_dim=256, crop_dim=224, + mean_file=None, + output_layer=None, + oversample=False, is_color=True): + """ + train_conf: network configure. + model_dir: string, directory of model. + resize_dim: int, resized image size. + crop_dim: int, crop size. + mean_file: string, image mean file. + oversample: bool, oversample means multiple crops, namely five + patches (the four corner patches and the center + patch) as well as their horizontal reflections, + ten crops in all. + """ + self.train_conf = train_conf + self.model_dir = model_dir + if model_dir is None: + self.model_dir = os.path.dirname(train_conf) + + self.resize_dim = resize_dim + self.crop_dims = [crop_dim, crop_dim] + self.oversample = oversample + self.is_color = is_color + + self.output_layer = output_layer + if self.output_layer: + assert isinstance(self.output_layer, basestring) + self.output_layer = self.output_layer.split(",") + + self.transformer = image_util.ImageTransformer(is_color = is_color) + self.transformer.set_transpose((2,0,1)) + self.transformer.set_channel_swap((2,1,0)) + + self.mean_file = mean_file + if self.mean_file is not None: + mean = np.load(self.mean_file)['data_mean'] + mean = mean.reshape(3, self.crop_dims[0], self.crop_dims[1]) + self.transformer.set_mean(mean) # mean pixel + else: + # if you use three mean value, set like: + # this three mean value is calculated from ImageNet. + self.transformer.set_mean(np.array([103.939,116.779,123.68])) + + conf_args = "is_test=1,use_gpu=1,is_predict=1" + conf = parse_config(train_conf, conf_args) + swig_paddle.initPaddle("--use_gpu=1") + self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config) + assert isinstance(self.network, swig_paddle.GradientMachine) + self.network.loadParameters(self.model_dir) + + data_size = 3 * self.crop_dims[0] * self.crop_dims[1] + slots = [DenseSlot(data_size)] + is_sequence = False + self.converter = util.DataProviderWrapperConverter(is_sequence, slots) + + def get_data(self, img_path): + """ + 1. load image from img_path. + 2. resize or oversampling. + 3. transformer data: transpose, channel swap, sub mean. + return K x H x W ndarray. + + img_path: image path. + """ + image = image_util.load_image(img_path, self.is_color) + # Another way to extract oversampled features is that + # cropping and averaging from large feature map which is + # calculated by large size of image. + # This way reduces the computation. + if self.oversample: + # image_util.resize_image: short side is self.resize_dim + image = image_util.resize_image(image, self.resize_dim) + image = np.array(image) + input = np.zeros((1, image.shape[0], image.shape[1], 3), + dtype=np.float32) + input[0] = image.astype(np.float32) + input = image_util.oversample(input, self.crop_dims) + else: + image = image.resize(self.crop_dims, Image.ANTIALIAS) + input = np.zeros((1, self.crop_dims[0], self.crop_dims[1], 3), + dtype=np.float32) + input[0] = np.array(image).astype(np.float32) + + data_in = [] + for img in input: + img = self.transformer.transformer(img).flatten() + data_in.append([img.tolist()]) + # paddle input: [[[]],[[]],...], [[]] is one sample. + return data_in + + def forward(self, input_data): + """ + return output arguments which are the Outputs() in network configure. + + input_data: py_paddle input data. + call forward. + """ + in_arg = self.converter(input_data) + return self.network.forwardTest(in_arg) + + def forward(self, data, output_layer): + """ + return output arguments which are the Outputs() in network configure. + + input_data: py_paddle input data. + call forward. + """ + input = self.converter(data) + self.network.forwardTest(input) + output = self.network.getLayerOutputs(output_layer) + res = {} + if isinstance(output_layer, basestring): + output_layer = [output_layer] + for name in output_layer: + # For oversampling, average predictions across crops. + # If not, the shape of output[name]: (1, class_number), + # the mean is also applicable. + res[name] = output[name].mean(0) + + return res + + def predict(self, data_file): + """ + call forward and predicting. + + data_file: input image list. + """ + image_files = open(data_file, 'rb').readlines() + results = {} + if self.output_layer is None: + self.output_layer = ["output"] + for line in image_files: + image = line.split()[0] + data = self.get_data(image) + prob = self.forward(data, self.output_layer) + lab = np.argsort(-prob[self.output_layer[0]]) + results[image] = lab[0] + logging.info("Label of %s is: %d", image, lab[0]) + return results + + def extract(self, data_file, output_dir, batch_size = 10000): + """ + extract and save features of output layers, which are + specify in Outputs() in network configure. + + data_file: file name of input data. + output_dir: saved directory of extracted features. + batch_size: sample number of one batch file. + """ + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + sample_num = 0 + batch_num = 0 + image_feature = {} + image_files = open(data_file, 'rb').readlines() + for idx, line in enumerate(image_files): + image = line.split()[0] + data = self.get_data(image) + feature = self.forward(data, self.output_layer) + # save extracted features + file_name = image.split("/")[-1] + image_feature[file_name] = feature + sample_num += 1 + if sample_num == batch_size: + batch_name = os.path.join(output_dir, 'batch_%d' %(batch_num)) + self.save_file(image_feature, batch_name) + logging.info('Finish batch %d', batch_num) + batch_num += 1 + sample_num = 0 + image_feature = {} + if idx % 1000 == 0: + logging.info('%d/%d, %s', idx, len(image_files), file_name) + if sample_num > 0: + batch_name = os.path.join(output_dir, 'batch_%d' %(batch_num)) + self.save_file(image_feature, batch_name) + logging.info('Finish batch %d', batch_num) + logging.info('Done: make image feature batch') + + def save_file(self, data, file): + of = open(file, 'wb') + cPickle.dump(data, of, protocol=cPickle.HIGHEST_PROTOCOL) + +def option_parser(): + """ + Main entry for predciting + """ + usage = "%prog -c config -i data_list -w model_dir [options]" + parser = OptionParser(usage="usage: %s" % usage) + parser.add_option("-j", "--job", + action="store", dest="job_type", + help="job type: predict, extract\ + predict: predicting,\ + extract: extract features") + parser.add_option("-c", "--conf", + action="store", dest="train_conf", + help="network config") + parser.add_option("-i", "--data", + action="store", dest="data_file", + help="image list") + parser.add_option("-w", "--model", + action="store", dest="model_path", + default=None, help="model path") + parser.add_option("-o", "--output_dir", + action="store", dest="output_dir", + default="output", help="output path") + parser.add_option("-m", "--mean", action="store", + dest="mean", default=None, + help="mean file.") + parser.add_option("-p", "--multi_crop", action="store_true", + dest="multi_crop", default=False, + help="Wether to use multiple crops on image.") + parser.add_option("-l", "--output_layer", action="store", + dest="output_layer", default=None, + help="--job=extract, specify layers to extract "\ + "features, --job=predict, specify layer of " + "classification probability, output in resnet.py.") + return parser.parse_args() + +def main(): + """ + 1. parse input arguments. + 2. predicting or extract features according job type. + """ + options, args = option_parser() + obj = ImageClassifier(options.train_conf, + options.model_path, + mean_file=options.mean, + output_layer=options.output_layer, + oversample=options.multi_crop) + if options.job_type == "predict": + obj.predict(options.data_file) + + elif options.job_type == "extract": + obj.extract(options.data_file, + options.output_dir) + +if __name__ == '__main__': + main() diff --git a/demo/model_zoo/resnet/example/.gitignore b/demo/model_zoo/resnet/example/.gitignore new file mode 100644 index 00000000000000..4a2b5962a6800f --- /dev/null +++ b/demo/model_zoo/resnet/example/.gitignore @@ -0,0 +1 @@ +*image_list_provider_copy_1.py diff --git a/demo/model_zoo/resnet/example/__init__.py b/demo/model_zoo/resnet/example/__init__.py new file mode 100644 index 00000000000000..7f9e87eee60376 --- /dev/null +++ b/demo/model_zoo/resnet/example/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/demo/model_zoo/resnet/example/cat.jpg b/demo/model_zoo/resnet/example/cat.jpg new file mode 100644 index 00000000000000..47b01db90eddc4 Binary files /dev/null and b/demo/model_zoo/resnet/example/cat.jpg differ diff --git a/demo/model_zoo/resnet/example/dog.jpg b/demo/model_zoo/resnet/example/dog.jpg new file mode 100644 index 00000000000000..b9cc33cf069da5 Binary files /dev/null and b/demo/model_zoo/resnet/example/dog.jpg differ diff --git a/demo/model_zoo/resnet/example/image_list_provider.py b/demo/model_zoo/resnet/example/image_list_provider.py new file mode 100644 index 00000000000000..ee457e1fffc7ed --- /dev/null +++ b/demo/model_zoo/resnet/example/image_list_provider.py @@ -0,0 +1,105 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils.image_util import * +from paddle.trainer.PyDataProvider2 import * + + +def hook(settings, image_size, crop_size, color, file_list, + is_train, **kwargs): + """ + Description: Init with a list of data file + file_list is the name list of input files. + kwargs["load_data_args"] is the value of 'load_data_args' + which can be set in config. + Each args is separated by a column. + image_size: the crop image size. + mean_meta: the path of the meta file to store the mean image. + mean_value: can be mean value, not a file. + can not set mean_meta and mean_value at the same time. + color: 'color' means a color image. Otherwise, it means a gray image. + is_train: whether the data provider is used for training. + Data argumentation might be different for training and testing. + """ + settings.img_size = image_size + settings.crop_size = crop_size + settings.mean_img_size = settings.crop_size + settings.color = color # default is color + settings.is_train = is_train + + settings.is_swap_channel = kwargs.get('swap_channel', None) + if settings.is_swap_channel is not None: + settings.swap_channel = settings.is_swap_channel + settings.is_swap_channel = True + + if settings.color: + settings.img_input_size = settings.crop_size * settings.crop_size * 3 + else: + settings.img_input_size = settings.crop_size * settings.crop_size + + settings.file_list = file_list + settings.mean_meta = kwargs.get('mean_meta', None) + settings.mean_value = kwargs.get('mean_value', None) + # can not specify both mean_meta and mean_value. + assert not (settings.mean_meta and settings.mean_value) + if not settings.mean_meta: + settings.mean_value = kwargs.get('mean_value') + sz = settings.crop_size * settings.crop_size + settings.img_mean = np.zeros(sz * 3, dtype=np.single) + for idx, value in enumerate(settings.mean_value): + settings.img_mean[idx * sz: (idx + 1) * sz] = value + settings.img_mean = settings.img_mean.reshape(3, settings.crop_size, + settings.crop_size) + + else: + settings.img_mean = load_meta(settings.mean_meta, + settings.mean_img_size, + settings.crop_size, settings.color) + + settings.input_types = [ + dense_vector(settings.img_input_size), # image feature + integer_value(1)] # labels + + settings.logger.info('Image short side: %s', settings.img_size) + settings.logger.info('Crop size: %s', settings.crop_size) + settings.logger.info('Meta path: %s', settings.mean_meta) + if settings.is_swap_channel: + settings.logger.info('swap channel: %s', settings.swap_channel) + settings.logger.info('DataProvider Initialization finished') + + +@provider(init_hook=hook, should_shuffle=False) +def processData(settings, file_list): + """ + The main function for loading data. + Load the batch, iterate all the images and labels in this batch. + file_name: the batch file name. + """ + img_path, lab = file_list.strip().split(' ') + img = Image.open(img_path) + img.load() + img = img.resize((settings.img_size, settings.img_size), Image.ANTIALIAS) + img = np.array(img).astype(np.float32) + if len(img.shape) == 3: + img = np.swapaxes(img, 1, 2) + img = np.swapaxes(img, 1, 0) + # swap channel + if settings.is_swap_channel: + img = img[settings.swap_channel, :, :] + img_feat = preprocess_img(img, + settings.img_mean, + settings.crop_size, + settings.is_train, + settings.color) + yield img_feat.tolist(), int(lab.strip()) diff --git a/demo/model_zoo/resnet/example/test.list b/demo/model_zoo/resnet/example/test.list new file mode 100644 index 00000000000000..30bbf630b640a2 --- /dev/null +++ b/demo/model_zoo/resnet/example/test.list @@ -0,0 +1,2 @@ +example/dog.jpg 0 +example/cat.jpg 0 diff --git a/demo/model_zoo/resnet/extract_fea_c++.sh b/demo/model_zoo/resnet/extract_fea_c++.sh new file mode 100755 index 00000000000000..c7f9aea9a57df5 --- /dev/null +++ b/demo/model_zoo/resnet/extract_fea_c++.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +#set names of layer which you want to extract feature +#in Outputs() of resnet.py +#like: Outputs("res5_3_branch2c_conv", "res5_3_branch2c_bn") +layer_num=50 +configure=./resnet.py +model_path=./model/resnet_$layer_num +fea_dir=fea_output +#Output is text file. +#Each line is one sample's features. +#If you set N layer names in Outputs() +#each line contains N features sperated by ";". + +# create model list file. +model_list=./model.list +touch $model_list | echo $model_path > $model_list + +paddle train \ + --local=true \ + --job=test \ + --config=$configure \ + --model_list=$model_list \ + --use_gpu=1 \ + --predict_output_dir=$fea_dir \ + --config_args=is_test=1,layer_num=$layer_num diff --git a/demo/model_zoo/resnet/extract_fea_py.sh b/demo/model_zoo/resnet/extract_fea_py.sh new file mode 100755 index 00000000000000..b0ec748bb8f0f8 --- /dev/null +++ b/demo/model_zoo/resnet/extract_fea_py.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +python classify.py \ + --job=extract \ + --conf=resnet.py \ + --mean=model/mean_meta_224/mean.meta \ + --model=model/resnet_50 \ + --data=./example/test.list \ + --output_layer="res5_3_branch2c_conv,res5_3_branch2c_bn" \ + --output_dir=features diff --git a/demo/model_zoo/resnet/get_model.sh b/demo/model_zoo/resnet/get_model.sh new file mode 100755 index 00000000000000..89312d43edf8e4 --- /dev/null +++ b/demo/model_zoo/resnet/get_model.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +mkdir model +cd model + +echo "Downloading ResNet models..." + +for file in resnet_50.tar.gz resnet_101.tar.gz resnet_152.tar.gz mean_meta_224.tar.gz +do + # following is the google drive address + # you can also directly download from https://pan.baidu.com/s/1o8q577s + wget https://www.googledrive.com/host/0B7Q8d52jqeI9ejh6Q1RpMTFQT1k/imagenet/$file --no-check-certificate + tar -xvf $file + rm $file +done + +echo "Done." diff --git a/demo/model_zoo/resnet/load_feature.py b/demo/model_zoo/resnet/load_feature.py new file mode 100644 index 00000000000000..ee4930b7a17f7f --- /dev/null +++ b/demo/model_zoo/resnet/load_feature.py @@ -0,0 +1,59 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import cPickle +import logging + +logging.basicConfig(format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s') +logging.getLogger().setLevel(logging.INFO) + +def load_feature_c(file): + """ + Load feature extracted by C++ interface. + Return a list. + file: feature file. + """ + features = [] + f = open(file, 'r') + for line in f: + sample = [] + for slot in line.strip().split(";"): + fea = [float(val) for val in slot.strip().split()] + if fea: + sample.append(fea) + features.append(sample) + f.close() + return features + +def load_feature_py(feature_dir): + """ + Load feature extracted by python interface. + Return a dictionary. + feature_dir: directory of feature file. + """ + file_list = os.listdir(feature_dir) + file_list = [os.path.join(feature_dir, f) for f in file_list] + features = {} + for file_name in file_list: + with open(file_name, 'rb') as f: + feature = cPickle.load(f) + features.update(feature) + logging.info('Load feature file %s', file_name) + return features + +if __name__ == '__main__': + print load_feature_py(sys.argv[1]) + #print load_feature_c(sys.argv[1]) diff --git a/demo/model_zoo/resnet/net_diagram.sh b/demo/model_zoo/resnet/net_diagram.sh new file mode 100755 index 00000000000000..ec72432f0ad026 --- /dev/null +++ b/demo/model_zoo/resnet/net_diagram.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +:' +Visual deep residual network +1. Using make_model_diagram.py to generate dot file. +2. Using graphviz to convert dot file. + +Usage: +./net_diagram.sh +' + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +img_type=png +img_fileprefix=ResNet_50 +conf_filename=resnet.py +dot_filename=ResNet_50.dot +config_str="layer_num=50,data_provider=0" + +python -m paddle.utils.make_model_diagram $conf_filename $dot_filename $config_str + +# If you have installed graphviz, running like this: +# dot -Tpng -o ResNet.png ResNet.dot diff --git a/demo/model_zoo/resnet/predict.sh b/demo/model_zoo/resnet/predict.sh new file mode 100755 index 00000000000000..0375cd2e08c85d --- /dev/null +++ b/demo/model_zoo/resnet/predict.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +python classify.py \ + --job=predict \ + --conf=resnet.py\ + --model=model/resnet_50 \ + --multi_crop \ + --data=./example/test.list diff --git a/demo/model_zoo/resnet/resnet.py b/demo/model_zoo/resnet/resnet.py new file mode 100644 index 00000000000000..483e308ac804e1 --- /dev/null +++ b/demo/model_zoo/resnet/resnet.py @@ -0,0 +1,260 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +""" +paper: https://arxiv.org/abs/1512.03385 +""" +is_test = get_config_arg("is_test", bool, False) +is_predict = get_config_arg("is_predict", bool, False) +data_provider = get_config_arg("data_provider", bool, True) +layer_num = get_config_arg("layer_num", int, 50) + +if not is_predict and data_provider: + train_list = 'train.list' if not is_test else None + # mean.meta is mean file of ImageNet dataset. + # mean.meta size : 3 x 224 x 224. + # If you use three mean value, set like: + # "mean_value:103.939,116.779,123.68;" + args={ + 'mean_meta': "model/mean_meta_224/mean.meta", + 'image_size': 224, 'crop_size': 224, + 'color': True,'swap_channel:': [2, 1, 0]} + define_py_data_sources2(train_list, + 'example/test.list', + module="example.image_list_provider", + obj="processData", + args=args) + +batch_size = 1 +learning_rate = 0.1 / batch_size +momentum = 0.9 +weight_decay = 0.0001 * batch_size +default_momentum(momentum) +default_decay_rate(weight_decay) + +Settings( + algorithm='sgd', + batch_size=batch_size, + learning_rate=learning_rate, + + # set the appropriate parameters according your schedule + learning_method='momentum', + learning_rate_decay_a=0.5, + learning_rate_decay_b=1200000 * 10, + learning_rate_schedule="discexp", +) + + +def conv_bn_layer(name, input, filter_size, num_filters, + stride, padding, channels=None, + active_type=ReluActivation()): + """ + A wrapper for conv layer with batch normalization layers. + Note: + conv layer has no activation. + """ + + tmp = img_conv_layer(name=name + "_conv", + input=input, + filter_size=filter_size, + num_channels=channels, + num_filters=num_filters, + stride=stride, + padding=padding, + act=LinearActivation(), + bias_attr=False) + return batch_norm_layer(name=name + "_bn", + input=tmp, + act=active_type, + use_global_stats=is_test) + + +def bottleneck_block(name, input, num_filters1, num_filters2): + """ + A wrapper for bottlenect building block in ResNet. + Last conv_bn_layer has no activation. + Addto layer has activation of relu. + """ + last_name = conv_bn_layer(name=name + '_branch2a', + input=input, + filter_size=1, + num_filters=num_filters1, + stride=1, + padding=0) + last_name = conv_bn_layer(name=name + '_branch2b', + input=last_name, + filter_size=3, + num_filters=num_filters1, + stride=1, + padding=1) + last_name = conv_bn_layer(name=name + '_branch2c', + input=last_name, + filter_size=1, + num_filters=num_filters2, + stride=1, + padding=0, + active_type=LinearActivation()) + + return addto_layer(name=name + "_addto", + input=[input, last_name], + act=ReluActivation()) + + +def mid_projection(name, input, num_filters1, num_filters2, stride=2): + """ + A wrapper for middile projection in ResNet. + projection shortcuts are used for increasing dimensions, + and other shortcuts are identity + branch1: projection shortcuts are used for increasing + dimensions, has no activation. + branch2x: bottleneck building block, shortcuts are identity. + """ + # stride = 2 + branch1 = conv_bn_layer(name=name + '_branch1', + input=input, + filter_size=1, + num_filters=num_filters2, + stride=stride, + padding=0, + active_type=LinearActivation()) + + last_name = conv_bn_layer(name=name + '_branch2a', + input=input, + filter_size=1, + num_filters=num_filters1, + stride=stride, + padding=0) + last_name = conv_bn_layer(name=name + '_branch2b', + input=last_name, + filter_size=3, + num_filters=num_filters1, + stride=1, + padding=1) + + last_name = conv_bn_layer(name=name + '_branch2c', + input=last_name, + filter_size=1, + num_filters=num_filters2, + stride=1, + padding=0, + active_type=LinearActivation()) + + return addto_layer(name=name + "_addto", + input=[branch1, last_name], + act=ReluActivation()) + + +def deep_res_net(res2_num=3, res3_num=4, res4_num=6, res5_num=3): + """ + A wrapper for 50,101,152 layers of ResNet. + res2_num: number of blocks stacked in conv2_x + res3_num: number of blocks stacked in conv3_x + res4_num: number of blocks stacked in conv4_x + res5_num: number of blocks stacked in conv5_x + """ + # For ImageNet + # conv1: 112x112 + img = data_layer(name='input', size=224 * 224 * 3) + tmp = conv_bn_layer("conv1", img, + filter_size=7, + channels=3, + num_filters=64, + stride=2, + padding=3) + tmp = img_pool_layer(name="pool1", input=tmp, pool_size=3, stride=2) + + # conv2_x: 56x56 + tmp = mid_projection(name="res2_1", + input=tmp, + num_filters1=64, + num_filters2=256, + stride=1) + for i in xrange(2, res2_num + 1, 1): + tmp = bottleneck_block(name="res2_" + str(i), + input=tmp, + num_filters1=64, + num_filters2=256) + + # conv3_x: 28x28 + tmp = mid_projection(name="res3_1", + input=tmp, + num_filters1=128, + num_filters2=512) + for i in xrange(2, res3_num + 1, 1): + tmp = bottleneck_block(name="res3_" + str(i), + input=tmp, num_filters1=128, + num_filters2=512) + + # conv4_x: 14x14 + tmp = mid_projection(name="res4_1", input=tmp, + num_filters1=256, num_filters2=1024) + for i in xrange(2, res4_num + 1, 1): + tmp = bottleneck_block(name="res4_" + str(i), + input=tmp, + num_filters1=256, + num_filters2=1024) + + # conv5_x: 7x7 + tmp = mid_projection(name="res5_1", input=tmp, + num_filters1=512, num_filters2=2048) + for i in xrange(2, res5_num + 1, 1): + tmp = bottleneck_block(name="res5_" + str(i), + input=tmp, num_filters1=512, + num_filters2=2048) + + tmp = img_pool_layer(name='avgpool', + input=tmp, + pool_size=7, + stride=1, + pool_type=AvgPooling()) + + output = fc_layer(name='output', + input=tmp, + size=1000, + act=SoftmaxActivation()) + + if not is_predict: + classification_cost(input=output, label=data_layer(name='label', + size=1)) + + +def res_net_50(): + deep_res_net(3, 4, 6, 3) + + +def res_net_101(): + deep_res_net(3, 4, 23, 3) + + +def res_net_152(): + deep_res_net(3, 8, 36, 3) + + +if not is_predict: + Inputs("input", "label") +else: + Inputs("input") +# Outputs("cost-softmax" if not is_predict else "output") +Outputs("res5_3_branch2c_conv", "res5_3_branch2c_bn") + +if layer_num == 50: + res_net_50() +elif layer_num == 101: + res_net_101() +elif layer_num == 152: + res_net_152() +else: + print("Wrong layer number.") diff --git a/demo/quick_start/.gitignore b/demo/quick_start/.gitignore new file mode 100644 index 00000000000000..d6bc73105b1abf --- /dev/null +++ b/demo/quick_start/.gitignore @@ -0,0 +1,13 @@ +*.pyc +data/dict.txt +data/dict_all.txt +data/labels.list +data/mosesdecoder-master/ +data/reviews_Electronics_5.json.gz +data/test.list +data/test.txt +data/train.list +data/train.txt +dataprovider_copy_1.py +train.log +output diff --git a/demo/quick_start/data/get_data.sh b/demo/quick_start/data/get_data.sh new file mode 100755 index 00000000000000..f355d63225b28a --- /dev/null +++ b/demo/quick_start/data/get_data.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +echo "Downloading Amazon Electronics reviews data..." +# http://jmcauley.ucsd.edu/data/amazon/ +wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics_5.json.gz + +echo "Downloading mosesdecoder..." +#https://github.com/moses-smt/mosesdecoder +wget https://github.com/moses-smt/mosesdecoder/archive/master.zip + +unzip master.zip +rm master.zip +echo "Done." diff --git a/demo/quick_start/data/pred.list b/demo/quick_start/data/pred.list new file mode 100644 index 00000000000000..d88b2b63851101 --- /dev/null +++ b/demo/quick_start/data/pred.list @@ -0,0 +1 @@ +./data/pred.txt diff --git a/demo/quick_start/data/pred.txt b/demo/quick_start/data/pred.txt new file mode 100644 index 00000000000000..6ed5f738ddaff6 --- /dev/null +++ b/demo/quick_start/data/pred.txt @@ -0,0 +1,2 @@ +the device is cute , but that 's just about all that 's good. the specs are what you 'd expect : it 's a wifi mic , with some noise filter options. the app has the option to upload your baby 's name and photo , which is a cutesy touch. but the app is otherwise unstable and useless unless you upgrade for $ 60 / year.set up involves downloading the app , turning on the mic , switching your phone to the wifi network of the mic , telling the app your wifi settings , switching your wifi back to your home router. the app is then directly connected to your mic.the app is adware ! the main screen says " cry notifications on / off : upgrade to evoz premium and receive a text message of email when your baby is crying " .but the adware points out an important limitation , this monitor is only intended to be used from your home network. if you want to access it remotely , get a webcam. this app would make a lot more sense of the premium features were included with the hardware . +don 't be fooled by my one star rating. if there was a zero , i would have selected it. this product was a waste of my money.it has never worked like the company said it supposed to. i only have one device , an iphone 4gs. after charging the the iphone mid way , the i.sound portable power max 16,000 mah is completely drained. the led light no longer lit up. when plugging the isound portable power max into a wall outlet to charge , it would charge for about 20-30 minutes and then all four battery led indicator lit up showing a full charge. i would leave it on to charge for the full 8 hours or more but each time with the same result upon using. don 't buy this thing. put your money to good use elsewhere . diff --git a/demo/quick_start/dataprovider_bow.py b/demo/quick_start/dataprovider_bow.py new file mode 100644 index 00000000000000..bbd3ecabaadbf5 --- /dev/null +++ b/demo/quick_start/dataprovider_bow.py @@ -0,0 +1,84 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.PyDataProvider2 import * + +# id of the word not in dictionary +UNK_IDX = 0 + +# initializer is called by the framework during initialization. +# It allows the user to describe the data types and setup the +# necessary data structure for later use. +# `settings` is an object. initializer need to properly fill settings.input_types. +# initializer can also store other data structures needed to be used at process(). +# In this example, dictionary is stored in settings. +# `dictionay` and `kwargs` are arguments passed from trainer_config.lr.py +def initializer(settings, dictionary, **kwargs): + # Put the word dictionary into settings + settings.word_dict = dictionary + + # setting.input_types specifies what the data types the data provider + # generates. + settings.input_types = [ + # The first input is a sparse_binary_vector, + # which means each dimension of the vector is either 0 or 1. It is the + # bag-of-words (BOW) representation of the texts. + sparse_binary_vector(len(dictionary)), + # The second input is an integer. It represents the category id of the + # sample. 2 means there are two labels in the dataset. + # (1 for positive and 0 for negative) + integer_value(2)] + +# Delaring a data provider. It has an initializer 'data_initialzer'. +# It will cache the generated data of the first pass in memory, so that +# during later pass, no on-the-fly data generation will be needed. +# `setting` is the same object used by initializer() +# `file_name` is the name of a file listed train_list or test_list file given +# to define_py_data_sources2(). See trainer_config.lr.py. +@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM) +def process(settings, file_name): + # Open the input data file. + with open(file_name, 'r') as f: + # Read each line. + for line in f: + # Each line contains the label and text of the comment, separated by \t. + label, comment = line.strip().split('\t') + + # Split the words into a list. + words = comment.split() + + # convert the words into a list of ids by looking them up in word_dict. + word_vector = [settings.word_dict.get(w, UNK_IDX) for w in words] + + # Return the features for the current comment. The first is a list + # of ids representing a 0-1 binary sparse vector of the text, + # the second is the integer id of the label. + yield word_vector, int(label) + + +def predict_initializer(settings, dictionary, **kwargs): + settings.word_dict = dictionary + settings.input_types = [ + sparse_binary_vector(len(dictionary)) + ] + +# Declaring a data provider for prediction. The difference with process +# is that label is not generated. +@provider(init_hook=predict_initializer) +def process_predict(settings, file_name): + with open(file_name, 'r') as f: + for line in f: + comment = line.strip() + word_vector = [settings.word_dict.get(w, UNK_IDX) for w in comment] + yield word_vector diff --git a/demo/quick_start/dataprovider_emb.py b/demo/quick_start/dataprovider_emb.py new file mode 100755 index 00000000000000..e9b17603818b3a --- /dev/null +++ b/demo/quick_start/dataprovider_emb.py @@ -0,0 +1,52 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.PyDataProvider2 import * + +UNK_IDX = 0 + +def initializer(settings, dictionary, **kwargs): + settings.word_dict = dictionary + settings.input_types = [ + # Define the type of the first input as sequence of integer. + # The value of the integers range from 0 to len(dictrionary)-1 + integer_value_sequence(len(dictionary)), + # Define the second input for label id + integer_value(2)] + + +@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM) +def process(settings, file_name): + with open(file_name, 'r') as f: + for line in f: + label, comment = line.strip().split('\t') + words = comment.split() + word_slot = [settings.word_dict.get(w, UNK_IDX) for w in words] + yield word_slot, int(label) + + +def predict_initializer(settings, dictionary, **kwargs): + settings.word_dict = dictionary + settings.input_types = [ + integer_value(len(dictionary), seq_type=SequenceType.SEQUENCE) + ] + + +@provider(init_hook=predict_initializer) +def process_predict(settings, file_name): + with open(file_name, 'r') as f: + for line in f: + comment = line.strip() + word_slot = [settings.word_dict.get(w, UNK_IDX) for w in comment] + yield word_slot diff --git a/demo/quick_start/predict.sh b/demo/quick_start/predict.sh new file mode 100755 index 00000000000000..f764e202446a4e --- /dev/null +++ b/demo/quick_start/predict.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +#cfg=trainer_config.lr.py +#cfg=trainer_config.emb.py +#cfg=trainer_config.cnn.py +cfg=trainer_config.lstm.py +model="output/pass-00003" +paddle train \ + --config=$cfg \ + --use_gpu=false \ + --job=test \ + --init_model_path=$model \ + --config_args=is_predict=1 \ + --predict_output_dir=. \ + +mv rank-00000 result.txt diff --git a/demo/quick_start/preprocess.py b/demo/quick_start/preprocess.py new file mode 100755 index 00000000000000..0ef7e65c749e75 --- /dev/null +++ b/demo/quick_start/preprocess.py @@ -0,0 +1,186 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +1. remove HTML before tokensizing +2. pos sample : rating score 5; neg sample: rating score 1-2. +3. size of pos : neg = 1:1. +4. size of testing set = min(25k, len(all_data) * 0.1), others is traning set. +5. distinct train set and test set. + +Usage: + python preprocess.py -i data_file [random seed] +''' + +import sys,os +import re +import operator +import gzip,math +import random +import numpy as np +from bs4 import BeautifulSoup +from subprocess import Popen, PIPE +from optparse import OptionParser + +def parse(path): + """ + Open .gz file. + """ + g = gzip.open(path, 'r') + for l in g: + yield eval(l) + +def clean(review): + """ + Clean input review: remove HTML, convert words to lower cases. + """ + # Remove HTML + review_text = BeautifulSoup(review, "html.parser").get_text() + + # Convert words to lower case + review_text = review_text.lower() + return review_text + +def tokenize(sentences): + """ + Use tokenizer.perl to tokenize input sentences. + tokenizer.perl is tool of Moses. + sentences : a list of input sentences. + return: a list of processed text. + """ + dir = './data/mosesdecoder-master/scripts/tokenizer/tokenizer.perl' + tokenizer_cmd = [dir, '-l', 'en', '-q', '-'] + assert isinstance(sentences, list) + text = "\n".join(sentences) + tokenizer = Popen(tokenizer_cmd, stdin=PIPE, stdout=PIPE) + tok_text, _ = tokenizer.communicate(text) + toks = tok_text.split('\n')[:-1] + return toks + +def create_dict(data, data_dir): + """ + Create dictionary based on data, and saved in data_dir/dict.txt. + The first line is unk \t -1. + data: list, input data. + data_dir: path to save dict. + """ + word_count = {} + for seq in data: + try: + for w in seq.lower().split(): + if w not in word_count: + word_count[w] = 1 + else: + word_count[w] += 1 + except: + sys.stderr.write(seq+"\tERROR\n") + f = open(os.path.join(data_dir, 'dict.txt'), 'w') + f.write('%s\t%s\n' % ('unk', '-1')) + for k, v in sorted(word_count.items(), key=operator.itemgetter(1),\ + reverse=True): + f.write('%s\t%s\n' % (k, v)) + f.close() + +def save_data(data, data_dir, prefix = ""): + file_name = os.path.join(data_dir, "%s.txt" % (prefix)) + file(file_name,'w').write('\n'.join(data)+'\n') + file(os.path.join(data_dir, prefix+'.list'),'w').write('%s\n' % file_name) + +def split_data(raw_txt): + """ + Extract positive and negative sample. + """ + pos = [] + neg = [] + count = 0 + dup_cnt = 0 + sys.stderr.write("extract raw data") + for l in raw_txt: + rating = l["overall"] + text = clean(l["reviewText"]) + if rating == 5.0 and text: + pos.append(text) + if rating < 3.0 and text: + neg.append(text) + count += 1 + if count % 20000==0: + sys.stderr.write(".") + sys.stderr.write("\n") + return pos, neg + +def preprocess(pos_in, neg_in, data_dir, rand_seed): + # tokenize + sys.stderr.write("tokenize...\n") + tmppos = tokenize(pos_in) + tmpneg = tokenize(neg_in) + cnt = len(tmppos) + len(tmpneg) + + # unique smaples + tmppos = list(set(tmppos)) + tmpneg = list(set(tmpneg)) + dup_cnt = cnt - len(tmppos) - len(tmpneg) + sys.stderr.write("\ntotal size of data set: %d, duplicate data: %d\n" % (cnt, dup_cnt)) + + # keep same size of positive and negative sample + min_len = min(len(tmppos), len(tmpneg)) + tmppos = tmppos[0:min_len] + tmpneg = tmpneg[0:min_len] + + # creat dictionary + sys.stderr.write("create dict with train and test data...\n") + all_data = tmppos + tmpneg + create_dict(all_data, data_dir) + + # split into train set and test set + sys.stderr.write("split data...\n") + pos = ["1\t"+i for i in tmppos] + neg = ["0\t"+i for i in tmpneg] + random.seed(rand_seed) + random.shuffle(pos) + random.shuffle(neg) + + # split into test set and train set + test_len = min(12500, int(min_len * 0.1)) + test = pos[0:test_len] + neg[0:test_len] + train = pos[test_len:] + neg[test_len:] + + # save data + sys.stderr.write("save data...\n") + save_data(train, data_dir, prefix = 'train') + save_data(test, data_dir, prefix = 'test') + file(os.path.join(data_dir,'labels.list'),'w').write('neg\t0\npos\t1\n') + +def option_parser(): + parser = OptionParser(usage="usage: python preprcoess.py "\ + "-i data_path [options]") + parser.add_option("-i", "--data", action="store", + dest="input", help="Input data path.") + parser.add_option("-s", "--seed", action="store", + dest="seed", default=1024, + help="Set random seed.") + return parser.parse_args() + +def main(): + reload(sys) + sys.setdefaultencoding('utf-8') + options, args = option_parser() + data=options.input + seed=options.seed + data_dir = os.path.dirname(data) + pos, neg = split_data(parse(data)) + preprocess(pos, neg, data_dir, seed) + sys.stderr.write("Done.\n") + +if __name__ == '__main__': + main() diff --git a/demo/quick_start/preprocess.sh b/demo/quick_start/preprocess.sh new file mode 100755 index 00000000000000..f4d8e647a22525 --- /dev/null +++ b/demo/quick_start/preprocess.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +python preprocess.py -i data/reviews_Electronics_5.json.gz + +# use 30k dict +mv data/dict.txt data/dict_all.txt +cat data/dict_all.txt | head -n 30001 > data/dict.txt diff --git a/demo/quick_start/requirements.txt b/demo/quick_start/requirements.txt new file mode 100644 index 00000000000000..c1f5f713cdafc4 --- /dev/null +++ b/demo/quick_start/requirements.txt @@ -0,0 +1 @@ +beautifulsoup4 diff --git a/demo/quick_start/train.sh b/demo/quick_start/train.sh new file mode 100755 index 00000000000000..1f0a137c8bd594 --- /dev/null +++ b/demo/quick_start/train.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +cfg=trainer_config.lr.py +#cfg=trainer_config.emb.py +#cfg=trainer_config.cnn.py +#cfg=trainer_config.lstm.py +paddle train \ + --config=$cfg \ + --save_dir=./output \ + --trainer_count=4 \ + --log_period=20 \ + --num_passes=15 \ + --use_gpu=false \ + --show_parameter_stats_period=100 \ + --test_all_data_in_one_period=1 \ + 2>&1 | tee 'train.log' diff --git a/demo/quick_start/trainer_config.cnn.py b/demo/quick_start/trainer_config.cnn.py new file mode 100644 index 00000000000000..253ec0aee26cf4 --- /dev/null +++ b/demo/quick_start/trainer_config.cnn.py @@ -0,0 +1,55 @@ +# edit-mode: -*- python -*- + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +dict_file = "./data/dict.txt" +word_dict = dict() +with open(dict_file, 'r') as f: + for i, line in enumerate(f): + w = line.strip().split()[0] + word_dict[w] = i + +is_predict = get_config_arg('is_predict', bool, False) +trn = 'data/train.list' if not is_predict else None +tst = 'data/test.list' if not is_predict else 'data/pred.list' +process = 'process' if not is_predict else 'process_predict' +define_py_data_sources2(train_list=trn, + test_list=tst, + module="dataprovider_emb", + obj=process, + args={"dictionary": word_dict}) + +batch_size = 128 if not is_predict else 1 +settings( + batch_size=batch_size, + learning_rate=2e-3, + learning_method=AdamOptimizer(), + regularization=L2Regularization(8e-4), + gradient_clipping_threshold=25 +) + +data = data_layer(name="word", size=len(word_dict)) +embedding = embedding_layer(input=data, size=128) +conv = sequence_conv_pool(input=embedding, context_len=3, hidden_size=512) +output = fc_layer(input=conv, size=2, act=SoftmaxActivation()) +if is_predict: + maxid = maxid_layer(output) + outputs([maxid, output]) +else: + label = data_layer(name="label", size=2) + cls = classification_cost(input=output, label=label) + outputs(cls) diff --git a/demo/quick_start/trainer_config.emb.py b/demo/quick_start/trainer_config.emb.py new file mode 100644 index 00000000000000..34dd7b96f2f142 --- /dev/null +++ b/demo/quick_start/trainer_config.emb.py @@ -0,0 +1,53 @@ +# edit-mode: -*- python -*- + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +dict_file = "./data/dict.txt" +word_dict = dict() +with open(dict_file, 'r') as f: + for i, line in enumerate(f): + w = line.strip().split()[0] + word_dict[w] = i + +is_predict = get_config_arg('is_predict', bool, False) +trn = 'data/train.list' if not is_predict else None +tst = 'data/test.list' if not is_predict else 'data/pred.list' +process = 'process' if not is_predict else 'process_predict' +define_py_data_sources2(train_list=trn, + test_list=tst, + module="dataprovider_emb", + obj=process, + args={"dictionary": word_dict}) + +batch_size = 128 if not is_predict else 1 +settings( + batch_size=batch_size, + learning_rate=2e-3, + learning_method=AdamOptimizer() +) + +data = data_layer(name="word", size=len(word_dict)) +embedding = embedding_layer(input=data, size=128) +avg = pooling_layer(input=embedding, pooling_type=AvgPooling()) +output = fc_layer(input=avg, size=2, act=SoftmaxActivation()) +if is_predict: + maxid = maxid_layer(output) + outputs([maxid, output]) +else: + label = data_layer(name="label", size=2) + cls = classification_cost(input=output, label=label) + outputs(cls) diff --git a/demo/quick_start/trainer_config.lr.py b/demo/quick_start/trainer_config.lr.py new file mode 100644 index 00000000000000..119e3849a4b7e0 --- /dev/null +++ b/demo/quick_start/trainer_config.lr.py @@ -0,0 +1,73 @@ +# edit-mode: -*- python -*- + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +dict_file = "./data/dict.txt" +word_dict = dict() +with open(dict_file, 'r') as f: + for i, line in enumerate(f): + w = line.strip().split()[0] + word_dict[w] = i + +is_predict = get_config_arg('is_predict', bool, False) +trn = 'data/train.list' if not is_predict else None +tst = 'data/test.list' if not is_predict else 'data/pred.list' +process = 'process' if not is_predict else 'process_predict' + +# define the data sources for the model. +# We need to use different process for training and prediction. +# For training, the input data includes both word IDs and labels. +# For prediction, the input data only includs word Ids. +define_py_data_sources2(train_list=trn, + test_list=tst, + module="dataprovider_bow", + obj=process, + args={"dictionary": word_dict}) + +batch_size = 128 if not is_predict else 1 +settings( + batch_size=batch_size, + learning_rate=2e-3, + learning_method=AdamOptimizer(), + regularization=L2Regularization(8e-4), + gradient_clipping_threshold=25 +) + +# Define the data for text features. The size of the data layer is the number +# of words in the dictionary. +data = data_layer(name="word", size=len(word_dict)) + +# Define a fully connected layer with logistic activation. +# (also called softmax activation). +output = fc_layer(input=data, size=2, act=SoftmaxActivation()) + +if not is_predict: + # For training, we need label and cost + + # define the category id for each example. + # The size of the data layer is the number of labels. + label = data_layer(name="label", size=2) + + # Define cross-entropy classification loss and error. + classification_cost(input=output, label=label) + cls = classification_cost(input=output, label=label) + outputs(cls) +else: + # For prediction, no label is needed. We need to output + # We need to output classification result, and class probabilities. + maxid = maxid_layer(output) + outputs([maxid, output]) diff --git a/demo/quick_start/trainer_config.lstm.py b/demo/quick_start/trainer_config.lstm.py new file mode 100644 index 00000000000000..ec8a2cb00abd19 --- /dev/null +++ b/demo/quick_start/trainer_config.lstm.py @@ -0,0 +1,66 @@ +# edit-mode: -*- python -*- + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +dict_file = "./data/dict.txt" +word_dict = dict() +with open(dict_file, 'r') as f: + for i, line in enumerate(f): + w = line.strip().split()[0] + word_dict[w] = i + +is_predict = get_config_arg('is_predict', bool, False) +trn = 'data/train.list' if not is_predict else None +tst = 'data/test.list' if not is_predict else 'data/pred.list' +process = 'process' if not is_predict else 'process_predict' +define_py_data_sources2(train_list=trn, + test_list=tst, + module="dataprovider_emb", + obj=process, + args={"dictionary": word_dict}) + +batch_size = 128 if not is_predict else 1 +settings( + batch_size=batch_size, + learning_rate=2e-3, + learning_method=AdamOptimizer(), + regularization=L2Regularization(8e-4), + gradient_clipping_threshold=25 +) + +bias_attr = ParamAttr(initial_std=0.,l2_rate=0.) + +data = data_layer(name="word", size=len(word_dict)) +emb = embedding_layer(input=data, size=128) +fc = fc_layer(input=emb, size=512, + act=LinearActivation(), + bias_attr=bias_attr, + layer_attr=ExtraAttr(drop_rate=0.1)) +lstm = lstmemory(input=fc, act=TanhActivation(), + bias_attr=bias_attr, + layer_attr=ExtraAttr(drop_rate=0.25)) +lstm_last = pooling_layer(input=lstm, pooling_type=MaxPooling()) +output = fc_layer(input=lstm_last, size=2, + bias_attr=bias_attr, + act=SoftmaxActivation()) +if is_predict: + maxid = maxid_layer(output) + outputs([maxid, output]) +else: + label = data_layer(name="label", size=2) + cls = classification_cost(input=output, label=label) + outputs(cls) diff --git a/demo/recommendation/.gitignore b/demo/recommendation/.gitignore new file mode 100644 index 00000000000000..aeae0f189dbbbf --- /dev/null +++ b/demo/recommendation/.gitignore @@ -0,0 +1,9 @@ +log.txt +data/meta.bin +data/ml-1m +data/ratings.dat.train +data/ratings.dat.test +data/train.list +data/test.list +dataprovider_copy_1.py +*.pyc diff --git a/demo/recommendation/common_utils.py b/demo/recommendation/common_utils.py new file mode 100755 index 00000000000000..a5f00b3ef9ca00 --- /dev/null +++ b/demo/recommendation/common_utils.py @@ -0,0 +1,28 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.trainer.PyDataProvider2 import * + + +def meta_to_header(meta, name): + metas = meta[name]['__meta__']['raw_meta'] + for each_meta in metas: + if each_meta['type'] == 'id': + yield integer_value(each_meta['max']) + elif each_meta['type'] == 'embedding': + is_seq = each_meta['seq'] == 'sequence' + yield integer_value(len(each_meta['dict']), + seq_type=SequenceType.SEQUENCE if is_seq + else SequenceType.NO_SEQUENCE) + elif each_meta['type'] == 'one_hot_dense': + yield dense_vector(len(each_meta['dict'])) diff --git a/demo/recommendation/data/config.json b/demo/recommendation/data/config.json new file mode 100644 index 00000000000000..71a9dd7be6bd10 --- /dev/null +++ b/demo/recommendation/data/config.json @@ -0,0 +1,17 @@ +{ + "user": { + "file": { + "name": "users.dat", + "delimiter": "::" + }, + "fields": ["id", "gender", "age", "occupation"] + }, + "movie": { + "file": { + "name": "movies.dat", + "delimiter": "::" + }, + "fields": ["id", "title", "genres"] + } +} + diff --git a/demo/recommendation/data/config_generator.py b/demo/recommendation/data/config_generator.py new file mode 100644 index 00000000000000..29f38082693ad8 --- /dev/null +++ b/demo/recommendation/data/config_generator.py @@ -0,0 +1,134 @@ +#!/bin/env python2 +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +config_generator.py + +Usage: + ./config_generator.py [--output_format=] + ./config_generator.py -h | --help + +Options: + -h --help Show this screen. + --output_format= Output Config format(json or yaml) [default: json]. +""" + +import json +import docopt +import copy + +DEFAULT_FILE = { + "type": "split", + "delimiter": "," +} + +DEFAULT_FIELD = { + "id": { + "type": "id" + }, + "gender": { + "name": "gender", + "type": "embedding", + "dict": { + "type": "char_based" + } + }, + "age": { + "name": "age", + "type": "embedding", + "dict": { + "type": "whole_content", + "sort": True + } + }, + "occupation": { + "name": "occupation", + "type": "embedding", + "dict": { + "type": "whole_content", + "sort": "true" + } + }, + "title": { + "regex": { + "pattern": r"^(.*)\((\d+)\)$", + "group_id": 1, + "strip": True + }, + "name": "title", + "type": { + "name": "embedding", + "seq_type": "sequence", + }, + "dict": { + "type": "char_based" + } + }, + "genres": { + "type": "one_hot_dense", + "dict": { + "type": "split", + "delimiter": "|" + }, + "name": "genres" + } +} + + +def merge_dict(master_dict, slave_dict): + return dict(((k, master_dict.get(k) or slave_dict.get(k)) + for k in set(slave_dict) | set(master_dict))) + + +def main(filename, fmt): + with open(filename, 'r') as f: + conf = json.load(f) + obj = dict() + for k in conf: + val = conf[k] + file_dict = val['file'] + file_dict = merge_dict(file_dict, DEFAULT_FILE) + + fields = [] + for pos, field_key in enumerate(val['fields']): + assert isinstance(field_key, basestring) + field = copy.deepcopy(DEFAULT_FIELD[field_key]) + field['pos'] = pos + fields.append(field) + obj[k] = { + "file": file_dict, + "fields": fields + } + meta = { + "meta": obj + } + # print meta + if fmt == 'json': + def formatter(x): + import json + return json.dumps(x, indent=2) + elif fmt == 'yaml': + def formatter(x): + import yaml + return yaml.safe_dump(x, default_flow_style=False) + else: + raise NotImplementedError("Dump format %s is not implemented" % fmt) + + print formatter(meta) + + +if __name__ == '__main__': + args = docopt.docopt(__doc__, version="0.1.0") + main(args[""], args["--output_format"]) diff --git a/demo/recommendation/data/meta_config.json b/demo/recommendation/data/meta_config.json new file mode 100644 index 00000000000000..cc6a046e271dd0 --- /dev/null +++ b/demo/recommendation/data/meta_config.json @@ -0,0 +1,81 @@ +{ + "meta": { + "movie": { + "fields": [ + { + "type": "id", + "pos": 0 + }, + { + "regex": { + "pattern": "^(.*)\\((\\d+)\\)$", + "group_id": 1, + "strip": true + }, + "type": { + "seq_type": "sequence", + "name": "embedding" + }, + "dict": { + "type": "char_based" + }, + "name": "title", + "pos": 1 + }, + { + "type": "one_hot_dense", + "dict": { + "delimiter": "|", + "type": "split" + }, + "name": "genres", + "pos": 2 + } + ], + "file": { + "delimiter": "::", + "type": "split", + "name": "movies.dat" + } + }, + "user": { + "fields": [ + { + "type": "id", + "pos": 0 + }, + { + "type": "embedding", + "dict": { + "type": "char_based" + }, + "name": "gender", + "pos": 1 + }, + { + "type": "embedding", + "dict": { + "sort": true, + "type": "whole_content" + }, + "name": "age", + "pos": 2 + }, + { + "type": "embedding", + "dict": { + "sort": "true", + "type": "whole_content" + }, + "name": "occupation", + "pos": 3 + } + ], + "file": { + "delimiter": "::", + "type": "split", + "name": "users.dat" + } + } + } +} diff --git a/demo/recommendation/data/meta_generator.py b/demo/recommendation/data/meta_generator.py new file mode 100644 index 00000000000000..8d1a33d02aea11 --- /dev/null +++ b/demo/recommendation/data/meta_generator.py @@ -0,0 +1,436 @@ +#!/bin/env python2 +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preprocess Movielens dataset, to get movie/user object. + +Usage: + ./preprocess.py [--config=] + ./preprocess.py -h | --help + +Options: + -h --help Show this screen. + --version Show version. + --config= Get MetaData config file [default: config.json]. +""" +import docopt +import os +import sys +import re +import collections + +try: + import cPickle as pickle +except ImportError: + import pickle + + +class UniqueIDGenerator(object): + def __init__(self): + self.pool = collections.defaultdict(self.__next_id__) + self.next_id = 0 + + def __next_id__(self): + tmp = self.next_id + self.next_id += 1 + return tmp + + def __call__(self, k): + return self.pool[k] + + def to_list(self): + ret_val = [None] * len(self.pool) + for k in self.pool.keys(): + ret_val[self.pool[k]] = k + return ret_val + + +class SortedIDGenerator(object): + def __init__(self): + self.__key_set__ = set() + self.dict = None + + def scan(self, key): + self.__key_set__.add(key) + + def finish_scan(self, compare=None, key=None, reverse=False): + self.__key_set__ = sorted(list(self.__key_set__), cmp=compare, + key=key, reverse=reverse) + self.dict = dict() + for idx, each_key in enumerate(self.__key_set__): + self.dict[each_key] = idx + + def __call__(self, key): + return self.dict[key] + + def to_list(self): + return self.__key_set__ + + +class SplitFileReader(object): + def __init__(self, work_dir, config): + assert isinstance(config, dict) + self.filename = config['name'] + self.delimiter = config.get('delimiter', ',') + self.work_dir = work_dir + + def read(self): + with open(os.path.join(self.work_dir, self.filename), 'r') as f: + for line in f: + line = line.strip() + if isinstance(self.delimiter, unicode): + self.delimiter = str(self.delimiter) + yield line.split(self.delimiter) + + @staticmethod + def create(work_dir, config): + assert isinstance(config, dict) + if config['type'] == 'split': + return SplitFileReader(work_dir, config) + + +class IFileReader(object): + READERS = [SplitFileReader] + + def read(self): + raise NotImplementedError() + + @staticmethod + def create(work_dir, config): + for reader_cls in IFileReader.READERS: + val = reader_cls.create(work_dir, config) + if val is not None: + return val + + +class IDFieldParser(object): + TYPE = 'id' + + def __init__(self, config): + self.__max_id__ = -sys.maxint - 1 + self.__min_id__ = sys.maxint + self.__id_count__ = 0 + + def scan(self, line): + idx = int(line) + self.__max_id__ = max(self.__max_id__, idx) + self.__min_id__ = min(self.__min_id__, idx) + self.__id_count__ += 1 + + def parse(self, line): + return int(line) + + def meta_field(self): + return { + "is_key": True, + 'max': self.__max_id__, + 'min': self.__min_id__, + 'count': self.__id_count__, + 'type': 'id' + } + + +class SplitEmbeddingDict(object): + def __init__(self, delimiter): + self.__id__ = UniqueIDGenerator() + self.delimiter = delimiter + + def scan(self, multi): + for val in multi.split(self.delimiter): + self.__id__(val) + + def parse(self, multi): + return map(self.__id__, multi.split(self.delimiter)) + + def meta_field(self): + return self.__id__.to_list() + + +class EmbeddingFieldParser(object): + TYPE = 'embedding' + + NO_SEQUENCE = "no_sequence" + SEQUENCE = "sequence" + + class CharBasedEmbeddingDict(object): + def __init__(self, is_seq=True): + self.__id__ = UniqueIDGenerator() + self.is_seq = is_seq + + def scan(self, s): + for ch in s: + self.__id__(ch) + + def parse(self, s): + return map(self.__id__, s) if self.is_seq else self.__id__(s[0]) + + def meta_field(self): + return self.__id__.to_list() + + class WholeContentDict(object): + def __init__(self, need_sort=True): + assert need_sort + self.__id__ = SortedIDGenerator() + self.__has_finished__ = False + + def scan(self, txt): + self.__id__.scan(txt) + + def meta_field(self): + if not self.__has_finished__: + self.__id__.finish_scan() + self.__has_finished__ = True + return self.__id__.to_list() + + def parse(self, txt): + return self.__id__(txt) + + def __init__(self, config): + try: + self.seq_type = config['type']['seq_type'] + except TypeError: + self.seq_type = EmbeddingFieldParser.NO_SEQUENCE + + if config['dict']['type'] == 'char_based': + self.dict = EmbeddingFieldParser.CharBasedEmbeddingDict( + self.seq_type == EmbeddingFieldParser.SEQUENCE) + elif config['dict']['type'] == 'split': + self.dict = SplitEmbeddingDict( + config['dict'].get('delimiter', ',')) + elif config['dict']['type'] == 'whole_content': + self.dict = EmbeddingFieldParser.WholeContentDict( + config['dict']['sort']) + else: + print config + assert False + + self.name = config['name'] + + def scan(self, s): + self.dict.scan(s) + + def meta_field(self): + return { + 'name': self.name, + 'dict': self.dict.meta_field(), + 'type': 'embedding', + 'seq': self.seq_type + } + + def parse(self, s): + return self.dict.parse(s) + + +class OneHotDenseFieldParser(object): + TYPE = 'one_hot_dense' + + def __init__(self, config): + if config['dict']['type'] == 'split': + self.dict = SplitEmbeddingDict(config['dict']['delimiter']) + self.name = config['name'] + + def scan(self, s): + self.dict.scan(s) + + def meta_field(self): + # print self.dict.meta_field() + return { + 'dict': self.dict.meta_field(), + 'name': self.name, + 'type': 'one_hot_dense' + } + + def parse(self, s): + ids = self.dict.parse(s) + retv = [0.0] * len(self.dict.meta_field()) + for idx in ids: + retv[idx] = 1.0 + # print retv + return retv + + +class FieldParserFactory(object): + PARSERS = [IDFieldParser, EmbeddingFieldParser, OneHotDenseFieldParser] + + @staticmethod + def create(config): + if isinstance(config['type'], basestring): + config_type = config['type'] + elif isinstance(config['type'], dict): + config_type = config['type']['name'] + + assert config_type is not None + + for each_parser_cls in FieldParserFactory.PARSERS: + if config_type == each_parser_cls.TYPE: + return each_parser_cls(config) + print config + + +class CompositeFieldParser(object): + def __init__(self, parser, extractor): + self.extractor = extractor + self.parser = parser + + def scan(self, *args, **kwargs): + self.parser.scan(self.extractor.extract(*args, **kwargs)) + + def parse(self, *args, **kwargs): + return self.parser.parse(self.extractor.extract(*args, **kwargs)) + + def meta_field(self): + return self.parser.meta_field() + + +class PositionContentExtractor(object): + def __init__(self, pos): + self.pos = pos + + def extract(self, line): + assert isinstance(line, list) + return line[self.pos] + + +class RegexPositionContentExtractor(PositionContentExtractor): + def __init__(self, pos, pattern, group_id, strip=True): + PositionContentExtractor.__init__(self, pos) + pattern = pattern.strip() + self.pattern = re.compile(pattern) + self.group_id = group_id + self.strip = strip + + def extract(self, line): + line = PositionContentExtractor.extract(self, line) + match = self.pattern.match(line) + # print line, self.pattern.pattern, match + assert match is not None + txt = match.group(self.group_id) + if self.strip: + txt.strip() + return txt + + +class ContentExtractorFactory(object): + def extract(self, line): + pass + + @staticmethod + def create(config): + if 'pos' in config: + if 'regex' not in config: + return PositionContentExtractor(config['pos']) + else: + extra_args = config['regex'] + return RegexPositionContentExtractor(pos=config['pos'], + **extra_args) + + +class MetaFile(object): + def __init__(self, work_dir): + self.work_dir = work_dir + self.obj = dict() + + def parse(self, config): + config = config['meta'] + + ret_obj = dict() + for key in config.keys(): + val = config[key] + assert 'file' in val + reader = IFileReader.create(self.work_dir, val['file']) + assert reader is not None + assert 'fields' in val and isinstance(val['fields'], list) + fields_config = val['fields'] + field_parsers = map(MetaFile.__field_config_mapper__, fields_config) + + for each_parser in field_parsers: + assert each_parser is not None + + for each_block in reader.read(): + for each_parser in field_parsers: + each_parser.scan(each_block) + + metas = map(lambda x: x.meta_field(), field_parsers) + # print metas + key_index = filter(lambda x: x is not None, map( + lambda (idx, meta): idx if 'is_key' in meta and meta['is_key'] + else None, enumerate(metas)))[0] + + key_map = [] + for i in range(min(key_index, len(metas))): + key_map.append(i) + for i in range(key_index + 1, len(metas)): + key_map.append(i) + + obj = { + '__meta__': { + 'raw_meta': metas, + 'feature_map': key_map + } + } + + for each_block in reader.read(): + idx = field_parsers[key_index].parse(each_block) + val = [] + for i, each_parser in enumerate(field_parsers): + if i != key_index: + val.append(each_parser.parse(each_block)) + obj[idx] = val + ret_obj[key] = obj + self.obj = ret_obj + return ret_obj + + @staticmethod + def __field_config_mapper__(conf): + assert isinstance(conf, dict) + extrator = ContentExtractorFactory.create(conf) + field_parser = FieldParserFactory.create(conf) + assert extrator is not None + assert field_parser is not None + return CompositeFieldParser(field_parser, extrator) + + def dump(self, fp): + pickle.dump(self.obj, fp, pickle.HIGHEST_PROTOCOL) + + +def preprocess(binary_filename, dataset_dir, config, **kwargs): + assert isinstance(config, str) + with open(config, 'r') as config_file: + file_loader = None + if config.lower().endswith('.yaml'): + import yaml + file_loader = yaml + elif config.lower().endswith('.json'): + import json + file_loader = json + config = file_loader.load(config_file) + meta = MetaFile(dataset_dir) + meta.parse(config) + with open(binary_filename, 'wb') as outf: + meta.dump(outf) + + +if __name__ == '__main__': + args = docopt.docopt(__doc__, version='0.1.0') + kwargs = dict() + for key in args.keys(): + if key != '--help': + param_name = key + assert isinstance(param_name, str) + param_name = param_name.replace('<', '') + param_name = param_name.replace('>', '') + param_name = param_name.replace('--', '') + kwargs[param_name] = args[key] + preprocess(**kwargs) diff --git a/demo/recommendation/data/ml_data.sh b/demo/recommendation/data/ml_data.sh new file mode 100755 index 00000000000000..408a8723e086d3 --- /dev/null +++ b/demo/recommendation/data/ml_data.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -ex +cd "$(dirname "$0")" +# download the dataset +wget http://files.grouplens.org/datasets/movielens/ml-1m.zip +# unzip the dataset +unzip ml-1m.zip +# remove the unused zip file +rm ml-1m.zip diff --git a/demo/recommendation/data/split.py b/demo/recommendation/data/split.py new file mode 100644 index 00000000000000..ff1f7fab7befdb --- /dev/null +++ b/demo/recommendation/data/split.py @@ -0,0 +1,67 @@ +#!/bin/env python2 +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Separate movielens 1m dataset to train/test file. + +Usage: + ./separate.py [--test_ratio=] [--delimiter=] + ./separate.py -h | --help + +Options: + -h --help Show this screen. + --version Show version. + --test_ratio= Test ratio for separate [default: 0.1]. + --delimiter= File delimiter [default: ,]. +""" +import docopt +import collections +import random + + +def process(test_ratio, input_file, delimiter, **kwargs): + test_ratio = float(test_ratio) + rating_dict = collections.defaultdict(list) + with open(input_file, 'r') as f: + for line in f: + user_id = int(line.split(delimiter)[0]) + rating_dict[user_id].append(line.strip()) + + with open(input_file + ".train", 'w') as train_file: + with open(input_file + ".test", 'w') as test_file: + for k in rating_dict.keys(): + lines = rating_dict[k] + assert isinstance(lines, list) + random.shuffle(lines) + test_len = int(len(lines) * test_ratio) + for line in lines[:test_len]: + print >> test_file, line + + for line in lines[test_len:]: + print >> train_file, line + + +if __name__ == '__main__': + args = docopt.docopt(__doc__, version='0.1.0') + kwargs = dict() + for key in args.keys(): + if key != '--help': + param_name = key + assert isinstance(param_name, str) + param_name = param_name.replace('<', '') + param_name = param_name.replace('>', '') + param_name = param_name.replace('--', '') + kwargs[param_name] = args[key] + process(**kwargs) diff --git a/demo/recommendation/dataprovider.py b/demo/recommendation/dataprovider.py new file mode 100755 index 00000000000000..29cfd7224803e0 --- /dev/null +++ b/demo/recommendation/dataprovider.py @@ -0,0 +1,81 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import cPickle as pickle +except ImportError: + import pickle + +from paddle.trainer.PyDataProvider2 import * +import common_utils # parse + + +def hook(settings, meta, **kwargs): + """ + Init hook is invoked before process data. It will set obj.slots and store + data meta. + + :param obj: global object. It will passed to process routine. + :type obj: object + :param meta: the meta file object, which passed from trainer_config. Meta + file record movie/user features. + :param kwargs: unused other arguments. + """ + del kwargs # unused kwargs + + # Header define slots that used for paddle. + # first part is movie features. + # second part is user features. + # final part is rating score. + # header is a list of [USE_SEQ_OR_NOT?, SlotType] + headers = list(common_utils.meta_to_header(meta, 'movie')) + headers.extend(list(common_utils.meta_to_header(meta, 'user'))) + headers.append(dense_vector(1)) # Score + + # slot types. + settings.input_types = headers + settings.meta = meta + + +@provider(init_hook=hook, cache=CacheType.CACHE_PASS_IN_MEM) +def process(settings, filename): + with open(filename, 'r') as f: + for line in f: + # Get a rating from file. + user_id, movie_id, score = map(int, line.split('::')[:-1]) + + # Scale score to [-5, +5] + score = float(score) * 2 - 5.0 + + # Get movie/user features by movie_id, user_id + movie_meta = settings.meta['movie'][movie_id] + user_meta = settings.meta['user'][user_id] + + outputs = [movie_id - 1] + + # Then add movie features + for each_meta in movie_meta: + outputs.append(each_meta) + + # Then add user id. + outputs.append(user_id - 1) + + # Then add user features. + for each_meta in user_meta: + outputs.append(each_meta) + + # Finally, add score + outputs.append([score]) + # Return data to paddle + yield outputs diff --git a/demo/recommendation/evaluate.sh b/demo/recommendation/evaluate.sh new file mode 100755 index 00000000000000..38c1562c6370dd --- /dev/null +++ b/demo/recommendation/evaluate.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +function get_best_pass() { + cat $1 | grep -Pzo 'Test .*\n.*pass-.*' | sed -r 'N;s/Test.* cost=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g' | sort | head -n 1 +} + +LOG=`get_best_pass log.txt` +LOG=(${LOG}) +echo 'Best pass is '${LOG[1]}, ' error is '${LOG[0]}, 'which means predict get error as '`echo ${LOG[0]} | python -c 'import math; print math.sqrt(float(raw_input()))/2'` + +evaluate_pass="output/pass-${LOG[1]}" + +echo 'evaluating from pass '$evaluate_pass diff --git a/demo/recommendation/prediction.py b/demo/recommendation/prediction.py new file mode 100755 index 00000000000000..1a6cfce58fe537 --- /dev/null +++ b/demo/recommendation/prediction.py @@ -0,0 +1,51 @@ +#!/bin/env python2 +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from py_paddle import swig_paddle, DataProviderWrapperConverter + +from common_utils import * +from paddle.trainer.config_parser import parse_config + +try: + import cPickle as pickle +except ImportError: + import pickle +import sys + +if __name__ == '__main__': + model_path = sys.argv[1] + swig_paddle.initPaddle('--use_gpu=0') + conf = parse_config("trainer_config.py", "is_predict=1") + network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config) + assert isinstance(network, swig_paddle.GradientMachine) + network.loadParameters(model_path) + with open('meta.bin', 'rb') as f: + meta = pickle.load(f) + headers = list(meta_to_header(meta, 'movie')) + headers.extend(list(meta_to_header(meta, 'user'))) + cvt = DataProviderWrapperConverter(True, map(lambda x: x[1], headers)) + while True: + movie_id = int(raw_input("Input movie_id: ")) + user_id = int(raw_input("Input user_id: ")) + movie_meta = meta['movie'][movie_id] # Query Data From Meta. + user_meta = meta['user'][user_id] + data = [movie_id - 1] + data.extend(movie_meta) + data.append(user_id - 1) + data.extend(user_meta) + data = map(lambda (header, val): val if header[0] else [val], + zip(headers, data)) + print "Prediction Score is %.2f" % ((network.forwardTest(cvt([ + data]))[0]['value'][0][0] + 5) / 2) diff --git a/demo/recommendation/preprocess.sh b/demo/recommendation/preprocess.sh new file mode 100755 index 00000000000000..e181d0be455589 --- /dev/null +++ b/demo/recommendation/preprocess.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +cd "$(dirname "$0")" +delimiter='::' +dir=ml-1m +cd data +echo 'generate meta config file' +python config_generator.py config.json > meta_config.json +echo 'generate meta file' +python meta_generator.py $dir meta.bin --config=meta_config.json +echo 'split train/test file' +python split.py $dir/ratings.dat --delimiter=${delimiter} --test_ratio=0.1 +echo 'shuffle train file' +shuf $dir/ratings.dat.train > ratings.dat.train +cp $dir/ratings.dat.test . +echo "./data/ratings.dat.train" > train.list +echo "./data/ratings.dat.test" > test.list diff --git a/demo/recommendation/requirements.txt b/demo/recommendation/requirements.txt new file mode 100644 index 00000000000000..1ea154584a428b --- /dev/null +++ b/demo/recommendation/requirements.txt @@ -0,0 +1,2 @@ +PyYAML +docopt diff --git a/demo/recommendation/run.sh b/demo/recommendation/run.sh new file mode 100755 index 00000000000000..846b59cec9fc50 --- /dev/null +++ b/demo/recommendation/run.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +paddle train \ + --config=trainer_config.py \ + --save_dir=./output \ + --use_gpu=false \ + --trainer_count=4\ + --test_all_data_in_one_period=true \ + --log_period=100 \ + --dot_period=1 \ + --num_passes=50 2>&1 | tee 'log.txt' diff --git a/demo/recommendation/trainer_config.py b/demo/recommendation/trainer_config.py new file mode 100755 index 00000000000000..69b9aa7a77cafd --- /dev/null +++ b/demo/recommendation/trainer_config.py @@ -0,0 +1,101 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +try: + import cPickle as pickle +except ImportError: + import pickle + +is_predict = get_config_arg('is_predict', bool, False) + +META_FILE = 'data/meta.bin' + +with open(META_FILE, 'rb') as f: + # load meta file + meta = pickle.load(f) + +settings(batch_size=1600, learning_rate=1e-3, + learning_method=RMSPropOptimizer()) + + +def construct_feature(name): + """ + Construct movie/user features. + + This method read from meta data. Then convert feature to neural network due + to feature type. The map relation as follow. + + * id: embedding => fc + * embedding: + is_sequence: embedding => context_projection => fc => pool + not sequence: embedding => fc + * one_hot_dense: fc => fc + + Then gather all features vector, and use a fc layer to combined them as + return. + + :param name: 'movie' or 'user' + :type name: basestring + :return: combined feature output + :rtype: LayerOutput + """ + __meta__ = meta[name]['__meta__']['raw_meta'] + fusion = [] + for each_meta in __meta__: + type_name = each_meta['type'] + slot_name = each_meta.get('name', '%s_id' % name) + if type_name == 'id': + slot_dim = each_meta['max'] + embedding = embedding_layer(input=data_layer(slot_name, + size=slot_dim), + size=256, + param_attr=ParamAttr( + sparse_update=True)) + fusion.append(fc_layer(input=embedding, + size=256)) + elif type_name == 'embedding': + is_seq = each_meta['seq'] == 'sequence' + slot_dim = len(each_meta['dict']) + din = data_layer(slot_name, slot_dim) + embedding = embedding_layer(input=din, size=256) + if is_seq: + fusion.append( + text_conv_pool(input=embedding, context_len=5, + hidden_size=256)) + else: + fusion.append(fc_layer(input=embedding, + size=256)) + elif type_name == 'one_hot_dense': + slot_dim = len(each_meta['dict']) + hidden = fc_layer(input=data_layer(slot_name, slot_dim), + size=256) + fusion.append(fc_layer(input=hidden, + size=256)) + + return fc_layer(name="%s_fusion" % name, input=fusion, size=256) + + +movie_feature = construct_feature("movie") +user_feature = construct_feature("user") +similarity = cos_sim(a=movie_feature, b=user_feature) +if not is_predict: + outputs(regression_cost(input=similarity, + label=data_layer('rating', size=1))) + + define_py_data_sources2('data/train.list', 'data/test.list', module='dataprovider', + obj='process', args={'meta': meta}) +else: + outputs(similarity) diff --git a/demo/semantic_role_labeling/data/extract_dict_feature.py b/demo/semantic_role_labeling/data/extract_dict_feature.py new file mode 100644 index 00000000000000..2982e54c665b41 --- /dev/null +++ b/demo/semantic_role_labeling/data/extract_dict_feature.py @@ -0,0 +1,88 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from optparse import OptionParser + + +def extract_dict_features(pair_file, feature_file, src_dict_file, + tgt_dict_file): + src_dict = set() + tgt_dict = set() + + with open(pair_file) as fin, open(feature_file, 'w') as feature_out, open( + src_dict_file, 'w') as src_dict_out, open(tgt_dict_file, + 'w') as tgt_dict_out: + for line in fin: + sentence, labels = line.strip().split('\t') + sentence_list = sentence.split() + labels_list = labels.split() + + src_dict.update(sentence_list) + tgt_dict.update(labels_list) + + verb_index = labels_list.index('B-V') + verb_feature = sentence_list[verb_index] + + mark = [0] * len(labels_list) + if verb_index > 0: + mark[verb_index - 1] = 1 + ctx_n1 = sentence_list[verb_index - 1] + else: + ctx_n1 = 'bos' + ctx_n1_feature = ctx_n1 + + mark[verb_index] = 1 + ctx_0_feature = sentence_list[verb_index] + + if verb_index < len(labels_list) - 2: + mark[verb_index + 1] = 1 + ctx_p1 = sentence_list[verb_index + 1] + else: + ctx_p1 = 'eos' + ctx_p1_feature = ctx_p1 + + feature_str = sentence + '\t' \ + + verb_feature + '\t' \ + + ctx_n1_feature + '\t' \ + + ctx_0_feature + '\t' \ + + ctx_p1_feature + '\t' \ + + ' '.join([str(i) for i in mark]) + '\t' \ + + labels + + feature_out.write(feature_str + '\n') + + src_dict_out.write('\n') + src_dict_out.write('\n'.join(list(src_dict))) + + tgt_dict_out.write('\n'.join(list(tgt_dict))) + + +if __name__ == '__main__': + + usage = '-p pair_file -f feature_file -s source dictionary -t target dictionary ' + parser = OptionParser(usage) + parser.add_option('-p', dest='pair_file', help='the pair file') + parser.add_option( + '-f', dest='feature_file', help='the file to store feature') + parser.add_option( + '-s', dest='src_dict', help='the file to store source dictionary') + parser.add_option( + '-t', dest='tgt_dict', help='the file to store target dictionary') + + (options, args) = parser.parse_args() + + extract_dict_features(options.pair_file, options.feature_file, + options.src_dict, options.tgt_dict) diff --git a/demo/semantic_role_labeling/data/extract_pairs.py b/demo/semantic_role_labeling/data/extract_pairs.py new file mode 100644 index 00000000000000..4d1bef8f958a62 --- /dev/null +++ b/demo/semantic_role_labeling/data/extract_pairs.py @@ -0,0 +1,118 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from optparse import OptionParser + + +def read_labels(props_file): + ''' + a sentence maybe has more than one verb, each verb has its label sequence + label[], is a 3-dimension list. + the first dim is to store all sentence's label seqs, len is the sentence number + the second dim is to store all label sequences for one sentences + the third dim is to store each label for one word + ''' + labels = [] + with open(props_file) as fin: + label_seqs_for_one_sentences = [] + one_seg_in_file = [] + for line in fin: + line = line.strip() + if line == '': + for i in xrange(len(one_seg_in_file[0])): + a_kind_lable = [x[i] for x in one_seg_in_file] + label_seqs_for_one_sentences.append(a_kind_lable) + labels.append(label_seqs_for_one_sentences) + one_seg_in_file = [] + label_seqs_for_one_sentences = [] + else: + part = line.split() + one_seg_in_file.append(part) + return labels + + +def read_sentences(words_file): + sentences = [] + with open(words_file) as fin: + s = '' + for line in fin: + line = line.strip() + if line == '': + sentences.append(s.lower()) + s = '' + else: + s += line + ' ' + return sentences + + +def transform_labels(sentences, labels): + sen_lab_pair = [] + for i in xrange(len(sentences)): + if len(labels[i]) == 1: + continue + else: + for j in xrange(1, len(labels[i])): + label_list = labels[i][j] + current_tag = 'O' + is_in_bracket = False + label_seq = [] + verb_word = '' + for ll in label_list: + if ll == '*' and is_in_bracket == False: + label_seq.append('O') + elif ll == '*' and is_in_bracket == True: + label_seq.append('I-' + current_tag) + elif ll == '*)': + label_seq.append('I-' + current_tag) + is_in_bracket = False + elif ll.find('(') != -1 and ll.find(')') != -1: + current_tag = ll[1:ll.find('*')] + label_seq.append('B-' + current_tag) + is_in_bracket = False + elif ll.find('(') != -1 and ll.find(')') == -1: + current_tag = ll[1:ll.find('*')] + label_seq.append('B-' + current_tag) + is_in_bracket = True + else: + print 'error:', ll + + sen_lab_pair.append((sentences[i], label_seq)) + return sen_lab_pair + + +def write_file(sen_lab_pair, output_file): + with open(output_file, 'w') as fout: + for x in sen_lab_pair: + sentence = x[0] + label_seq = ' '.join(x[1]) + assert len(sentence.split()) == len(x[1]) + fout.write(sentence + '\t' + label_seq + '\n') + + +if __name__ == '__main__': + + usage = '-w words_file -p props_file -o output_file' + parser = OptionParser(usage) + parser.add_option('-w', dest='words_file', help='the words file') + parser.add_option('-p', dest='props_file', help='the props file') + parser.add_option('-o', dest='output_file', help='the output_file') + (options, args) = parser.parse_args() + + sentences = read_sentences(options.words_file) + labels = read_labels(options.props_file) + sen_lab_pair = transform_labels(sentences, labels) + + write_file(sen_lab_pair, options.output_file) diff --git a/demo/semantic_role_labeling/data/get_data.sh b/demo/semantic_role_labeling/data/get_data.sh new file mode 100644 index 00000000000000..268c0995e27006 --- /dev/null +++ b/demo/semantic_role_labeling/data/get_data.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +wget http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz +tar -xzvf conll05st-tests.tar.gz +rm conll05st-tests.tar.gz +cp ./conll05st-release/test.wsj/words/test.wsj.words.gz . +cp ./conll05st-release/test.wsj/props/test.wsj.props.gz . +gunzip test.wsj.words.gz +gunzip test.wsj.props.gz + +python extract_pairs.py -w test.wsj.words -p test.wsj.props -o test.wsj.seq_pair +python extract_dict_feature.py -p test.wsj.seq_pair -f feature -s src.dict -t tgt.dict diff --git a/demo/semantic_role_labeling/data/test.list b/demo/semantic_role_labeling/data/test.list new file mode 100644 index 00000000000000..ec370e897a7811 --- /dev/null +++ b/demo/semantic_role_labeling/data/test.list @@ -0,0 +1 @@ +./data/feature diff --git a/demo/semantic_role_labeling/data/train.list b/demo/semantic_role_labeling/data/train.list new file mode 100644 index 00000000000000..ec370e897a7811 --- /dev/null +++ b/demo/semantic_role_labeling/data/train.list @@ -0,0 +1 @@ +./data/feature diff --git a/demo/semantic_role_labeling/dataprovider.py b/demo/semantic_role_labeling/dataprovider.py new file mode 100644 index 00000000000000..ca7346b3db97e8 --- /dev/null +++ b/demo/semantic_role_labeling/dataprovider.py @@ -0,0 +1,57 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.PyDataProvider2 import * + +UNK_IDX = 0 + + +def hook(settings, word_dict, label_dict, **kwargs): + settings.word_dict = word_dict + settings.label_dict = label_dict + #all inputs are integral and sequential type + settings.slots = [ + integer_value(len(word_dict), seq_type=SequenceType.SEQUENCE), + integer_value(len(word_dict), seq_type=SequenceType.SEQUENCE), + integer_value(len(word_dict), seq_type=SequenceType.SEQUENCE), + integer_value(len(word_dict), seq_type=SequenceType.SEQUENCE), + integer_value(len(word_dict), seq_type=SequenceType.SEQUENCE), + integer_value(2, seq_type=SequenceType.SEQUENCE), + integer_value(len(label_dict), seq_type=SequenceType.SEQUENCE)] + + +@provider(init_hook=hook) +def process(obj, file_name): + with open(file_name, 'r') as fdata: + for line in fdata: + sentence, predicate, ctx_n1, ctx_0, ctx_p1, mark, label = \ + line.strip().split('\t') + + words = sentence.split() + sen_len = len(words) + word_slot = [obj.word_dict.get(w, UNK_IDX) for w in words] + + predicate_slot = [obj.word_dict.get(predicate, UNK_IDX)] * sen_len + ctx_n1_slot = [obj.word_dict.get(ctx_n1, UNK_IDX)] * sen_len + ctx_0_slot = [obj.word_dict.get(ctx_0, UNK_IDX)] * sen_len + ctx_p1_slot = [obj.word_dict.get(ctx_p1, UNK_IDX)] * sen_len + + marks = mark.split() + mark_slot = [int(w) for w in marks] + + label_list = label.split() + label_slot = [obj.label_dict.get(w) for w in label_list] + + yield word_slot, predicate_slot, ctx_n1_slot, \ + ctx_0_slot, ctx_p1_slot, mark_slot, label_slot diff --git a/demo/semantic_role_labeling/db_lstm.py b/demo/semantic_role_labeling/db_lstm.py new file mode 100644 index 00000000000000..364460afbe31ca --- /dev/null +++ b/demo/semantic_role_labeling/db_lstm.py @@ -0,0 +1,141 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import os +import sys +from paddle.trainer_config_helpers import * + +#file paths +word_dict_file = './data/src.dict' +label_dict_file = './data/tgt.dict' +train_list_file = './data/train.list' +test_list_file = './data/test.list' + +is_test = get_config_arg('is_test', bool, False) +is_predict = get_config_arg('is_predict', bool, False) + +if not is_predict: + #load dictionaries + word_dict = dict() + label_dict = dict() + with open(word_dict_file, 'r') as f_word, \ + open(label_dict_file, 'r') as f_label: + for i, line in enumerate(f_word): + w = line.strip() + word_dict[w] = i + + for i, line in enumerate(f_label): + w = line.strip() + label_dict[w] = i + + if is_test: + train_list_file = None + + #define data provider + define_py_data_sources2( + train_list=train_list_file, + test_list=test_list_file, + module='dataprovider', + obj='process', + args={'word_dict': word_dict, + 'label_dict': label_dict}) + + word_dict_len = len(word_dict) + label_dict_len = len(label_dict) + +else: + word_dict_len = get_config_arg('dict_len', int) + label_dict_len = get_config_arg('label_len', int) + +mark_dict_len = 2 +word_dim = 32 +mark_dim = 5 +hidden_dim = 128 +depth = 8 +emb_lr = 1e-2 +fc_lr = 1e-2 +lstm_lr = 2e-2 + +settings( + batch_size=150, + learning_method=AdamOptimizer(), + learning_rate=1e-3, + regularization=L2Regularization(8e-4), + gradient_clipping_threshold=25) + +#6 features +word = data_layer(name='word_data', size=word_dict_len) +predicate = data_layer(name='verb_data', size=word_dict_len) +ctx_n1 = data_layer(name='ctx_n1_data', size=word_dict_len) +ctx_0 = data_layer(name='ctx_0_data', size=word_dict_len) +ctx_p1 = data_layer(name='ctx_p1_data', size=word_dict_len) +mark = data_layer(name='mark_data', size=mark_dict_len) + +if not is_predict: + target = data_layer(name='target', size=label_dict_len) + +ptt = ParameterAttribute(name='src_emb', learning_rate=emb_lr) +layer_attr = ExtraLayerAttribute(drop_rate=0.5) +fc_para_attr = ParameterAttribute(learning_rate=fc_lr) +lstm_para_attr = ParameterAttribute(initial_std=0., learning_rate=lstm_lr) +para_attr = [fc_para_attr, lstm_para_attr] + +word_embedding = embedding_layer(size=word_dim, input=word, param_attr=ptt) +predicate_embedding = embedding_layer( + size=word_dim, input=predicate, param_attr=ptt) +ctx_n1_embedding = embedding_layer(size=word_dim, input=ctx_n1, param_attr=ptt) +ctx_0_embedding = embedding_layer(size=word_dim, input=ctx_0, param_attr=ptt) +ctx_p1_embedding = embedding_layer(size=word_dim, input=ctx_p1, param_attr=ptt) +mark_embedding = embedding_layer(size=mark_dim, input=mark) + +hidden_0 = mixed_layer( + size=hidden_dim, + input=[ + full_matrix_projection(input=word_embedding), + full_matrix_projection(input=predicate_embedding), + full_matrix_projection(input=ctx_n1_embedding), + full_matrix_projection(input=ctx_0_embedding), + full_matrix_projection(input=ctx_p1_embedding), + full_matrix_projection(input=mark_embedding), + ]) + +lstm_0 = lstmemory(input=hidden_0, layer_attr=layer_attr) + +#stack L-LSTM and R-LSTM with direct edges +input_tmp = [hidden_0, lstm_0] + +for i in range(1, depth): + + fc = fc_layer(input=input_tmp, size=hidden_dim, param_attr=para_attr) + + lstm = lstmemory( + input=fc, + act=ReluActivation(), + reverse=(i % 2) == 1, + layer_attr=layer_attr) + input_tmp = [fc, lstm] + +prob = fc_layer( + input=input_tmp, + size=label_dict_len, + act=SoftmaxActivation(), + param_attr=para_attr) + +if not is_predict: + cls = classification_cost(input=prob, label=target) + outputs(cls) +else: + outputs(prob) diff --git a/demo/semantic_role_labeling/predict.py b/demo/semantic_role_labeling/predict.py new file mode 100644 index 00000000000000..5250ec6dc68559 --- /dev/null +++ b/demo/semantic_role_labeling/predict.py @@ -0,0 +1,164 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from optparse import OptionParser +from py_paddle import swig_paddle, util, DataProviderWrapperConverter +from paddle.trainer.PyDataProviderWrapper import IndexSlot +from paddle.trainer.config_parser import parse_config +""" +Usage: run following command to show help message. + python predict.py -h +""" +UNK_IDX = 0 + + +class Prediction(): + def __init__(self, train_conf, dict_file, model_dir, label_file): + """ + train_conf: trainer configure. + dict_file: word dictionary file name. + model_dir: directory of model. + """ + + self.dict = {} + self.labels = {} + self.labels_reverse = {} + self.load_dict_label(dict_file, label_file) + + len_dict = len(self.dict) + len_label = len(self.labels) + + conf = parse_config( + train_conf, + 'dict_len=' + str(len_dict) + + ',label_len=' + str(len_label) + + ',is_predict=True') + self.network = swig_paddle.GradientMachine.createFromConfigProto( + conf.model_config) + self.network.loadParameters(model_dir) + + slots = [IndexSlot(len_dict), IndexSlot(len_dict), IndexSlot(len_dict), + IndexSlot(len_dict), IndexSlot(len_dict), IndexSlot(2)] + self.converter = util.DataProviderWrapperConverter(True, slots) + + def load_dict_label(self, dict_file, label_file): + """ + Load dictionary from self.dict_file. + """ + for line_count, line in enumerate(open(dict_file, 'r')): + self.dict[line.strip()] = line_count + + for line_count, line in enumerate(open(label_file, 'r')): + self.labels[line.strip()] = line_count + self.labels_reverse[line_count] = line.strip() + + def get_data(self, data_file): + """ + Get input data of paddle format. + """ + with open(data_file, 'r') as fdata: + for line in fdata: + sentence, predicate, ctx_n1, ctx_0, ctx_p1, mark, label = line.strip( + ).split('\t') + words = sentence.split() + sen_len = len(words) + + word_slot = [self.dict.get(w, UNK_IDX) for w in words] + predicate_slot = [self.dict.get(predicate, UNK_IDX)] * sen_len + ctx_n1_slot = [self.dict.get(ctx_n1, UNK_IDX)] * sen_len + ctx_0_slot = [self.dict.get(ctx_0, UNK_IDX)] * sen_len + ctx_p1_slot = [self.dict.get(ctx_p1, UNK_IDX)] * sen_len + + marks = mark.split() + mark_slot = [int(w) for w in marks] + + yield word_slot, predicate_slot, ctx_n1_slot, \ + ctx_0_slot, ctx_p1_slot, mark_slot + + def predict(self, data_file): + """ + data_file: file name of input data. + """ + input = self.converter(self.get_data(data_file)) + output = self.network.forwardTest(input) + prob = output[0]["value"] + lab = list(np.argsort(-prob)[:, 0]) + + with open(data_file, 'r') as fin, open('predict.res', 'w') as fout: + index = 0 + for line in fin: + sen = line.split('\t')[0] + len_sen = len(sen.split()) + line_labels = lab[index:index + len_sen] + index += len_sen + fout.write(sen + '\t' + ' '.join([self.labels_reverse[ + i] for i in line_labels]) + '\n') + + +def option_parser(): + usage = ("python predict.py -c config -w model_dir " + "-d word dictionary -l label_file -i input_file") + parser = OptionParser(usage="usage: %s [options]" % usage) + parser.add_option( + "-c", + "--tconf", + action="store", + dest="train_conf", + help="network config") + parser.add_option( + "-d", + "--dict", + action="store", + dest="dict_file", + help="dictionary file") + parser.add_option( + "-l", + "--label", + action="store", + dest="label_file", + default=None, + help="label file") + parser.add_option( + "-i", + "--data", + action="store", + dest="data_file", + help="data file to predict") + parser.add_option( + "-w", + "--model", + action="store", + dest="model_path", + default=None, + help="model path") + return parser.parse_args() + + +def main(): + options, args = option_parser() + train_conf = options.train_conf + data_file = options.data_file + dict_file = options.dict_file + model_path = options.model_path + label_file = options.label_file + + swig_paddle.initPaddle("--use_gpu=0") + predict = Prediction(train_conf, dict_file, model_path, label_file) + predict.predict(data_file) + + +if __name__ == '__main__': + main() diff --git a/demo/semantic_role_labeling/predict.sh b/demo/semantic_role_labeling/predict.sh new file mode 100644 index 00000000000000..a545b9a5d591b4 --- /dev/null +++ b/demo/semantic_role_labeling/predict.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +function get_best_pass() { + cat $1 | grep -Pzo 'Test .*\n.*pass-.*' | \ + sed -r 'N;s/Test.* cost=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g' | \ + sort | head -n 1 +} + +log=train.log +LOG=`get_best_pass $log` +LOG=(${LOG}) +best_model_path="output/pass-${LOG[1]}" + + +config_file=db_lstm.py +dict_file=./data/src.dict +label_file=./data/tgt.dict +input_file=./data/feature + +python predict.py \ + -c $config_file \ + -w $best_model_path \ + -l $label_file \ + -d $dict_file \ + -i $input_file diff --git a/demo/semantic_role_labeling/test.sh b/demo/semantic_role_labeling/test.sh new file mode 100644 index 00000000000000..804f722e5b8e9e --- /dev/null +++ b/demo/semantic_role_labeling/test.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +function get_best_pass() { + cat $1 | grep -Pzo 'Test .*\n.*pass-.*' | \ + sed -r 'N;s/Test.* cost=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g' |\ + sort | head -n 1 +} + +log=train.log +LOG=`get_best_pass $log` +LOG=(${LOG}) +evaluate_pass="output/pass-${LOG[1]}" + +echo 'evaluating from pass '$evaluate_pass +model_list=./model.list +touch $model_list | echo $evaluate_pass > $model_list + +paddle train \ + --config=./db_lstm.py \ + --model_list=$model_list \ + --job=test \ + --use_gpu=false \ + --config_args=is_test=1 \ +2>&1 | tee 'test.log' + diff --git a/demo/semantic_role_labeling/train.sh b/demo/semantic_role_labeling/train.sh new file mode 100644 index 00000000000000..94c7b6f31df3b5 --- /dev/null +++ b/demo/semantic_role_labeling/train.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +paddle train \ + --config=./db_lstm.py \ + --save_dir=./output \ + --trainer_count=4 \ + --log_period=10 \ + --num_passes=500 \ + --use_gpu=false \ + --show_parameter_stats_period=10 \ + --test_all_data_in_one_period=1 \ +2>&1 | tee 'train.log' + diff --git a/demo/sentiment/.gitignore b/demo/sentiment/.gitignore new file mode 100644 index 00000000000000..bf2a9ab1ce3c93 --- /dev/null +++ b/demo/sentiment/.gitignore @@ -0,0 +1,11 @@ +data/aclImdb +data/imdb +data/pre-imdb +data/mosesdecoder-master +logs/ +model_output +dataprovider_copy_1.py +model.list +test.log +train.log +*.pyc diff --git a/demo/sentiment/data/get_imdb.sh b/demo/sentiment/data/get_imdb.sh new file mode 100755 index 00000000000000..41523927afe754 --- /dev/null +++ b/demo/sentiment/data/get_imdb.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e +set -x + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +#download the dataset +echo "Downloading aclImdb..." +#http://ai.stanford.edu/%7Eamaas/data/sentiment/ +wget http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz + +echo "Downloading mosesdecoder..." +#https://github.com/moses-smt/mosesdecoder +wget https://github.com/moses-smt/mosesdecoder/archive/master.zip + +#extract package +echo "Unzipping..." +tar -zxvf aclImdb_v1.tar.gz +unzip master.zip + +#move train and test set to imdb_data directory +#in order to process when traing +mkdir -p imdb/train +mkdir -p imdb/test + +cp -r aclImdb/train/pos/ imdb/train/ +cp -r aclImdb/train/neg/ imdb/train/ + +cp -r aclImdb/test/pos/ imdb/test/ +cp -r aclImdb/test/neg/ imdb/test/ + +#remove compressed package +rm aclImdb_v1.tar.gz +rm master.zip + +echo "Done." diff --git a/demo/sentiment/dataprovider.py b/demo/sentiment/dataprovider.py new file mode 100755 index 00000000000000..c325d33485c872 --- /dev/null +++ b/demo/sentiment/dataprovider.py @@ -0,0 +1,34 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from paddle.trainer.PyDataProvider2 import * + + +def hook(settings, dictionary, **kwargs): + settings.word_dict = dictionary + settings.input_types = [ + integer_value(len(settings.word_dict), seq_type=SequenceType.SEQUENCE), + integer_value(2)] + settings.logger.info('dict len : %d' % (len(settings.word_dict))) + + +@provider(init_hook=hook) +def process(settings, file_name): + with open(file_name, 'r') as fdata: + for line_count, line in enumerate(fdata): + label, comment = line.strip().split('\t\t') + label = int(label) + words = comment.split() + word_slot = [settings.word_dict[w] for w in words if w in + settings.word_dict] + yield word_slot, label diff --git a/demo/sentiment/predict.py b/demo/sentiment/predict.py new file mode 100755 index 00000000000000..4ece6bb06d9e30 --- /dev/null +++ b/demo/sentiment/predict.py @@ -0,0 +1,123 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from optparse import OptionParser +from py_paddle import swig_paddle, util, DataProviderWrapperConverter +from paddle.trainer.PyDataProviderWrapper import IndexSlot +from paddle.trainer.config_parser import parse_config + +""" +Usage: run following command to show help message. + python predict.py -h +""" + +class SentimentPrediction(): + def __init__(self, train_conf, dict_file, model_dir=None, label_file = None): + """ + train_conf: trainer configure. + dict_file: word dictionary file name. + model_dir: directory of model. + """ + self.train_conf = train_conf + self.dict_file = dict_file + self.word_dict = {} + self.dict_dim = self.load_dict() + self.model_dir = model_dir + if model_dir is None: + self.model_dir = os.path.dirname(train_conf) + + self.label = None + if label_file is not None: + self.load_label(label_file) + + conf = parse_config(train_conf, "is_predict=1") + self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config) + self.network.loadParameters(self.model_dir) + slots = [IndexSlot(self.dict_dim)] + self.converter = util.DataProviderWrapperConverter(True, slots) + + def load_dict(self): + """ + Load dictionary from self.dict_file. + """ + for line_count, line in enumerate(open(self.dict_file, 'r')): + self.word_dict[line.strip().split('\t')[0]] = line_count + return len(self.word_dict) + + def load_label(self, label_file): + """ + Load label. + """ + self.label={} + for v in open(label_file, 'r'): + self.label[int(v.split('\t')[1])] = v.split('\t')[0] + + def get_data(self, data_file): + """ + Get input data of paddle format. + """ + with open(data_file, 'r') as fdata: + for line in fdata: + words = line.strip().split() + word_slot = [self.word_dict[w] for w in words if w in self.word_dict] + if not word_slot: + print "all words are not in dictionary: %s", line + continue + yield [word_slot] + + def predict(self, data_file): + """ + data_file: file name of input data. + """ + input = self.converter(self.get_data(data_file)) + output = self.network.forwardTest(input) + prob = output[0]["value"] + lab = np.argsort(-prob) + if self.label is None: + print("%s: predicting label is %d" % (data_file, lab[0][0])) + else: + print("%s: predicting label is %s" % (data_file, self.label[lab[0][0]])) + +def option_parser(): + usage = "python predict.py -n config -w model_dir -d dictionary -i input_file " + parser = OptionParser(usage="usage: %s [options]" % usage) + parser.add_option("-n", "--tconf", action="store", + dest="train_conf", help="network config") + parser.add_option("-d", "--dict", action="store", + dest="dict_file",help="dictionary file") + parser.add_option("-b", "--label", action="store", + dest="label", default=None, + help="dictionary file") + parser.add_option("-i", "--data", action="store", + dest="data", help="data file to predict") + parser.add_option("-w", "--model", action="store", + dest="model_path", default=None, + help="model path") + return parser.parse_args() + +def main(): + options, args = option_parser() + train_conf = options.train_conf + data = options.data + dict_file = options.dict_file + model_path = options.model_path + label = options.label + swig_paddle.initPaddle("--use_gpu=0") + predict = SentimentPrediction(train_conf, dict_file, model_path, label) + predict.predict(data) + +if __name__ == '__main__': + main() diff --git a/demo/sentiment/predict.sh b/demo/sentiment/predict.sh new file mode 100755 index 00000000000000..c3bfc1c8b61921 --- /dev/null +++ b/demo/sentiment/predict.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +config=trainer_config.py +model=model_output/pass-00002/ +label=data/pre-imdb/labels.list +python predict.py \ + -n $config\ + -w $model \ + -b $label \ + -d ./data/pre-imdb/dict.txt \ + -i ./data/aclImdb/test/pos/10007_10.txt diff --git a/demo/sentiment/preprocess.py b/demo/sentiment/preprocess.py new file mode 100755 index 00000000000000..49b53d500a1bf8 --- /dev/null +++ b/demo/sentiment/preprocess.py @@ -0,0 +1,338 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import random +import operator +import numpy as np +from subprocess import Popen, PIPE +from os.path import join as join_path +from optparse import OptionParser + +from paddle.utils.preprocess_util import * + +""" +Usage: run following command to show help message. + python preprocess.py -h +""" + +def save_dict(dict, filename, is_reverse = True): + """ + Save dictionary into file. + dict: input dictionary. + filename: output file name, string. + is_reverse: True, descending order by value. + False, ascending order by value. + """ + f = open(filename, 'w') + for k, v in sorted(dict.items(), key=operator.itemgetter(1),\ + reverse=is_reverse): + f.write('%s\t%s\n'%(k, v)) + f.close() + +def tokenize(sentences): + """ + Use tokenizer.perl to tokenize input sentences. + tokenizer.perl is tool of Moses. + sentences : a list of input sentences. + return: a list of processed text. + """ + dir = './data/mosesdecoder-master/scripts/tokenizer/tokenizer.perl' + tokenizer_cmd = [dir, '-l', 'en', '-q', '-'] + assert isinstance(sentences, list) + text = "\n".join(sentences) + tokenizer = Popen(tokenizer_cmd, stdin=PIPE, stdout=PIPE) + tok_text, _ = tokenizer.communicate(text) + toks = tok_text.split('\n')[:-1] + return toks + +def read_lines(path): + """ + path: String, file path. + return a list of sequence. + """ + seqs = [] + with open(path, 'r') as f: + for line in f.readlines(): + line = line.strip() + if len(line): + seqs.append(line) + return seqs + +class SentimentDataSetCreate(): + """ + A class to process data for sentiment analysis task. + """ + def __init__(self, data_path, output_path, + use_okenizer = True, multi_lines = False): + """ + data_path: string, traing and testing dataset path + output_path: string, output path, store processed dataset + multi_lines: whether a file has multi lines. + In order to shuffle fully, it needs to read all files into + memory, then shuffle them if one file has multi lines. + """ + self.output_path = output_path + self.data_path = data_path + + self.train_dir = 'train' + self.test_dir = 'test' + + self.train_list = "train.list" + self.test_list = "test.list" + + self.label_list = "labels.list" + self.classes_num = 0 + + self.batch_size = 50000 + self.batch_dir = 'batches' + + self.dict_file = "dict.txt" + self.dict_with_test = False + self.dict_size = 0 + self.word_count = {} + + self.tokenizer = use_okenizer + self.overwrite = False + + self.multi_lines = multi_lines + + self.train_dir = join_path(data_path, self.train_dir) + self.test_dir = join_path(data_path, self.test_dir) + self.train_list = join_path(output_path, self.train_list) + self.test_list = join_path(output_path, self.test_list) + self.label_list = join_path(output_path, self.label_list) + self.dict_file = join_path(output_path, self.dict_file) + + def data_list(self, path): + """ + create dataset from path + path: data path + return: data list + """ + label_set = get_label_set_from_dir(path) + data = [] + for lab_name in label_set.keys(): + file_paths = list_files(join_path(path, lab_name)) + for p in file_paths: + data.append({"label" : label_set[lab_name],\ + "seq_path": p}) + return data, label_set + + def create_dict(self, data): + """ + create dict for input data. + data: list, [sequence, sequnce, ...] + """ + for seq in data: + for w in seq.strip().lower().split(): + if w not in self.word_count: + self.word_count[w] = 1 + else: + self.word_count[w] += 1 + + def create_dataset(self): + """ + create file batches and dictionary of train data set. + If the self.overwrite is false and train.list already exists in + self.output_path, this function will not create and save file + batches from the data set path. + return: dictionary size, class number. + """ + out_path = self.output_path + if out_path and not os.path.exists(out_path): + os.makedirs(out_path) + + # If self.overwrite is false or self.train_list has existed, + # it will not process dataset. + if not (self.overwrite or not os.path.exists(self.train_list)): + print "%s already exists." % self.train_list + return + + # Preprocess train data. + train_data, train_lab_set = self.data_list(self.train_dir) + print "processing train set..." + file_lists = self.save_data(train_data, + "train", + self.batch_size, + True, + True) + save_list(file_lists, self.train_list) + + # If have test data path, preprocess test data. + if os.path.exists(self.test_dir): + test_data, test_lab_set = self.data_list(self.test_dir) + assert(train_lab_set == test_lab_set) + print "processing test set..." + file_lists = self.save_data(test_data, + "test", + self.batch_size, + False, + self.dict_with_test) + save_list(file_lists, self.test_list) + + # save labels set. + save_dict(train_lab_set, self.label_list, False) + self.classes_num = len(train_lab_set.keys()) + + # save dictionary. + save_dict(self.word_count, self.dict_file, True) + self.dict_size = len(self.word_count) + + def save_data(self, data, prefix = "", + batch_size=50000, + is_shuffle=False, + build_dict=False): + """ + Create batches for a Dataset object. + data: the Dataset object to process. + prefix: the prefix of each batch. + batch_size: number of data in each batch. + build_dict: whether to build dictionary for data + + return: list of batch names + """ + if is_shuffle and self.multi_lines: + return self.save_data_multi_lines(data, prefix, batch_size, build_dict) + + if is_shuffle: + random.shuffle(data) + num_batches = int(math.ceil(len(data) / float(batch_size))) + batch_names = [] + for i in range(num_batches): + batch_name = join_path(self.output_path, + "%s_part_%03d" %(prefix, i)) + begin = i * batch_size + end = min((i + 1) * batch_size, len(data)) + # read a batch of data + label_list, data_list = self.get_data_list(begin, end, data) + if build_dict: + self.create_dict(data_list) + self.save_file(label_list, data_list, batch_name) + batch_names.append(batch_name) + + return batch_names + + def get_data_list(self, begin, end, data): + """ + begin: int, begining index of data. + end: int, ending index of data. + data: a list of {"seq_path": seqquence path, "label": label index} + + return a list of label and a list of sequence. + """ + label_list = [] + data_list = [] + for j in range(begin, end): + seqs = read_lines(data[j]["seq_path"]) + lab = int(data[j]["label"]) + #File may have multiple lines. + for seq in seqs: + data_list.append(seq) + label_list.append(lab) + if self.tokenizer: + data_list = tokenize(data_list) + return label_list, data_list + + def save_data_multi_lines(self, data, prefix = "", + batch_size=50000, + build_dict=False): + """ + In order to shuffle fully, there is no need to load all data if + each file only contains one sample, it only needs to shuffle list + of file name. But one file contains multi lines, each line is one + sample. It needs to read all data into memory to shuffle fully. + This interface is mainly for data containning multi lines in each + file, which consumes more memory if there is a great mount of data. + + data: the Dataset object to process. + prefix: the prefix of each batch. + batch_size: number of data in each batch. + build_dict: whether to build dictionary for data + + return: list of batch names + """ + assert self.multi_lines + label_list = [] + data_list = [] + + # read all data + label_list, data_list = self.get_data_list(0, len(data), data) + if build_dict: + self.create_dict(data_list) + + length = len(label_list) + perm_list = np.array([ i for i in xrange(length) ]) + random.shuffle(perm_list) + + num_batches = int(math.ceil(length / float(batch_size))) + batch_names = [] + for i in range(num_batches): + batch_name = join_path(self.output_path, + "%s_part_%03d" %(prefix, i)) + begin = i * batch_size + end = min((i + 1) * batch_size, length) + sub_label = [label_list[perm_list[i]] for i in range(begin, end)] + sub_data = [data_list[perm_list[i]] for i in range(begin, end)] + self.save_file(sub_label, sub_data, batch_name) + batch_names.append(batch_name) + + return batch_names + + def save_file(self, label_list, data_list, filename): + """ + Save data into file. + label_list: a list of int value. + data_list: a list of sequnece. + filename: output file name. + """ + f = open(filename, 'w') + print "saving file: %s" % filename + for lab, seq in zip(label_list, data_list): + f.write('%s\t\t%s\n' % (lab, seq)) + f.close() + +def option_parser(): + parser = OptionParser(usage="usage: python preprcoess.py "\ + "-i data_dir [options]") + parser.add_option("-i", "--data", action="store", + dest="input", help="Input data directory.") + parser.add_option("-o", "--output", action="store", + dest="output", default=None, + help="Output directory.") + parser.add_option("-t", "--tokenizer", action="store", + dest="use_tokenizer", default=True, + help="Whether to use tokenizer.") + parser.add_option("-m", "--multi_lines", action="store", + dest="multi_lines", default=False, + help="If input text files have multi lines and they "\ + "need to be shuffled, you should set -m True,") + return parser.parse_args() + +def main(): + options, args = option_parser() + data_dir=options.input + output_dir=options.output + use_tokenizer=options.use_tokenizer + multi_lines=options.multi_lines + if output_dir is None: + outname = os.path.basename(options.input) + output_dir = join_path(os.path.dirname(data_dir), 'pre-' + outname) + data_creator = SentimentDataSetCreate(data_dir, output_dir, + use_tokenizer, multi_lines) + data_creator.create_dataset() + +if __name__ == '__main__': + main() diff --git a/demo/sentiment/preprocess.sh b/demo/sentiment/preprocess.sh new file mode 100755 index 00000000000000..5f5c78e222917d --- /dev/null +++ b/demo/sentiment/preprocess.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +echo "Start to preprcess..." + +data_dir="./data/imdb" +python preprocess.py -i $data_dir + +echo "Done." diff --git a/demo/sentiment/sentiment_net.py b/demo/sentiment/sentiment_net.py new file mode 100644 index 00000000000000..f9f784c1f0b20e --- /dev/null +++ b/demo/sentiment/sentiment_net.py @@ -0,0 +1,135 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from os.path import join as join_path + +from paddle.trainer_config_helpers import * + + +def sentiment_data(data_dir=None, + is_test=False, + is_predict=False, + train_list="train.list", + test_list="test.list", + dict_file="dict.txt"): + """ + Predefined data provider for sentiment analysis. + is_test: whether this config is used for test. + is_predict: whether this config is used for prediction. + train_list: text file name, containing a list of training set. + test_list: text file name, containing a list of testing set. + dict_file: text file name, containing dictionary. + """ + dict_dim = len(open(join_path(data_dir, "dict.txt")).readlines()) + class_dim = len(open(join_path(data_dir, 'labels.list')).readlines()) + if is_predict: + return dict_dim, class_dim + + if data_dir is not None: + train_list = join_path(data_dir, train_list) + test_list = join_path(data_dir, test_list) + dict_file = join_path(data_dir, dict_file) + + train_list = train_list if not is_test else None + word_dict = dict() + with open(dict_file, 'r') as f: + for i, line in enumerate(open(dict_file, 'r')): + word_dict[line.split('\t')[0]] = i + + define_py_data_sources2(train_list, test_list, + module="dataprovider", + obj="process", + args={'dictionary': word_dict}) + + return dict_dim, class_dim + + +def bidirectional_lstm_net(input_dim, + class_dim=2, + emb_dim=128, + lstm_dim=128, + is_predict=False): + data = data_layer("word", input_dim) + emb = embedding_layer(input=data, size=emb_dim) + bi_lstm = bidirectional_lstm(input=emb, size=lstm_dim) + dropout = dropout_layer(input=bi_lstm, dropout_rate=0.5) + output = fc_layer(input=dropout, size=class_dim, + act_type=SoftmaxActivation()) + + if not is_predict: + lbl = data_layer("label", 1) + outputs(classification_cost(input=output, label=lbl)) + else: + outputs(output) + + +def stacked_lstm_net(input_dim, + class_dim=2, + emb_dim=128, + hid_dim=512, + stacked_num=3, + is_predict=False): + """ + A Wrapper for sentiment classification task. + This network uses bi-directional recurrent network, + consisting three LSTM layers. This configure is referred to + the paper as following url, but use fewer layrs. + http://www.aclweb.org/anthology/P15-1109 + + input_dim: here is word dictionary dimension. + class_dim: number of categories. + emb_dim: dimension of word embedding. + hid_dim: dimension of hidden layer. + stacked_num: number of stacked lstm-hidden layer. + is_predict: is predicting or not. + Some layers is not needed in network when predicting. + """ + hid_lr = 1e-3 + assert stacked_num % 2 == 1 + + layer_attr = ExtraLayerAttribute(drop_rate=0.5) + fc_para_attr = ParameterAttribute(learning_rate=hid_lr) + lstm_para_attr = ParameterAttribute(initial_std=0., learning_rate=1.) + para_attr = [fc_para_attr, lstm_para_attr] + bias_attr = ParameterAttribute(initial_std=0., l2_rate=0.) + relu = ReluActivation() + linear = LinearActivation() + + data = data_layer("word", input_dim) + emb = embedding_layer(input=data, size=emb_dim) + + fc1 = fc_layer(input=emb, size=hid_dim, act=linear, + bias_attr=bias_attr) + lstm1 = lstmemory(input=fc1, act=relu, bias_attr=bias_attr, + layer_attr=layer_attr) + + inputs = [fc1, lstm1] + for i in range(2, stacked_num + 1): + fc = fc_layer(input=inputs, size=hid_dim, act=linear, + param_attr=para_attr, bias_attr=bias_attr) + lstm = lstmemory(input=fc, reverse=(i % 2) == 0, act=relu, + bias_attr=bias_attr, layer_attr=layer_attr) + inputs = [fc, lstm] + + fc_last = pooling_layer(input=inputs[0], pooling_type=MaxPooling()) + lstm_last = pooling_layer(input=inputs[1], pooling_type=MaxPooling()) + output = fc_layer(input=[fc_last, lstm_last], size=class_dim, + act=SoftmaxActivation(), + bias_attr=bias_attr, param_attr=para_attr) + + if is_predict: + outputs(output) + else: + outputs( + classification_cost(input=output, label=data_layer('label', 1))) diff --git a/demo/sentiment/test.sh b/demo/sentiment/test.sh new file mode 100755 index 00000000000000..ffe404de6b5227 --- /dev/null +++ b/demo/sentiment/test.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +function get_best_pass() { + cat $1 | grep -Pzo 'Test .*\n.*pass-.*' | \ + sed -r 'N;s/Test.* cost=([0-9]+\.[0-9]+).*\n.*pass-([0-9]+)/\1 \2/g' |\ + sort | head -n 1 +} + +log=train.log +LOG=`get_best_pass $log` +LOG=(${LOG}) +evaluate_pass="model_output/pass-${LOG[1]}" + +echo 'evaluating from pass '$evaluate_pass + +model_list=./model.list +touch $model_list | echo $evaluate_pass > $model_list +net_conf=trainer_config.py +paddle train --config=$net_conf \ + --model_list=$model_list \ + --job=test \ + --use_gpu=false \ + --trainer_count=4 \ + --config_args=is_test=1 \ + 2>&1 | tee 'test.log' diff --git a/demo/sentiment/train.sh b/demo/sentiment/train.sh new file mode 100755 index 00000000000000..f44a9a53f2db9a --- /dev/null +++ b/demo/sentiment/train.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e + +config=trainer_config.py +output=./model_output +paddle train --config=$config \ + --save_dir=$output \ + --job=train \ + --use_gpu=false \ + --trainer_count=4 \ + --num_passes=10 \ + --log_period=10 \ + --dot_period=20 \ + --show_parameter_stats_period=100 \ + --test_all_data_in_one_period=1 \ + 2>&1 | tee 'train.log' diff --git a/demo/sentiment/trainer_config.py b/demo/sentiment/trainer_config.py new file mode 100644 index 00000000000000..db24182a8d7359 --- /dev/null +++ b/demo/sentiment/trainer_config.py @@ -0,0 +1,39 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sentiment_net import * +from paddle.trainer_config_helpers import * + +# whether this config is used for test +is_test = get_config_arg('is_test', bool, False) +# whether this config is used for prediction +is_predict = get_config_arg('is_predict', bool, False) + +data_dir = "./data/pre-imdb" +dict_dim, class_dim = sentiment_data(data_dir, is_test, is_predict) + +################## Algorithm Config ##################### + +settings( + batch_size=128, + learning_rate=2e-3, + learning_method=AdamOptimizer(), + regularization=L2Regularization(8e-4), + gradient_clipping_threshold=25 +) + +#################### Network Config ###################### +stacked_lstm_net(dict_dim, class_dim=class_dim, + stacked_num=3, is_predict=is_predict) +# bidirectional_lstm_net(dict_dim, class_dim=class_dim, is_predict=is_predict) diff --git a/demo/seqToseq/.gitignore b/demo/seqToseq/.gitignore new file mode 100644 index 00000000000000..21cec2c2c1f342 --- /dev/null +++ b/demo/seqToseq/.gitignore @@ -0,0 +1,17 @@ +data/wmt14 +data/pre-wmt14 +data/wmt14_model +data/paraphrase +data/pre-paraphrase +data/paraphrase_model +translation/gen.log +translation/gen_result +translation/train.log +paraphrase/train.log +dataprovider_copy_1.py +translation/thirdparty.tgz +translation/thirdparty/train.conf +translation/thirdparty/dataprovider.py +translation/thirdparty/seqToseq_net.py +translation/thirdparty/*.dict +*.pyc diff --git a/demo/seqToseq/data/paraphrase_data.sh b/demo/seqToseq/data/paraphrase_data.sh new file mode 100755 index 00000000000000..ea1f8dbcfad356 --- /dev/null +++ b/demo/seqToseq/data/paraphrase_data.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +set -x + +# download the in-house paraphrase dataset +# following is the google drive address +# you can also directly download from https://pan.baidu.com/s/1o8q577s +wget https://www.googledrive.com/host/0B7Q8d52jqeI9ejh6Q1RpMTFQT1k/embedding/paraphrase.tar.gz --no-check-certificate + +# untar the dataset +tar -zxvf paraphrase.tar.gz +rm paraphrase.tar.gz diff --git a/demo/seqToseq/data/paraphrase_model.sh b/demo/seqToseq/data/paraphrase_model.sh new file mode 100755 index 00000000000000..041f69cf467b13 --- /dev/null +++ b/demo/seqToseq/data/paraphrase_model.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +set -x + +dim=32 +pretrained_dir='../../model_zoo/embedding/' +preModel=$pretrained_dir'model_'$dim'.emb' +preDict=$pretrained_dir'baidu.dict' + +usrDict_dir='pre-paraphrase/' +srcDict=$usrDict_dir'src.dict' +trgDict=$usrDict_dir'trg.dict' + +usrModel_dir='paraphrase_model/' +mkdir $usrModel_dir +srcModel=$usrModel_dir'_source_language_embedding' +trgModel=$usrModel_dir'_target_language_embedding' + +echo 'extract desired parameters based on user dictionary' +script=$pretrained_dir'extract_para.py' +python $script --preModel $preModel --preDict $preDict \ + --usrModel $srcModel --usrDict $srcDict -d $dim +python $script --preModel $preModel --preDict $preDict \ + --usrModel $trgModel --usrDict $trgDict -d $dim diff --git a/demo/seqToseq/data/wmt14_data.sh b/demo/seqToseq/data/wmt14_data.sh new file mode 100755 index 00000000000000..6c360b206011a7 --- /dev/null +++ b/demo/seqToseq/data/wmt14_data.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +set -x +mkdir wmt14 +cd wmt14 + +# download the dataset +wget http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/bitexts.tgz +wget http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz + +# untar the dataset +tar -zxvf bitexts.tgz +tar -zxvf dev+test.tgz +gunzip bitexts.selected/* +mv bitexts.selected train +rm bitexts.tgz +rm dev+test.tgz + +# separate the dev and test dataset +mkdir test gen +mv dev/ntst1213.* test +mv dev/ntst14.* gen +rm -rf dev + +set +x +# rename the suffix, .fr->.src, .en->.trg +for dir in train test gen +do + filelist=`ls $dir` + cd $dir + for file in $filelist + do + if [ ${file##*.} = "fr" ]; then + mv $file ${file/%fr/src} + elif [ ${file##*.} = 'en' ]; then + mv $file ${file/%en/trg} + fi + done + cd .. +done diff --git a/demo/seqToseq/data/wmt14_model.sh b/demo/seqToseq/data/wmt14_model.sh new file mode 100755 index 00000000000000..2cec30688d27a5 --- /dev/null +++ b/demo/seqToseq/data/wmt14_model.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +set -x + +# download the pretrained model +# following is the google drive address +# you can also directly download from https://pan.baidu.com/s/1o8q577s +wget https://www.googledrive.com/host/0B7Q8d52jqeI9ejh6Q1RpMTFQT1k/wmt14_model.tar.gz --no-check-certificate + +# untar the model +tar -zxvf wmt14_model.tar.gz +rm wmt14_model.tar.gz diff --git a/demo/seqToseq/dataprovider.py b/demo/seqToseq/dataprovider.py new file mode 100755 index 00000000000000..a646667977d3eb --- /dev/null +++ b/demo/seqToseq/dataprovider.py @@ -0,0 +1,82 @@ +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.PyDataProvider2 import * + +UNK_IDX = 2 +START = "" +END = "" + + +def hook(settings, src_dict, trg_dict, file_list, **kwargs): + # job_mode = 1: training mode + # job_mode = 0: generating mode + settings.job_mode = trg_dict is not None + settings.src_dict = src_dict + settings.logger.info("src dict len : %d" % (len(settings.src_dict))) + settings.sample_count = 0 + + if settings.job_mode: + settings.trg_dict = trg_dict + settings.slots = [ + integer_value( + len(settings.src_dict), + seq_type=SequenceType.SEQUENCE), integer_value( + len(settings.trg_dict), + seq_type=SequenceType.SEQUENCE), integer_value( + len(settings.trg_dict), + seq_type=SequenceType.SEQUENCE) + ] + settings.logger.info("trg dict len : %d" % (len(settings.trg_dict))) + else: + settings.slots = [ + integer_value( + len(settings.src_dict), + seq_type=SequenceType.SEQUENCE), integer_value( + len(open(file_list[0], "r").readlines()), + seq_type=SequenceType.SEQUENCE) + ] + + +def _get_ids(s, dictionary): + words = s.strip().split() + return [dictionary[START]] + \ + [dictionary.get(w, UNK_IDX) for w in words] + \ + [dictionary[END]] + + +@provider(init_hook=hook, pool_size=50000) +def process(settings, file_name): + with open(file_name, 'r') as f: + for line_count, line in enumerate(f): + line_split = line.strip().split('\t') + if settings.job_mode and len(line_split) != 2: + continue + src_seq = line_split[0] # one source sequence + src_ids = _get_ids(src_seq, settings.src_dict) + + if settings.job_mode: + trg_seq = line_split[1] # one target sequence + trg_words = trg_seq.split() + trg_ids = [settings.trg_dict.get(w, UNK_IDX) + for w in trg_words] + + # remove sequence whose length > 80 in training mode + if len(src_ids) > 80 or len(trg_ids) > 80: + continue + trg_ids_next = trg_ids + [settings.trg_dict[END]] + trg_ids = [settings.trg_dict[START]] + trg_ids + yield src_ids, trg_ids, trg_ids_next + else: + yield src_ids, [line_count] diff --git a/demo/seqToseq/paraphrase/train.conf b/demo/seqToseq/paraphrase/train.conf new file mode 100644 index 00000000000000..748920e2c72537 --- /dev/null +++ b/demo/seqToseq/paraphrase/train.conf @@ -0,0 +1,33 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("..") + +from seqToseq_net import * + +is_generating = False +### Data Definiation +train_conf = seq_to_seq_data(data_dir = "./data/pre-paraphrase", + is_generating = is_generating) + +### Algorithm Configuration +settings( + learning_method = AdamOptimizer(), + batch_size = 50, + learning_rate = 5e-4) + +### Network Architecture +gru_encoder_decoder(train_conf, is_generating, word_vector_dim = 32) diff --git a/demo/seqToseq/paraphrase/train.sh b/demo/seqToseq/paraphrase/train.sh new file mode 100755 index 00000000000000..2aa7b84060b198 --- /dev/null +++ b/demo/seqToseq/paraphrase/train.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +cd .. + +paddle train \ + --config='paraphrase/train.conf' \ + --save_dir='paraphrase/model' \ + --init_model_path='data/paraphrase_model' \ + --load_missing_parameter_strategy=rand \ + --use_gpu=false \ + --num_passes=16 \ + --show_parameter_stats_period=100 \ + --trainer_count=4 \ + --log_period=10 \ + --dot_period=5 \ + 2>&1 | tee 'paraphrase/train.log' diff --git a/demo/seqToseq/preprocess.py b/demo/seqToseq/preprocess.py new file mode 100755 index 00000000000000..5efb17a664b9a2 --- /dev/null +++ b/demo/seqToseq/preprocess.py @@ -0,0 +1,204 @@ +#!/bin/env python +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: + python preprocess.py -i INPUT [-d DICTSIZE] [-m] + +Options: + -h, --help show this help message and exit + -i INPUT input original dataset path + -d DICTSIZE specified word count of dictionary + -m --mergeDict merge source and target dictionary +""" +import os +import sys + +import string +from optparse import OptionParser +from paddle.utils.preprocess_util import save_list, DatasetCreater + +class SeqToSeqDatasetCreater(DatasetCreater): + """ + A class to process data for sequence to sequence application. + """ + + def __init__(self, data_path, output_path): + """ + data_path: the path to store the train data, test data and gen data + output_path: the path to store the processed dataset + """ + DatasetCreater.__init__(self, data_path) + self.gen_dir_name = 'gen' + self.gen_list_name = 'gen.list' + self.output_path = output_path + + def concat_file(self, file_path, file1, file2, output_path, output): + """ + Concat file1 and file2 to be one output file + The i-th line of output = i-th line of file1 + '\t' + i-th line of file2 + file_path: the path to store file1 and file2 + output_path: the path to store output file + """ + file1 = os.path.join(file_path, file1) + file2 = os.path.join(file_path, file2) + output = os.path.join(output_path, output) + if not os.path.exists(output): + os.system('paste ' + file1 + ' ' + file2 + ' > ' + output) + + def cat_file(self, dir_path, suffix, output_path, output): + """ + Cat all the files in dir_path with suffix to be one output file + dir_path: the base directory to store input file + suffix: suffix of file name + output_path: the path to store output file + """ + cmd = 'cat ' + file_list = os.listdir(dir_path) + file_list.sort() + for file in file_list: + if file.endswith(suffix): + cmd += os.path.join(dir_path, file) + ' ' + output = os.path.join(output_path, output) + if not os.path.exists(output): + os.system(cmd + '> ' + output) + + def build_dict(self, file_path, dict_path, dict_size = -1): + """ + Create the dictionary for the file, Note that + 1. Valid characters include all printable characters + 2. There is distinction between uppercase and lowercase letters + 3. There is 3 special token: + : the start of a sequence + : the end of a sequence + : a word not included in dictionary + file_path: the path to store file + dict_path: the path to store dictionary + dict_size: word count of dictionary + if is -1, dictionary will contains all the words in file + """ + if not os.path.exists(dict_path): + dictory = dict() + with open(file_path, "r") as fdata: + for line in fdata: + line = line.split('\t') + for line_split in line: + words = line_split.strip().split() + for word in words: + if word not in dictory: + dictory[word] = 1 + else: + dictory[word] += 1 + output = open(dict_path, "w+") + output.write('\n\n\n') + count = 3 + for key, value in sorted(dictory.items(), key = lambda d:d[1], reverse = True): + output.write(key + "\n") + count += 1 + if count == dict_size: + break + self.dict_size = count + + def create_dataset(self, dict_size = -1, mergeDict = False, + suffixes = ['.src', '.trg']): + """ + Create seqToseq dataset + """ + # dataset_list and dir_list has one-to-one relationship + train_dataset = os.path.join(self.data_path, self.train_dir_name) + test_dataset = os.path.join(self.data_path, self.test_dir_name) + gen_dataset = os.path.join(self.data_path, self.gen_dir_name) + dataset_list = [train_dataset, test_dataset, gen_dataset] + + train_dir = os.path.join(self.output_path, self.train_dir_name) + test_dir = os.path.join(self.output_path, self.test_dir_name) + gen_dir = os.path.join(self.output_path, self.gen_dir_name) + dir_list = [train_dir, test_dir, gen_dir] + + # create directory + for dir in dir_list: + if not os.path.exists(dir): + os.mkdir(dir) + + # checkout dataset should be parallel corpora + suffix_len = len(suffixes[0]) + for dataset in dataset_list: + file_list = os.listdir(dataset) + if len(file_list) % 2 == 1: + raise RuntimeError("dataset should be parallel corpora") + file_list.sort() + for i in range(0, len(file_list), 2): + if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]: + raise RuntimeError("source and target file name should be equal") + + # cat all the files with the same suffix in dataset + for suffix in suffixes: + for dataset in dataset_list: + outname = os.path.basename(dataset) + suffix + self.cat_file(dataset, suffix, dataset, outname) + + # concat parallel corpora and create file.list + print 'concat parallel corpora for dataset' + id = 0 + list = ['train.list', 'test.list', 'gen.list'] + for dataset in dataset_list: + outname = os.path.basename(dataset) + self.concat_file(dataset, outname + suffixes[0], + outname + suffixes[1], dir_list[id], outname) + save_list([os.path.join(dir_list[id], outname)], + os.path.join(self.output_path, list[id])) + id += 1 + + # build dictionary for train data + dict = ['src.dict', 'trg.dict'] + dict_path = [os.path.join(self.output_path, dict[0]), + os.path.join(self.output_path, dict[1])] + if mergeDict: + outname = os.path.join(train_dir, train_dataset.split('/')[-1]) + print 'build src dictionary for train data' + self.build_dict(outname, dict_path[0], dict_size) + print 'build trg dictionary for train data' + os.system('cp ' + dict_path[0] + ' ' + dict_path[1]) + else: + outname = os.path.join(train_dataset, self.train_dir_name) + for id in range(0,2): + suffix = suffixes[id] + print 'build ' + suffix[1:] + ' dictionary for train data' + self.build_dict(outname + suffix, dict_path[id], dict_size) + print 'dictionary size is', self.dict_size + +def main(): + usage = "usage: \n" \ + "python %prog -i INPUT [-d DICTSIZE] [-m]" + parser = OptionParser(usage) + parser.add_option("-i", action="store", dest="input", + help="input original dataset path") + parser.add_option("-d", action="store", dest="dictsize", + help="specified word count of dictionary") + parser.add_option("-m", "--mergeDict", action="store_true", dest="mergeDict", + help="merge source and target dictionary") + (options, args) = parser.parse_args() + if options.input[-1] == os.path.sep: + options.input = options.input[:-1] + outname = os.path.basename(options.input) + output_path = os.path.join(os.path.dirname(options.input), 'pre-' + outname) + dictsize = int(options.dictsize) if options.dictsize else -1 + if not os.path.exists(output_path): + os.mkdir(output_path) + data_creator = SeqToSeqDatasetCreater(options.input, output_path) + data_creator.create_dataset(dictsize, options.mergeDict) + +if __name__ == "__main__": + main(); diff --git a/demo/seqToseq/seqToseq_net.py b/demo/seqToseq/seqToseq_net.py new file mode 100644 index 00000000000000..8b613de71ade4d --- /dev/null +++ b/demo/seqToseq/seqToseq_net.py @@ -0,0 +1,183 @@ +# edit-mode: -*- python -*- + +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from paddle.trainer_config_helpers import * + + +def seq_to_seq_data(data_dir, + is_generating, + dict_size=30000, + train_list='train.list', + test_list='test.list', + gen_list='gen.list', + gen_result='gen_result'): + """ + Predefined seqToseq train data provider for application + is_generating: whether this config is used for generating + dict_size: word count of dictionary + train_list: a text file containing a list of training data + test_list: a text file containing a list of testing data + gen_list: a text file containing a list of generating data + gen_result: a text file containing generating result + """ + src_lang_dict = os.path.join(data_dir, 'src.dict') + trg_lang_dict = os.path.join(data_dir, 'trg.dict') + src_dict = dict() + for line_count, line in enumerate(open(src_lang_dict, "r")): + src_dict[line.strip()] = line_count + trg_dict = dict() + for line_count, line in enumerate(open(trg_lang_dict, "r")): + trg_dict[line.strip()] = line_count + + if is_generating: + train_list = None + test_list = os.path.join(data_dir, gen_list) + trg_dict = None + else: + train_list = os.path.join(data_dir, train_list) + test_list = os.path.join(data_dir,test_list) + + define_py_data_sources2(train_list, test_list, + module = "dataprovider", + obj = "process", + args = {"src_dict": src_dict, + "trg_dict": trg_dict}) + + return {"src_dict_path": src_lang_dict, "trg_dict_path": trg_lang_dict, + "gen_result": gen_result} + + +def gru_encoder_decoder(data_conf, + is_generating, + word_vector_dim=512, + encoder_size=512, + decoder_size=512, + beam_size=3, + max_length=250): + """ + A wrapper for an attention version of GRU Encoder-Decoder network + is_generating: whether this config is used for generating + encoder_size: dimension of hidden unit in GRU Encoder network + decoder_size: dimension of hidden unit in GRU Decoder network + word_vector_dim: dimension of word vector + beam_size: expand width in beam search + max_length: a stop condition of sequence generation + """ + for k, v in data_conf.iteritems(): + globals()[k] = v + source_dict_dim = len(open(src_dict_path, "r").readlines()) + target_dict_dim = len(open(trg_dict_path, "r").readlines()) + gen_trans_file = gen_result + + src_word_id = data_layer(name='source_language_word', size=source_dict_dim) + src_embedding = embedding_layer( + input=src_word_id, + size=word_vector_dim, + param_attr=ParamAttr(name='_source_language_embedding'), ) + src_forward = simple_gru(input=src_embedding, size=encoder_size, ) + src_backward = simple_gru(input=src_embedding, + size=encoder_size, + reverse=True, ) + encoded_vector = concat_layer(input=[src_forward, src_backward]) + + with mixed_layer(size=decoder_size) as encoded_proj: + encoded_proj += full_matrix_projection(encoded_vector) + + backward_first = first_seq(input=src_backward) + with mixed_layer(size=decoder_size, + act=TanhActivation(), ) as decoder_boot: + decoder_boot += full_matrix_projection(backward_first) + + def gru_decoder_with_attention(enc_vec, enc_proj, current_word): + decoder_mem = memory(name='gru_decoder', + size=decoder_size, + boot_layer=decoder_boot) + + context = simple_attention(encoded_sequence=enc_vec, + encoded_proj=enc_proj, + decoder_state=decoder_mem, ) + + with mixed_layer(size=decoder_size * 3) as decoder_inputs: + decoder_inputs += full_matrix_projection(context) + decoder_inputs += full_matrix_projection(current_word) + + gru_step = gru_step_layer(name='gru_decoder', + input=decoder_inputs, + output_mem=decoder_mem, + size=decoder_size) + + with mixed_layer(size=target_dict_dim, + bias_attr=True, + act=SoftmaxActivation()) as out: + out += full_matrix_projection(input=gru_step) + return out + + decoder_group_name = "decoder_group" + if not is_generating: + trg_embedding = embedding_layer( + input=data_layer(name='target_language_word', + size=target_dict_dim), + size=word_vector_dim, + param_attr=ParamAttr(name='_target_language_embedding')) + + # For decoder equipped with attention mechanism, in training, + # target embeding (the groudtruth) is the data input, + # while encoded source sequence is accessed to as an unbounded memory. + # Here, the StaticInput defines a read-only memory + # for the recurrent_group. + decoder = recurrent_group(name=decoder_group_name, + step=gru_decoder_with_attention, + input=[ + StaticInput(input=encoded_vector, + is_seq=True), + StaticInput(input=encoded_proj, + is_seq=True), trg_embedding + ], ) + + lbl = data_layer(name='target_language_next_word', + size=target_dict_dim) + cost = classification_cost(input=decoder, label=lbl, ) + outputs(cost) + else: + gen_inputs = [StaticInput(input=encoded_vector, + is_seq=True), + StaticInput(input=encoded_proj, + is_seq=True), ] + # In generation, decoder predicts a next target word based on + # the encoded source sequence and the last generated target word. + # The encoded source sequence (encoder's output) must be specified by + # StaticInput which is a read-only memory. + # Here, GeneratedInputs automatically fetchs the last generated word, + # which is initialized by a start mark, such as . + trg_embedding = GeneratedInput( + size=target_dict_dim, + embedding_name='_target_language_embedding', + embedding_size=word_vector_dim) + gen_inputs.append(trg_embedding) + beam_gen = beam_search(name=decoder_group_name, + step=gru_decoder_with_attention, + input=gen_inputs, + id_input=data_layer(name="sent_id", + size=1), + dict_file=trg_dict_path, + bos_id=0, + eos_id=1, + beam_size=beam_size, + max_length=max_length, + result_file=gen_trans_file) + outputs(beam_gen) diff --git a/demo/seqToseq/translation/eval_bleu.sh b/demo/seqToseq/translation/eval_bleu.sh new file mode 100755 index 00000000000000..ef0ede717a740f --- /dev/null +++ b/demo/seqToseq/translation/eval_bleu.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +gen_file=$1 +beam_size=$2 + +# find top1 generating result +top1=$(printf '%s_top1.txt' `basename $gen_file .txt`) +if [ $beam_size -eq 1 ]; then + awk -F "\t" '{sub(" ","",$2);sub(" ","",$2);print $2}' $gen_file >$top1 +else + awk 'BEGIN{ + FS="\t"; + OFS="\t"; + read_pos = 2} { + if (NR == read_pos){ + sub(" ","",$3); + sub(" ","",$3); + print $3; + read_pos += (2 + res_num); + }}' res_num=$beam_size $gen_file >$top1 +fi + +# evalute bleu value +bleu_script=multi-bleu.perl +standard_res=../data/wmt14/gen/ntst14.trg +bleu_res=`perl $bleu_script $standard_res <$top1` + +echo $bleu_res +rm $top1 diff --git a/demo/seqToseq/translation/gen.conf b/demo/seqToseq/translation/gen.conf new file mode 100644 index 00000000000000..63c5c2f9a6052c --- /dev/null +++ b/demo/seqToseq/translation/gen.conf @@ -0,0 +1,36 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("..") + +from seqToseq_net import * + +# whether this config is used for generating +is_generating = True + +### Data Definiation +gen_conf = seq_to_seq_data(data_dir = "./data/pre-wmt14", + is_generating = is_generating, + gen_result = "./translation/gen_result") + +### Algorithm Configuration +settings( + learning_method = AdamOptimizer(), + batch_size = 1, + learning_rate = 0) + +### Network Architecture +gru_encoder_decoder(gen_conf, is_generating) diff --git a/demo/seqToseq/translation/gen.sh b/demo/seqToseq/translation/gen.sh new file mode 100755 index 00000000000000..ad977c05ff9897 --- /dev/null +++ b/demo/seqToseq/translation/gen.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +cd .. + +paddle train \ + --job=test \ + --config='translation/gen.conf' \ + --save_dir='data/wmt14_model' \ + --use_gpu=false \ + --num_passes=13 \ + --test_pass=12 \ + --trainer_count=1 \ + 2>&1 | tee 'translation/gen.log' diff --git a/demo/seqToseq/translation/moses_bleu.sh b/demo/seqToseq/translation/moses_bleu.sh new file mode 100755 index 00000000000000..bfaba40b26905c --- /dev/null +++ b/demo/seqToseq/translation/moses_bleu.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +set -x +echo "Downloading multi-bleu.perl" +wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl --no-check-certificate diff --git a/demo/seqToseq/translation/train.conf b/demo/seqToseq/translation/train.conf new file mode 100644 index 00000000000000..cf1bde15c4a8aa --- /dev/null +++ b/demo/seqToseq/translation/train.conf @@ -0,0 +1,36 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("..") + +from seqToseq_net import * + +# whether this config is used for generating +is_generating = False + +### Data Definiation +data_dir = "./data/pre-wmt14" +train_conf = seq_to_seq_data(data_dir = data_dir, + is_generating = is_generating) + +### Algorithm Configuration +settings( + learning_method = AdamOptimizer(), + batch_size = 50, + learning_rate = 5e-4) + +### Network Architecture +gru_encoder_decoder(train_conf, is_generating) diff --git a/demo/seqToseq/translation/train.sh b/demo/seqToseq/translation/train.sh new file mode 100755 index 00000000000000..976b5ba3b054c4 --- /dev/null +++ b/demo/seqToseq/translation/train.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +cd .. + +paddle train \ +--config='translation/train.conf' \ +--save_dir='translation/model' \ +--use_gpu=false \ +--num_passes=16 \ +--show_parameter_stats_period=100 \ +--trainer_count=4 \ +--log_period=10 \ +--dot_period=5 \ +2>&1 | tee 'translation/train.log' diff --git a/doc/CMakeLists.txt b/doc/CMakeLists.txt new file mode 100644 index 00000000000000..b8ccfc6be5d34c --- /dev/null +++ b/doc/CMakeLists.txt @@ -0,0 +1,49 @@ + + + +if(NOT DEFINED SPHINX_THEME) + set(SPHINX_THEME default) +endif() + +if(NOT DEFINED SPHINX_THEME_DIR) + set(SPHINX_THEME_DIR) +endif() + +# configured documentation tools and intermediate build results +set(BINARY_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/_build") + +# Sphinx cache with pickled ReST documents +set(SPHINX_CACHE_DIR "${CMAKE_CURRENT_BINARY_DIR}/_doctrees") + +# HTML output directory +set(SPHINX_HTML_DIR "${CMAKE_CURRENT_BINARY_DIR}/html") + + +set(PADDLE_DOXYGEN_OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/doxygen_xml") + +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/conf.py.in" + "${BINARY_BUILD_DIR}/conf.py" + @ONLY) + +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/Doxyfile.in" + "${CMAKE_CURRENT_BINARY_DIR}/Doxyfile" + @ONLY + ) + +add_custom_target(paddle_doxygen_docs ALL + ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +) + +sphinx_add_target(paddle_docs + html + ${BINARY_BUILD_DIR} + ${SPHINX_CACHE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR} + ${SPHINX_HTML_DIR}) + +add_dependencies(paddle_docs + gen_proto_py + paddle_doxygen_docs) \ No newline at end of file diff --git a/doc/Doxyfile.in b/doc/Doxyfile.in new file mode 100644 index 00000000000000..a1fc3801925dd3 --- /dev/null +++ b/doc/Doxyfile.in @@ -0,0 +1,2384 @@ +# Doxyfile 1.8.10 + +# This file describes the settings to be used by the documentation system +# doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). + +#--------------------------------------------------------------------------- +# Project related configuration options +#--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the config file +# that follow. The default is UTF-8 which is also the encoding used for all text +# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv +# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv +# for the list of possible encodings. +# The default value is: UTF-8. + +DOXYFILE_ENCODING = UTF-8 + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = "paddle" + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + +PROJECT_NUMBER = 1.0.0 + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + +PROJECT_BRIEF = + +# With the PROJECT_LOGO tag one can specify a logo or an icon that is included +# in the documentation. The maximum height of the logo should not exceed 55 +# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy +# the logo to the output directory. + +PROJECT_LOGO = + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where doxygen was started. If +# left blank the current directory will be used. + +OUTPUT_DIRECTORY = @PADDLE_DOXYGEN_OUTPUT@ + +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- +# directories (in 2 levels) under the output directory of each output format and +# will distribute the generated files over these directories. Enabling this +# option can be useful when feeding doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. +# The default value is: NO. + +CREATE_SUBDIRS = NO + +# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + +ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, +# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), +# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, +# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), +# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, +# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, +# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, +# Ukrainian and Vietnamese. +# The default value is: English. + +OUTPUT_LANGUAGE = English + +# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + +BRIEF_MEMBER_DESC = YES + +# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + +REPEAT_BRIEF = YES + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + +ABBREVIATE_BRIEF = + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + +ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + +INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + +FULL_PATH_NAMES = YES + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + +STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + +STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + +SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + +JAVADOC_AUTOBRIEF = NO + +# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + +QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + +MULTILINE_CPP_IS_BRIEF = NO + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + +INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new +# page for each member. If set to NO, the documentation of a member will be part +# of the file/class/namespace that contains it. +# The default value is: NO. + +SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + +TAB_SIZE = 2 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:\n" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". You can put \n's in the value part of an alias to insert +# newlines. + +ALIASES = + +# This tag can be used to specify a number of word-keyword mappings (TCL only). +# A mapping has the form "name=value". For example adding "class=itcl::class" +# will allow you to use the command class in the itcl::class meaning. + +TCL_SUBST = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_JAVA = NO + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + +OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + +OPTIMIZE_OUTPUT_VHDL = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by doxygen: IDL, Java, Javascript, +# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: +# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: +# Fortran. In the later case the parser tries to guess whether the code is fixed +# or free formatted code, this is the default for Fortran type files), VHDL. For +# instance to make doxygen treat .inc files as Fortran files (default is PHP), +# and .f files as C (default is Fortran), use: inc=Fortran f=C. +# +# Note: For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by doxygen. + +EXTENSION_MAPPING = + +# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See http://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by doxygen, so you can +# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + +MARKDOWN_SUPPORT = YES + +# When enabled doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by putting a % sign in front of the word or +# globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + +AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also make the inheritance and collaboration +# diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + +BUILTIN_STL_SUPPORT = YES + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + +CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen +# will parse them like normal C++ but will assume all classes use public instead +# of private inheritance when no explicit protection keyword is present. +# The default value is: NO. + +SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + +IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES then doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + +DISTRIBUTE_GROUP_DOC = NO + +# If one adds a struct or class to a group and this option is enabled, then also +# any nested class or struct is added to the same group. By default this option +# is disabled and one has to add nested compounds explicitly via \ingroup. +# The default value is: NO. + +GROUP_NESTED_COMPOUNDS = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + +SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + +INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + +INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + +TYPEDEF_HIDES_STRUCT = NO + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + +LOOKUP_CACHE_SIZE = 0 + +#--------------------------------------------------------------------------- +# Build related configuration options +#--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + +EXTRACT_ALL = NO + +# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will +# be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIVATE = NO + +# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + +EXTRACT_PACKAGE = NO + +# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be +# included in the documentation. +# The default value is: NO. + +EXTRACT_STATIC = NO + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO, +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + +EXTRACT_LOCAL_CLASSES = YES + +# This flag is only useful for Objective-C code. If set to YES, local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO, only methods in the interface are +# included. +# The default value is: NO. + +EXTRACT_LOCAL_METHODS = NO + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + +EXTRACT_ANON_NSPACES = NO + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_MEMBERS = NO + +# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO, these classes will be included in the various overviews. This option +# has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_CLASSES = NO + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend +# (class|struct|union) declarations. If set to NO, these declarations will be +# included in the documentation. +# The default value is: NO. + +HIDE_FRIEND_COMPOUNDS = NO + +# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO, these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + +HIDE_IN_BODY_DOCS = NO + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + +INTERNAL_DOCS = NO + +# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file +# names in lower-case letters. If set to YES, upper-case letters are also +# allowed. This is useful if you have classes or files whose names only differ +# in case and if your file system supports case sensitive file names. Windows +# and Mac users are advised to set this option to NO. +# The default value is: system dependent. + +CASE_SENSE_NAMES = YES + +# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES, the +# scope will be hidden. +# The default value is: NO. + +HIDE_SCOPE_NAMES = NO + +# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will +# append additional text to a page's title, such as Class Reference. If set to +# YES the compound reference will be hidden. +# The default value is: NO. + +HIDE_COMPOUND_REFERENCE= NO + +# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + +SHOW_INCLUDE_FILES = NO + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + +SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + +FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + +INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. +# The default value is: YES. + +SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + +SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + +SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + +SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + +SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + +STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo +# list. This list is created by putting \todo commands in the documentation. +# The default value is: YES. + +GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test +# list. This list is created by putting \test commands in the documentation. +# The default value is: YES. + +GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + +GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + +GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + +ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + +MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES, the +# list will mention the files that were used to generate the documentation. +# The default value is: YES. + +SHOW_USED_FILES = YES + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + +SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + +SHOW_NAMESPACES = YES + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + +FILE_VERSION_FILTER = + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents doxygen's defaults, run doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. +# +# Note that if you run doxygen from a directory containing a file called +# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + +LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + +CITE_BIB_FILES = + +#--------------------------------------------------------------------------- +# Configuration options related to warning and progress messages +#--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + +QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + +WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + +WARN_IF_UNDOCUMENTED = NO + +# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for +# potential errors in the documentation, such as not documenting some parameters +# in a documented function, or documenting parameters that don't exist or using +# markup commands wrongly. +# The default value is: YES. + +WARN_IF_DOC_ERROR = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO, doxygen will only warn about wrong or incomplete +# parameter documentation, but not about the absence of documentation. +# The default value is: NO. + +WARN_NO_PARAMDOC = NO + +# The WARN_FORMAT tag determines the format of the warning messages that doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# The default value is: $file:$line: $text. + +WARN_FORMAT = "$file:$line: $text" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). + +WARN_LOGFILE = + +#--------------------------------------------------------------------------- +# Configuration options related to the input files +#--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING +# Note: If this tag is empty the current directory is searched. + +INPUT = @PROJ_ROOT@/paddle + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: http://www.gnu.org/software/libiconv) for the list of +# possible encodings. +# The default value is: UTF-8. + +INPUT_ENCODING = UTF-8 + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# read by doxygen. +# +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, +# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, +# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, +# *.m, *.markdown, *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, +# *.vhdl, *.ucf, *.qsf, *.as and *.js. + +FILE_PATTERNS = *.c *.cc *.cpp *.cu *.h *.hpp *.cuh *.ph + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + +RECURSIVE = YES + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which doxygen is +# run. + +EXCLUDE = + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + +EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + +EXCLUDE_PATTERNS = */x86_64-scm-linux-gnu/* */internals/* */mkl/* */test/* */tests/* */platform/* + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# AClass::ANamespace, ANamespace::*Test +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories use the pattern */test/* + +EXCLUDE_SYMBOLS = + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + +EXAMPLE_PATTERNS = + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + +EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + +IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. + +INPUT_FILTER = + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. + +FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + +FILTER_SOURCE_FILES = NO + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + +FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the doxygen output. + +USE_MDFILE_AS_MAINPAGE = + +#--------------------------------------------------------------------------- +# Configuration options related to source browsing +#--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + +SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# classes and enums directly into the documentation. +# The default value is: NO. + +INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + +STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# function all documented functions referencing it will be listed. +# The default value is: NO. + +REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + +REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + +REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see http://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the config file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + +VERBATIM_HEADERS = YES + +#--------------------------------------------------------------------------- +# Configuration options related to the alphabetical class index +#--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + +ALPHABETICAL_INDEX = YES + +# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in +# which the alphabetical index list will be split. +# Minimum value: 1, maximum value: 20, default value: 5. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +COLS_IN_ALPHA_INDEX = 5 + +# In case all classes in a project start with a common prefix, all classes will +# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag +# can be used to specify a prefix (or a list of prefixes) that should be ignored +# while generating the index headers. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +IGNORE_PREFIX = + +#--------------------------------------------------------------------------- +# Configuration options related to the HTML output +#--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output +# The default value is: YES. + +GENERATE_HTML = NO + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_HEADER = + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FOOTER = + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefore more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra style sheet files is of importance (e.g. the last +# style sheet in the list overrules the setting of the previous ones in the +# list). For an example see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_STYLESHEET = + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_FILES = + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the style sheet and background images according to +# this color. Hue is specified as an angle on a colorwheel, see +# http://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use grayscales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML +# page will contain the date and time when the page was generated. Setting this +# to YES can help to show when doxygen was last run and thus if the +# documentation is up to date. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_TIMESTAMP = NO + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_SECTIONS = NO + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: http://developer.apple.com/tools/xcode/), introduced with +# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a +# Makefile in the HTML output directory. Running make will produce the docset in +# that directory and running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html +# for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on +# Windows. +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler (hhc.exe). If non-empty, +# doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated +# (YES) or that it should be included in the master .chm file (NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated +# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +TOC_EXPAND = NO + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- +# folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- +# filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- +# filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location of Qt's +# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the +# generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + +ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can +# further fine-tune the look of the index. As an example, the default style +# sheet generated by doxygen has an example that shows how to put an image at +# the root of the tree instead of the PROJECT_NAME. Since the tree basically has +# the same information as the tab index, you could consider setting +# DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_TREEVIEW = NO + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + +ENUM_VALUES_PER_LINE = 4 + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + +TREEVIEW_WIDTH = 250 + +# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +EXT_LINKS_IN_WINDOW = NO + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_FONTSIZE = 10 + +# Use the FORMULA_TRANPARENT tag to determine whether or not the images +# generated for formulas are transparent PNGs. Transparent PNGs are not +# supported properly for IE 6.0, but are supported on all modern browsers. +# +# Note that when changing this option you need to delete any form_*.png files in +# the HTML output directory before the changes have effect. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_TRANSPARENT = YES + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# http://www.mathjax.org) which uses client side Javascript for the rendering +# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +USE_MATHJAX = NO + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. See the MathJax site (see: +# http://docs.mathjax.org/en/latest/output.html) for more details. +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility), NativeMML (i.e. MathML) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from http://www.mathjax.org before deployment. +# The default value is: http://cdn.mathjax.org/mathjax/latest. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled doxygen will generate a search box for +# the HTML output. The underlying search engine uses javascript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the javascript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , /