Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… add_dilation
  • Loading branch information
NHZlX committed Oct 27, 2017
2 parents f0c3c49 + 12d862b commit e06f8fc
Show file tree
Hide file tree
Showing 139 changed files with 4,588 additions and 677 deletions.
4 changes: 2 additions & 2 deletions go/cmd/pserver/pserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
log.Info("Could not find the pserver checkpoint.")
log.Info("load checkpoint error", "error", err)
} else {
panic(err)
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func main() {
candy.Must(err)

go func() {
log.Info("starting pserver", log.Ctx{"port": *port})
log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil)
candy.Must(err)
}()
Expand Down
3 changes: 2 additions & 1 deletion go/master/c/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
}
err := c.SetDataset(paths)
if err != nil {
log.Error("error set dataset", log.Ctx{"error": err})
log.Error("error set dataset",
log.Ctx{"error": err, "paths": paths})
return C.PADDLE_MASTER_ERROR
}

Expand Down
19 changes: 14 additions & 5 deletions go/master/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ func (c *Client) StartGetRecords(passID int) {
}

func (c *Client) getRecords(passID int) {
i := 0
for {
t, err := c.getTask(passID)
if err != nil {
Expand All @@ -130,12 +131,20 @@ func (c *Client) getRecords(passID int) {
c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue

if i%60 == 0 {
log.Debug("getTask of passID error.",
log.Ctx{"error": err, "passID": passID})
i = 0
}
log.Error("getTask error.", log.Ctx{"error": err})

// if err.Error() == ErrPassAfter.Error()
// wait util last pass finishes
// if other error such as network error
// wait to reconnect or task time out
time.Sleep(time.Second * 3)
i += 3
continue
}

for _, chunk := range t.Chunks {
Expand Down
1 change: 1 addition & 0 deletions go/master/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ func TestNextRecord(t *testing.T) {
if e != nil {
panic(e)
}

// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
Expand Down
8 changes: 7 additions & 1 deletion go/pserver/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0])
}

var cptr (*C.uchar)
if len(c) > 0 {
cptr = (*C.uchar)(&c[0])
} else {
log.Error("empty config", "param name", paramWithConfigs.Param.Name)
}
o.config = c
o.opt = C.paddle_create_optimizer(
(*C.uchar)(&c[0]),
cptr,
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
Expand Down
58 changes: 31 additions & 27 deletions go/pserver/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ package pserver
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path"
Expand All @@ -40,7 +39,7 @@ type ElementType int

// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")

// RPC error message.
const (
Expand Down Expand Up @@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type checkpointMeta struct {
UUID string `json:"uuid"`
Path string `json:"path"`
MD5 string `json:"md5"`
CRC32 uint32 `json:"crc32"`
Timestamp int64 `json:"timestamp"`
}

Expand All @@ -92,7 +91,7 @@ type Service struct {
idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
client KVStore

mu sync.Mutex
optMap map[string]*optimizer
Expand All @@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State []byte
}

func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
type KVStore interface {
GetKey(key string, timeout time.Duration) ([]byte, error)
PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
}

func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil {
return
Expand All @@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}

// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
log.Info("Loading checkpoint", "pserver index", idx)
defer traceTime(time.Now(), "load checkpoint")

Expand All @@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return nil, err
}

// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
crc32 := crc32.ChecksumIEEE(content)
if crc32 != cpMeta.CRC32 {
return nil, errors.New(WrongChecksum)
}

Expand All @@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if err = dec.Decode(&cp); err != nil {
return nil, err
}

return cp, nil
}

// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
Expand All @@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
close(s.initialized)
}
return s, nil
}
Expand Down Expand Up @@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t {
err := s.checkpoint()
if err != nil {
log.Error("finish init params error", log.Ctx{"error": err})
log.Error("checkpoint error", log.Ctx{"error": err})
}
}
}()
Expand Down Expand Up @@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()

log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil
}
Expand Down Expand Up @@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {

oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
log.Info("Do not have existing checkpoint.")
log.Info("old meta not found, skip removing old meta")
err = nil
} else if err == nil {
log.Info("removing old meta")
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}
}

if err != nil {
return
}

h := md5.New()
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
crc32 := crc32.ChecksumIEEE(buf.Bytes())
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
MD5: md5,
CRC32: crc32,
Path: p,
}

Expand All @@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
}

if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}

return
}
86 changes: 86 additions & 0 deletions go/pserver/service_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package pserver

import (
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

const testDir = "./test_data"

type myKV struct {
m map[string][]byte
}

func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
if m.m == nil {
m.m = make(map[string][]byte)
}
return m.m[key], nil
}

func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
if m.m == nil {
m.m = make(map[string][]byte)
}
m.m[key] = value
return nil
}

func TestCheckpoint(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
err = s.checkpoint()
assert.Nil(t, err)
_, err = LoadCheckpoint(kv, 0)
assert.Nil(t, err)
}

func float32ToByte(f float32) []byte {
var buf bytes.Buffer
err := binary.Write(&buf, binary.LittleEndian, f)
if err != nil {
fmt.Println("binary.Write failed:", err)
}
return buf.Bytes()
}

func TestCheckpointWithData(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)

var content []byte
for i := 0; i < 50000; i++ {
content = append(content, float32ToByte(float32(i))...)
}

p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
assert.Nil(t, err)

err = s.FinishInitParams(0, nil)
assert.Nil(t, err)

var p2 Parameter
err = s.GetParam(p1.Name, &p2)
assert.Nil(t, err)
assert.Equal(t, p1, p2)

err = s.checkpoint()
assert.Nil(t, err)
cp, err := LoadCheckpoint(kv, 0)
assert.Nil(t, err)
s1, err := NewService(0, time.Hour, testDir, kv, cp)
assert.Nil(t, err)

var p3 Parameter
err = s1.GetParam(p1.Name, &p3)
assert.Nil(t, err)
assert.Equal(t, p1, p3)
}
4 changes: 0 additions & 4 deletions go/pserver/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {

wg.Wait()
}

func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}
10 changes: 8 additions & 2 deletions paddle/capi/gradient_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
modelConfigProtobuf.resize(modelConfigSize);
is.read(&modelConfigProtobuf[0], modelConfigSize);
paddle::TrainerConfig config;
paddle::ModelConfig modelConfig;
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
return kPD_PROTOBUF_ERROR;
if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
!modelConfig.IsInitialized()) {
return kPD_PROTOBUF_ERROR;
}
} else {
modelConfig = config.model_config();
}
auto ptr = new paddle::capi::CGradientMachine();
ptr->machine.reset(paddle::GradientMachine::create(
config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters();
for (auto& para : parameters) {
para->load(is);
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog)

cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
Expand All @@ -42,7 +42,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})

cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)

cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)

Expand Down
Loading

0 comments on commit e06f8fc

Please sign in to comment.