Skip to content

Commit 330df21

Browse files
authored
Unified error reporting for Android ExecuTorch JNI layers (pytorch#18128)
All four Android JNI modules (Generic, LLM, ASR, Training) previously used inconsistent exception types — a mix of RuntimeException, IllegalStateException, IllegalArgumentException, and the structured ExecutorchRuntimeException. This unifies them so every JNI error path throws ExecutorchRuntimeException with a proper error code, giving callers a single type to catch and programmatic access to the underlying runtime error. Key changes per module: LLM (jni_layer_llama.cpp) — generate() now captures the Error return from runner_->generate(), wraps the call in try/catch, reports failures via a new onError(errorCode, message) callback, and returns the actual error code instead of always returning 0. ASR (jni_layer_asr.cpp) — replaced six env->ThrowNew(...) calls with setExecutorchPendingException (for pure JNI path) Training (jni_layer_training.cpp) — added jni_helper.h include; replaced five throwNewJavaException("java/lang/Exception", ...) calls with throwExecutorchException, preserving the actual error codes from the failed operations. Additionally: added default onError callbacks to LlmCallback (Java) and AsrCallback (Kotlin); This PR was authored with the assistance of Claude. cc @kirklandsign @cbilgin
1 parent 3604d3e commit 330df21

7 files changed

Lines changed: 268 additions & 64 deletions

File tree

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrCallback.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,12 @@ interface AsrCallback {
2525
* @param token The decoded text token
2626
*/
2727
fun onToken(token: String)
28+
29+
/**
30+
* Called when an error occurs during transcription.
31+
*
32+
* @param errorCode Error code from the ExecuTorch runtime
33+
* @param message Human-readable error description
34+
*/
35+
fun onError(errorCode: Int, message: String) {}
2836
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,14 @@ public interface LlmCallback {
3737
*/
3838
@DoNotStrip
3939
default void onStats(String stats) {}
40+
41+
/**
42+
* Called when an error occurs during generate().
43+
*
44+
* @param errorCode Error code from the ExecuTorch runtime (see {@link
45+
* org.pytorch.executorch.ExecutorchRuntimeException})
46+
* @param message Human-readable error description
47+
*/
48+
@DoNotStrip
49+
default void onError(int errorCode, String message) {}
4050
}

extension/android/jni/jni_helper.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,115 @@ void throwExecutorchException(uint32_t errorCode, const std::string& details) {
3535
facebook::jni::throwNewJavaException(exception.get());
3636
}
3737

38+
void setExecutorchPendingException(
39+
JNIEnv* env,
40+
uint32_t errorCode,
41+
const std::string& details) {
42+
if (!env) {
43+
return;
44+
}
45+
if (env->ExceptionCheck()) {
46+
// Preserve any preexisting pending exception; do not overwrite it here.
47+
return;
48+
}
49+
50+
jclass exceptionClass =
51+
env->FindClass("org/pytorch/executorch/ExecutorchRuntimeException");
52+
if (env->ExceptionCheck()) {
53+
if (exceptionClass) {
54+
env->DeleteLocalRef(exceptionClass);
55+
}
56+
// Preserve the original exception; do not clear or overwrite it.
57+
return;
58+
}
59+
60+
if (!exceptionClass) {
61+
// FindClass failed. It should have set a pending exception; leave it as is.
62+
return;
63+
}
64+
65+
jmethodID factoryMethod = env->GetStaticMethodID(
66+
exceptionClass,
67+
"makeExecutorchException",
68+
"(ILjava/lang/String;)Ljava/lang/RuntimeException;");
69+
if (env->ExceptionCheck()) {
70+
// Preserve the original exception; do not overwrite it.
71+
env->DeleteLocalRef(exceptionClass);
72+
return;
73+
}
74+
75+
if (!factoryMethod) {
76+
// Factory unavailable; try the (int, String) constructor directly so
77+
// callers still receive a structured ExecutorchRuntimeException.
78+
jmethodID ctor =
79+
env->GetMethodID(exceptionClass, "<init>", "(ILjava/lang/String;)V");
80+
if (!env->ExceptionCheck() && ctor) {
81+
jstring jDetails = env->NewStringUTF(details.c_str());
82+
if (!env->ExceptionCheck() && jDetails) {
83+
jobject exObj = env->NewObject(
84+
exceptionClass, ctor, static_cast<jint>(errorCode), jDetails);
85+
if (!env->ExceptionCheck() && exObj) {
86+
env->Throw(reinterpret_cast<jthrowable>(exObj));
87+
env->DeleteLocalRef(exObj);
88+
}
89+
env->DeleteLocalRef(jDetails);
90+
}
91+
} else {
92+
// Clear any NoSuchMethodError from GetMethodID before falling back.
93+
if (env->ExceptionCheck()) {
94+
env->ExceptionClear();
95+
}
96+
jclass runtimeExClass = env->FindClass("java/lang/RuntimeException");
97+
if (!env->ExceptionCheck() && runtimeExClass) {
98+
env->ThrowNew(runtimeExClass, details.c_str());
99+
env->DeleteLocalRef(runtimeExClass);
100+
}
101+
}
102+
env->DeleteLocalRef(exceptionClass);
103+
return;
104+
}
105+
106+
jstring jDetails = env->NewStringUTF(details.c_str());
107+
if (env->ExceptionCheck()) {
108+
// Preserve the original exception; do not overwrite it.
109+
if (jDetails) {
110+
env->DeleteLocalRef(jDetails);
111+
}
112+
env->DeleteLocalRef(exceptionClass);
113+
return;
114+
}
115+
116+
if (!jDetails) {
117+
// NewStringUTF returned null without setting an exception; fall back to
118+
// a standard RuntimeException since ExecutorchRuntimeException lacks a
119+
// (String) ctor.
120+
jclass runtimeExClass = env->FindClass("java/lang/RuntimeException");
121+
if (!env->ExceptionCheck() && runtimeExClass) {
122+
env->ThrowNew(runtimeExClass, details.c_str());
123+
env->DeleteLocalRef(runtimeExClass);
124+
}
125+
env->DeleteLocalRef(exceptionClass);
126+
return;
127+
}
128+
129+
auto exception = static_cast<jthrowable>(env->CallStaticObjectMethod(
130+
exceptionClass, factoryMethod, static_cast<jint>(errorCode), jDetails));
131+
if (env->ExceptionCheck() || !exception) {
132+
// If a Java exception was thrown, it is already pending; just clean up.
133+
if (exception) {
134+
env->DeleteLocalRef(exception);
135+
}
136+
env->DeleteLocalRef(jDetails);
137+
env->DeleteLocalRef(exceptionClass);
138+
return;
139+
}
140+
141+
env->Throw(exception);
142+
env->DeleteLocalRef(exception);
143+
env->DeleteLocalRef(jDetails);
144+
env->DeleteLocalRef(exceptionClass);
145+
}
146+
38147
bool utf8_check_validity(const char* str, size_t length) {
39148
for (size_t i = 0; i < length; ++i) {
40149
uint8_t byte = static_cast<uint8_t>(str[i]);

extension/android/jni/jni_helper.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,35 @@ namespace executorch::jni_helper {
2020
* code and details. Uses the Java factory method
2121
* ExecutorchRuntimeException.makeExecutorchException(int, String).
2222
*
23+
* IMPORTANT: This attempts to throw a C++ exception (via fbjni). Only use in
24+
* fbjni HybridClass methods where fbjni catches it at the JNI boundary.
25+
* For plain extern "C" JNIEXPORT functions, use setExecutorchPendingException.
26+
*
27+
* Note: If there is no current JNI environment (for example, if
28+
* facebook::jni::Environment::current() returns null), this function is a
29+
* no-op and does not throw. Callers must not rely on this always aborting
30+
* control flow.
31+
*
2332
* @param errorCode The error code from the C++ Executorch runtime.
2433
* @param details Additional details to include in the exception message.
2534
*/
2635
void throwExecutorchException(uint32_t errorCode, const std::string& details);
2736

37+
/**
38+
* Sets a pending Java ExecutorchRuntimeException without throwing a C++
39+
* exception. Safe to call from plain extern "C" JNIEXPORT functions.
40+
* After calling this, the caller must return from the JNI function promptly;
41+
* the Java exception will be delivered when control returns to the JVM.
42+
*
43+
* @param env The JNI environment pointer.
44+
* @param errorCode The error code from the C++ Executorch runtime.
45+
* @param details Additional details to include in the exception message.
46+
*/
47+
void setExecutorchPendingException(
48+
JNIEnv* env,
49+
uint32_t errorCode,
50+
const std::string& details);
51+
2852
// Define the JavaClass wrapper
2953
struct JExecutorchRuntimeException
3054
: public facebook::jni::JavaClass<JExecutorchRuntimeException> {

extension/android/jni/jni_layer_asr.cpp

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeCreate(
128128
auto load_error = handle->preprocessor->load();
129129
if (load_error != Error::Ok) {
130130
ET_LOG(Error, "Failed to load preprocessor module");
131-
env->ThrowNew(
132-
env->FindClass("java/lang/RuntimeException"),
131+
executorch::jni_helper::setExecutorchPendingException(
132+
env,
133+
static_cast<uint32_t>(load_error),
133134
"Failed to load preprocessor module");
134135
return 0;
135136
}
@@ -138,9 +139,10 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeCreate(
138139
return reinterpret_cast<jlong>(handle.release());
139140
} catch (const std::exception& e) {
140141
ET_LOG(Error, "Failed to create AsrModule: %s", e.what());
141-
env->ThrowNew(
142-
env->FindClass("java/lang/RuntimeException"),
143-
("Failed to create AsrModule: " + std::string(e.what())).c_str());
142+
executorch::jni_helper::setExecutorchPendingException(
143+
env,
144+
static_cast<uint32_t>(Error::Internal),
145+
"Failed to create AsrModule: " + std::string(e.what()));
144146
return 0;
145147
}
146148
}
@@ -172,8 +174,9 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeLoad(
172174
jobject /* this */,
173175
jlong nativeHandle) {
174176
if (nativeHandle == 0) {
175-
env->ThrowNew(
176-
env->FindClass("java/lang/IllegalStateException"),
177+
executorch::jni_helper::setExecutorchPendingException(
178+
env,
179+
static_cast<uint32_t>(Error::InvalidState),
177180
"Module has been destroyed");
178181
return -1;
179182
}
@@ -218,15 +221,17 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
218221
jlong decoderStartTokenId,
219222
jobject callback) {
220223
if (nativeHandle == 0) {
221-
env->ThrowNew(
222-
env->FindClass("java/lang/IllegalStateException"),
224+
executorch::jni_helper::setExecutorchPendingException(
225+
env,
226+
static_cast<uint32_t>(Error::InvalidState),
223227
"Module has been destroyed");
224228
return -1;
225229
}
226230

227231
if (wavPath == nullptr) {
228-
env->ThrowNew(
229-
env->FindClass("java/lang/IllegalArgumentException"),
232+
executorch::jni_helper::setExecutorchPendingException(
233+
env,
234+
static_cast<uint32_t>(Error::InvalidArgument),
230235
"WAV path cannot be null");
231236
return -1;
232237
}
@@ -239,15 +244,17 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
239244
try {
240245
audioData = ::executorch::extension::llm::load_wav_audio_data(wavPathStr);
241246
} catch (const std::exception& e) {
242-
env->ThrowNew(
243-
env->FindClass("java/lang/RuntimeException"),
244-
("Failed to load WAV file: " + std::string(e.what())).c_str());
247+
executorch::jni_helper::setExecutorchPendingException(
248+
env,
249+
static_cast<uint32_t>(Error::AccessFailed),
250+
"Failed to load WAV file: " + std::string(e.what()));
245251
return -1;
246252
}
247253

248254
if (audioData.empty()) {
249-
env->ThrowNew(
250-
env->FindClass("java/lang/IllegalArgumentException"),
255+
executorch::jni_helper::setExecutorchPendingException(
256+
env,
257+
static_cast<uint32_t>(Error::InvalidArgument),
251258
"WAV file contains no audio data");
252259
return -1;
253260
}
@@ -267,16 +274,18 @@ Java_org_pytorch_executorch_extension_asr_AsrModule_nativeTranscribe(
267274
auto processedResult =
268275
handle->preprocessor->execute("forward", audioTensor);
269276
if (processedResult.error() != Error::Ok) {
270-
env->ThrowNew(
271-
env->FindClass("java/lang/RuntimeException"),
277+
executorch::jni_helper::setExecutorchPendingException(
278+
env,
279+
static_cast<uint32_t>(processedResult.error()),
272280
"Audio preprocessing failed");
273281
return -1;
274282
}
275283

276284
auto outputs = std::move(processedResult.get());
277285
if (outputs.empty() || !outputs[0].isTensor()) {
278-
env->ThrowNew(
279-
env->FindClass("java/lang/RuntimeException"),
286+
executorch::jni_helper::setExecutorchPendingException(
287+
env,
288+
static_cast<uint32_t>(Error::Internal),
280289
"Preprocessor returned unexpected output");
281290
return -1;
282291
}

extension/android/jni/jni_layer_llama.cpp

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ class ExecuTorchLlmCallbackJni
7171
facebook::jni::make_jstring(
7272
executorch::extension::llm::stats_to_json_string(result)));
7373
}
74+
75+
void onError(int errorCode, const std::string& message) const {
76+
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
77+
static const auto on_error_method =
78+
cls->getMethod<void(jint, facebook::jni::local_ref<jstring>)>(
79+
"onError");
80+
on_error_method(self(), errorCode, facebook::jni::make_jstring(message));
81+
}
7482
};
7583

7684
class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
@@ -198,6 +206,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
198206
jfloat temperature,
199207
jint num_bos,
200208
jint num_eos) {
209+
Error err = Error::Ok;
210+
if (!prompt) {
211+
err = Error::InvalidArgument;
212+
if (callback) {
213+
callback->onError(
214+
static_cast<int>(err),
215+
"generate() failed: prompt must not be null");
216+
}
217+
return static_cast<jint>(err);
218+
}
219+
201220
float effective_temperature = temperature >= 0 ? temperature : temperature_;
202221
std::string token_buffer;
203222
auto token_callback = [callback, &token_buffer](const std::string& token) {
@@ -209,26 +228,53 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
209228
}
210229
std::string result = token_buffer;
211230
token_buffer.clear();
212-
callback->onResult(result);
231+
if (callback) {
232+
callback->onResult(result);
233+
}
213234
};
214235

215236
if (!runner_) {
216-
return static_cast<jint>(Error::InvalidState);
217-
}
218-
executorch::extension::llm::GenerationConfig config{
219-
.echo = static_cast<bool>(echo),
220-
.seq_len = seq_len,
221-
.temperature = effective_temperature,
222-
.num_bos = needs_bos_ ? num_bos_ : 0,
223-
.num_eos = num_eos,
224-
};
225-
auto err = runner_->generate(
226-
prompt->toStdString(),
227-
config,
228-
token_callback,
229-
[callback](const llm::Stats& result) { callback->onStats(result); });
230-
if (err == Error::Ok) {
231-
needs_bos_ = false;
237+
err = Error::InvalidState;
238+
if (callback) {
239+
callback->onError(
240+
static_cast<int>(err), "generate() failed: runner not initialized");
241+
}
242+
return static_cast<jint>(err);
243+
}
244+
245+
try {
246+
executorch::extension::llm::GenerationConfig config{
247+
.echo = static_cast<bool>(echo),
248+
.seq_len = seq_len,
249+
.temperature = effective_temperature,
250+
.num_bos = needs_bos_ ? num_bos_ : 0,
251+
.num_eos = num_eos,
252+
};
253+
err = runner_->generate(
254+
prompt->toStdString(),
255+
config,
256+
token_callback,
257+
[callback](const llm::Stats& result) {
258+
if (callback) {
259+
callback->onStats(result);
260+
}
261+
});
262+
if (err == Error::Ok) {
263+
needs_bos_ = false;
264+
}
265+
if (err != Error::Ok && callback) {
266+
callback->onError(
267+
static_cast<int>(err),
268+
"generate() failed with error code " +
269+
std::to_string(static_cast<int>(err)));
270+
}
271+
} catch (const std::exception& e) {
272+
if (callback) {
273+
callback->onError(
274+
static_cast<int>(Error::Internal),
275+
std::string("generate() threw: ") + e.what());
276+
}
277+
return static_cast<jint>(Error::Internal);
232278
}
233279
return static_cast<jint>(err);
234280
}

0 commit comments

Comments
 (0)