Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ Run it with `mvn -B install -Pdocling` (installs PyTorch + Docling, ~5GB on firs

See also the [Langchain4j agent demo](examples/langchain4j-agent/) showing a Java AI agent using Python spaCy for named entity extraction with bidirectional callbacks.

## Thread safety

PythonEngine is thread-safe. All public methods acquire and release CPython's GIL automatically via `PyGILState_Ensure`/`PyGILState_Release`. This works with both platform threads and virtual threads. Multiple threads (or virtual threads) can call the engine concurrently; the GIL serializes access to CPython internally.

Note: CPython's GIL means only one thread executes Python code at a time. Concurrency is safe but not parallel for Python-bound work. Native extensions (NumPy, PyTorch) release the GIL during computation, so actual parallelism happens in native code.

## Building

```bash
Expand Down
40 changes: 36 additions & 4 deletions core/src/main/java/io/roastedroot/cpython4j/core/LibPython.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ final class LibPython {
private final MethodHandle pyErrOccurred;
private final MethodHandle pyErrPrint;
private final MethodHandle pyErrSetString;
private final MethodHandle pyErrFetch;
private final MethodHandle pyObjectStr;
private final MethodHandle pyLongFromLong;
private final MethodHandle pyUnicodeFromString;
Expand All @@ -42,6 +41,9 @@ final class LibPython {
private final MethodHandle pyTupleSetItem;
private final MethodHandle pyDecRef;
private final MethodHandle pyIncRef;
private final MethodHandle pyEvalSaveThread;
private final MethodHandle pyGILStateEnsure;
private final MethodHandle pyGILStateRelease;
private final MemorySegment cachedNone;

static final AddressLayout POINTER = ValueLayout.ADDRESS;
Expand Down Expand Up @@ -152,9 +154,6 @@ final class LibPython {
// void PyErr_SetString(PyObject *type, const char *message)
pyErrSetString = downcall("PyErr_SetString", FunctionDescriptor.ofVoid(POINTER, POINTER));

// void PyErr_Fetch(PyObject **ptype, PyObject **pvalue, PyObject **ptraceback)
pyErrFetch = downcall("PyErr_Fetch", FunctionDescriptor.ofVoid(POINTER, POINTER, POINTER));

// PyObject* PyObject_Str(PyObject *o)
pyObjectStr = downcall("PyObject_Str", FunctionDescriptor.of(POINTER, POINTER));

Expand Down Expand Up @@ -201,6 +200,15 @@ final class LibPython {

// void Py_IncRef(PyObject *o)
pyIncRef = downcall("Py_IncRef", FunctionDescriptor.ofVoid(POINTER));

// PyThreadState* PyEval_SaveThread(void)
pyEvalSaveThread = downcall("PyEval_SaveThread", FunctionDescriptor.of(POINTER));

// PyGILState_STATE PyGILState_Ensure(void)
pyGILStateEnsure = downcall("PyGILState_Ensure", FunctionDescriptor.of(C_INT));

// void PyGILState_Release(PyGILState_STATE)
pyGILStateRelease = downcall("PyGILState_Release", FunctionDescriptor.ofVoid(C_INT));
}

private MethodHandle downcall(String name, FunctionDescriptor desc) {
Expand Down Expand Up @@ -422,6 +430,30 @@ MemorySegment getExcRuntimeError() {
.get(POINTER, 0);
}

MemorySegment evalSaveThread() {
try {
return (MemorySegment) pyEvalSaveThread.invokeExact();
} catch (Throwable t) {
throw new RuntimeException("PyEval_SaveThread failed", t);
}
}

int gilStateEnsure() {
try {
return (int) pyGILStateEnsure.invokeExact();
} catch (Throwable t) {
throw new RuntimeException("PyGILState_Ensure failed", t);
}
}

void gilStateRelease(int state) {
try {
pyGILStateRelease.invokeExact(state);
} catch (Throwable t) {
throw new RuntimeException("PyGILState_Release failed", t);
}
}

Linker linker() {
return linker;
}
Expand Down
90 changes: 57 additions & 33 deletions core/src/main/java/io/roastedroot/cpython4j/core/PythonEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ private PythonEngine(
throw new PythonException("Failed to add sys.path entry: " + path);
}
}

py.evalSaveThread();
}

private void registerTypedHostModule(String moduleName, List<HostFunction> functions) {
Expand Down Expand Up @@ -165,58 +167,75 @@ private void registerTypedHostModule(String moduleName, List<HostFunction> funct
}

public int exec(String code) {
return py.runSimpleString(arena, code);
var gstate = py.gilStateEnsure();
try {
return py.runSimpleString(arena, code);
} finally {
py.gilStateRelease(gstate);
}
}

public <T> T invokeFunction(
String moduleName, String functionName, List<Object> args, Class<T> returnType) {
var module = py.importModule(arena, moduleName);
if (module.equals(MemorySegment.NULL)) {
checkError();
throw new PythonException("Module not found: " + moduleName);
var gstate = py.gilStateEnsure();
try {
return invokeFunctionLocked(moduleName, functionName, args, returnType);
} finally {
py.gilStateRelease(gstate);
}
}

public <T> T invokeFunction(
String moduleName,
String functionName,
List<Object> args,
TypeReference<T> returnTypeRef) {
var gstate = py.gilStateEnsure();
try {
var func = py.objectGetAttrString(arena, module, functionName);
if (func.equals(MemorySegment.NULL)) {
var module = py.importModule(arena, moduleName);
if (module.equals(MemorySegment.NULL)) {
checkError();
throw new PythonException(
"Function not found: " + functionName + " in module " + moduleName);
throw new PythonException("Module not found: " + moduleName);
}

try {
var pyArgs = py.tupleNew(args.size());
for (int i = 0; i < args.size(); i++) {
py.tupleSetItem(pyArgs, i, javaToPython(args.get(i)));
}

var result = py.objectCallObject(func, pyArgs);
py.decRef(pyArgs);

if (result.equals(MemorySegment.NULL)) {
var func = py.objectGetAttrString(arena, module, functionName);
if (func.equals(MemorySegment.NULL)) {
checkError();
throw new PythonException("Call failed: " + moduleName + "." + functionName);
throw new PythonException(
"Function not found: " + functionName + " in module " + moduleName);
}

if (returnType == void.class || returnType == Void.class) {
py.decRef(result);
return null;
try {
var pyArgs = py.tupleNew(args.size());
for (int i = 0; i < args.size(); i++) {
py.tupleSetItem(pyArgs, i, javaToPython(args.get(i)));
}

var result = py.objectCallObject(func, pyArgs);
py.decRef(pyArgs);

if (result.equals(MemorySegment.NULL)) {
checkError();
throw new PythonException(
"Call failed: " + moduleName + "." + functionName);
}

var javaType = mapper.getTypeFactory().constructType(returnTypeRef);
return pythonToJavaType(result, javaType, true);
} finally {
py.decRef(func);
}

return pythonToJava(result, returnType);
} finally {
py.decRef(func);
py.decRef(module);
}
} finally {
py.decRef(module);
py.gilStateRelease(gstate);
}
}

public <T> T invokeFunction(
String moduleName,
String functionName,
List<Object> args,
TypeReference<T> returnTypeRef) {
private <T> T invokeFunctionLocked(
String moduleName, String functionName, List<Object> args, Class<T> returnType) {
var module = py.importModule(arena, moduleName);
if (module.equals(MemorySegment.NULL)) {
checkError();
Expand Down Expand Up @@ -245,8 +264,12 @@ public <T> T invokeFunction(
throw new PythonException("Call failed: " + moduleName + "." + functionName);
}

var javaType = mapper.getTypeFactory().constructType(returnTypeRef);
return pythonToJavaType(result, javaType, true);
if (returnType == void.class || returnType == Void.class) {
py.decRef(result);
return null;
}

return pythonToJava(result, returnType);
} finally {
py.decRef(func);
}
Expand Down Expand Up @@ -452,6 +475,7 @@ private void checkError() {
@Override
public void close() {
if (initialized) {
py.gilStateEnsure();
if (jsonDumps != null) {
py.decRef(jsonDumps);
}
Expand Down
Loading
Loading