Skip to content

Commit 8f2e203

Browse files
committed
Refactor nexus completion handler
1 parent 39c205c commit 8f2e203

8 files changed

Lines changed: 203 additions & 104 deletions

File tree

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package nexusoperation
2+
3+
import (
4+
"context"
5+
6+
"github.com/nexus-rpc/sdk-go/nexus"
7+
"go.temporal.io/server/common/nexus/nexusrpc"
8+
)
9+
10+
type CompletionHandler struct{}
11+
12+
func NewCompletionHandler() *CompletionHandler {
13+
return &CompletionHandler{}
14+
}
15+
16+
func (h *CompletionHandler) CompleteOperation(context.Context, *nexusrpc.CompletionRequest) error {
17+
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "CHASM nexus completion is not implemented")
18+
}

chasm/lib/nexusoperation/fx.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ var Module = fx.Module(
3232
fx.Invoke(endpointRegistryLifetimeHooks),
3333
fx.Provide(defaultNexusTransportProvider),
3434
fx.Provide(clientProviderFactory),
35+
fx.Provide(NewCompletionHandler),
3536
fx.Provide(newCancellationBackoffTaskHandler),
3637
fx.Provide(newCancellationInvocationTaskHandler),
3738
fx.Provide(newOperationBackoffTaskHandler),
Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
11
package frontend
22

33
import (
4-
"net/http"
5-
6-
"github.com/gorilla/mux"
74
"go.temporal.io/server/common/dynamicconfig"
8-
"go.temporal.io/server/common/headers"
9-
"go.temporal.io/server/common/log"
10-
"go.temporal.io/server/common/metrics"
115
commonnexus "go.temporal.io/server/common/nexus"
12-
"go.temporal.io/server/common/nexus/nexusrpc"
13-
"go.temporal.io/server/common/rpc"
146
"go.temporal.io/server/components/nexusoperations"
157
"go.uber.org/fx"
168
)
@@ -19,7 +11,7 @@ var Module = fx.Module(
1911
"component.nexusoperations.frontend",
2012
fx.Provide(ConfigProvider),
2113
fx.Provide(commonnexus.NewCallbackTokenGenerator),
22-
fx.Invoke(RegisterHTTPHandler),
14+
fx.Provide(NewCompletionHandler),
2315
)
2416

2517
func ConfigProvider(coll *dynamicconfig.Collection) *Config {
@@ -29,29 +21,3 @@ func ConfigProvider(coll *dynamicconfig.Collection) *Config {
2921
MaxOperationTokenLength: nexusoperations.MaxOperationTokenLength.Get(coll),
3022
}
3123
}
32-
33-
func RegisterHTTPHandler(options HandlerOptions, logger log.Logger, router *mux.Router) {
34-
h := nexusrpc.NewCompletionHTTPHandler(nexusrpc.CompletionHandlerOptions{
35-
Handler: &completionHandler{
36-
options,
37-
headers.NewDefaultVersionChecker(),
38-
options.MetricsHandler.Counter(metrics.NexusCompletionRequestPreProcessErrors.Name()),
39-
},
40-
Logger: log.NewSlogLogger(logger),
41-
Serializer: commonnexus.PayloadSerializer,
42-
})
43-
router.Path("/" + commonnexus.RouteCompletionCallback.Representation()).HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
44-
// Limit the request body to max allowed Payload size.
45-
// Content headers are transformed to Payload metadata and contribute to the Payload size as well. A separate
46-
// limit is enforced on top of this in the CompleteOperation method.
47-
r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes)
48-
h.ServeHTTP(w, r)
49-
})
50-
router.Path(commonnexus.PathCompletionCallbackNoIdentifier).HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51-
// Limit the request body to max allowed Payload size.
52-
// Content headers are transformed to Payload metadata and contribute to the Payload size as well. A separate
53-
// limit is enforced on top of this in the CompleteOperation method.
54-
r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes)
55-
h.ServeHTTP(w, r)
56-
})
57-
}

components/nexusoperations/frontend/handler.go

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ type completionHandler struct {
8080
preProcessErrorsCounter metrics.CounterIface
8181
}
8282

