Skip to content

Commit fc91b1c

Browse files
authored
Merge pull request #10 from Shimuuar/callbacks
Fix calling python from callbacks in threaded runtime
2 parents 9f7f68a + 1504a40 commit fc91b1c

File tree

14 files changed

+267
-165
lines changed

14 files changed

+267
-165
lines changed

cbits/python.c

Lines changed: 82 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,93 @@
11
#include <inline-python.h>
22
#include <stdlib.h>
33

4-
PyObject *inline_py_function_wrapper(PyCFunction fun, int flags) {
5-
PyMethodDef *meth = malloc(sizeof(PyMethodDef));
6-
meth->ml_name = "[inline_python]";
7-
meth->ml_meth = fun;
8-
meth->ml_flags = flags;
9-
meth->ml_doc = "Wrapper constructed by inline-python";
10-
// Python wrapper which carries PyMethodDef
11-
PyObject* meth_obj = PyCapsule_New(meth, NULL, &inline_py_free_capsule);
4+
// ================================================================
5+
// Callbacks
6+
//
7+
// General idea: we store function pointer (haskell's FunPtr) in
8+
// PyCapsule and use to call function. Most importantly we must
9+
// release GIL before calling into haskell. Haskell callback will
10+
// happen on different thread (on threaded RTS). So it'll have to
11+
// reacquire GIL there.
12+
// ================================================================
13+
14+
int inline_py_callback_depth = 0;
15+
16+
static PyObject* callback_METH_O(PyObject* self, PyObject* arg) {
17+
PyObject *res;
18+
PyCFunction *fun = PyCapsule_GetPointer(self, NULL);
19+
//--
20+
inline_py_callback_depth++;
21+
Py_BEGIN_ALLOW_THREADS
22+
res = (*fun)(self, arg);
23+
Py_END_ALLOW_THREADS
24+
inline_py_callback_depth--;
25+
return res;
26+
}
27+
28+
static PyObject* callback_METH_FASTCALL(PyObject* self, PyObject** args, Py_ssize_t nargs) {
29+
PyObject *res;
30+
PyCFunctionFast *fun = PyCapsule_GetPointer(self, NULL);
31+
//--
32+
inline_py_callback_depth++;
33+
Py_BEGIN_ALLOW_THREADS
34+
res = (*fun)(self, args, nargs);
35+
Py_END_ALLOW_THREADS
36+
inline_py_callback_depth--;
37+
return res;
38+
}
39+
40+
static void capsule_free_FunPtr(PyObject* capsule) {
41+
PyCFunction *fun = PyCapsule_GetPointer(capsule, NULL);
42+
// We call directly to haskell RTS to free FunPtr. Only question
43+
// is how stable is this API.
44+
freeHaskellFunctionPtr(*fun);
45+
free(fun);
46+
}
47+
48+
static PyMethodDef method_METH_O = {
49+
.ml_name = "[inline_python]",
50+
.ml_meth = callback_METH_O,
51+
.ml_flags = METH_O,
52+
.ml_doc = "Wrapper for haskell callback"
53+
};
54+
55+
static PyMethodDef method_METH_FASTCALL = {
56+
.ml_name = "[inline_python]",
57+
.ml_meth = (PyCFunction)callback_METH_FASTCALL,
58+
.ml_flags = METH_FASTCALL,
59+
.ml_doc = "Wrapper for haskell callback"
60+
};
61+
62+
PyObject *inline_py_callback_METH_O(PyCFunction fun) {
63+
PyCFunction *buf = malloc(sizeof(PyCFunction));
64+
*buf = fun;
65+
PyObject* self = PyCapsule_New(buf, NULL, &capsule_free_FunPtr);
1266
if( PyErr_Occurred() )
1367
return NULL;
1468
// Python function
15-
PyObject* f = PyCFunction_New(meth, meth_obj);
16-
Py_DECREF(meth_obj);
17-
return f;
69+
PyObject* f = PyCFunction_New(&method_METH_O, self);
70+
Py_DECREF(self);
71+
return f;
1872
}
1973

74+
PyObject *inline_py_callback_METH_FASTCALL(PyCFunctionFast fun) {
75+
PyCFunctionFast *buf = malloc(sizeof(PyCFunctionFast));
76+
*buf = fun;
77+
PyObject* self = PyCapsule_New(buf, NULL, &capsule_free_FunPtr);
78+
if( PyErr_Occurred() )
79+
return NULL;
80+
// Python function
81+
PyObject* f = PyCFunction_New(&method_METH_FASTCALL, self);
82+
Py_DECREF(self);
83+
return f;
84+
}
85+
86+
87+
// ================================================================
88+
// Marshalling
89+
// ================================================================
90+
2091
int inline_py_unpack_iterable(PyObject *iterable, int n, PyObject **out) {
2192
// Initialize iterator. If object is not an iterable we treat this
2293
// as not an exception but as a conversion failure
@@ -57,11 +128,3 @@ int inline_py_unpack_iterable(PyObject *iterable, int n, PyObject **out) {
57128
return -1;
58129
}
59130

60-
void inline_py_free_capsule(PyObject* py) {
61-
PyMethodDef *meth = PyCapsule_GetPointer(py, NULL);
62-
// HACK: We want to release wrappers created by wrapper. It
63-
// doesn't seems to be nice and stable C API
64-
freeHaskellFunctionPtr(meth->ml_meth);
65-
free(meth);
66-
}
67-

include/inline-python.h

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,39 @@
55
#include <Rts.h>
66

77

8+
// Use new stable API from
9+
#ifndef PyCFunctionFast
10+
typedef _PyCFunctionFast PyCFunctionFast;
11+
#endif
12+
813
// ----------------------------------------------------------------
914
// Standard status codes
1015

1116
#define IPY_OK 0
1217
#define IPY_ERR_PYTHON 1
1318
#define IPY_ERR_COMPILE 2
1419

15-
// ----------------------------------------------------------------
20+
21+
22+
// ================================================================
23+
// Callbacks
24+
// ================================================================
25+
26+
// Callback depth. It's used to decide whether we want to just
27+
// continue in bound thread. Should only be modified while GIL is held
28+
extern int inline_py_callback_depth;
29+
30+
// Wrap haskell callback using METH_O calling convention
31+
PyObject *inline_py_callback_METH_O(PyCFunction fun);
32+
33+
// Wrap haskell callback using METH_FASTCALL calling convention
34+
PyObject *inline_py_callback_METH_FASTCALL(PyCFunctionFast fun);
35+
36+
37+
38+
// ================================================================
39+
// Callbacks
40+
// ================================================================
1641

1742
// Unpack iterable into array of PyObjects. Iterable must contain
1843
// exactly N elements.
@@ -27,13 +52,3 @@ int inline_py_unpack_iterable(
2752
int n,
2853
PyObject **out
2954
);
30-
31-
// Allocate python function object which carrries its own PyMethodDef.
32-
// Returns function object or NULL with error raised.
33-
//
34-
// See NOTE: [Creation of python functions]
35-
PyObject *inline_py_function_wrapper(PyCFunction fun, int flags);
36-
37-
// Free malloc'd buffer inside PyCapsule
38-
void inline_py_free_capsule(PyObject*);
39-

inline-python.cabal

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ Library
7979
----------------------------------------------------------------
8080
library test
8181
import: language
82+
Default-Extensions:
83+
QuasiQuotes
8284
build-depends: base
8385
, inline-python
8486
, tasty >=1.2
@@ -88,6 +90,8 @@ library test
8890
TST.Run
8991
TST.ToPy
9092
TST.FromPy
93+
TST.Callbacks
94+
TST.Util
9195

9296
test-suite inline-python-tests
9397
import: language

src/Python/Inline/Literal.hs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ module Python.Inline.Literal
1212
, fromPy'
1313
) where
1414

15-
import Control.Exception
1615
import Control.Monad
1716
import Control.Monad.IO.Class
1817
import Control.Monad.Trans.Class
@@ -23,7 +22,6 @@ import Data.Word
2322
import Data.Foldable
2423
import Foreign.Ptr
2524
import Foreign.C.Types
26-
import Foreign.Marshal.Alloc
2725
import Foreign.Storable
2826

2927
import Language.C.Inline qualified as C
@@ -285,7 +283,7 @@ instance (FromPy a) => FromPy [a] where
285283
-- with async exception out of the blue
286284

287285

288-
instance (FromPy a, ToPy b) => ToPy (a -> IO b) where
286+
instance (FromPy a, Show a, ToPy b) => ToPy (a -> IO b) where
289287
basicToPy f = Py $ do
290288
-- C function pointer for callback
291289
f_ptr <- wrapO $ \_ p_a -> pyCallback $ do
@@ -295,10 +293,9 @@ instance (FromPy a, ToPy b) => ToPy (a -> IO b) where
295293
Right a -> pure a
296294
liftIO $ unPy . basicToPy =<< f a
297295
--
298-
[C.exp| PyObject* {
299-
inline_py_function_wrapper(
300-
$(PyObject* (*f_ptr)(PyObject*, PyObject*)),
301-
METH_O)
296+
[CU.block| PyObject* {
297+
inline_py_callback_METH_O(
298+
$(PyObject* (*f_ptr)(PyObject*, PyObject*)));
302299
}|]
303300

304301
instance (FromPy a1, FromPy a2, ToPy b) => ToPy (a1 -> a2 -> IO b) where
@@ -311,10 +308,8 @@ instance (FromPy a1, FromPy a2, ToPy b) => ToPy (a1 -> a2 -> IO b) where
311308
liftIO $ unPy . basicToPy =<< f a b
312309
-- Create python function
313310
[C.block| PyObject* {
314-
_PyCFunctionFast impl = $(PyObject* (*f_ptr)(PyObject*, PyObject*const*, int64_t));
315-
return inline_py_function_wrapper(
316-
(PyCFunction)impl,
317-
METH_FASTCALL);
311+
PyCFunctionFast impl = $(PyObject* (*f_ptr)(PyObject*, PyObject*const*, int64_t));
312+
return inline_py_callback_METH_FASTCALL(impl);
318313
}|]
319314

320315
loadArgFastcall :: FromPy a => Ptr (Ptr PyObject) -> Int -> Int64 -> Program (Ptr PyObject) a
@@ -331,7 +326,7 @@ loadArgFastcall p_arr i tot = do
331326
----------------------------------------------------------------
332327

333328
pyCallback :: Program (Ptr PyObject) (Ptr PyObject) -> IO (Ptr PyObject)
334-
pyCallback io = unPy $ evalContT io `catchPy` convertHaskell2Py
329+
pyCallback io = unPy $ ensureGIL $ evalContT io `catchPy` convertHaskell2Py
335330

336331
raiseUndecodedArg :: CInt -> CInt -> Py (Ptr PyObject)
337332
raiseUndecodedArg n tot = Py [CU.block| PyObject* {

src/Python/Internal/Eval.hs

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module Python.Internal.Eval
1515
-- * PyObject wrapper
1616
, newPyObject
1717
, decref
18+
, ensureGIL
1819
-- * Exceptions
1920
, convertHaskell2Py
2021
, convertPy2Haskell
@@ -186,26 +187,30 @@ runPy :: Py a -> IO a
186187
-- See NOTE: [Threading and exceptions]
187188
runPy py
188189
-- Multithreaded RTS
189-
| rtsSupportsBoundThreads = do
190-
result <- newEmptyMVar
191-
status <- newMVar Pending
192-
let onExc :: SomeException -> IO b
193-
onExc e = do
194-
modifyMVar_ status $ \case
195-
Pending -> pure Cancelled
196-
Cancelled -> pure Cancelled
197-
Done -> pure Done
198-
Running -> Cancelled <$ [CU.exp| void { PyErr_SetInterrupt() } |]
199-
throwIO e
200-
(do putMVar toPythonThread $ PyEvalReq{ prog=py, ..}
201-
takeMVar result >>= \case
202-
Left e -> throwIO e
203-
Right a -> pure a
204-
) `catch` onExc
190+
--
191+
-- Here we check whether we're in callback or creating a new call
192+
| rtsSupportsBoundThreads = [CU.exp| int { inline_py_callback_depth } |] >>= \case
193+
0 -> do
194+
result <- newEmptyMVar
195+
status <- newMVar Pending
196+
let onExc :: SomeException -> IO b
197+
onExc e = do
198+
modifyMVar_ status $ \case
199+
Pending -> pure Cancelled
200+
Cancelled -> pure Cancelled
201+
Done -> pure Done
202+
Running -> Cancelled <$ [CU.exp| void { PyErr_SetInterrupt() } |]
203+
throwIO e
204+
(do putMVar toPythonThread $ PyEvalReq{ prog=py, ..}
205+
takeMVar result >>= \case
206+
Left e -> throwIO e
207+
Right a -> pure a
208+
) `catch` onExc
209+
_ -> unPy $ ensureGIL py
205210
-- Single-threaded RTS
206211
--
207212
-- See NOTE: [Async exceptions]
208-
| otherwise = mask_ $ unPy py
213+
| otherwise = mask_ $ unPy $ ensureGIL py
209214

210215

211216
-- | Execute python action. This function is unsafe and should be only
@@ -311,7 +316,7 @@ evalReq :: IO ()
311316
-- See NOTE: [Python and threading]
312317
-- See NOTE: [Threading and exceptions]
313318
evalReq = do
314-
PyEvalReq{prog=Py io, result, status} <- takeMVar toPythonThread
319+
PyEvalReq{prog, result, status} <- takeMVar toPythonThread
315320
-- GC
316321
let decrefList Nil = pure ()
317322
decrefList (p `Cons` ps) = do [CU.exp| void { Py_XDECREF($(PyObject* p)) } |]
@@ -324,7 +329,7 @@ evalReq = do
324329
Cancelled -> return (Cancelled,False)
325330
Pending -> return (Running, True)
326331
when do_eval $ do
327-
a <- (Right <$> mask_ io) `catches`
332+
a <- (Right <$> mask_ (unPy $ ensureGIL prog)) `catches`
328333
[ Handler $ \(e :: AsyncException) -> throwIO e
329334
, Handler $ \(e :: SomeAsyncException) -> throwIO e
330335
, Handler $ \(e :: SomeException) -> pure (Left e)
@@ -347,6 +352,16 @@ evalReq = do
347352
decref :: Ptr PyObject -> Py ()
348353
decref p = Py [CU.exp| void { Py_DECREF($(PyObject* p)) } |]
349354

355+
-- | Ensure that we hold GIL for duration of action
356+
ensureGIL :: Py a -> Py a
357+
ensureGIL action = do
358+
-- NOTE: We're cheating here and looking behind the veil.
359+
-- PyGILState_STATE is defined as enum. Let hope it will stay
360+
-- this way.
361+
gil_state <- Py [CU.exp| int { PyGILState_Ensure() } |]
362+
action `finallyPy` Py [CU.exp| void { PyGILState_Release($(int gil_state)) } |]
363+
364+
350365
-- | Wrap raw python object into
351366
newPyObject :: Ptr PyObject -> Py PyObject
352367
-- We need to use different implementation for different RTS

src/Python/Internal/EvalQQ.hs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@ module Python.Internal.EvalQQ
1414
, unindent
1515
) where
1616

17-
import Control.Exception
1817
import Control.Monad.IO.Class
1918
import Control.Monad.Trans.Class
2019
import Control.Monad.Trans.Cont
2120
import Data.Char
2221
import Foreign.C.Types
23-
import Foreign.C.String
24-
import Foreign.Marshal.Alloc
2522
import Foreign.Ptr
2623
import Foreign.Storable
2724
import System.Exit

src/Python/Internal/Program.hs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ module Python.Internal.Program
1414
, withPyWCString
1515
) where
1616

17-
import Control.Exception
1817
import Control.Monad.Trans.Cont
1918
import Data.Coerce
2019
import Foreign.Ptr

src/Python/Types.hs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,8 @@ module Python.Types
1010
) where
1111

1212
import Data.Coerce
13-
1413
import Foreign.Ptr
15-
import Foreign.ForeignPtr
16-
import Language.C.Inline qualified as C
17-
1814
import GHC.ForeignPtr
19-
2015
import Python.Internal.Types
2116

2217
unsafeWithPyObject :: forall a. PyObject -> (Ptr PyObject -> Py a) -> Py a

0 commit comments

Comments
 (0)