From 83483624804dabc70fd699185f9bc4bd95d6caa7 Mon Sep 17 00:00:00 2001 From: andreatp Date: Mon, 8 Jun 2026 21:22:42 +0100 Subject: [PATCH 1/2] Add GIL management for thread safety (TDD) ThreadSafetyTest (3 tests): concurrent exec, concurrent invokeFunction, and callbacks during concurrent access. These crashed the JVM (SIGABRT) before the fix. LibPython: add PyEval_SaveThread, PyGILState_Ensure, PyGILState_Release bindings (all Stable ABI since Python 3.2). Remove dead PyErr_Fetch. PythonEngine: release GIL after initialization via PyEval_SaveThread. Wrap exec(), invokeFunction() (both overloads) with GILState ensure/ release. close() acquires GIL but does not release after Py_FinalizeEx (finalization invalidates all thread states). FunctionDispatcher unchanged (GIL already held by CPython during callbacks). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../roastedroot/cpython4j/core/LibPython.java | 40 ++++- .../cpython4j/core/PythonEngine.java | 90 +++++++---- .../cpython4j/core/ThreadSafetyTest.java | 145 ++++++++++++++++++ 3 files changed, 238 insertions(+), 37 deletions(-) create mode 100644 core/src/test/java/io/roastedroot/cpython4j/core/ThreadSafetyTest.java diff --git a/core/src/main/java/io/roastedroot/cpython4j/core/LibPython.java b/core/src/main/java/io/roastedroot/cpython4j/core/LibPython.java index 3ac5f1b..7cfb414 100644 --- a/core/src/main/java/io/roastedroot/cpython4j/core/LibPython.java +++ b/core/src/main/java/io/roastedroot/cpython4j/core/LibPython.java @@ -30,7 +30,6 @@ final class LibPython { private final MethodHandle pyErrOccurred; private final MethodHandle pyErrPrint; private final MethodHandle pyErrSetString; - private final MethodHandle pyErrFetch; private final MethodHandle pyObjectStr; private final MethodHandle pyLongFromLong; private final MethodHandle pyUnicodeFromString; @@ -42,6 +41,9 @@ final class LibPython { private final MethodHandle pyTupleSetItem; private final MethodHandle pyDecRef; private final MethodHandle pyIncRef; + private final MethodHandle pyEvalSaveThread; + private final MethodHandle pyGILStateEnsure; + private final MethodHandle pyGILStateRelease; private final MemorySegment cachedNone; static final AddressLayout POINTER = ValueLayout.ADDRESS; @@ -152,9 +154,6 @@ final class LibPython { // void PyErr_SetString(PyObject *type, const char *message) pyErrSetString = downcall("PyErr_SetString", FunctionDescriptor.ofVoid(POINTER, POINTER)); - // void PyErr_Fetch(PyObject **ptype, PyObject **pvalue, PyObject **ptraceback) - pyErrFetch = downcall("PyErr_Fetch", FunctionDescriptor.ofVoid(POINTER, POINTER, POINTER)); - // PyObject* PyObject_Str(PyObject *o) pyObjectStr = downcall("PyObject_Str", FunctionDescriptor.of(POINTER, POINTER)); @@ -201,6 +200,15 @@ final class LibPython { // void Py_IncRef(PyObject *o) pyIncRef = downcall("Py_IncRef", FunctionDescriptor.ofVoid(POINTER)); + + // PyThreadState* PyEval_SaveThread(void) + pyEvalSaveThread = downcall("PyEval_SaveThread", FunctionDescriptor.of(POINTER)); + + // PyGILState_STATE PyGILState_Ensure(void) + pyGILStateEnsure = downcall("PyGILState_Ensure", FunctionDescriptor.of(C_INT)); + + // void PyGILState_Release(PyGILState_STATE) + pyGILStateRelease = downcall("PyGILState_Release", FunctionDescriptor.ofVoid(C_INT)); } private MethodHandle downcall(String name, FunctionDescriptor desc) { @@ -422,6 +430,30 @@ MemorySegment getExcRuntimeError() { .get(POINTER, 0); } + MemorySegment evalSaveThread() { + try { + return (MemorySegment) pyEvalSaveThread.invokeExact(); + } catch (Throwable t) { + throw new RuntimeException("PyEval_SaveThread failed", t); + } + } + + int gilStateEnsure() { + try { + return (int) pyGILStateEnsure.invokeExact(); + } catch (Throwable t) { + throw new RuntimeException("PyGILState_Ensure failed", t); + } + } + + void gilStateRelease(int state) { + try { + pyGILStateRelease.invokeExact(state); + } catch (Throwable t) { + throw new RuntimeException("PyGILState_Release failed", t); + } + } + Linker linker() { return linker; } diff --git a/core/src/main/java/io/roastedroot/cpython4j/core/PythonEngine.java b/core/src/main/java/io/roastedroot/cpython4j/core/PythonEngine.java index ecfc684..376f51d 100644 --- a/core/src/main/java/io/roastedroot/cpython4j/core/PythonEngine.java +++ b/core/src/main/java/io/roastedroot/cpython4j/core/PythonEngine.java @@ -68,6 +68,8 @@ private PythonEngine( throw new PythonException("Failed to add sys.path entry: " + path); } } + + py.evalSaveThread(); } private void registerTypedHostModule(String moduleName, List functions) { @@ -165,58 +167,75 @@ private void registerTypedHostModule(String moduleName, List funct } public int exec(String code) { - return py.runSimpleString(arena, code); + var gstate = py.gilStateEnsure(); + try { + return py.runSimpleString(arena, code); + } finally { + py.gilStateRelease(gstate); + } } public T invokeFunction( String moduleName, String functionName, List args, Class returnType) { - var module = py.importModule(arena, moduleName); - if (module.equals(MemorySegment.NULL)) { - checkError(); - throw new PythonException("Module not found: " + moduleName); + var gstate = py.gilStateEnsure(); + try { + return invokeFunctionLocked(moduleName, functionName, args, returnType); + } finally { + py.gilStateRelease(gstate); } + } + public T invokeFunction( + String moduleName, + String functionName, + List args, + TypeReference returnTypeRef) { + var gstate = py.gilStateEnsure(); try { - var func = py.objectGetAttrString(arena, module, functionName); - if (func.equals(MemorySegment.NULL)) { + var module = py.importModule(arena, moduleName); + if (module.equals(MemorySegment.NULL)) { checkError(); - throw new PythonException( - "Function not found: " + functionName + " in module " + moduleName); + throw new PythonException("Module not found: " + moduleName); } try { - var pyArgs = py.tupleNew(args.size()); - for (int i = 0; i < args.size(); i++) { - py.tupleSetItem(pyArgs, i, javaToPython(args.get(i))); - } - - var result = py.objectCallObject(func, pyArgs); - py.decRef(pyArgs); - - if (result.equals(MemorySegment.NULL)) { + var func = py.objectGetAttrString(arena, module, functionName); + if (func.equals(MemorySegment.NULL)) { checkError(); - throw new PythonException("Call failed: " + moduleName + "." + functionName); + throw new PythonException( + "Function not found: " + functionName + " in module " + moduleName); } - if (returnType == void.class || returnType == Void.class) { - py.decRef(result); - return null; + try { + var pyArgs = py.tupleNew(args.size()); + for (int i = 0; i < args.size(); i++) { + py.tupleSetItem(pyArgs, i, javaToPython(args.get(i))); + } + + var result = py.objectCallObject(func, pyArgs); + py.decRef(pyArgs); + + if (result.equals(MemorySegment.NULL)) { + checkError(); + throw new PythonException( + "Call failed: " + moduleName + "." + functionName); + } + + var javaType = mapper.getTypeFactory().constructType(returnTypeRef); + return pythonToJavaType(result, javaType, true); + } finally { + py.decRef(func); } - - return pythonToJava(result, returnType); } finally { - py.decRef(func); + py.decRef(module); } } finally { - py.decRef(module); + py.gilStateRelease(gstate); } } - public T invokeFunction( - String moduleName, - String functionName, - List args, - TypeReference returnTypeRef) { + private T invokeFunctionLocked( + String moduleName, String functionName, List args, Class returnType) { var module = py.importModule(arena, moduleName); if (module.equals(MemorySegment.NULL)) { checkError(); @@ -245,8 +264,12 @@ public T invokeFunction( throw new PythonException("Call failed: " + moduleName + "." + functionName); } - var javaType = mapper.getTypeFactory().constructType(returnTypeRef); - return pythonToJavaType(result, javaType, true); + if (returnType == void.class || returnType == Void.class) { + py.decRef(result); + return null; + } + + return pythonToJava(result, returnType); } finally { py.decRef(func); } @@ -452,6 +475,7 @@ private void checkError() { @Override public void close() { if (initialized) { + py.gilStateEnsure(); if (jsonDumps != null) { py.decRef(jsonDumps); } diff --git a/core/src/test/java/io/roastedroot/cpython4j/core/ThreadSafetyTest.java b/core/src/test/java/io/roastedroot/cpython4j/core/ThreadSafetyTest.java new file mode 100644 index 0000000..48222c4 --- /dev/null +++ b/core/src/test/java/io/roastedroot/cpython4j/core/ThreadSafetyTest.java @@ -0,0 +1,145 @@ +package io.roastedroot.cpython4j.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class ThreadSafetyTest { + + @TempDir static Path tempDir; + static PythonEngine engine; + + @BeforeAll + static void setUp() throws IOException { + Files.writeString( + tempDir.resolve("pyproject.toml"), + """ + [project] + name = "thread-test" + version = "0.1.0" + requires-python = ">=3.13" + dependencies = [] + """); + + Files.writeString( + tempDir.resolve("thread_ops.py"), + """ + def double(x): + return x * 2 + + def greet(name): + return "hello " + name + """); + + var env = PythonEnv.uvProject(tempDir).sync(true).build(); + + engine = + PythonEngine.builder() + .withEnv(env) + .expose( + "host", + new HostFunction( + "mirror", + List.of(String.class), + String.class, + args -> "mirrored:" + args.get(0))) + .build(); + } + + @AfterAll + static void tearDown() { + if (engine != null) { + engine.close(); + } + } + + @Test + void concurrentExec() throws Exception { + int threads = 4; + var executor = Executors.newFixedThreadPool(threads); + var latch = new CountDownLatch(1); + var futures = new ArrayList>(); + + for (int i = 0; i < threads; i++) { + futures.add( + executor.submit( + () -> { + latch.await(); + for (int j = 0; j < 10; j++) { + engine.exec("x = 1 + 1"); + } + return null; + })); + } + + latch.countDown(); + for (var f : futures) { + f.get(); + } + executor.shutdown(); + } + + @Test + void concurrentInvokeFunction() throws Exception { + int threads = 4; + var executor = Executors.newFixedThreadPool(threads); + var latch = new CountDownLatch(1); + var futures = new ArrayList>(); + + for (int i = 0; i < threads; i++) { + int arg = i; + futures.add( + executor.submit( + () -> { + latch.await(); + return engine.invokeFunction( + "thread_ops", "double", List.of(arg), int.class); + })); + } + + latch.countDown(); + for (int i = 0; i < threads; i++) { + assertEquals(i * 2, futures.get(i).get()); + } + executor.shutdown(); + } + + @Test + void callbackDuringConcurrentExec() throws Exception { + int threads = 4; + var executor = Executors.newFixedThreadPool(threads); + var latch = new CountDownLatch(1); + var futures = new ArrayList>(); + + for (int i = 0; i < threads; i++) { + int idx = i; + futures.add( + executor.submit( + () -> { + latch.await(); + return engine.invokeFunction( + "thread_ops", + "greet", + List.of("thread" + idx), + String.class); + })); + } + + latch.countDown(); + for (int i = 0; i < threads; i++) { + assertEquals("hello thread" + i, futures.get(i).get()); + } + executor.shutdown(); + } +} From da1359b78d03bfbbe06aa4fc8c987d94c7b9e480 Mon Sep 17 00:00:00 2001 From: andreatp Date: Tue, 9 Jun 2026 13:03:13 +0100 Subject: [PATCH 2/2] Add virtual thread tests, document thread safety - Add virtualThreadsInvokeFunction test (8 virtual threads) - Add mixedPlatformAndVirtualThreads test (2 platform + 4 virtual) - All 5 thread safety tests pass: GIL ensure/release pins virtual threads to their carrier during native calls automatically - Document thread safety in README Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 6 ++ .../cpython4j/core/ThreadSafetyTest.java | 71 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/README.md b/README.md index 09b39e1..f2efbfb 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,12 @@ Run it with `mvn -B install -Pdocling` (installs PyTorch + Docling, ~5GB on firs See also the [Langchain4j agent demo](examples/langchain4j-agent/) showing a Java AI agent using Python spaCy for named entity extraction with bidirectional callbacks. +## Thread safety + +PythonEngine is thread-safe. All public methods acquire and release CPython's GIL automatically via `PyGILState_Ensure`/`PyGILState_Release`. This works with both platform threads and virtual threads. Multiple threads (or virtual threads) can call the engine concurrently; the GIL serializes access to CPython internally. + +Note: CPython's GIL means only one thread executes Python code at a time. Concurrency is safe but not parallel for Python-bound work. Native extensions (NumPy, PyTorch) release the GIL during computation, so actual parallelism happens in native code. + ## Building ```bash diff --git a/core/src/test/java/io/roastedroot/cpython4j/core/ThreadSafetyTest.java b/core/src/test/java/io/roastedroot/cpython4j/core/ThreadSafetyTest.java index 48222c4..7b4925a 100644 --- a/core/src/test/java/io/roastedroot/cpython4j/core/ThreadSafetyTest.java +++ b/core/src/test/java/io/roastedroot/cpython4j/core/ThreadSafetyTest.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.junit.jupiter.api.AfterAll; @@ -142,4 +143,74 @@ void callbackDuringConcurrentExec() throws Exception { } executor.shutdown(); } + + @Test + void virtualThreadsInvokeFunction() throws Exception { + int count = 8; + var futures = new ArrayList>(); + + try (var executor = Executors.newVirtualThreadPerTaskExecutor()) { + for (int i = 0; i < count; i++) { + int arg = i; + futures.add( + executor.submit( + () -> + engine.invokeFunction( + "thread_ops", "double", List.of(arg), int.class))); + } + for (int i = 0; i < count; i++) { + assertEquals(i * 2, futures.get(i).get()); + } + } + } + + @Test + void mixedPlatformAndVirtualThreads() throws InterruptedException, ExecutionException { + var platformFutures = new ArrayList>(); + var virtualFutures = new ArrayList>(); + var latch = new CountDownLatch(1); + + var platformExecutor = Executors.newFixedThreadPool(2); + var virtualExecutor = Executors.newVirtualThreadPerTaskExecutor(); + + for (int i = 0; i < 2; i++) { + int idx = i; + platformFutures.add( + platformExecutor.submit( + () -> { + latch.await(); + return engine.invokeFunction( + "thread_ops", + "greet", + List.of("platform" + idx), + String.class); + })); + } + + for (int i = 0; i < 4; i++) { + int idx = i; + virtualFutures.add( + virtualExecutor.submit( + () -> { + latch.await(); + return engine.invokeFunction( + "thread_ops", + "greet", + List.of("virtual" + idx), + String.class); + })); + } + + latch.countDown(); + + for (int i = 0; i < 2; i++) { + assertEquals("hello platform" + i, platformFutures.get(i).get()); + } + for (int i = 0; i < 4; i++) { + assertEquals("hello virtual" + i, virtualFutures.get(i).get()); + } + + platformExecutor.shutdown(); + virtualExecutor.close(); + } }