@@ -80,6 +80,20 @@ type completionHandler struct {
8080 preProcessErrorsCounter metrics.CounterIface
8181}
8282
83+ type CompletionHandler struct {
84+ * completionHandler
85+ }
86+
87+ func NewCompletionHandler (options HandlerOptions ) * CompletionHandler {
88+ return & CompletionHandler {
89+ completionHandler : & completionHandler {
90+ HandlerOptions : options ,
91+ clientVersionChecker : headers .NewDefaultVersionChecker (),
92+ preProcessErrorsCounter : options .MetricsHandler .Counter (metrics .NexusCompletionRequestPreProcessErrors .Name ()),
93+ },
94+ }
95+ }
96+
8397// CompleteOperation implements nexus.CompletionHandler.
8498// nolint:revive // (cyclomatic complexity) This function is long but the complexity is justified.
8599func (h * completionHandler ) CompleteOperation (ctx context.Context , r * nexusrpc.CompletionRequest ) (retErr error ) {
@@ -89,63 +103,16 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
89103 h .Logger .Error ("failed to decode callback token" , tag .Error (err ))
90104 return nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "invalid callback token" )
91105 }
92-
93106 completion , err := h .CallbackTokenGenerator .DecodeCompletion (token )
94107 if err != nil {
95108 h .Logger .Error ("failed to decode completion from token" , tag .Error (err ))
96109 return nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "invalid callback token" )
97110 }
98- ns , err := h .NamespaceRegistry . GetNamespaceByID ( namespace . ID ( completion .NamespaceId ))
111+ rCtx , logger , ctx , err := h .newRequestContext ( ctx , r , startTime , completion . GetNamespaceId (), completion . GetWorkflowId (), completion .GetRunId ( ))
99112 if err != nil {
100- h .Logger .Error ("failed to get namespace for nexus completion request" , tag .WorkflowNamespaceID (completion .NamespaceId ), tag .Error (err ))
101- h .preProcessErrorsCounter .Record (1 )
102- var nfe * serviceerror.NamespaceNotFound
103- if errors .As (err , & nfe ) {
104- return nexus .NewHandlerErrorf (nexus .HandlerErrorTypeNotFound , "namespace %q not found" , completion .NamespaceId )
105- }
106- return commonnexus .ConvertGRPCError (err , false )
107- }
108- logger := log .With (
109- h .Logger ,
110- tag .WorkflowNamespace (ns .Name ().String ()),
111- tag .WorkflowID (completion .GetWorkflowId ()),
112- tag .WorkflowRunID (completion .GetRunId ()),
113- )
114- rCtx := & requestContext {
115- completionHandler : h ,
116- namespace : ns ,
117- workflowID : completion .GetWorkflowId (),
118- logger : log .With (h .Logger , tag .WorkflowNamespace (ns .Name ().String ())),
119- metricsHandler : h .MetricsHandler .WithTags (metrics .NamespaceTag (ns .Name ().String ())),
120- metricsHandlerForInterceptors : h .MetricsHandler .WithTags (
121- metrics .OperationTag (methodNameForMetrics ),
122- metrics .NamespaceTag (ns .Name ().String ()),
123- ),
124- requestStartTime : startTime ,
125- }
126- if r .HTTPRequest .Header != nil {
127- rCtx .originalHeaders = r .HTTPRequest .Header .Clone ()
113+ return err
128114 }
129- ctx = rCtx .augmentContext (ctx , r .HTTPRequest .Header )
130115 defer rCtx .capturePanicAndRecordMetrics (& ctx , & retErr )
131- if r .HTTPRequest .URL .Path != commonnexus .PathCompletionCallbackNoIdentifier {
132- nsNameEscaped := commonnexus .RouteCompletionCallback .Deserialize (mux .Vars (r .HTTPRequest ))
133- nsName , err := url .PathUnescape (nsNameEscaped )
134- if err != nil {
135- h .Logger .Error ("failed to extract namespace from request" , tag .Error (err ))
136- h .preProcessErrorsCounter .Record (1 )
137- return nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "invalid URL" )
138- }
139- if nsName != ns .Name ().String () {
140- logger .Error (
141- "namespace ID in token doesn't match the token" ,
142- tag .WorkflowNamespaceID (ns .ID ().String ()),
143- tag .Error (err ),
144- tag .String ("completion-namespace-id" , completion .GetNamespaceId ()),
145- )
146- return nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "invalid callback token" )
147- }
148- }
149116
150117 if err := rCtx .interceptRequest (ctx , r ); err != nil {
151118 var notActiveErr * serviceerror.NamespaceNotActive
@@ -154,7 +121,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
154121 }
155122 return err
156123 }
157- tokenLimit := h .Config .MaxOperationTokenLength (ns .Name ().String ())
124+ tokenLimit := h .Config .MaxOperationTokenLength (rCtx . namespace .Name ().String ())
158125 if len (r .OperationToken ) > tokenLimit {
159126 return nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "operation token length exceeds allowed limit (%d/%d)" , len (r .OperationToken ), tokenLimit )
160127 }
@@ -201,8 +168,8 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
201168 logger .Error ("cannot deserialize payload from completion result" , tag .Error (err ))
202169 return nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "invalid result content" )
203170 }
204- if result .Size () > h .Config .PayloadSizeLimit (ns .Name ().String ()) {
205- logger .Error ("payload size exceeds error limit for Nexus CompleteOperation request" , tag .WorkflowNamespace (ns .Name ().String ()))
171+ if result .Size () > h .Config .PayloadSizeLimit (rCtx . namespace .Name ().String ()) {
172+ logger .Error ("payload size exceeds error limit for Nexus CompleteOperation request" , tag .WorkflowNamespace (rCtx . namespace .Name ().String ()))
206173 return nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "result exceeds size limit" )
207174 }
208175 hr .Outcome = & historyservice.CompleteNexusOperationRequest_Success {
@@ -229,6 +196,70 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
229196 return nil
230197}
231198
199+ func (h * completionHandler ) newRequestContext (
200+ ctx context.Context ,
201+ r * nexusrpc.CompletionRequest ,
202+ startTime time.Time ,
203+ namespaceID string ,
204+ workflowID string ,
205+ runID string ,
206+ ) (* requestContext , log.Logger , context.Context , error ) {
207+ ns , err := h .NamespaceRegistry .GetNamespaceByID (namespace .ID (namespaceID ))
208+ if err != nil {
209+ h .Logger .Error ("failed to get namespace for nexus completion request" , tag .WorkflowNamespaceID (namespaceID ), tag .Error (err ))
210+ h .preProcessErrorsCounter .Record (1 )
211+ var nfe * serviceerror.NamespaceNotFound
212+ if errors .As (err , & nfe ) {
213+ return nil , nil , nil , nexus .NewHandlerErrorf (nexus .HandlerErrorTypeNotFound , "namespace %q not found" , namespaceID )
214+ }
215+ return nil , nil , nil , commonnexus .ConvertGRPCError (err , false )
216+ }
217+
218+ logger := log .With (
219+ h .Logger ,
220+ tag .WorkflowNamespace (ns .Name ().String ()),
221+ tag .WorkflowID (workflowID ),
222+ tag .WorkflowRunID (runID ),
223+ )
224+ rCtx := & requestContext {
225+ completionHandler : h ,
226+ namespace : ns ,
227+ workflowID : workflowID ,
228+ logger : log .With (h .Logger , tag .WorkflowNamespace (ns .Name ().String ())),
229+ metricsHandler : h .MetricsHandler .WithTags (metrics .NamespaceTag (ns .Name ().String ())),
230+ metricsHandlerForInterceptors : h .MetricsHandler .WithTags (
231+ metrics .OperationTag (methodNameForMetrics ),
232+ metrics .NamespaceTag (ns .Name ().String ()),
233+ ),
234+ requestStartTime : startTime ,
235+ }
236+ if r .HTTPRequest .Header != nil {
237+ rCtx .originalHeaders = r .HTTPRequest .Header .Clone ()
238+ }
239+ ctx = rCtx .augmentContext (ctx , r .HTTPRequest .Header )
240+
241+ if r .HTTPRequest .URL .Path != commonnexus .PathCompletionCallbackNoIdentifier {
242+ nsNameEscaped := commonnexus .RouteCompletionCallback .Deserialize (mux .Vars (r .HTTPRequest ))
243+ nsName , err := url .PathUnescape (nsNameEscaped )
244+ if err != nil {
245+ h .Logger .Error ("failed to extract namespace from request" , tag .Error (err ))
246+ h .preProcessErrorsCounter .Record (1 )
247+ return nil , nil , nil , nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "invalid URL" )
248+ }
249+ if nsName != ns .Name ().String () {
250+ logger .Error (
251+ "namespace ID in token doesn't match the token" ,
252+ tag .WorkflowNamespaceID (ns .ID ().String ()),
253+ tag .Error (err ),
254+ tag .String ("completion-namespace-id" , namespaceID ),
255+ )
256+ return nil , nil , nil , nexus .NewHandlerErrorf (nexus .HandlerErrorTypeBadRequest , "invalid callback token" )
257+ }
258+ }
259+
260+ return rCtx , logger , ctx , nil
261+ }
262+
232263func (h * completionHandler ) forwardCompleteOperation (ctx context.Context , r * nexusrpc.CompletionRequest , rCtx * requestContext ) error {
233264 client , err := h .ForwardingClients .Get (rCtx .namespace .ActiveClusterName (rCtx .workflowID ))
234265 if err != nil {
@@ -320,6 +351,10 @@ type requestContext struct {
320351 originalHeaders http.Header
321352}
322353
354+ type RequestContext struct {
355+ * requestContext
356+ }
357+
323358func (c * requestContext ) augmentContext (ctx context.Context , header http.Header ) context.Context {
324359 ctx = metrics .AddMetricsContext (ctx )
325360 ctx = interceptor .AddTelemetryContext (ctx , c .metricsHandlerForInterceptors )
0 commit comments