-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.go
119 lines (117 loc) · 3.81 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package main
import (
"log"
"os"
"suvvm.work/toad_ocr_engine/common"
"suvvm.work/toad_ocr_engine/model"
"suvvm.work/toad_ocr_engine/nn"
"suvvm.work/toad_ocr_engine/rpc"
"suvvm.work/toad_ocr_engine/utils"
)
func main() {
if len(os.Args) == 1 {
log.Printf("Please provide command parameters\n Running with " +
"`help` to show currently supported commands")
return
}
cmd := os.Args[1]
if _, ok := common.CMDMap[cmd]; !ok {
log.Printf("Unknow command!\n")
return
}
if cmd == common.CmdServer {
rpc.RunRPCServer()
return
} else if cmd == common.CmdClient {
rpc.RunRpcClient()
return
} else if cmd == common.CmdHelp { // 帮助命令
log.Printf("\nToad OCR Engine Help:\ntrain: use command `%s` + target neural networks" +
" to training networks(use mnist train set when training)\n" +
"test: use command `%s` + target neural networks to testing" +
" networks(use mnist test set when testing)\n" +
"reset: use command `%s` + target neural networks to delete the weights file\n" +
"nnlist: use command `%s` to show all supported networks\n" +
"server: use command `%s` to run rpc server(etdc load blance control center must be online)\n" +
"client: use command `%s` to run rpc client to sent one snn predict request and one cnn predict request" +
"(etdc load blance control center must be online and at least one server registered)",
common.CmdTrain, common.CmdTest, common.CmdReset, common.CmdList, common.CmdServer, common.CmdClient)
return
} else if cmd == common.CmdList { // 展示支持的神经网络
log.Printf("supported networks:\n")
for key, _ := range common.NNMap {
log.Printf("\t%s\n", key)
}
return
}
if len(os.Args) < 3 { // 判断target neural networks
log.Printf("Please provide target neural networks\n")
return
}
cmdnn := os.Args[2]
if _, ok := common.NNMap[cmdnn]; !ok { // 检查神经网络种类
log.Printf("Please confirm the name of neural networks, use cmd " +
"`nnlist` to show all supported networks\n")
return
}
if cmd == common.CmdTrain { // 训练命令
if cmdnn == common.CnnName {
nn.RunCNN()
} else if cmdnn == common.SnnName {
nn.RunSNN()
}
} else if cmd == common.CmdTest { // 测试命令
_, _, testData, testLbl := utils.LoadNIST(common.EMNSITByClassTrainImagesPath,
common.EMNISTByClassTrainLabelsPath, common.EMNISTByClassTestImagesPath, common.EMNISTByClassTestLabelsPath)
if cmdnn == common.SnnName {
_, err := os.Stat("snn_weights")
if err != nil && !os.IsExist(err){
log.Printf("Please training first!\n")
return
}
snn, err := model.LoadSNNFromSave()
if err != nil {
log.Fatalf("Failed at load snn weights %v", err)
}
nn.SNNTesting(snn, testData, testLbl)
} else if cmdnn == common.CnnName {
_, err := os.Stat("cnn_weights")
if err != nil && !os.IsExist(err){
log.Printf("Please training first!\n")
return
}
cnn, err := model.LoadCNNFromSave()
defer cnn.VM.Close()
if err != nil {
log.Fatalf("Unable to load cnn file %v", err)
}
nn.CNNTesting(cnn, testData, testLbl)
}
} else if cmd == common.CmdReset { // 重置命令(清除权重矩阵)
if cmdnn == common.CnnName {
_, err := os.Stat("cnn_weights")
if err != nil && !os.IsExist(err){
log.Printf("Cnn weights file had been deleted!\n")
return
}
err = os.Remove("cnn_weights")
if err != nil {
log.Printf("Failed to delete cnn weights:%v\n", err)
}
log.Printf("Cnn weights file deleted!\n")
} else if cmdnn == common.SnnName {
_, err := os.Stat("snn_weights")
if err != nil && !os.IsExist(err){
log.Printf("Snn weights file had been deleted!\n")
return
}
err = os.Remove("snn_weights")
if err != nil {
log.Printf("Failed to delete snn weights:%v\n", err)
}
log.Printf("Snn weights file deleted!\n")
}
} else {
log.Printf("Unknow command!\n")
}
}