Skip to content

Commit b7dce0f

Browse files
committed
Implemented backend generators as instances of a class
1 parent d0530a5 commit b7dce0f

File tree

5 files changed

+52
-39
lines changed

5 files changed

+52
-39
lines changed

app/Main.hs

+15-4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@ import TensorSafe.Commands.Check (check)
77
import TensorSafe.Commands.Compile (compile)
88
import TensorSafe.Commands.Examples (examples)
99

10+
data Backend = JavaScript | Python deriving (Data,Typeable,Show,Eq)
11+
1012
data TensorSafe = Check { path :: FilePath }
11-
| Compile { path :: FilePath, module_name :: String }
13+
| Compile {
14+
path :: FilePath,
15+
module_name :: String,
16+
backend :: Backend,
17+
out :: Maybe FilePath
18+
}
1219
| Examples
1320
deriving (Data, Typeable, Show, Eq)
1421

@@ -21,6 +28,10 @@ cCompile :: TensorSafe
2128
cCompile = Compile
2229
{ path = def &= typ "PATH" &= help "Path to Haskell module with TensorSafe model inside"
2330
, module_name = def &= help "The module name inside the TensorSafe model file"
31+
, backend = enum
32+
[ JavaScript &= help "Compile to JavaScript backend"
33+
, Python &= help "Compile to Python backend"]
34+
, out = def &= help "If specified, the output file path to which the network will be generated"
2435
} &= help "Compiles module and outputs Neural Network model for the specified backend"
2536

2637
cExamples :: TensorSafe
@@ -34,6 +45,6 @@ main = do
3445
-- print =<< tensorSafe
3546
r <- tensorSafe
3647
case r of
37-
Check { path = p } -> check p
38-
Compile { path = p, module_name = m } -> compile p m
39-
Examples -> examples
48+
Check { path = p } -> check p
49+
Compile { path = p, module_name = m, backend = b, out = o } -> compile p m (show b) o
50+
Examples -> examples