83+
type CompletionHandler struct {
84+
*completionHandler
85+
}
86+
87+
func NewCompletionHandler(options HandlerOptions) *CompletionHandler {
88+
return &CompletionHandler{
89+
completionHandler: &completionHandler{
90+
HandlerOptions: options,
91+
clientVersionChecker: headers.NewDefaultVersionChecker(),
92+
preProcessErrorsCounter: options.MetricsHandler.Counter(metrics.NexusCompletionRequestPreProcessErrors.Name()),
93+
},
94+
}
95+
}
96+
8397
// CompleteOperation implements nexus.CompletionHandler.
8498
// nolint:revive // (cyclomatic complexity) This function is long but the complexity is justified.
8599
func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.CompletionRequest) (retErr error) {
@@ -89,63 +103,16 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
89103
h.Logger.Error("failed to decode callback token", tag.Error(err))
90104
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
91105
}
92-
93106
completion, err := h.CallbackTokenGenerator.DecodeCompletion(token)
94107
if err != nil {
95108
h.Logger.Error("failed to decode completion from token", tag.Error(err))
96109
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
97110
}
98-
ns, err := h.NamespaceRegistry.GetNamespaceByID(namespace.ID(completion.NamespaceId))
111+
rCtx, logger, ctx, err := h.newRequestContext(ctx, r, startTime, completion.GetNamespaceId(), completion.GetWorkflowId(), completion.GetRunId())
99112
if err != nil {
100-
h.Logger.Error("failed to get namespace for nexus completion request", tag.WorkflowNamespaceID(completion.NamespaceId), tag.Error(err))
101-
h.preProcessErrorsCounter.Record(1)
102-
var nfe *serviceerror.NamespaceNotFound
103-
if errors.As(err, &nfe) {
104-
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", completion.NamespaceId)
105-
}
106-
return commonnexus.ConvertGRPCError(err, false)
107-
}
108-
logger := log.With(
109-
h.Logger,
110-
tag.WorkflowNamespace(ns.Name().String()),
111-
tag.WorkflowID(completion.GetWorkflowId()),
112-
tag.WorkflowRunID(completion.GetRunId()),
113-
)
114-
rCtx := &requestContext{
115-
completionHandler: h,
116-
namespace: ns,
117-
workflowID: completion.GetWorkflowId(),
118-
logger: log.With(h.Logger, tag.WorkflowNamespace(ns.Name().String())),
119-
metricsHandler: h.MetricsHandler.WithTags(metrics.NamespaceTag(ns.Name().String())),
120-
metricsHandlerForInterceptors: h.MetricsHandler.WithTags(
121-
metrics.OperationTag(methodNameForMetrics),
122-
metrics.NamespaceTag(ns.Name().String()),
123-
),
124-
requestStartTime: startTime,
125-
}
126-
if r.HTTPRequest.Header != nil {
127-
rCtx.originalHeaders = r.HTTPRequest.Header.Clone()
113+
return err
128114
}
129-
ctx = rCtx.augmentContext(ctx, r.HTTPRequest.Header)
130115
defer rCtx.capturePanicAndRecordMetrics(&ctx, &retErr)
131-
if r.HTTPRequest.URL.Path != commonnexus.PathCompletionCallbackNoIdentifier {
132-
nsNameEscaped := commonnexus.RouteCompletionCallback.Deserialize(mux.Vars(r.HTTPRequest))
133-
nsName, err := url.PathUnescape(nsNameEscaped)
134-
if err != nil {
135-
h.Logger.Error("failed to extract namespace from request", tag.Error(err))
136-
h.preProcessErrorsCounter.Record(1)
137-
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL")
138-
}
139-
if nsName != ns.Name().String() {
140-
logger.Error(
141-
"namespace ID in token doesn't match the token",
142-
tag.WorkflowNamespaceID(ns.ID().String()),
143-
tag.Error(err),
144-
tag.String("completion-namespace-id", completion.GetNamespaceId()),
145-
)
146-
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
147-
}
148-
}
149116

150117
if err := rCtx.interceptRequest(ctx, r); err != nil {
151118
var notActiveErr *serviceerror.NamespaceNotActive
@@ -154,7 +121,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
154121
}
155122
return err
156123
}
157-
tokenLimit := h.Config.MaxOperationTokenLength(ns.Name().String())
124+
tokenLimit := h.Config.MaxOperationTokenLength(rCtx.namespace.Name().String())
158125
if len(r.OperationToken) > tokenLimit {
159126
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation token length exceeds allowed limit (%d/%d)", len(r.OperationToken), tokenLimit)
160127
}
@@ -201,8 +168,8 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
201168
logger.Error("cannot deserialize payload from completion result", tag.Error(err))
202169
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result content")
203170
}
204-
if result.Size() > h.Config.PayloadSizeLimit(ns.Name().String()) {
205-
logger.Error("payload size exceeds error limit for Nexus CompleteOperation request", tag.WorkflowNamespace(ns.Name().String()))
171+
if result.Size() > h.Config.PayloadSizeLimit(rCtx.namespace.Name().String()) {
172+
logger.Error("payload size exceeds error limit for Nexus CompleteOperation request", tag.WorkflowNamespace(rCtx.namespace.Name().String()))
206173
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "result exceeds size limit")
207174
}
208175
hr.Outcome = &historyservice.CompleteNexusOperationRequest_Success{
@@ -229,6 +196,70 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
229196
return nil
230197
}
231198

