Skip to content

Commit

Permalink
Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
bmsherman committed Feb 24, 2020
1 parent 33c76ff commit bfe0680
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 25 deletions.
68 changes: 50 additions & 18 deletions src/FwdMode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ addD (D x) (D y) = D (dSum x y)
scalarMultD :: VectorSpace v => g :~> M.Real -> g :~> v -> g :~> v
scalarMultD (D c) (D x) = D (scalarMult c x)

multD :: RE.CNum a => g :~> a -> g :~> a -> g :~> a
multD (D x) (D y) = D (dMult x y)

{-| Composition of two smooth maps yields a smooth map -}
(@.) :: Additive c => (b :~> c) -> (a :~> b) -> (a :~> c)
(D g@(g0 :# g')) @. (D f@(f0 :# f')) = D $
Expand Down Expand Up @@ -238,8 +241,27 @@ dap2 :: Additive c => (a, b) :~> c -> g :~> a -> g :~> b -> g :~> c
dap2 f x y = f @. pairD x y

-- Seems right. Could inline scalarMult if I wanted
-- lift1 :: RE.CNum a => CMap a a -> a :~> a -> a :~> a
-- lift1 f (D f') = D $ (f <<< arr fst) :# dMult (dWkn (arr snd) f') (arr (fst . snd) :# dZero)

fromFuncs :: RE.CNum a => [CMap a a] -> a :~> a
fromFuncs = D . go 1
where
go :: RE.CNum a => CMap (a, k) a -> [CMap a a] -> Df a a a k
go prods (f : fs) = ((f <<< arr fst) * prods) :# go ((prods <<< (arr id *** arr snd)) * arr (fst . snd)) fs

toFuncs :: RE.CNum a => a :~> a -> [CMap a a]
toFuncs (D f) = go f (arr (\_ -> ())) where
go :: RE.CNum a => Df g a b k -> CMap g k -> [CMap g b]
go (g :# g') y = (g <<< (C.id &&& y)) : go g' (1 &&& y)

-- Alternative version
-- lift1 :: RE.CNum a => CMap a a -> a :~> a -> a :~> a
-- lift1 f f' = fromFwd f (multD (f' @. fstD) sndD)

-- Pray this one works
lift1 :: RE.CNum a => CMap a a -> a :~> a -> a :~> a
lift1 f (D f') = D $ (f <<< arr fst) :# dMult (dWkn (arr snd) f') (arr (fst . snd) :# dZero)
lift1 f f' = fromFuncs (f : toFuncs f')

negate' :: RE.CNum a => a :~> a
negate' = linearD RE.cnegate
Expand Down Expand Up @@ -272,18 +294,25 @@ sqrt' :: RE.CFloating a => a :~> a
sqrt' = lift1 RE.csqrt (recip' @. linearD ((2 *) (arr id)) @. sqrt')
pow' :: R.Rounded a => Int -> Interval a :~> Interval a
pow' 0 = lift1 1 zeroD
pow' 2 = lift1 (RE.pow 2) (linearD ((2 *) (arr id)))
pow' k = lift1 (RE.pow k) (linearD ((fromIntegral k *) (arr id)) @. (pow' (k - 1)))

square'' :: R.Rounded a => Interval a :~> Interval a
square'' = lift1 (RE.pow 2) (linearD ((2 *) (arr id)))


