-
Notifications
You must be signed in to change notification settings - Fork 130
Description
Description
I am working on a machine learning/mathematical optimization library with accelerate for array computation. An optimization algorithm typically takes an objective function, like A.Acc (A.Vector Bool) -> A.Acc (A.Scalar b), as an argument. Not all objective functions can be described in terms of Acc, so it is important that a user can "lift" a non-Acc function to Acc, like A.use . f . A.run.
To ensure accelerate is capable of this, I put together a trivial lifted objective function:
liftedSumBools :: A.Acc (A.Vector Bool) -> A.Acc (A.Scalar Double)
liftedSumBools = A.use . A.fromList A.Z . (: []) . sumBools . A.toList . A.run
sumBools :: [Bool] -> Double
sumBools = sum . fmap (\b -> if b then 1 else 0)However, when I tried to run an optimizer on this, I got an error:
*** Internal error in package accelerate ***
*** Please submit a bug report at https://github.com/AccelerateHS/accelerate/issues
inconsistent valuation @ shared 'Acc' tree with stable name 224;
aenv = [296]
CallStack (from HasCallStack):
internalError: Data.Array.Accelerate.Trafo.Sharing:267:5
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:292:16
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:292:16
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:292:16
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:292:16
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:285:13
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:282:14
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:292:16
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:292:16
convertSharingAcc: Data.Array.Accelerate.Trafo.Sharing:243:3
convertOpenAcc: Data.Array.Accelerate.Trafo.Sharing:161:35
convertAccWith: Data.Array.Accelerate.Trafo:69:37
If you replace liftedSumBools with
sumBoolsAcc :: A.Acc (A.Vector Bool) -> A.Acc (A.Scalar Double)
sumBoolsAcc = A.sum . A.map (\b -> A.cond b 1 0)there is no error.
Steps to reproduce
Run this program: https://gist.github.com/JustinLovinger/49b81dc83284732c05e4b657670b57c0.
Expected behaviour
Program runs without error.
Your environment
- Accelerate: 1.3
- Accelerate backend(s): Reference interpreter
- GHC: 8.10.3
- OS: NixOS 20.09
Additional context
While trying to create a minimal reproduction, I ran into a different error, derivative-free-comparison: Cyclic definition of a value of type 'Acc' (sa = 26):
module Main where
import qualified Data.Array.Accelerate as A
import qualified Data.Array.Accelerate.Interpreter
as A
main :: IO ()
main = do
print $ A.run $ aiterate 2 (step liftedSumBools) $ A.use $ A.fromList
(A.Z A.:. 2)
[False, False]
step
:: (A.Acc (A.Vector Bool) -> A.Acc (A.Scalar Double))
-> A.Acc (A.Vector Bool)
-> A.Acc (A.Vector Bool)
step f xs = A.acond (A.the (f xs) A.> 1) xs (A.map A.not xs)
liftedSumBools :: A.Acc (A.Vector Bool) -> A.Acc (A.Scalar Double)
liftedSumBools = A.use . A.fromList A.Z . (: []) . sumBools . A.toList . A.run
sumBools :: [Bool] -> Double
sumBools = sum . fmap (\b -> if b then 1 else 0)
-- | Repeatedly apply a function a fixed number of times.
aiterate
:: (A.Arrays a)
=> A.Exp Int -- ^ number of times to apply function
-> (A.Acc a -> A.Acc a) -- ^ function to apply
-> A.Acc a -- ^ initial value
-> A.Acc a
aiterate n f xs0 = A.asnd $ A.awhile
(A.unit . (A.< n) . A.the . A.afst)
(\(A.T2 i xs) -> A.T2 (A.map (+ 1) i) (f xs))
(A.lift (A.unit $ A.constant (0 :: Int), xs0))This program doesn't give an error if you replace liftedSumBools with
sumBoolsAcc :: A.Acc (A.Vector Bool) -> A.Acc (A.Scalar Double)
sumBoolsAcc = A.sum . A.map (\b -> A.cond b 1 0)or aiterate 2 (step liftedSumBools) with step liftedSumBools $ step liftedSumBools.