Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@
import com.tencent.trpc.proto.http.common.RpcServerContextWithHttp;
import com.tencent.trpc.proto.http.common.TrpcServletRequestWrapper;
import com.tencent.trpc.proto.http.common.TrpcServletResponseWrapper;
import java.io.IOException;
import java.lang.reflect.Type;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand All @@ -65,34 +69,41 @@ public abstract class AbstractHttpExecutor {

protected void execute(HttpServletRequest request, HttpServletResponse response,
RpcMethodInfoAndInvoker methodInfoAndInvoker) {

AtomicBoolean responded = new AtomicBoolean(false);
try {

DefRequest rpcRequest = buildDefRequest(request, response, methodInfoAndInvoker);

CountDownLatch countDownLatch = new CountDownLatch(1);
CompletableFuture<Void> completionFuture = new CompletableFuture<>();

// use a thread pool for asynchronous processing
invokeRpcRequest(methodInfoAndInvoker.getInvoker(), rpcRequest, countDownLatch);
invokeRpcRequest(methodInfoAndInvoker.getInvoker(), rpcRequest, completionFuture, responded);

// If the request carries a timeout, use this timeout to wait for the request to be processed.
// If not carried, use the default timeout.
long requestTimeout = rpcRequest.getMeta().getTimeout();
if (requestTimeout <= 0) {
requestTimeout = methodInfoAndInvoker.getInvoker().getConfig().getRequestTimeout();
}
if (requestTimeout > 0 && !countDownLatch.await(requestTimeout, TimeUnit.MILLISECONDS)) {
throw TRpcException.newFrameException(ErrorCode.TRPC_SERVER_TIMEOUT_ERR,
"wait http request execute timeout");
if (requestTimeout > 0) {
try {
completionFuture.get(requestTimeout, TimeUnit.MILLISECONDS);
} catch (TimeoutException ex) {
if (responded.compareAndSet(false, true)) {
doErrorReply(request, response,
TRpcException.newFrameException(ErrorCode.TRPC_SERVER_TIMEOUT_ERR,
"wait http request execute timeout"));
}
}
} else {
countDownLatch.await();
completionFuture.get();
}

} catch (Exception ex) {
logger.error("dispatch request [{}] error", request, ex);
doErrorReply(request, response, ex);
if (responded.compareAndSet(false, true)) {
doErrorReply(request, response, ex);
}
}

}

/**
Expand All @@ -107,55 +118,83 @@ protected void execute(HttpServletRequest request, HttpServletResponse response,
/**
* Request processing
*
* @param countDownLatch latch used to wait for the request processing
* @param invoker the invoker
* @param rpcRequest the rpc request
* @param completionFuture the completion future
* @param responded the responded flag
*/
private void invokeRpcRequest(ProviderInvoker<?> invoker, DefRequest rpcRequest, CountDownLatch countDownLatch) {
private void invokeRpcRequest(ProviderInvoker<?> invoker, DefRequest rpcRequest,
CompletableFuture<Void> completionFuture,
AtomicBoolean responded) {

WorkerPool workerPool = invoker.getConfig().getWorkerPoolObj();

if (null == workerPool) {
logger.error("dispatch rpcRequest [{}] error, workerPool is empty", rpcRequest);
throw TRpcException.newFrameException(ErrorCode.TRPC_SERVER_NOSERVICE_ERR,
"not found service, workerPool is empty");
completionFuture.completeExceptionally(TRpcException.newFrameException(ErrorCode.TRPC_SERVER_NOSERVICE_ERR,
"not found service, workerPool is empty"));
return;
}

workerPool.execute(() -> {

// Get the original http response
HttpServletResponse response = getOriginalResponse(rpcRequest);

// Invoke the routing implementation method to handle the request.
CompletionStage<Response> future = invoker.invoke(rpcRequest);
future.whenComplete((result, t) -> {
try {
// Throw the call exception, which will be handled uniformly by the exception handling program.
if (t != null) {
throw t;
}

// Throw a business logic exception, which will be handled uniformly
// by the exception handling program.
Throwable ex = result.getException();
if (ex != null) {
throw ex;
try {
// Get the original http response
HttpServletResponse response = getOriginalResponse(rpcRequest);
// Invoke the routing implementation method to handle the request.
CompletionStage<Response> rpcFuture = invoker.invoke(rpcRequest);

rpcFuture.whenComplete((result, throwable) -> {
try {
if (responded.get()) {
return;
}

// Throw the call exception, which will be handled uniformly by the exception handling program.
if (throwable != null) {
throw throwable;
}

// Throw a business logic exception, which will be handled uniformly
// by the exception handling program.
if (result.getException() != null) {
throw result.getException();
}

// normal response
if (responded.compareAndSet(false, true)) {
response.setStatus(HttpStatus.SC_OK);
httpCodec.writeHttpResponse(response, result);
response.flushBuffer();
}

completionFuture.complete(null);
} catch (Throwable t) {
handleError(t, rpcRequest, response, responded, completionFuture);
}
});

// normal response
response.setStatus(HttpStatus.SC_OK);
httpCodec.writeHttpResponse(response, result);
response.flushBuffer();
} catch (Throwable e) {
HttpServletRequest request = getOriginalRequest(rpcRequest);
logger.warn("reply message error, channel: [{}], msg:[{}]", request.getRemoteAddr(), request, e);
httpErrorReply(request, response,
ErrorResponse.create(request, HttpStatus.SC_SERVICE_UNAVAILABLE, e));
} finally {
countDownLatch.countDown();
}
});
} catch (Exception e) {
handleError(e, rpcRequest, getOriginalResponse(rpcRequest), responded, completionFuture);
}
});
}

/**
* Handle error
*/
private void handleError(Throwable t, DefRequest rpcRequest, HttpServletResponse response,
AtomicBoolean responded, CompletableFuture<Void> completionFuture) {
try {
if (responded.compareAndSet(false, true)) {
HttpServletRequest request = getOriginalRequest(rpcRequest);
logger.warn("reply message error, channel: [{}], msg:[{}]", request.getRemoteAddr(), request, t);
httpErrorReply(request, response, ErrorResponse.create(request, HttpStatus.SC_SERVICE_UNAVAILABLE, t));
}
} finally {
completionFuture.completeExceptionally(t);
}
}

/**
* Build the context request.
*
Expand Down Expand Up @@ -480,4 +519,4 @@ private String getString(String[] callInfos, int length, int cursor) {
return callInfos.length < length ? StringUtils.EMPTY : callInfos[cursor];
}

}
}
Loading