Skip to content

Commit

Permalink
[new] Check arith expr args in # positions (#397)
Browse files Browse the repository at this point in the history
Allow arithmetic expressions in type signatures with some bespoke logic for handling them as NumValues
  • Loading branch information
croyzor committed Jul 23, 2024
1 parent aad2d4c commit 85f1274
Show file tree
Hide file tree
Showing 16 changed files with 208 additions and 22 deletions.
28 changes: 24 additions & 4 deletions brat/Brat/Checker.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import Brat.Constructors
import Brat.Error
import Brat.Eval
import Brat.FC hiding (end)
import qualified Brat.FC as FC
import Brat.Graph
import Brat.Naming
-- import Brat.Search
Expand Down Expand Up @@ -395,7 +396,7 @@ check' (Simple tm) ((), ((hungry, ty):unders)) = do
(Braty, Left Nat, Num n) -> do
(_, _, [(dangling, _)], _) <- next "" (Const (Num n)) (S0,Some (Zy :* S0))
R0 (REx ("value", Nat) (S0 ::- R0))
let val = VNum (nConstant n)
let val = VNum (nConstant (fromIntegral n))
defineSrc dangling val
defineTgt hungry val
wire (dangling, kindType Nat, hungry)
Expand Down Expand Up @@ -543,11 +544,29 @@ kindCheck unders (Emb (WC fc (Var v))) = localFC fc $ vlup v >>= f unders
-- TODO: Add other operations on numbers
kindCheck ((hungry, Nat):unders) (Simple (Num n)) | n >= 0 = do
(_, _, [(dangling, _)], _) <- next "" (Const (Num n)) (S0,Some (Zy :* S0)) R0 (REx ("value", Nat) (S0 ::- R0))
let value = VNum (nConstant n)
let value = VNum (nConstant (fromIntegral n))
defineTgt hungry value
defineSrc dangling value
wire (dangling, TNat, hungry)
pure ([value], unders)
kindCheck ((hungry, Nat):unders) (Arith op lhs rhs) = do
(_, arithUnders, [(dangling,_)], _) <- next ("arith_" ++ show op) (ArithNode op) (S0, Some (Zy :* S0))
(REx ("lhs", Nat) (S0 ::- (REx ("rhs", Nat) (S0 ::- R0))))
(REx ("value", Nat) (S0 ::- R0))
([vlhs, vrhs], []) <- kindCheck [ (p, k) | (p, Left k) <- arithUnders ] (lhs :|: rhs)
let arithFC = FC (FC.start (fcOf lhs)) (FC.end (fcOf rhs))
localFC arithFC $ case (vlhs, vrhs) of
(VNum lhs, VNum rhs) -> do
case runArith lhs op rhs of
Nothing -> typeErr "Type level arithmetic too confusing"
Just result -> do
defineTgt hungry (VNum result)
defineSrc dangling (VNum result)
wire (dangling, kindType Nat, hungry)
pure ([VNum result], unders)
(VNum _, x) -> localFC (fcOf rhs) . typeErr $ "Expected numeric expression, found " ++ show x
(x, VNum _) -> localFC (fcOf lhs) . typeErr $ "Expected numeric expression, found " ++ show x
_ -> typeErr "Expected arguments to arithmetic operators to have kind #"
kindCheck ((hungry, Nat):unders) (Con c arg)
| Just (_, f) <- M.lookup c natConstructors = do
-- All nat constructors have one input and one output
Expand All @@ -562,7 +581,8 @@ kindCheck ((hungry, Nat):unders) (Con c arg)
defineTgt hungry v
pure ([v], unders)

kindCheck unders tm = err $ Unimplemented "kindCheck" [showRow unders, show tm]
kindCheck ((_, k):_) tm = typeErr $ "Expected " ++ show tm ++ " to have kind " ++ show k


