Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
BigPeng committed Jul 22, 2014
1 parent de0043c commit 3937030
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/edu/hitsz/c102c/cnn/RunCNN.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
public class RunCNN {

public static void runCnn() {
//创建一个卷积神经网络
LayerBuilder builder = new LayerBuilder();
builder.addLayer(Layer.buildInputLayer(new Size(28, 28)));
builder.addLayer(Layer.buildConvLayer(6, new Size(5, 5)));
Expand All @@ -18,18 +19,18 @@ public static void runCnn() {
builder.addLayer(Layer.buildSampLayer(new Size(2, 2)));
builder.addLayer(Layer.buildOutputLayer(10));
CNN cnn = new CNN(builder, 50);

//导入数据集
String fileName = "dataset/train.format";
Dataset dataset = Dataset.load(fileName, ",", 784);
cnn.train(dataset, 100);//
cnn.train(dataset, 3);//
String modelName = "model/model.cnn";
cnn.saveModel(modelName);
// CNN cnn = CNN.loadModel(modelName);
// Dataset dataset = Dataset.load(fileName,
// ",", 784);
// cnn.train(dataset, 100);
// cnn.saveModel(modelName);
cnn.saveModel(modelName);
dataset.clear();
dataset = null;

//预测
// CNN cnn = CNN.loadModel(modelName);
Dataset testset = Dataset.load("dataset/test.format", ",", -1);
cnn.predict(testset, "dataset/test.predict");
}
Expand Down

0 comments on commit 3937030

Please sign in to comment.