Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/Python/Inline/Literal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import Data.Char
import Data.Int
import Data.Word
import Data.Set qualified as Set
import Data.Map.Strict qualified as Map
import Foreign.Ptr
import Foreign.C.Types
import Foreign.Storable
Expand Down Expand Up @@ -408,6 +409,33 @@ instance (FromPy a, Ord a) => FromPy (Set.Set a) where
pure $! Set.insert a s)
Set.empty


instance (ToPy k, ToPy v, Ord k) => ToPy (Map.Map k v) where
basicToPy dct = runProgram $ do
p_dict <- takeOwnership =<< checkNull basicNewDict
progPy $ do
let loop [] = p_dict <$ incref p_dict
loop ((k,v):xs) = basicToPy k >>= \case
NULL -> mustThrowPyError
p_k -> flip finally (decref p_k) $ basicToPy v >>= \case
NULL -> mustThrowPyError
p_v -> Py [CU.exp| int { PyDict_SetItem($(PyObject *p_dict), $(PyObject* p_k), $(PyObject *p_v)) }|] >>= \case
0 -> loop xs
_ -> nullPtr <$ decref p_v
loop $ Map.toList dct

instance (FromPy k, FromPy v, Ord k) => FromPy (Map.Map k v) where
basicFromPy p_dct = basicGetIter p_dct >>= \case
NULL -> do Py [C.exp| void { PyErr_Clear() } |]
throwM BadPyType
p_iter -> foldPyIterable p_iter
(\m p -> do k <- basicFromPy p
v <- Py [CU.exp| PyObject* { PyDict_GetItem($(PyObject* p_dct), $(PyObject *p)) }|] >>= \case
NULL -> throwM BadPyType
p_v -> basicFromPy p_v
pure $! Map.insert k v m)
Map.empty

-- | Fold over iterable. Function takes ownership over iterator.
foldPyIterable
:: Ptr PyObject -- ^ Python iterator (not checked)
Expand Down
2 changes: 2 additions & 0 deletions test/TST/Roundtrip.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Data.Int
import Data.Word
import Data.Typeable
import Data.Set (Set)
import Data.Map.Strict (Map)
import Foreign.C.Types

import Test.Tasty
Expand Down Expand Up @@ -53,6 +54,7 @@ tests = testGroup "Roundtrip"
, testRoundtrip @[Int]
, testRoundtrip @[[Int]]
, testRoundtrip @(Set Int)
, testRoundtrip @(Map Int Int)
-- , testRoundtrip @String -- Trips on zero byte as it should
]
, testGroup "OutOfRange"
Expand Down
9 changes: 8 additions & 1 deletion test/TST/ToPy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
module TST.ToPy (tests) where

import Data.Set qualified as Set
import Data.Map.Strict qualified as Map
import Test.Tasty
import Test.Tasty.HUnit
import Python.Inline
Expand Down Expand Up @@ -38,5 +39,11 @@ tests = testGroup "ToPy"
in [py_| assert x_hs == {1,3,5} |]
, testCase "set unhashable" $ runPy $
let x = Set.fromList [[1], [5], [3::Int]]
in throwsPy [py_| assert x_hs == {1,3,5} |]
in throwsPy [py_| x_hs |]
, testCase "dict<int,int>" $ runPy $
let x = Map.fromList [(1,10), (5,50), (3,30)] :: Map.Map Int Int
in [py_| assert x_hs == {1:10, 3:30, 5:50} |]
, testCase "dict unhashable" $ runPy $
let x = Map.fromList [([1],10), ([5],50), ([3],30)] :: Map.Map [Int] Int
in throwsPy [py_| x_hs |]
]
Loading