diff --git a/brat/Brat/Checker.hs b/brat/Brat/Checker.hs index 23084657..ec165766 100644 --- a/brat/Brat/Checker.hs +++ b/brat/Brat/Checker.hs @@ -37,6 +37,7 @@ import Brat.Constructors import Brat.Error import Brat.Eval import Brat.FC hiding (end) +import qualified Brat.FC as FC import Brat.Graph import Brat.Naming -- import Brat.Search @@ -395,7 +396,7 @@ check' (Simple tm) ((), ((hungry, ty):unders)) = do (Braty, Left Nat, Num n) -> do (_, _, [(dangling, _)], _) <- next "" (Const (Num n)) (S0,Some (Zy :* S0)) R0 (REx ("value", Nat) (S0 ::- R0)) - let val = VNum (nConstant n) + let val = VNum (nConstant (fromIntegral n)) defineSrc dangling val defineTgt hungry val wire (dangling, kindType Nat, hungry) @@ -543,11 +544,29 @@ kindCheck unders (Emb (WC fc (Var v))) = localFC fc $ vlup v >>= f unders -- TODO: Add other operations on numbers kindCheck ((hungry, Nat):unders) (Simple (Num n)) | n >= 0 = do (_, _, [(dangling, _)], _) <- next "" (Const (Num n)) (S0,Some (Zy :* S0)) R0 (REx ("value", Nat) (S0 ::- R0)) - let value = VNum (nConstant n) + let value = VNum (nConstant (fromIntegral n)) defineTgt hungry value defineSrc dangling value wire (dangling, TNat, hungry) pure ([value], unders) +kindCheck ((hungry, Nat):unders) (Arith op lhs rhs) = do + (_, arithUnders, [(dangling,_)], _) <- next ("arith_" ++ show op) (ArithNode op) (S0, Some (Zy :* S0)) + (REx ("lhs", Nat) (S0 ::- (REx ("rhs", Nat) (S0 ::- R0)))) + (REx ("value", Nat) (S0 ::- R0)) + ([vlhs, vrhs], []) <- kindCheck [ (p, k) | (p, Left k) <- arithUnders ] (lhs :|: rhs) + let arithFC = FC (FC.start (fcOf lhs)) (FC.end (fcOf rhs)) + localFC arithFC $ case (vlhs, vrhs) of + (VNum lhs, VNum rhs) -> do + case runArith lhs op rhs of + Nothing -> typeErr "Type level arithmetic too confusing" + Just result -> do + defineTgt hungry (VNum result) + defineSrc dangling (VNum result) + wire (dangling, kindType Nat, hungry) + pure ([VNum result], unders) + (VNum _, x) -> localFC (fcOf rhs) . typeErr $ "Expected numeric expression, found " ++ show x + (x, VNum _) -> localFC (fcOf lhs) . typeErr $ "Expected numeric expression, found " ++ show x + _ -> typeErr "Expected arguments to arithmetic operators to have kind #" kindCheck ((hungry, Nat):unders) (Con c arg) | Just (_, f) <- M.lookup c natConstructors = do -- All nat constructors have one input and one output @@ -562,7 +581,8 @@ kindCheck ((hungry, Nat):unders) (Con c arg) defineTgt hungry v pure ([v], unders) -kindCheck unders tm = err $ Unimplemented "kindCheck" [showRow unders, show tm] +kindCheck ((_, k):_) tm = typeErr $ "Expected " ++ show tm ++ " to have kind " ++ show k + -- Checks the kinds of the types in a dependent row kindCheckRow :: Modey m @@ -713,7 +733,7 @@ abstractPattern Braty (dangling, Left k) pat = abstractKind k pat abstractKind _ (Bind x) = let ?my = Braty in singletonEnv x (dangling, Left k) abstractKind _ (DontCare) = pure emptyEnv abstractKind k (Lit x) = case (k, x) of - (Nat, Num n) -> defineSrc dangling (VNum (nConstant n)) $> emptyEnv + (Nat, Num n) -> defineSrc dangling (VNum (nConstant (fromIntegral n))) $> emptyEnv (Star _, _) -> err MatchingOnTypes _ -> err (PattErr $ "Couldn't resolve pattern " ++ show pat ++ " with kind " ++ show k) -- abstractKind Braty Nat p = abstractPattern Braty (src, Right TNat) p diff --git a/brat/Brat/Checker/Clauses.hs b/brat/Brat/Checker/Clauses.hs index 9843d2d7..bf9f7458 100644 --- a/brat/Brat/Checker/Clauses.hs +++ b/brat/Brat/Checker/Clauses.hs @@ -172,7 +172,7 @@ solve my ((src, Lit tm):p) = do (Braty, Left Nat) | Num n <- tm -> do unless (n >= 0) $ typeErr "Negative Nat kind" - unifyNum (nConstant n) (nVar (VPar (ExEnd (end src)))) + unifyNum (nConstant (fromIntegral n)) (nVar (VPar (ExEnd (end src)))) (Braty, Right ty) -> do throwLeft (simpleCheck Braty ty tm) _ -> typeErr $ "Literal " ++ show tm ++ " isn't valid at this type" diff --git a/brat/Brat/Checker/Helpers.hs b/brat/Brat/Checker/Helpers.hs index df496010..41684217 100644 --- a/brat/Brat/Checker/Helpers.hs +++ b/brat/Brat/Checker/Helpers.hs @@ -35,6 +35,7 @@ import Brat.Syntax.Value import Brat.UserName import Bwd import Hasochism +import Util (log2) import Control.Monad.Freer (req, Free(Ret)) import Control.Arrow ((***)) @@ -376,3 +377,53 @@ roToTuple R0 = TNil roToTuple (RPr (_, ty) ro) = TCons ty (roToTuple ro) roToTuple (REx _ ro) = case ro of _ -> error "the impossible happened" + +-- Low hanging fruit that we can easily do to our normal forms of numbers +runArith :: NumVal Z -> ArithOp -> NumVal Z -> Maybe (NumVal Z) +runArith (NumValue upl grol) Add (NumValue upr gror) +-- We can add when one of the sides is a constant... + | Constant0 <- grol = pure $ NumValue (upl + upr) gror + | Constant0 <- gror = pure $ NumValue (upl + upr) grol + -- ... or when Fun00s are the same + | grol == gror = pure $ NumValue (upl + upr) grol +runArith (NumValue upl grol) Sub (NumValue upr gror) +-- We can subtract when the rhs is a constant... + | Constant0 <- gror, upl >= upr = pure $ NumValue (upl - upr) grol + -- ... or when the Fun00s are the same... + | grol == gror, upl >= upr = pure $ NumValue (upl - upr) Constant0 + -- ... or we have (c + 2^(k + 1) * x) - (c' + 2^k * x) + | StrictMonoFun (StrictMono k m) <- grol + , StrictMonoFun (StrictMono k' m') <- gror + , m == m' + , k == (k' + 1) + , upl >= upr = pure $ NumValue (upl - upr) gror +runArith (NumValue upl grol) Mul (NumValue upr gror) + -- We can multiply two constants... + | Constant0 <- grol + , Constant0 <- gror = pure $ NumValue (upl * upr) grol + -- ... or we can multiply by a power of 2 + | Constant0 <- grol + , StrictMonoFun (StrictMono k' m) <- gror + , Just k <- log2 upl = pure $ NumValue (upl * upr) (StrictMonoFun (StrictMono (k + k') m)) + | Constant0 <- gror + , StrictMonoFun (StrictMono k' m) <- grol + , Just k <- log2 upr = pure $ NumValue (upl * upr) (StrictMonoFun (StrictMono (k + k') m)) +runArith (NumValue upl grol) Pow (NumValue upr gror) + -- We can take constants to the power of constants... + | Constant0 <- grol + , Constant0 <- gror = pure $ NumValue (upl ^ upr) Constant0 + -- ... or we can take a constant to the power of a NumValue if the constant + -- is 2^(2^c) + | Constant0 <- grol + , Just l <- log2 upl + , Just k <- log2 l + , StrictMonoFun (StrictMono k' mono) <- gror + -- 2^(2^k) ^ (upr + (2^k' * mono)) + -- (2^(2^k))^upr * (2^(2^k))^(2^k' * mono) + -- 2^(2^k * upr) * 2^(2^k * (2^k' * mono)) + -- 2^(2^k * upr) * (1 + 2^(2^k * (2^k' * mono)) - 1) + -- 2^(2^k * upr) + 2^(2^k * upr) * (2^(2^k * (2^k' * mono)) - 1) + -- 2^(2^k * upr) + 2^(2^k * upr) * (full(2^k * (2^k' * mono)) + -- 2^(2^k * upr) + 2^(2^k * upr) * (full(2^(k + k') * mono)) + = pure $ NumValue (upl ^ upr) (StrictMonoFun (StrictMono (l * upr) (Full (StrictMono (k + k') mono)))) +runArith _ _ _ = Nothing diff --git a/brat/Brat/Search.hs b/brat/Brat/Search.hs index c4dfe3fd..6fb9e176 100644 --- a/brat/Brat/Search.hs +++ b/brat/Brat/Search.hs @@ -52,7 +52,7 @@ tokenValues fc (TList ty) = concat $ do [[vec fc (WC fc <$> list)]] tokenValues fc (TVec ty (VNum (NumValue n Constant0))) = do tm <- tokenValues fc ty - [vec fc (replicate n $ WC fc tm)] + [vec fc (replicate (fromIntegral n) $ WC fc tm)] tokenValues _ (TVec _ _) = [] -- not enough info -- HACK: Lookup in the default constructor table, rather than using the Checking -- monad to look up definitions. @@ -102,7 +102,7 @@ tokenValues fc (VFun Kerny (ss :->> ts)) = tokenSType (TVec TQ _) = [] tokenSType (TVec sty (VNum (NumValue n Constant0))) = do tm <- tokenSType sty - [vec fc (replicate n $ WC fc tm)] + [vec fc (replicate (fromIntegral n) $ WC fc tm)] tokenSType _ = [] tokenValues _ _ = [] diff --git a/brat/Brat/Syntax/Value.hs b/brat/Brat/Syntax/Value.hs index b38e364a..4f20fd8e 100644 --- a/brat/Brat/Syntax/Value.hs +++ b/brat/Brat/Syntax/Value.hs @@ -188,7 +188,7 @@ instance forall m top bot. MODEY m => Show (Ro' m top bot) where -------------------------------- Number Values --------------------------------- data NumVal n = NumValue - { upshift :: Int + { upshift :: Integer , grower :: Fun00 n } deriving Eq @@ -209,7 +209,7 @@ instance Show (Fun00 n) where -- Strictly increasing function data StrictMono n = StrictMono - { multBy2ToThe :: Int + { multBy2ToThe :: Integer , monotone :: Monotone n } deriving Eq @@ -234,7 +234,7 @@ class NumFun (t :: N -> Type) where numValue :: t n -> NumVal n instance NumFun NumVal where - numEval NumValue{..} = ((fromIntegral upshift) +) . numEval grower + numEval NumValue{..} = (upshift +) . numEval grower numValue = id instance NumFun Fun00 where @@ -267,16 +267,16 @@ nVar v = NumValue }) } -nConstant :: Int -> NumVal n +nConstant :: Integer -> NumVal n nConstant n = NumValue n Constant0 nZero :: NumVal n nZero = nConstant 0 -nPlus :: Int -> NumVal n -> NumVal n +nPlus :: Integer -> NumVal n -> NumVal n nPlus n (NumValue up g) = NumValue (up + n) g -n2PowTimes :: Int -> NumVal n -> NumVal n +n2PowTimes :: Integer -> NumVal n -> NumVal n n2PowTimes n NumValue{..} = NumValue { upshift = upshift * (2 ^ n) , grower = mult2PowGrower grower diff --git a/brat/Util.hs b/brat/Util.hs index e80b061b..456da740 100644 --- a/brat/Util.hs +++ b/brat/Util.hs @@ -44,3 +44,8 @@ names = do infixr 3 **^ infixr 3 ^** + +log2 :: Integer -> Maybe Integer +log2 m | m > 1, (n, 0) <- m `divMod` 2 = (1+) <$> log2 n +log2 1 = pure 0 +log2 _ = Nothing diff --git a/brat/brat.cabal b/brat/brat.cabal index ede87cf5..c44e4cf6 100644 --- a/brat/brat.cabal +++ b/brat/brat.cabal @@ -167,6 +167,7 @@ test-suite tests Test.Search, Test.Substitution, Test.Syntax.Let, + Test.TypeArith, Test.Util build-depends: base <5, diff --git a/brat/examples/adder.brat b/brat/examples/adder.brat index b7af760b..12e9c06e 100644 --- a/brat/examples/adder.brat +++ b/brat/examples/adder.brat @@ -48,17 +48,17 @@ if(X :: *, Bool, X, X) -> X if(_, true, then, _) = then if(_, false, _, else) = else -fastAdder(n :: #, Vec(Bool, succ(full(n))), Vec(Bool, succ(full(n))), carryIn :: Bool) -> carryOut :: Bool, Vec(Bool, succ(full(n))) +fastAdder(n :: #, Vec(Bool, 2^n), Vec(Bool, 2^n), carryIn :: Bool) -> carryOut :: Bool, Vec(Bool, 2^n) fastAdder(0, [x], [y], b) = let c, z = fullAdder(x, y, b) in c, [z] fastAdder(succ(n), xsh =,= xsl, ysh =,= ysl, b) = fastAdder(n, xsh, ysh, true), fastAdder(n, xsh, ysh, false), fastAdder(n, xsl, ysl, b) |> (d1, zsh1, d0, zsh0, c, zsl => if(Bool, c, d1, d0), if(Vec(Bool, succ(full(n))), c, zsh1, zsh0) =,= zsl) -chop(n :: #, Vec(Bool, doub(n))) -> Vec(Bool, n), Vec(Bool, n) +chop(n :: #, Vec(Bool, 2 * n)) -> Vec(Bool, n), Vec(Bool, n) chop(n, hi =,= lo) = hi, lo -multAndAddTwo(n :: #, mul1 :: Vec(Bool, succ(full(n))), mul2 :: Vec(Bool, succ(full(n))), add1 :: Vec(Bool, succ(full(n))), add2 :: Vec(Bool, succ(full(n)))) - -> Vec(Bool, succ(full(succ(n)))) +multAndAddTwo(n :: #, mul1 :: Vec(Bool, 2^n), mul2 :: Vec(Bool, 2^n), add1 :: Vec(Bool, 2^n), add2 :: Vec(Bool, 2^n)) + -> Vec(Bool, 2^(n + 1)) multAndAddTwo(0, [m1], [m2], [a1], [a2]) = let b, a = fullAdder(and(m1, m2), a1, a2) in [b,a] multAndAddTwo(succ(n), msh1 =,= msl1, msh2 =,= msl2, ash1 =,= asl1, ash2 =,= asl2) = let hilo1, lohi1, lohi2, lolo = chop(succ(full(n)), multAndAddTwo(n, msh1, msl2, ash1, ash2)), chop(succ(full(n)), multAndAddTwo(n, msl1, msl2, asl1, asl2)) in diff --git a/brat/examples/batcher-merge-sort.brat b/brat/examples/batcher-merge-sort.brat index cba60d8a..10e7f949 100644 --- a/brat/examples/batcher-merge-sort.brat +++ b/brat/examples/batcher-merge-sort.brat @@ -1,4 +1,4 @@ -sort(n :: #, Vec(Nat, succ(full(n)))) -> Vec(Nat, succ(full(n))) +sort(n :: #, Vec(Nat, 2^n)) -> Vec(Nat, 2^n) sort(0, [x]) = [x] sort(succ(n), xs =,= ys) = merge(n, sort(n, xs), sort(n, ys)) @@ -12,7 +12,7 @@ cas(succ(a), succ(b)) = let a', b' = cas(a, b) in succ(a'), succ(b') -- Should make some syntactic sugar for 2^n -- Merging two sorted vectors into one -merge(n :: #, Vec(Nat, succ(full(n))), Vec(Nat, succ(full(n)))) -> Vec(Nat, succ(full(succ(n)))) +merge(n :: #, Vec(Nat, 2^n), Vec(Nat, 2^n)) -> Vec(Nat, 2^(n + 1)) merge(0, [x], [y]) = cas(x, y) |> (x, y => [x, y]) -- Need to merge fan{in,out}! merge(succ(n), xs0 =%= xs1, ys0 =%= ys1) = fixOffBy1(succ(n), merge(n, xs0, ys0) =%= merge(n, xs1, ys1)) @@ -32,10 +32,10 @@ merge(succ(n), xs0 =%= xs1, ys0 =%= ys1) -- let mid0, mid1 = (full(n) of cas)(mid0, mid1) -- in lo, (mid0' =%= mid1'), hi -fixOffBy1(n :: #, Vec(Nat, succ(full(succ(n))))) -> Vec(Nat, succ(full(succ(n)))) +fixOffBy1(n :: #, Vec(Nat, 2^(n + 1))) -> Vec(Nat, 2^(n + 1)) fixOffBy1(n, lo ,- mid -, hi) = lo ,- casses(full(n), mid) -, hi -casses(m :: #, Vec(Nat, doub(m))) -> Vec(Nat, doub(m)) +casses(m :: #, Vec(Nat, 2 * m)) -> Vec(Nat, 2 * m) casses(0, []) = [] casses(succ(m), a ,- b ,- cs) = let a', b' = cas(a, b) in diff --git a/brat/test/Main.hs b/brat/test/Main.hs index 499b6d36..5f279373 100644 --- a/brat/test/Main.hs +++ b/brat/test/Main.hs @@ -15,6 +15,7 @@ import Test.Parsing import Test.Search import Test.Substitution import Test.Syntax.Let +import Test.TypeArith main = do failureTests <- getFailureTests @@ -35,4 +36,5 @@ main = do ,abstractorTests ,eqTests ,compilationTests + ,typeArithTests ] diff --git a/brat/test/Test/Search.hs b/brat/test/Test/Search.hs index 13345c02..2cd2634c 100644 --- a/brat/test/Test/Search.hs +++ b/brat/test/Test/Search.hs @@ -41,7 +41,7 @@ arbitrarySValue d = case d of vec d = do n <- chooseInt bounds ty <- arbitrarySValue d - pure (TVec ty (VNum (NumValue n Constant0))) -- Only the simplest values of `n` + pure (TVec ty (VNum (NumValue (fromIntegral n) Constant0))) -- Only the simplest values of `n` instance Arbitrary (Val Z) where diff --git a/brat/test/Test/TypeArith.hs b/brat/test/Test/TypeArith.hs new file mode 100644 index 00000000..30087f32 --- /dev/null +++ b/brat/test/Test/TypeArith.hs @@ -0,0 +1,91 @@ +-- Test our ability to do arithmetic in types +module Test.TypeArith where + +import Brat.Checker.Helpers (runArith) +import Brat.FC +import Brat.Naming (Name(..)) +import Brat.Syntax.Common (ArithOp(..), TypeKind(Nat)) +import Brat.Syntax.Port +import Brat.Syntax.Simple (SimpleTerm(..)) +import Brat.Syntax.Value +import Hasochism (N(..), Ny(..), Some(..), (:*)(..)) + +import Data.List (sort) +import Test.Tasty +import Test.Tasty.HUnit +import Test.Tasty.QuickCheck hiding ((^)) + +-- A dummy variable to make NumVals with +var = VPar (ExEnd (Ex (MkName []) 0)) + +instance Arbitrary (NumVal Z) where + arbitrary = NumValue <$> (abs <$> arbitrary) <*> arbitrary + +instance Arbitrary (Fun00 Z) where + arbitrary = sized aux + where + aux 0 = pure Constant0 + aux n = oneof [pure Constant0, StrictMonoFun <$> resize (n `div` 2) arbitrary] + +instance Arbitrary (StrictMono Z) where + arbitrary = StrictMono <$> (abs <$> arbitrary) <*> arbitrary + +instance Arbitrary (Monotone Z) where + arbitrary = sized aux + where + aux 0 = pure (Linear var) + aux n = oneof [Full <$> resize (n `div` 2) arbitrary, pure (Linear var)] + +adding = testProperty "adding" $ \x a' b' -> + let (a, b) = (abs a', abs b') + lhs = NumValue a x + rhs = NumValue b x + in Just (NumValue (a + b) x) == runArith lhs Add rhs + +subtractEq = testProperty "subtract equal Fun00" $ \x a' b' -> + let [b, a] = sort [abs a', abs b'] + lhs = NumValue a x + rhs = NumValue b x + in Just (NumValue (a - b) Constant0) == runArith lhs Sub rhs + +subtractConst = testProperty "subtract const" $ \x a' b' -> + let [b, a] = sort [abs a', abs b'] + lhs = NumValue a x + rhs = NumValue b Constant0 + in Just (NumValue (a - b) x) == runArith lhs Sub rhs + +subtractFactorOf2 = testProperty "subtract factor of 2" $ \x a' b' -> + let [b, a] = sort [abs a', abs b'] + lhs = nPlus a (n2PowTimes 1 (NumValue 0 x)) + rhs = NumValue b x + in Just (NumValue (a - b) x) == runArith lhs Sub rhs + +multiplyByPowerOf2 = testProperty "multiply by a power of 2" $ \x k' b' coin -> + let (k, b) = (abs k', abs b') + a = 2 ^ k + -- This should be commutative, so flip the arguments sometimes + (lhs, rhs) = if coin + then (NumValue a Constant0, NumValue b x) + else (NumValue b x, NumValue a Constant0) + NumValue 0 x' = n2PowTimes k (NumValue 0 x) + in Just (NumValue (a * b) x') == runArith lhs Mul rhs + +exponentiateConst = testProperty "exponentiate constants" $ \a' b' -> + let (a, b) = (abs a', abs b') + lhs = NumValue a Constant0 + rhs = NumValue b Constant0 + in Just (NumValue (a ^ b) Constant0) == runArith lhs Pow rhs + +exponentiatePow2 = testCase "(2^(2^k)) ^ n" $ assertEqual "" -- k = 2; 2^k = 4; 2^(2^k) = 16 + (runArith (nConstant 16) Pow (nVar var)) -- (2 ^ (2 ^ 4))^n + (Just (NumValue 1 (StrictMonoFun (StrictMono 0 (Full (StrictMono 2 (Linear var))))))) -- 1 + ((2^k * n) - 1) = 2 ^ 4n + +typeArithTests = testGroup "Type Arithmetic" + [adding + ,subtractEq + ,subtractConst + ,subtractFactorOf2 + ,multiplyByPowerOf2 + ,exponentiateConst + ,exponentiatePow2 + ] diff --git a/brat/test/golden/error/type-arith.brat b/brat/test/golden/error/type-arith.brat new file mode 100644 index 00000000..6f0909b7 --- /dev/null +++ b/brat/test/golden/error/type-arith.brat @@ -0,0 +1,2 @@ +f(n :: #, Vec(Nat, n ^ 3)) -> Bool +f = ?f diff --git a/brat/test/golden/error/type-arith.brat.golden b/brat/test/golden/error/type-arith.brat.golden new file mode 100644 index 00000000..ac2ab7bf --- /dev/null +++ b/brat/test/golden/error/type-arith.brat.golden @@ -0,0 +1,6 @@ +Error in test/golden/error/type-arith.brat@FC {start = Pos {line = 1, col = 20}, end = Pos {line = 1, col = 25}}: +f(n :: #, Vec(Nat, n ^ 3)) -> Bool + ^^^^^ + + Type error: Type level arithmetic too confusing + diff --git a/brat/test/golden/error/type-arith2.brat b/brat/test/golden/error/type-arith2.brat new file mode 100644 index 00000000..bb53c751 --- /dev/null +++ b/brat/test/golden/error/type-arith2.brat @@ -0,0 +1,2 @@ +f(n :: #, Vec(Nat, n * n)) -> Bool +f = ?f diff --git a/brat/test/golden/error/type-arith2.brat.golden b/brat/test/golden/error/type-arith2.brat.golden new file mode 100644 index 00000000..f6e41097 --- /dev/null +++ b/brat/test/golden/error/type-arith2.brat.golden @@ -0,0 +1,6 @@ +Error in test/golden/error/type-arith2.brat@FC {start = Pos {line = 1, col = 20}, end = Pos {line = 1, col = 25}}: +f(n :: #, Vec(Nat, n * n)) -> Bool + ^^^^^ + + Type error: Type level arithmetic too confusing +