Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions src/main/java/org/duckdb/DuckDBConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public final class DuckDBConnection implements java.sql.Connection {
public static final String DEFAULT_SCHEMA = "main";

ByteBuffer connRef;
final Lock connRefLock = new ReentrantLock();
final ReentrantLock connRefLock = new ReentrantLock();
final LinkedHashSet<DuckDBPreparedStatement> preparedStatements = new LinkedHashSet<>();
volatile boolean closing = false;

Expand Down Expand Up @@ -488,14 +488,11 @@ void checkOpen() throws SQLException {
* This function calls the underlying C++ interrupt function which aborts the query running on this connection.
*/
void interrupt() throws SQLException {
checkOpen();
connRefLock.lock();
try {
checkOpen();
DuckDBNative.duckdb_jdbc_interrupt(connRef);
} finally {
connRefLock.unlock();
if (!connRefLock.isHeldByCurrentThread()) {
throw new SQLException("Connection lock state error");
}
checkOpen();
DuckDBNative.duckdb_jdbc_interrupt(connRef);
}

QueryProgress queryProgress() throws SQLException {
Expand Down
41 changes: 35 additions & 6 deletions src/main/java/org/duckdb/DuckDBPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class DuckDBPreparedStatement implements PreparedStatement {
private DuckDBConnection conn;

private ByteBuffer stmtRef = null;
final Lock stmtRefLock = new ReentrantLock();
final ReentrantLock stmtRefLock = new ReentrantLock();
volatile boolean closeOnCompletion = false;

private DuckDBResultSet selectResult = null;
Expand Down Expand Up @@ -159,6 +159,11 @@ private boolean execute(boolean startTransaction) throws SQLException {
checkOpen();
checkPrepared();

// Wait with dispatching a new query if connection is locked by cancel() call
Lock connLock = getConnRefLock();
connLock.lock();
connLock.unlock();

ByteBuffer resultRef = null;

stmtRefLock.lock();
Expand Down Expand Up @@ -442,12 +447,27 @@ public void setQueryTimeout(int seconds) throws SQLException {
@Override
public void cancel() throws SQLException {
checkOpen();
// Only proceed to interrupt call after ensuring that the query on
// this statement is still running.
if (!stmtRefLock.isLocked()) {
return;
}
// Cancel is intended to be called concurrently with execute,
// thus we cannot take the statement lock that is held while
// query is running. NPE may be thrown if connection is closed
// concurrently.
try {
// Cancel is intended to be called concurrently with execute,
// thus we cannot take the statement lock that is held while
// query is running. NPE may be thrown if connection is closed
// concurrently.
conn.interrupt();
// Taking connection lock will prevent new queries to be executed
Lock connLock = getConnRefLock();
connLock.lock();
try {
if (!stmtRefLock.isLocked()) {
return;
}
conn.interrupt();
} finally {
connLock.unlock();
}
} catch (NullPointerException e) {
throw new SQLException(e);
}
Expand Down Expand Up @@ -1215,4 +1235,13 @@ private int[] intArrayFromLong(long[] arr) {
}
return res;
}

private Lock getConnRefLock() throws SQLException {
// NPE can be thrown if statement is closed concurrently.
try {
return conn.connRefLock;
} catch (NullPointerException e) {
throw new SQLException(e);
}
}
}
34 changes: 34 additions & 0 deletions src/test/java/org/duckdb/TestClosure.java
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,38 @@ public static void test_results_fetch_no_hang() throws Exception {
}
}
}

public static void test_stmt_can_only_cancel_self() throws Exception {
try (Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt1 = conn.createStatement();
Statement stmt2 = conn.createStatement()) {
stmt1.execute("DROP TABLE IF EXISTS test_fib1");
stmt1.execute("CREATE TABLE test_fib1(i bigint, p double, f double)");
stmt1.execute("INSERT INTO test_fib1 values(1, 0, 1)");
long start = System.currentTimeMillis();
Thread th = new Thread(() -> {
try {
Thread.sleep(200);
stmt1.cancel();
} catch (Exception e) {
e.printStackTrace();
}
});
th.start();
try (
ResultSet rs = stmt2.executeQuery(
"WITH RECURSIVE cte AS ("
+
"SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 40000) "
+ "SELECT avg(f) FROM cte")) {
rs.next();
assertTrue(rs.getDouble(1) > 0);
}
th.join();
long elapsed = System.currentTimeMillis() - start;
assertTrue(elapsed > 1000);
assertFalse(conn.isClosed());
assertFalse(stmt1.isClosed());
assertFalse(stmt2.isClosed());
}
}
}