-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathnets.hpp
130 lines (99 loc) · 2.9 KB
/
nets.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//#ifndef ALEXNET_H
//#define ALEXNET_H
#include "torch/torch.h"
using namespace torch;
typedef size_t alnet_t;
extern const alnet_t CHANELS_NUM;
/*extern const alnet_t STRIDE_SIZE;
extern const alnet_t PADDING_SIZE_1;
extern const alnet_t PADDING_SIZE_2;
extern const alnet_t OUT_NEURONS_CONV_1; // 64
extern const alnet_t OUT_NEURONS_CONV_2; // 192
extern const alnet_t OUT_NEURONS_CONV_3; // 384
extern const alnet_t OUT_NEURONS_CONV_4; // 256
extern const alnet_t OUT_NEURONS_CONV_5; // 256*/
extern const std::string FORWARD_ERROR_STR;
class Net: public nn::Module
{
public:
virtual Tensor forward(torch::Tensor x);
virtual ~Net();
};
class OurNet: public Net//public torch::nn::Module
{
public:
nn::Conv2d conv1_1{nullptr};
nn::Conv2d conv1_2{nullptr};
nn::Conv2d conv2_1{nullptr};
nn::Conv2d conv2_2{nullptr};
nn::Conv2d conv3_1{nullptr};
nn::Conv2d conv3_2{nullptr};
nn::Linear fc1{nullptr}, fc2{nullptr};
OurNet();
Tensor forward(Tensor x);
};
class AlexNet : public Net//public torch::nn::Module
{
public:
nn::Conv2d conv1 = nullptr;
nn::Conv2d conv2 = nullptr;
nn::Conv2d conv3 = nullptr;
nn::Conv2d conv4 = nullptr;
nn::Conv2d conv5 = nullptr;
nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
AlexNet();
Tensor forward(Tensor x);
};
class SmallNet : public Net//public torch::nn::Module
{
public:
nn::Conv2d conv1{nullptr};
nn::Conv2d conv2{nullptr};
nn::Linear fc1{nullptr}, fc2{nullptr};
SmallNet();
Tensor forward(Tensor x);
};
class VGG16 : public Net //public torch::nn::Module
{
public:
nn::Conv2d conv1 = nullptr;
nn::Conv2d conv2 = nullptr;
nn::Conv2d conv3 = nullptr;
nn::Conv2d conv4 = nullptr;
nn::Conv2d conv5 = nullptr;
nn::Conv2d conv6 = nullptr;
nn::Conv2d conv7 = nullptr;
nn::Conv2d conv8 = nullptr;
nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
VGG16();
Tensor forward(Tensor x);
};
class BasicBlock : public Net //public torch::nn::Module
{
public:
alnet_t expansion;
alnet_t planes;
//nn::BatchNorm2d bn1 = nullptr;
nn::Conv2d conv1 = nullptr;
//nn::BatchNorm2d bn2 = nullptr;
nn::Conv2d conv2 = nullptr;
nn::Sequential downsample = nullptr;
alnet_t stride;
BasicBlock(alnet_t inplanes, alnet_t planes, alnet_t stride, nn::Sequential downsample);
Tensor forward(Tensor x);
};
class ResNet18 : public Net
{
public:
alnet_t inplanes;
nn::Conv2d convv1 = nullptr;
nn::Sequential layer1 = nullptr;
nn::Sequential layer2 = nullptr;
nn::Sequential layer3 = nullptr;
nn::Sequential layer4 = nullptr;
nn::Linear fc = nullptr;
ResNet18();
nn::Sequential make_layer(alnet_t planes, alnet_t stride = 1);
Tensor forward(Tensor x);
};
//#endif // NETS_H