-- Checks the kinds of the types in a dependent row
kindCheckRow :: Modey m
Expand Down Expand Up @@ -713,7 +733,7 @@ abstractPattern Braty (dangling, Left k) pat = abstractKind k pat
abstractKind _ (Bind x) = let ?my = Braty in singletonEnv x (dangling, Left k)
abstractKind _ (DontCare) = pure emptyEnv
abstractKind k (Lit x) = case (k, x) of
(Nat, Num n) -> defineSrc dangling (VNum (nConstant n)) $> emptyEnv
(Nat, Num n) -> defineSrc dangling (VNum (nConstant (fromIntegral n))) $> emptyEnv
(Star _, _) -> err MatchingOnTypes
_ -> err (PattErr $ "Couldn't resolve pattern " ++ show pat ++ " with kind " ++ show k)
-- abstractKind Braty Nat p = abstractPattern Braty (src, Right TNat) p
Expand Down
2 changes: 1 addition & 1 deletion brat/Brat/Checker/Clauses.hs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ solve my ((src, Lit tm):p) = do
(Braty, Left Nat)
| Num n <- tm -> do
unless (n >= 0) $ typeErr "Negative Nat kind"
unifyNum (nConstant n) (nVar (VPar (ExEnd (end src))))
unifyNum (nConstant (fromIntegral n)) (nVar (VPar (ExEnd (end src))))
(Braty, Right ty) -> do
throwLeft (simpleCheck Braty ty tm)
_ -> typeErr $ "Literal " ++ show tm ++ " isn't valid at this type"
Expand Down
51 changes: 51 additions & 0 deletions brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import Brat.Syntax.Value
import Brat.UserName
import Bwd
import Hasochism
import Util (log2)

import Control.Monad.Freer (req, Free(Ret))
import Control.Arrow ((***))
Expand Down Expand Up @@ -376,3 +377,53 @@ roToTuple R0 = TNil
roToTuple (RPr (_, ty) ro) = TCons ty (roToTuple ro)
roToTuple (REx _ ro) = case ro of
_ -> error "the impossible happened"