partialIfThenElse :: R.Rounded a => CMap g (Maybe Bool) -> g :~> Interval a -> g :~> Interval a -> g :~> Interval a
partialIfThenElse cond (D (t :# t')) (D (f :# f')) = D ((RE.partialIfThenElse cond t1 f1 <<< arr (\(x, ()) -> x)) :# partialIfThenElse' cond t' f')
where
t1 = t <<< arr (\x -> (x, ()))
f1 = f <<< arr (\x -> (x, ()))

getDerivTower :: R.Rounded a => Interval a :~> Interval a -> CMap g (Interval a) -> [CMap g (Interval a)]
getDerivTower (D f) x = go (wknValue x f) (arr (\_ -> ())) where
go :: R.Rounded a => Df g (Interval a) b k -> CMap g k -> [CMap g b]
go (g :# g') y = (g <<< (C.id &&& y)) : go g' (1 &&& y)
-- XXX: CHANGED IT TO TEST BROKEN lift1
-- CHANGE THE 0 back to a 1 AFTER!!!
getDerivTower :: RE.CNum a => a :~> a -> CMap g a -> [CMap g a]
getDerivTower (D f) x = go 2 (wknValue x f) (arr (\_ -> ())) where
go :: RE.CNum a => CMap g a -> Df g a b k -> CMap g k -> [CMap g b]
go dx (g :# g') y = (g <<< (C.id &&& y)) : go 0 g' (dx &&& y)

getValue :: g :~> a -> CMap g a
getValue (D (f :# f')) = f <<< arr (\x -> (x, ()))
Expand All @@ -299,12 +328,6 @@ fwdDeriv' g (f :# f') = (f <<< g) :# fwdDeriv' g2 f'
fwdDeriv :: Additive a => (g, a) :~> b -> (g, (a, a)):~> b
fwdDeriv (D f) = D (fwdDeriv' (arr (\((g, (a, da)), ()) -> ((g, a), ()))) f)

getDeriv :: Additive g => g :~> a -> g :~> a
getDeriv (D (f :# f')) = D (dWkn1 (zeroV &&& arr snd) f')

justDeriv :: Additive g => Additive a => g :~> a -> g :~> a
justDeriv (D f) = D (zeroV :# dWkn zeroV f)

genDeriv' :: Additive d => CMap (g, k) a
-> Df g (d, a) b k -> Df g (d, a) b k
genDeriv' dx (f :# f') = dWkn1 ((zeroV &&& dx) &&& arr snd) f'
Expand All @@ -323,6 +346,15 @@ genDeriv'' :: CMap (g, k) a
-> Df g a b k -> Df g a b k
genDeriv'' dx (f :# f') = dWkn1 (dx &&& arr snd) f'

-- The part of the fwd(f) that gives us these:
-- fwd(f)^(0)(x, v) = f^(1)(v)
-- fwd(f)^(1)(x, v)(dxa, dva) = f^(2)(v, dxa)
-- fwd(f)^(2)(x, v)(dxa, dva)(dxb, dvb) = f^(3)(v, dxa, dxb)
fwdDerU :: Df g g b (g, k) -> Df (g, g) (g, g) b k
fwdDerU = dWknA (arr fst) -- ignore all of the dv's
. dWkn1' (arr (\((x, v), dxs) -> (x, (v, dxs)))) -- make the input v the first dx


fwdDerDU :: Additive b => CMap k' k ->
Df g g b (g, k) -> Df (g, g) (g, g) b ((g, g), k')
fwdDerDU fext f@(f0 :# f') = f1 :#
Expand All @@ -333,10 +365,6 @@ fwdDerDU fext f@(f0 :# f') = f1 :#
f0 -< (x, (du, k))
fext' = arr fst *** fext

fwdDerU :: Df g g b (g, k) -> Df (g, g) (g, g) b k
fwdDerU = dWknA (arr fst) . dWkn1' (arr (\((x, dx), dxs) -> (x, (dx, dxs))))


fwdDerDUs :: Additive b => CMap k' k ->
Df g g b (g, k) -> Df (g, g) (g, g) b ((g, g), k')
fwdDerDUs fext f'@(_ :# f'') =
Expand All @@ -345,8 +373,7 @@ fwdDerDUs fext f'@(_ :# f'') =

fwdDer' :: Additive b =>
Df g g b k -> Df (g, g) (g, g) b k
fwdDer' (f :# f'@(f0' :# f'')) =
dSum (fwdDerU f') (zeroV :# fwdDerDUs (arr id) f')
fwdDer' (f :# f') = dSum (fwdDerU f') (zeroV :# fwdDerDUs (arr id) f')

fwdDer :: Additive b => g :~> b -> (g, g) :~> b
fwdDer (D f) = D (fwdDer' f)
Expand All @@ -373,7 +400,12 @@ diffeoExample n = getDerivTower (exp' @. linearD ((*2) C.id)) (E.asMPFR 0) !! n
-- A generalization of lift1. Give a continuous map for the value,
-- and a smooth map of the derivative for the rest.
fromFwd :: Additive a => CMap a b -> (a, a) :~> b -> a :~> b
fromFwd f (D f') = D $ (f <<< arr fst) :# convertFwdDeriv f'
fromFwd f (D f') = D $ (f <<< arr fst) :# convertFwdDerivFixed f'

-- the opposite of `fwdDerU`
convertFwdDerivFixed :: Additive g => Df (g, g) (g, g) b k -> Df g g b (g, k)
convertFwdDerivFixed = dWkn1' (arr (\(x, (v, dxs)) -> ((x, v), dxs)))
. dWknA (arr id &&& zeroV) -- make all the dv's 0s

convertFwdDeriv :: Additive a => Df (a, a) (a, a) b () -> Df a a b (a, ())
convertFwdDeriv = convertFwdDeriv' (arr id)
Expand Down
4 changes: 4 additions & 0 deletions src/Types/KShape.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Types.OShape (OShape)
import qualified Types.OShape as O

type KShape a = (a :=> SBool) :=> SBool
-- type Maximizer a = (a :=> DReal) :=> DReal

point :: Additive g => PShD a => a g -> KShape a g
point x = ArrD $ \wk f -> f # dmap wk x
Expand Down Expand Up @@ -116,3 +117,6 @@ simplerMaximization = supremum (intersect unit_interval (ArrD $ \wk x -> x < 0.5
-- Still not converging, but it should
simplerMaximizationDeriv :: DReal ()
simplerMaximizationDeriv = deriv (ArrD (\_ r -> supremum (intersect unit_interval (ArrD $ \wk x -> x < dmap wk r)))) 0.5

simpleDerivTest :: DReal ()
simpleDerivTest = deriv (ArrD (\_ c -> supremum ((map (ArrD (\wk x -> dmap wk c * x))) unit_interval))) 1.0
47 changes: 40 additions & 7 deletions src/Types/SmoothBool.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ testBSqrt z = let R f = dedekind_cut (ArrD (\c x -> x < 0 || x^2 < R c)) in
getDerivTower f z

testBCubert :: CPoint Real -> [CPoint Real]
testBCubert z = let R f = dedekind_cut (ArrD (\c x -> x < 0 || x^3 < R c)) in
testBCubert z = let R f = dedekind_cut (ArrD (\c x -> x^3 < R c)) in
getDerivTower f z

-- Only working via bisection, so derivatives must not be good.
Expand Down Expand Up @@ -118,13 +118,46 @@ simplerMaximizationPart = getDerivTower' (\q -> FwdPSh.argmax01 (\x -> min1 (0.5
simplerMaximizationPartExample :: CPoint Real
simplerMaximizationPartExample = simplerMaximizationPart 0.4 !! 1

-- BROKEN!
tester :: (Real, (Real, Real)) :~> Real
tester = fwdSecondDer ((\q-> pow q 2) dId)

-- Still not returning 0 when it should!
evalTester :: () :~> Real
evalTester = let f = ((\q-> pow q 2) dId) in
let f' = fwdDer f in
fwdDer f' @. pairD (pairD 0 1) (pairD 0 0)

tester1 :: CPoint Real -> [CPoint Real]
tester1 = let f = ((\q -> pow q 2) dId) in
getDerivTower' (\x -> let f' = fwdDer f in f' @. pairD x 0)
tester1 = let f = ((\q-> pow q 2) dId) in
getDerivTower (fwdDer f @. pairD 0 1)

-- tester1 x !! n = 2.0
-- when n >= 1
-- This is BROKEN!
-- should be 0, because it is the constant 0 function.
-- The error is in (@. dId)!

tester2 :: CPoint Real -> [CPoint Real]
tester2 = let f = ((\q -> pow q 2) dId) in getDerivTower f
tester2 = let f = ((\q-> pow q 3) dId) in
getDerivTower f

-- When I look at the derivatives for f(x) = x^3, I find that
-- f^(3)(dx1, dx2, dx3) = 6 * dx1^3
-- rather than 6 * dx1 * dx2 * dx3


-- !!! (pow q 2) dId is not the same as pow' 2!!!
-- pow q 2 = pow' 2 @. dId
-- i.e., f @. dId =/= f

tester3 :: CPoint Real -> [CPoint Real]
tester3 = let f = (square'' @. dId) in
getDerivTower f

bloat :: a -> [[a]] -> [[[a]]]
bloat x [] = [[[x]]]
bloat x (xs:xss) = ((x:xs):xss) : map (xs:) (bloat x xss)

-- tester1 0 !! 1 = 2.0
-- This is very very wrong! It should be constant 0, since the dx is 0.
partitions :: [a] -> [[[a]]]
partitions [] = [[]]
partitions (x:xs) = concatMap (bloat x) (partitions xs)

0 comments on commit bfe0680

Please sign in to comment.