Skip to content

Commit cdf611d

Browse files
committed
Updated concat layer and code generation structure
1 parent d59c0e9 commit cdf611d

File tree

7 files changed

+194
-45
lines changed

7 files changed

+194
-45
lines changed

src/TensorSafe/Commands/Examples.hs

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
module TensorSafe.Commands.Examples (examples) where
33

44
import TensorSafe.Examples.Examples
5-
( mnistConcatenateExample,
5+
( mnistConcatenateComplexExample,
6+
mnistConcatenateExample,
67
mnistExample,
78
mnistExampleDense,
89
simpleExample,
@@ -18,3 +19,5 @@ examples = do
1819
mnistExampleDense
1920
putStrLn "\n\n"
2021
mnistConcatenateExample
22+
putStrLn "\n\n"
23+
mnistConcatenateComplexExample

src/TensorSafe/Compile/Expr.hs

+43-11
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ data DLayer
4040

4141
-- | Defines the
4242
data CNetwork
43-
= CNSequence CNetwork
43+
= CNSequence (Map String String) CNetwork
4444
| CNConcatenate CNetwork CNetwork
4545
| CNCons CNetwork CNetwork
4646
| CNLayer DLayer (Map String String)
@@ -98,21 +98,53 @@ class Generator l where
9898
-- have the CNetwork compiled at a separate file.
9999
generateFile :: l -> CNetwork -> Text
100100

101+
data Model = Model String Integer
102+
103+
instance Show Model where
104+
show (Model name level) = name ++ "_" ++ show level
105+
106+
newModel :: Model
107+
newModel = Model "x" 0
108+
109+
nextModel :: String -> Model -> Model
110+
nextModel name (Model _ level) = Model name (level + 1)
111+
101112
instance Generator JavaScript where
102113
generate l =
103-
T.intercalate "\n" . generateJS
114+
T.intercalate "\n" . generateJS newModel
104115
where
105-
generateJS :: CNetwork -> [Text]
106-
generateJS (CNSequence cn) = "var model = tf.sequential();" : generateJS cn
107-
generateJS (CNConcatenate cn1 cn2) = generateJS cn1 ++ generateJS cn2 -- FIX
108-
generateJS (CNCons cn1 cn2) = generateJS cn1 ++ generateJS cn2
109-
generateJS CNNil = []
110-
generateJS CNReturn = []
111-
generateJS (CNLayer layer params) =
116+
generateJS :: Model -> CNetwork -> [Text]
117+
generateJS model (CNSequence params cn) =
118+
format ("var input = tf.input(" % string % ");") (paramsToJS params) :
119+
format ("var " % string % " = input;") (show model) :
120+
generateJS model cn
121+
++ [ format
122+
("model = tf.model({ inputs: input, outputs: " % string % " });")
123+
(show model)
124+
]
125+
generateJS model (CNConcatenate cn1 cn2) =
126+
let modelA = nextModel "a" model
127+
modelB = nextModel "b" model
128+
in format ("var " % string % " = " % string % ";") (show modelA) (show model) :
129+
generateJS modelA cn1
130+
++ format ("var " % string % " = " % string % ";") (show modelB) (show model) :
131+
generateJS modelB cn2
132+
++ [ format
133+
(string % " = tf.layers.concatenate().apply([" % string % ", " % string % "])")
134+
(show model)
135+
(show modelA)
136+
(show modelB)
137+
]
138+
generateJS model (CNCons cn1 cn2) = generateJS model cn1 ++ generateJS model cn2
139+
generateJS _ CNNil = []
140+
generateJS _ CNReturn = []
141+
generateJS model (CNLayer layer params) =
112142
[ format
113-
("model.add(tf.layers." % string % "(" % string % "));")
143+
(string % " = tf.layers." % string % "(" % string % ").apply(" % string % ")")
144+
(show model)
114145
(generateName l layer)
115146
(paramsToJS params)
147+
(show model)
116148
]
117149

118150
generateFile l cn =
@@ -152,7 +184,7 @@ instance Generator Python where
152184
T.intercalate "\n" . generatePy
153185
where
154186
generatePy :: CNetwork -> [Text]
155-
generatePy (CNSequence cn) = "model = tf.keras.models.Sequential()" : generatePy cn
187+
generatePy (CNSequence params cn) = "model = tf.keras.models.Sequential()" : generatePy cn
156188
generatePy (CNConcatenate cn1 cn2) = generatePy cn1 ++ generatePy cn2 -- FIX
157189
generatePy (CNCons cn1 cn2) = generatePy cn1 ++ generatePy cn2
158190
generatePy CNNil = []

src/TensorSafe/Examples/ConcatenateExample.hs

+66-10
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44
-- | This module implements the MNIST model using some Concatenate layers
55
module TensorSafe.Examples.ConcatenateExample
66
( mnistConcatenate,
7+
mnistConcatenateComplex,
78
)
89
where
910

1011
import TensorSafe.Layers
11-
( Conv2D,
12+
( Concatenate,
13+
Conv2D,
1214
Dense,
1315
Flatten,
1416
MaxPooling,
1517
Relu,
1618
Sigmoid,
1719
)
18-
import TensorSafe.Network (Concatenate, MkINetwork, mkINetwork)
20+
import TensorSafe.Network (MkINetwork, mkINetwork)
1921
import TensorSafe.Shape (Shape (D1, D3))
2022

2123
type Ls1 =
@@ -25,26 +27,80 @@ type Ls1 =
2527
MaxPooling 2 2 2 2,
2628
Flatten
2729
]
28-
('D3 64 64 1) -- Input
29-
('D1 28800) -- Output
30+
('D3 28 28 1) -- Input
31+
('D1 4608) -- Output
3032

3133
type Ls2 =
3234
MkINetwork
3335
'[Conv2D 1 16 8 8 1 1, Relu, MaxPooling 2 2 2 2, Flatten]
34-
('D3 64 64 1) -- Input
35-
('D1 12544) -- Output
36+
('D3 28 28 1) -- Input
37+
('D1 1600) -- Output
3638

3739
type MNISTConcatenate =
3840
MkINetwork
3941
'[ Concatenate Ls1 Ls2,
40-
Dense 41344 1024,
42+
Dense 6208 1024,
4143
Relu,
4244
Dense 1024 10,
4345
Sigmoid
4446
]
45-
('D3 64 64 1) -- Input
47+
('D3 28 28 1) -- Input
4648
('D1 10) -- Output
4749

48-
-- | MNIST implementation using Convolutional layers
50+
-- | MNIST implementation using Concatenate layer
4951
mnistConcatenate :: MNISTConcatenate
50-
mnistConcatenate = mkINetwork
52+
mnistConcatenate = mkINetwork
53+
54+
--
55+
--
56+
--
57+
type Ls211 =
58+
MkINetwork
59+
'[ Conv2D 8 32 4 4 1 1,
60+
Relu,
61+
MaxPooling 2 2 2 2,
62+
Flatten
63+
]
64+
('D3 14 14 8) -- Input
65+
('D1 800) -- Output
66+
67+
type Ls212 =
68+
MkINetwork
69+
'[ Conv2D 8 64 8 8 1 1,
70+
Relu,
71+
MaxPooling 2 2 2 2,
72+
Flatten
73+
]
74+
('D3 14 14 8) -- Input
75+
('D1 576) -- Output
76+
77+
type Ls21 =
78+
MkINetwork
79+
'[ Concatenate Ls211 Ls212
80+
]
81+
('D3 14 14 8) -- Input
82+
('D1 1376) -- Output
83+
84+
type Ls22 =
85+
MkINetwork
86+
'[Conv2D 8 16 8 8 1 1, Relu, MaxPooling 2 2 2 2, Flatten]
87+
('D3 14 14 8) -- Input
88+
('D1 144) -- Output
89+
90+
type MNISTConcatenateComplex =
91+
MkINetwork
92+
'[ Conv2D 1 8 1 1 1 1,
93+
Relu,
94+
MaxPooling 2 2 2 2,
95+
Concatenate Ls21 Ls22,
96+
Dense 1520 512,
97+
Relu,
98+
Dense 512 10,
99+
Sigmoid
100+
]
101+
('D3 28 28 1) -- Input
102+
('D1 10) -- Output
103+
104+
-- | MNIST implementation using Convolutional layers
105+
mnistConcatenateComplex :: MNISTConcatenateComplex
106+
mnistConcatenateComplex = mkINetwork

src/TensorSafe/Examples/Examples.hs

+18-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ module TensorSafe.Examples.Examples
77
mnistExampleDense,
88
simpleExample,
99
mnistConcatenateExample,
10+
mnistConcatenateComplexExample,
1011
)
1112
where
1213

1314
import Data.Text.Lazy (unpack)
1415
import TensorSafe.Compile.Expr (JavaScript (..), generate)
15-
import TensorSafe.Examples.ConcatenateExample (mnistConcatenate)
16+
import TensorSafe.Examples.ConcatenateExample (mnistConcatenate, mnistConcatenateComplex)
1617
import TensorSafe.Examples.MnistExample (mnist, mnistDense)
1718
import TensorSafe.Examples.SimpleExample
1819
( lstm,
@@ -84,4 +85,19 @@ mnistConcatenateExample =
8485
putStrLn "\n"
8586
putStrLn "MNIST Concatenate generation"
8687
putStrLn "------------------------------"
87-
putStrLn $ unpack $ generate JavaScript (toCNetwork mnistConcatenate)
88+
putStrLn $ unpack $ generate JavaScript (toCNetwork mnistConcatenate)
89+
90+
mnistConcatenateComplexExample :: IO ()
91+
mnistConcatenateComplexExample =
92+
do
93+
putStrLn "MNIST Concatenate Complex with Concatenate example"
94+
putStrLn "------------------------------"
95+
print mnistConcatenateComplex
96+
putStrLn "\n"
97+
putStrLn "MNIST Concatenate Complex compilation"
98+
putStrLn "------------------------------"
99+
print (toCNetwork mnistConcatenateComplex)
100+
putStrLn "\n"
101+
putStrLn "MNIST Concatenate Complex generation"
102+
putStrLn "------------------------------"
103+
putStrLn $ unpack $ generate JavaScript (toCNetwork mnistConcatenateComplex)

src/TensorSafe/Layers.hs

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
-- | This module exposes all Layers declared at TensorSafe.Layers.
22
module TensorSafe.Layers
33
( BatchNormalization,
4+
Concatenate (..),
45
Conv2D,
56
Dense,
67
Dropout,
@@ -17,6 +18,7 @@ module TensorSafe.Layers
1718
where
1819

1920
import TensorSafe.Layers.BatchNormalization (BatchNormalization)
21+
import TensorSafe.Layers.Concatenate (Concatenate (..))
2022
import TensorSafe.Layers.Conv2D (Conv2D)
2123
import TensorSafe.Layers.Dense (Dense)
2224
import TensorSafe.Layers.Dropout (Dropout)

src/TensorSafe/Layers/Concatenate.hs

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{-# LANGUAGE DataKinds #-}
2+
{-# LANGUAGE GADTs #-}
3+
{-# LANGUAGE ScopedTypeVariables #-}
4+
{-# LANGUAGE TypeFamilies #-}
5+
6+
-- | This module declares the Concatenate layer data type.
7+
module TensorSafe.Layers.Concatenate where
8+
9+
import Data.Kind (Type)
10+
11+
-- | Concatenates two valid INetwork types
12+
data Concatenate :: Type -> Type -> Type where
13+
Concatenate :: in1 -> in2 -> Concatenate in1 in2
14+
deriving (Show)
15+
16+
-- |
17+
-- | Layer instance of Concatenate defined at Network.hs module
18+
-- |

0 commit comments

Comments
 (0)