-- Low hanging fruit that we can easily do to our normal forms of numbers
runArith :: NumVal Z -> ArithOp -> NumVal Z -> Maybe (NumVal Z)
runArith (NumValue upl grol) Add (NumValue upr gror)
-- We can add when one of the sides is a constant...
| Constant0 <- grol = pure $ NumValue (upl + upr) gror
| Constant0 <- gror = pure $ NumValue (upl + upr) grol
-- ... or when Fun00s are the same
| grol == gror = pure $ NumValue (upl + upr) grol
runArith (NumValue upl grol) Sub (NumValue upr gror)
-- We can subtract when the rhs is a constant...
| Constant0 <- gror, upl >= upr = pure $ NumValue (upl - upr) grol
-- ... or when the Fun00s are the same...
| grol == gror, upl >= upr = pure $ NumValue (upl - upr) Constant0
-- ... or we have (c + 2^(k + 1) * x) - (c' + 2^k * x)
| StrictMonoFun (StrictMono k m) <- grol
, StrictMonoFun (StrictMono k' m') <- gror
, m == m'
, k == (k' + 1)
, upl >= upr = pure $ NumValue (upl - upr) gror
runArith (NumValue upl grol) Mul (NumValue upr gror)
-- We can multiply two constants...
| Constant0 <- grol
, Constant0 <- gror = pure $ NumValue (upl * upr) grol
-- ... or we can multiply by a power of 2
| Constant0 <- grol
, StrictMonoFun (StrictMono k' m) <- gror
, Just k <- log2 upl = pure $ NumValue (upl * upr) (StrictMonoFun (StrictMono (k + k') m))
| Constant0 <- gror
, StrictMonoFun (StrictMono k' m) <- grol
, Just k <- log2 upr = pure $ NumValue (upl * upr) (StrictMonoFun (StrictMono (k + k') m))
runArith (NumValue upl grol) Pow (NumValue upr gror)
-- We can take constants to the power of constants...
| Constant0 <- grol
, Constant0 <- gror = pure $ NumValue (upl ^ upr) Constant0
-- ... or we can take a constant to the power of a NumValue if the constant
-- is 2^(2^c)
| Constant0 <- grol
, Just l <- log2 upl
, Just k <- log2 l
, StrictMonoFun (StrictMono k' mono) <- gror
-- 2^(2^k) ^ (upr + (2^k' * mono))
-- (2^(2^k))^upr * (2^(2^k))^(2^k' * mono)
-- 2^(2^k * upr) * 2^(2^k * (2^k' * mono))
-- 2^(2^k * upr) * (1 + 2^(2^k * (2^k' * mono)) - 1)
-- 2^(2^k * upr) + 2^(2^k * upr) * (2^(2^k * (2^k' * mono)) - 1)
-- 2^(2^k * upr) + 2^(2^k * upr) * (full(2^k * (2^k' * mono))
-- 2^(2^k * upr) + 2^(2^k * upr) * (full(2^(k + k') * mono))
= pure $ NumValue (upl ^ upr) (StrictMonoFun (StrictMono (l * upr) (Full (StrictMono (k + k') mono))))
runArith _ _ _ = Nothing
4 changes: 2 additions & 2 deletions brat/Brat/Search.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ tokenValues fc (TList ty) = concat $ do
[[vec fc (WC fc <$> list)]]
tokenValues fc (TVec ty (VNum (NumValue n Constant0))) = do
tm <- tokenValues fc ty
[vec fc (replicate n $ WC fc tm)]
[vec fc (replicate (fromIntegral n) $ WC fc tm)]
tokenValues _ (TVec _ _) = [] -- not enough info
-- HACK: Lookup in the default constructor table, rather than using the Checking
-- monad to look up definitions.
Expand Down Expand Up @@ -102,7 +102,7 @@ tokenValues fc (VFun Kerny (ss :->> ts)) =
tokenSType (TVec TQ _) = []
tokenSType (TVec sty (VNum (NumValue n Constant0))) = do
tm <- tokenSType sty
[vec fc (replicate n $ WC fc tm)]
[vec fc (replicate (fromIntegral n) $ WC fc tm)]
tokenSType _ = []
tokenValues _ _ = []

Expand Down
12 changes: 6 additions & 6 deletions brat/Brat/Syntax/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ instance forall m top bot. MODEY m => Show (Ro' m top bot) where

-------------------------------- Number Values ---------------------------------
data NumVal n = NumValue
{ upshift :: Int
{ upshift :: Integer
, grower :: Fun00 n
} deriving Eq

Expand All @@ -209,7 +209,7 @@ instance Show (Fun00 n) where

-- Strictly increasing function
data StrictMono n = StrictMono
{ multBy2ToThe :: Int
{ multBy2ToThe :: Integer
, monotone :: Monotone n
} deriving Eq

Expand All @@ -234,7 +234,7 @@ class NumFun (t :: N -> Type) where
numValue :: t n -> NumVal n

instance NumFun NumVal where
numEval NumValue{..} = ((fromIntegral upshift) +) . numEval grower
numEval NumValue{..} = (upshift +) . numEval grower
numValue = id

instance NumFun Fun00 where
Expand Down Expand Up @@ -267,16 +267,16 @@ nVar v = NumValue
})
}

nConstant :: Int -> NumVal n
nConstant :: Integer -> NumVal n
nConstant n = NumValue n Constant0

nZero :: NumVal n
nZero = nConstant 0

nPlus :: Int -> NumVal n -> NumVal n
nPlus :: Integer -> NumVal n -> NumVal n
nPlus n (NumValue up g) = NumValue (up + n) g

n2PowTimes :: Int -> NumVal n -> NumVal n
n2PowTimes :: Integer -> NumVal n -> NumVal n
n2PowTimes n NumValue{..}
= NumValue { upshift = upshift * (2 ^ n)
, grower = mult2PowGrower grower
Expand Down
5 changes: 5 additions & 0 deletions brat/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ names = do

infixr 3 **^
infixr 3 ^**

log2 :: Integer -> Maybe Integer
log2 m | m > 1, (n, 0) <- m `divMod` 2 = (1+) <$> log2 n
log2 1 = pure 0
log2 _ = Nothing
1 change: 1 addition & 0 deletions brat/brat.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ test-suite tests
Test.Search,
Test.Substitution,
Test.Syntax.Let,
Test.TypeArith,
Test.Util

build-depends: base <5,
Expand Down
8 changes: 4 additions & 4 deletions brat/examples/adder.brat
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ if(X :: *, Bool, X, X) -> X
if(_, true, then, _) = then
if(_, false, _, else) = else

fastAdder(n :: #, Vec(Bool, succ(full(n))), Vec(Bool, succ(full(n))), carryIn :: Bool) -> carryOut :: Bool, Vec(Bool, succ(full(n)))
fastAdder(n :: #, Vec(Bool, 2^n), Vec(Bool, 2^n), carryIn :: Bool) -> carryOut :: Bool, Vec(Bool, 2^n)
fastAdder(0, [x], [y], b) = let c, z = fullAdder(x, y, b) in c, [z]
fastAdder(succ(n), xsh =,= xsl, ysh =,= ysl, b) =
fastAdder(n, xsh, ysh, true), fastAdder(n, xsh, ysh, false), fastAdder(n, xsl, ysl, b) |>
(d1, zsh1, d0, zsh0, c, zsl => if(Bool, c, d1, d0), if(Vec(Bool, succ(full(n))), c, zsh1, zsh0) =,= zsl)

chop(n :: #, Vec(Bool, doub(n))) -> Vec(Bool, n), Vec(Bool, n)
chop(n :: #, Vec(Bool, 2 * n)) -> Vec(Bool, n), Vec(Bool, n)
chop(n, hi =,= lo) = hi, lo

multAndAddTwo(n :: #, mul1 :: Vec(Bool, succ(full(n))), mul2 :: Vec(Bool, succ(full(n))), add1 :: Vec(Bool, succ(full(n))), add2 :: Vec(Bool, succ(full(n))))
-> Vec(Bool, succ(full(succ(n))))
multAndAddTwo(n :: #, mul1 :: Vec(Bool, 2^n), mul2 :: Vec(Bool, 2^n), add1 :: Vec(Bool, 2^n), add2 :: Vec(Bool, 2^n))
-> Vec(Bool, 2^(n + 1))
multAndAddTwo(0, [m1], [m2], [a1], [a2]) = let b, a = fullAdder(and(m1, m2), a1, a2) in [b,a]
multAndAddTwo(succ(n), msh1 =,= msl1, msh2 =,= msl2, ash1 =,= asl1, ash2 =,= asl2)
= let hilo1, lohi1, lohi2, lolo = chop(succ(full(n)), multAndAddTwo(n, msh1, msl2, ash1, ash2)), chop(succ(full(n)), multAndAddTwo(n, msl1, msl2, asl1, asl2)) in
Expand Down
8 changes: 4 additions & 4 deletions brat/examples/batcher-merge-sort.brat
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sort(n :: #, Vec(Nat, succ(full(n)))) -> Vec(Nat, succ(full(n)))
sort(n :: #, Vec(Nat, 2^n)) -> Vec(Nat, 2^n)
sort(0, [x]) = [x]
sort(succ(n), xs =,= ys) = merge(n, sort(n, xs), sort(n, ys))

Expand All @@ -12,7 +12,7 @@ cas(succ(a), succ(b)) = let a', b' = cas(a, b) in succ(a'), succ(b')

-- Should make some syntactic sugar for 2^n
-- Merging two sorted vectors into one
merge(n :: #, Vec(Nat, succ(full(n))), Vec(Nat, succ(full(n)))) -> Vec(Nat, succ(full(succ(n))))
merge(n :: #, Vec(Nat, 2^n), Vec(Nat, 2^n)) -> Vec(Nat, 2^(n + 1))
merge(0, [x], [y]) = cas(x, y) |> (x, y => [x, y]) -- Need to merge fan{in,out}!
merge(succ(n), xs0 =%= xs1, ys0 =%= ys1)
= fixOffBy1(succ(n), merge(n, xs0, ys0) =%= merge(n, xs1, ys1))
Expand All @@ -32,10 +32,10 @@ merge(succ(n), xs0 =%= xs1, ys0 =%= ys1)
-- let mid0, mid1 = (full(n) of cas)(mid0, mid1)
-- in lo, (mid0' =%= mid1'), hi

fixOffBy1(n :: #, Vec(Nat, succ(full(succ(n))))) -> Vec(Nat, succ(full(succ(n))))
fixOffBy1(n :: #, Vec(Nat, 2^(n + 1))) -> Vec(Nat, 2^(n + 1))
fixOffBy1(n, lo ,- mid -, hi) = lo ,- casses(full(n), mid) -, hi

casses(m :: #, Vec(Nat, doub(m))) -> Vec(Nat, doub(m))
casses(m :: #, Vec(Nat, 2 * m)) -> Vec(Nat, 2 * m)
casses(0, []) = []
casses(succ(m), a ,- b ,- cs) =
let a', b' = cas(a, b) in
Expand Down
2 changes: 2 additions & 0 deletions brat/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Test.Parsing
import Test.Search
import Test.Substitution
import Test.Syntax.Let
import Test.TypeArith

main = do
failureTests <- getFailureTests
Expand All @@ -35,4 +36,5 @@ main = do
,abstractorTests
,eqTests
,compilationTests
,typeArithTests
]
2 changes: 1 addition & 1 deletion brat/test/Test/Search.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ arbitrarySValue d = case d of
vec d = do
n <- chooseInt bounds
ty <- arbitrarySValue d
pure (TVec ty (VNum (NumValue n Constant0))) -- Only the simplest values of `n`
pure (TVec ty (VNum (NumValue (fromIntegral n) Constant0))) -- Only the simplest values of `n`


instance Arbitrary (Val Z) where
Expand Down
91 changes: 91 additions & 0 deletions brat/test/Test/TypeArith.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
-- Test our ability to do arithmetic in types
module Test.TypeArith where

import Brat.Checker.Helpers (runArith)
import Brat.FC
import Brat.Naming (Name(..))
import Brat.Syntax.Common (ArithOp(..), TypeKind(Nat))
import Brat.Syntax.Port
import Brat.Syntax.Simple (SimpleTerm(..))
import Brat.Syntax.Value
import Hasochism (N(..), Ny(..), Some(..), (:*)(..))

import Data.List (sort)
import Test.Tasty
import Test.Tasty.HUnit
import Test.Tasty.QuickCheck hiding ((^))

-- A dummy variable to make NumVals with
var = VPar (ExEnd (Ex (MkName []) 0))

instance Arbitrary (NumVal Z) where
arbitrary = NumValue <$> (abs <$> arbitrary) <*> arbitrary

instance Arbitrary (Fun00 Z) where
arbitrary = sized aux
where
aux 0 = pure Constant0
aux n = oneof [pure Constant0, StrictMonoFun <$> resize (n `div` 2) arbitrary]

instance Arbitrary (StrictMono Z) where
arbitrary = StrictMono <$> (abs <$> arbitrary) <*> arbitrary

instance Arbitrary (Monotone Z) where
arbitrary = sized aux
where
aux 0 = pure (Linear var)
aux n = oneof [Full <$> resize (n `div` 2) arbitrary, pure (Linear var)]

adding = testProperty "adding" $ \x a' b' ->
let (a, b) = (abs a', abs b')
lhs = NumValue a x
rhs = NumValue b x
in Just (NumValue (a + b) x) == runArith lhs Add rhs

subtractEq = testProperty "subtract equal Fun00" $ \x a' b' ->
let [b, a] = sort [abs a', abs b']
lhs = NumValue a x
rhs = NumValue b x
in Just (NumValue (a - b) Constant0) == runArith lhs Sub rhs

subtractConst = testProperty "subtract const" $ \x a' b' ->
let [b, a] = sort [abs a', abs b']
lhs = NumValue a x
rhs = NumValue b Constant0
in Just (NumValue (a - b) x) == runArith lhs Sub rhs

subtractFactorOf2 = testProperty "subtract factor of 2" $ \x a' b' ->
let [b, a] = sort [abs a', abs b']
lhs = nPlus a (n2PowTimes 1 (NumValue 0 x))
rhs = NumValue b x
in Just (NumValue (a - b) x) == runArith lhs Sub rhs

multiplyByPowerOf2 = testProperty "multiply by a power of 2" $ \x k' b' coin ->
let (k, b) = (abs k', abs b')
a = 2 ^ k
-- This should be commutative, so flip the arguments sometimes
(lhs, rhs) = if coin
then (NumValue a Constant0, NumValue b x)
else (NumValue b x, NumValue a Constant0)
NumValue 0 x' = n2PowTimes k (NumValue 0 x)
in Just (NumValue (a * b) x') == runArith lhs Mul rhs

exponentiateConst = testProperty "exponentiate constants" $ \a' b' ->
let (a, b) = (abs a', abs b')
lhs = NumValue a Constant0
rhs = NumValue b Constant0
in Just (NumValue (a ^ b) Constant0) == runArith lhs Pow rhs

exponentiatePow2 = testCase "(2^(2^k)) ^ n" $ assertEqual "" -- k = 2; 2^k = 4; 2^(2^k) = 16
(runArith (nConstant 16) Pow (nVar var)) -- (2 ^ (2 ^ 4))^n
(Just (NumValue 1 (StrictMonoFun (StrictMono 0 (Full (StrictMono 2 (Linear var))))))) -- 1 + ((2^k * n) - 1) = 2 ^ 4n

typeArithTests = testGroup "Type Arithmetic"
[adding
,subtractEq
,subtractConst
,subtractFactorOf2
,multiplyByPowerOf2
,exponentiateConst
,exponentiatePow2
]
2 changes: 2 additions & 0 deletions brat/test/golden/error/type-arith.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f(n :: #, Vec(Nat, n ^ 3)) -> Bool
f = ?f
6 changes: 6 additions & 0 deletions brat/test/golden/error/type-arith.brat.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Error in test/golden/error/type-arith.brat@FC {start = Pos {line = 1, col = 20}, end = Pos {line = 1, col = 25}}:
f(n :: #, Vec(Nat, n ^ 3)) -> Bool
^^^^^

Type error: Type level arithmetic too confusing

2 changes: 2 additions & 0 deletions brat/test/golden/error/type-arith2.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f(n :: #, Vec(Nat, n * n)) -> Bool
f = ?f
Loading

0 comments on commit 85f1274

Please sign in to comment.