src/TensorSafe.hs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
{-# LANGUAGE DataKinds #-}
22
{-# LANGUAGE ScopedTypeVariables #-}
33
module TensorSafe (
4-
Backend (..),
5-
evalCNetwork,
4+
JavaScript (..),
5+
generate,
66
INetwork,
77
MkINetwork,
88
mkINetwork,
99
toCNetwork
1010
) where
1111

12-
import TensorSafe.Compile.Expr (Backend (..), evalCNetwork)
12+
import TensorSafe.Compile.Expr (JavaScript (..), generate)
1313
import TensorSafe.Network (INetwork, MkINetwork, mkINetwork,
1414
toCNetwork)

src/TensorSafe/Commands/Compile.hs

+12-7
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,26 @@ import System.Exit
66
import TensorSafe.Commands.Utils
77

88

9-
compile :: String -> String -> IO ()
10-
compile path moduleName = do
11-
r <- runInterpreter $ checkAndCompile path moduleName
9+
compile :: String -> String -> String -> Maybe FilePath -> IO ()
10+
compile path moduleName backend out = do
11+
r <- runInterpreter $ checkAndCompile path moduleName backend out
1212
case r of
1313
Left err -> do
1414
putStrLn $ errorString err
1515
exitWith $ ExitFailure 1
1616
Right () -> do
1717
exitWith $ ExitSuccess
1818

19-
checkAndCompile :: String -> String -> Interpreter ()
20-
checkAndCompile path moduleName = do
19+
checkAndCompile :: String -> String -> String -> Maybe FilePath -> Interpreter ()
20+
checkAndCompile path moduleName backend out = do
2121
loadModules [path]
2222
setTopLevelModules [moduleName]
2323
setImportsQ [("TensorSafe", Nothing), ("Data.Text.Lazy", Nothing)]
2424

25-
r <- interpret "unpack $ evalCNetwork JavaScript (toCNetwork nn)" (as :: String)
26-
liftIO $ putStrLn r
25+
case out of
26+
Nothing -> do
27+
r <- interpret ("unpack $ generate " ++ backend ++ " (toCNetwork nn)") (as :: String)
28+
liftIO $ putStrLn r
29+
Just f -> do
30+
r <- interpret ("unpack $ generate " ++ backend ++ " (toCNetwork nn)") (as :: String)
31+
liftIO $ writeFile f r

src/TensorSafe/Compile/Expr.hs

+19-22
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,25 @@ data CNetwork = CNSequence CNetwork
1313
| CNNil
1414
deriving Show
1515

16-
-- TODO
17-
-- Use generate instead of eval
18-
-- Use classes and instanciate for each language.
19-
-- ie: class Generate l where
20-
-- generate :: l -> CN -> Text
21-
data Backend = JavaScript | Python deriving Show
22-
23-
evalCNetwork :: Backend -> CNetwork -> Text
24-
evalCNetwork Python cn = T.intercalate "\n" (evalPython cn)
25-
evalCNetwork JavaScript cn = T.intercalate "\n" (evalJS cn)
26-
27-
28-
evalPython :: CNetwork -> [Text]
29-
evalPython = undefined
30-
31-
evalJS :: CNetwork -> [Text]
32-
evalJS (CNSequence cn) = ["const model = tf.sequential();"] ++ evalJS cn
33-
evalJS (CNCons cn1 cn2) = (evalJS cn1) ++ (evalJS cn2)
34-
evalJS (CNNil) = []
35-
evalJS CNReturn = [] -- ["return model"]
36-
evalJS (CNLayer layer params) =
37-
[format ("model.add(tf.layers." % string % "(" % string % "))") layer (paramsToJS params)]
16+
17+
data JavaScript = JavaScript deriving Show
18+
19+
-- | Class that defines which languages are supported for CNetworks generation to text
20+
class Generator l where
21+
generate :: l -> CNetwork -> Text
22+
23+
-- | Instance for JavaScript generation
24+
instance Generator JavaScript where
25+
generate _ =
26+
T.intercalate "\n" . evalJS
27+
where
28+
evalJS :: CNetwork -> [Text]
29+
evalJS (CNSequence cn) = ["const model = tf.sequential();"] ++ evalJS cn
30+
evalJS (CNCons cn1 cn2) = (evalJS cn1) ++ (evalJS cn2)
31+
evalJS (CNNil) = []
32+
evalJS CNReturn = [] -- ["return model"]
33+
evalJS (CNLayer layer params) =
34+
[format ("model.add(tf.layers." % string % "(" % string % "))") layer (paramsToJS params)]
3835

3936

4037
paramsToJS :: Map String String -> String

src/TensorSafe/Examples/Examples.hs

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module TensorSafe.Examples.Examples where
55
import Data.Maybe (Maybe, fromJust)
66
import Data.Text.Lazy (unpack)
77

8-
import TensorSafe.Compile.Expr (Backend (..), evalCNetwork)
8+
import TensorSafe.Compile.Expr (JavaScript (..), generate)
99
import TensorSafe.Examples.MnistExample (mnist, mnistDense)
1010
import TensorSafe.Examples.SimpleExample (myNet)
1111
import TensorSafe.Generic.Shape
@@ -84,7 +84,7 @@ mnistExample =
8484
putStrLn $ "\n"
8585
putStrLn $ "MNIST generation"
8686
putStrLn $ "-------------"
87-
putStrLn $ unpack $ evalCNetwork JavaScript (toCNetwork mnist)
87+
putStrLn $ unpack $ generate JavaScript (toCNetwork mnist)
8888

8989
mnistExampleDense :: IO ()
9090
mnistExampleDense =
@@ -99,5 +99,5 @@ mnistExampleDense =
9999
putStrLn $ "\n"
100100
putStrLn $ "MNIST generation"
101101
putStrLn $ "-------------"
102-
putStrLn $ unpack $ evalCNetwork JavaScript (toCNetwork mnistDense)
102+
putStrLn $ unpack $ generate JavaScript (toCNetwork mnistDense)
103103

0 commit comments

Comments
 (0)