@@ -40,7 +40,7 @@ data DLayer
40
40
41
41
-- | Defines the
42
42
data CNetwork
43
- = CNSequence CNetwork
43
+ = CNSequence ( Map String String ) CNetwork
44
44
| CNConcatenate CNetwork CNetwork
45
45
| CNCons CNetwork CNetwork
46
46
| CNLayer DLayer (Map String String )
@@ -98,21 +98,53 @@ class Generator l where
98
98
-- have the CNetwork compiled at a separate file.
99
99
generateFile :: l -> CNetwork -> Text
100
100
101
+ data Model = Model String Integer
102
+
103
+ instance Show Model where
104
+ show (Model name level) = name ++ " _" ++ show level
105
+
106
+ newModel :: Model
107
+ newModel = Model " x" 0
108
+
109
+ nextModel :: String -> Model -> Model
110
+ nextModel name (Model _ level) = Model name (level + 1 )
111
+
101
112
instance Generator JavaScript where
102
113
generate l =
103
- T. intercalate " \n " . generateJS
114
+ T. intercalate " \n " . generateJS newModel
104
115
where
105
- generateJS :: CNetwork -> [Text ]
106
- generateJS (CNSequence cn) = " var model = tf.sequential();" : generateJS cn
107
- generateJS (CNConcatenate cn1 cn2) = generateJS cn1 ++ generateJS cn2 -- FIX
108
- generateJS (CNCons cn1 cn2) = generateJS cn1 ++ generateJS cn2
109
- generateJS CNNil = []
110
- generateJS CNReturn = []
111
- generateJS (CNLayer layer params) =
116
+ generateJS :: Model -> CNetwork -> [Text ]
117
+ generateJS model (CNSequence params cn) =
118
+ format (" var input = tf.input(" % string % " );" ) (paramsToJS params) :
119
+ format (" var " % string % " = input;" ) (show model) :
120
+ generateJS model cn
121
+ ++ [ format
122
+ (" model = tf.model({ inputs: input, outputs: " % string % " });" )
123
+ (show model)
124
+ ]
125
+ generateJS model (CNConcatenate cn1 cn2) =
126
+ let modelA = nextModel " a" model
127
+ modelB = nextModel " b" model
128
+ in format (" var " % string % " = " % string % " ;" ) (show modelA) (show model) :
129
+ generateJS modelA cn1
130
+ ++ format (" var " % string % " = " % string % " ;" ) (show modelB) (show model) :
131
+ generateJS modelB cn2
132
+ ++ [ format
133
+ (string % " = tf.layers.concatenate().apply([" % string % " , " % string % " ])" )
134
+ (show model)
135
+ (show modelA)
136
+ (show modelB)
137
+ ]
138
+ generateJS model (CNCons cn1 cn2) = generateJS model cn1 ++ generateJS model cn2
139
+ generateJS _ CNNil = []
140
+ generateJS _ CNReturn = []
141
+ generateJS model (CNLayer layer params) =
112
142
[ format
113
- (" model.add(tf.layers." % string % " (" % string % " ));" )
143
+ (string % " = tf.layers." % string % " (" % string % " ).apply(" % string % " )" )
144
+ (show model)
114
145
(generateName l layer)
115
146
(paramsToJS params)
147
+ (show model)
116
148
]
117
149
118
150
generateFile l cn =
@@ -152,7 +184,7 @@ instance Generator Python where
152
184
T. intercalate " \n " . generatePy
153
185
where
154
186
generatePy :: CNetwork -> [Text ]
155
- generatePy (CNSequence cn) = " model = tf.keras.models.Sequential()" : generatePy cn
187
+ generatePy (CNSequence params cn) = " model = tf.keras.models.Sequential()" : generatePy cn
156
188
generatePy (CNConcatenate cn1 cn2) = generatePy cn1 ++ generatePy cn2 -- FIX
157
189
generatePy (CNCons cn1 cn2) = generatePy cn1 ++ generatePy cn2
158
190
generatePy CNNil = []
0 commit comments