199+
func (h *completionHandler) newRequestContext(
200+
ctx context.Context,
201+
r *nexusrpc.CompletionRequest,
202+
startTime time.Time,
203+
namespaceID string,
204+
workflowID string,
205+
runID string,
206+
) (*requestContext, log.Logger, context.Context, error) {
207+
ns, err := h.NamespaceRegistry.GetNamespaceByID(namespace.ID(namespaceID))
208+
if err != nil {
209+
h.Logger.Error("failed to get namespace for nexus completion request", tag.WorkflowNamespaceID(namespaceID), tag.Error(err))
210+
h.preProcessErrorsCounter.Record(1)
211+
var nfe *serviceerror.NamespaceNotFound
212+
if errors.As(err, &nfe) {
213+
return nil, nil, nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", namespaceID)
214+
}
215+
return nil, nil, nil, commonnexus.ConvertGRPCError(err, false)
216+
}
217+
218+
logger := log.With(
219+
h.Logger,
220+
tag.WorkflowNamespace(ns.Name().String()),
221+
tag.WorkflowID(workflowID),
222+
tag.WorkflowRunID(runID),
223+
)
224+
rCtx := &requestContext{
225+
completionHandler: h,
226+
namespace: ns,
227+
workflowID: workflowID,
228+
logger: log.With(h.Logger, tag.WorkflowNamespace(ns.Name().String())),
229+
metricsHandler: h.MetricsHandler.WithTags(metrics.NamespaceTag(ns.Name().String())),
230+
metricsHandlerForInterceptors: h.MetricsHandler.WithTags(
231+
metrics.OperationTag(methodNameForMetrics),
232+
metrics.NamespaceTag(ns.Name().String()),
233+
),
234+
requestStartTime: startTime,
235+
}
236+
if r.HTTPRequest.Header != nil {
237+
rCtx.originalHeaders = r.HTTPRequest.Header.Clone()
238+
}
239+
ctx = rCtx.augmentContext(ctx, r.HTTPRequest.Header)
240+
241+
if r.HTTPRequest.URL.Path != commonnexus.PathCompletionCallbackNoIdentifier {
242+
nsNameEscaped := commonnexus.RouteCompletionCallback.Deserialize(mux.Vars(r.HTTPRequest))
243+
nsName, err := url.PathUnescape(nsNameEscaped)
244+
if err != nil {
245+
h.Logger.Error("failed to extract namespace from request", tag.Error(err))
246+
h.preProcessErrorsCounter.Record(1)
247+
return nil, nil, nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL")
248+
}
249+
if nsName != ns.Name().String() {
250+
logger.Error(
251+
"namespace ID in token doesn't match the token",
252+
tag.WorkflowNamespaceID(ns.ID().String()),
253+
tag.Error(err),
254+
tag.String("completion-namespace-id", namespaceID),
255+
)
256+
return nil, nil, nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
257+
}
258+
}
259+
260+
return rCtx, logger, ctx, nil
261+
}
262+
232263
func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nexusrpc.CompletionRequest, rCtx *requestContext) error {
233264
client, err := h.ForwardingClients.Get(rCtx.namespace.ActiveClusterName(rCtx.workflowID))
234265
if err != nil {
@@ -320,6 +351,10 @@ type requestContext struct {
320351
originalHeaders http.Header
321352
}
322353

354+
type RequestContext struct {
355+
*requestContext
356+
}
357+
323358
func (c *requestContext) augmentContext(ctx context.Context, header http.Header) context.Context {
324359
ctx = metrics.AddMetricsContext(ctx)
325360
ctx = interceptor.AddTelemetryContext(ctx, c.metricsHandlerForInterceptors)

service/frontend/fx.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ var Module = fx.Options(
107107
fx.Provide(OperatorHandlerProvider),
108108
fx.Provide(NewVersionChecker),
109109
fx.Provide(ServiceResolverProvider),
110-
fx.Invoke(RegisterNexusHTTPHandler),
110+
fx.Invoke(RegisterNexusOperationHTTPHandler),
111+
fx.Invoke(RegisterNexusCompletionHTTPHandler),
111112
fx.Invoke(RegisterOpenAPIHTTPHandler),
112113
fx.Provide(HTTPAPIServerProvider),
113114
fx.Provide(NewServiceProvider),
@@ -862,7 +863,7 @@ func HandlerProvider(
862863
return wfHandler
863864
}
864865

865-
func RegisterNexusHTTPHandler(
866+
func RegisterNexusOperationHTTPHandler(
866867
serviceConfig *Config,
867868
serviceName primitives.ServiceName,
868869
matchingClient resource.MatchingClient,
@@ -883,7 +884,7 @@ func RegisterNexusHTTPHandler(
883884
router *mux.Router,
884885
httpTraceProvider nexus.HTTPClientTraceProvider,
885886
) {
886-
h := NewNexusHTTPHandler(
887+
h := NewNexusOperationHTTPHandler(
887888
serviceConfig,
888889
matchingClient,
889890
metricsHandler,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package frontend
2+
3+
import (
4+
"context"
5+
"net/http"
6+
7+
"github.com/gorilla/mux"
8+
"github.com/nexus-rpc/sdk-go/nexus"
9+
persistencespb "go.temporal.io/server/api/persistence/v1"
10+
chasmnexusoperation "go.temporal.io/server/chasm/lib/nexusoperation"
11+
"go.temporal.io/server/common/log"
12+
commonnexus "go.temporal.io/server/common/nexus"
13+
"go.temporal.io/server/common/nexus/nexusrpc"
14+
"go.temporal.io/server/common/rpc"
15+
nexusfrontend "go.temporal.io/server/components/nexusoperations/frontend"
16+
)
17+
18+
type completionRoutingHandler struct {
19+
callbackTokenGenerator *commonnexus.CallbackTokenGenerator
20+
hsm nexusrpc.CompletionHandler
21+
chasm nexusrpc.CompletionHandler
22+
}
23+
24+
func (h *completionRoutingHandler) CompleteOperation(ctx context.Context, r *nexusrpc.CompletionRequest) error {
25+
token, err := commonnexus.DecodeCallbackToken(r.HTTPRequest.Header.Get(commonnexus.CallbackTokenHeader))
26+
if err != nil {
27+
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
28+
}
29+
30+
completion, err := h.callbackTokenGenerator.DecodeCompletion(token)
31+
if err != nil {
32+
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
33+
}
34+
35+
// TODO: need to handle migration between HSM and CHASM
36+
37+
if len(completion.GetComponentRef()) > 0 {
38+
ref := &persistencespb.ChasmComponentRef{}
39+
if err := ref.Unmarshal(completion.GetComponentRef()); err != nil {
40+
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
41+
}
42+
return h.chasm.CompleteOperation(ctx, r)
43+
}
44+
return h.hsm.CompleteOperation(ctx, r)
45+
}
46+
47+
func RegisterNexusCompletionHTTPHandler(
48+
hsmHandler *nexusfrontend.CompletionHandler,
49+
chasmHandler *chasmnexusoperation.CompletionHandler,
50+
callbackTokenGenerator *commonnexus.CallbackTokenGenerator,
51+
logger log.Logger,
52+
router *mux.Router,
53+
) {
54+
h := nexusrpc.NewCompletionHTTPHandler(nexusrpc.CompletionHandlerOptions{
55+
Handler: &completionRoutingHandler{
56+
callbackTokenGenerator: callbackTokenGenerator,
57+
hsm: hsmHandler,
58+
chasm: chasmHandler,
59+
},
60+
Logger: log.NewSlogLogger(logger),
61+
Serializer: commonnexus.PayloadSerializer,
62+
})
63+
64+
router.Path("/" + commonnexus.RouteCompletionCallback.Representation()).HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65+
// Limit the request body to max allowed Payload size.
66+
// Content headers are transformed to Payload metadata and contribute to the Payload size as well. A separate
67+
// limit is enforced on top of this in the CompleteOperation method.
68+
r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes)
69+
h.ServeHTTP(w, r)
70+
})
71+
router.Path(commonnexus.PathCompletionCallbackNoIdentifier).HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
72+
// Limit the request body to max allowed Payload size.
73+
// Content headers are transformed to Payload metadata and contribute to the Payload size as well. A separate
74+
// limit is enforced on top of this in the CompleteOperation method.
75+
r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes)
76+
h.ServeHTTP(w, r)
77+
})
78+
}

0 commit comments

Comments
 (0)