Skip to content

Commit ca4c950

Browse files
committed
Refactor nexus completion handler
1 parent 39c205c commit ca4c950

8 files changed

Lines changed: 142 additions & 50 deletions

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package nexusoperation
2+
3+
import (
4+
"context"
5+
6+
"github.com/nexus-rpc/sdk-go/nexus"
7+
"go.temporal.io/server/common/nexus/nexusrpc"
8+
)
9+
10+
type CompletionHandler struct{}
11+
12+
func newCompletionHandler() *CompletionHandler {
13+
return &CompletionHandler{}
14+
}
15+
16+
func (h *CompletionHandler) CompleteOperation(
17+
context.Context,
18+
*nexusrpc.CompletionRequest,
19+
) error {
20+
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "CHASM nexus completion is not implemented")
21+
}

chasm/lib/nexusoperation/fx.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ var Module = fx.Module(
3232
fx.Invoke(endpointRegistryLifetimeHooks),
3333
fx.Provide(defaultNexusTransportProvider),
3434
fx.Provide(clientProviderFactory),
35+
fx.Provide(newCompletionHandler),
3536
fx.Provide(newCancellationBackoffTaskHandler),
3637
fx.Provide(newCancellationInvocationTaskHandler),
3738
fx.Provide(newOperationBackoffTaskHandler),

common/nexus/callback_token.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66

77
tokenspb "go.temporal.io/server/api/token/v1"
8+
"go.temporal.io/server/common/nexus/nexusrpc"
89
"google.golang.org/grpc/codes"
910
"google.golang.org/grpc/status"
1011
"google.golang.org/protobuf/proto"
@@ -14,7 +15,7 @@ const (
1415
// Currently supported token version.
1516
TokenVersion = 1
1617
// Header key for the callback token in StartOperation requests.
17-
CallbackTokenHeader = "Temporal-Callback-Token"
18+
CallbackTokenHeader = nexusrpc.CallbackTokenHeader
1819
)
1920

2021
// CallbackToken contains an encoded NexusOperationCompletion message.

common/nexus/nexusrpc/completion.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ import (
1010
"time"
1111

1212
"github.com/nexus-rpc/sdk-go/nexus"
13+
tokenspb "go.temporal.io/server/api/token/v1"
1314
)
1415

16+
// Header key for the callback token in StartOperation requests.
17+
const CallbackTokenHeader = "Temporal-Callback-Token"
18+
1519
// CompletionHTTPClient is a client for sending Nexus operation completion callbacks via HTTP.
1620
type CompletionHTTPClient struct {
1721
baseHTTPClient
@@ -184,6 +188,8 @@ func (c CompleteOperationOptions) applyToHTTPRequest(cc *CompletionHTTPClient, r
184188
type CompletionRequest struct {
185189
// The original HTTP request.
186190
HTTPRequest *http.Request
191+
// Decoded Temporal callback token, if present and valid.
192+
CompletionToken *tokenspb.NexusOperationCompletion
187193
// State of the operation.
188194
State nexus.OperationState
189195
// OperationToken is the unique token for this operation. Used when a completion callback is received before a
@@ -220,6 +226,9 @@ type CompletionHandlerOptions struct {
220226
// A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to
221227
// [DefaultFailureConverter].
222228
FailureConverter FailureConverter
229+
// DecodeCompletionToken decodes the callback token from the original HTTP request into the
230+
// server-specific completion payload. If nil, CompletionToken is left unset.
231+
DecodeCompletionToken func(string) (*tokenspb.NexusOperationCompletion, error)
223232
}
224233

225234
type completionHTTPHandler struct {
@@ -234,6 +243,16 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h
234243
OperationToken: request.Header.Get(nexus.HeaderOperationToken),
235244
HTTPRequest: request,
236245
}
246+
if h.options.DecodeCompletionToken != nil {
247+
if callbackToken := request.Header.Get(CallbackTokenHeader); callbackToken != "" {
248+
var err error
249+
completion.CompletionToken, err = h.options.DecodeCompletionToken(callbackToken)
250+
if err != nil {
251+
h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token"))
252+
return
253+
}
254+
}
255+
}
237256
if startTimeHeader := request.Header.Get(headerOperationStartTime); startTimeHeader != "" {
238257
var parseTimeErr error
239258
if completion.StartTime, parseTimeErr = http.ParseTime(startTimeHeader); parseTimeErr != nil {
Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
11
package frontend
22

33
import (
4-
"net/http"
5-
6-
"github.com/gorilla/mux"
74
"go.temporal.io/server/common/dynamicconfig"
8-
"go.temporal.io/server/common/headers"
9-
"go.temporal.io/server/common/log"
10-
"go.temporal.io/server/common/metrics"
115
commonnexus "go.temporal.io/server/common/nexus"
12-
"go.temporal.io/server/common/nexus/nexusrpc"
13-
"go.temporal.io/server/common/rpc"
146
"go.temporal.io/server/components/nexusoperations"
157
"go.uber.org/fx"
168
)
@@ -19,7 +11,7 @@ var Module = fx.Module(
1911
"component.nexusoperations.frontend",
2012
fx.Provide(ConfigProvider),
2113
fx.Provide(commonnexus.NewCallbackTokenGenerator),
22-
fx.Invoke(RegisterHTTPHandler),
14+
fx.Provide(newCompletionHandler),
2315
)
2416

2517
func ConfigProvider(coll *dynamicconfig.Collection) *Config {
@@ -29,29 +21,3 @@ func ConfigProvider(coll *dynamicconfig.Collection) *Config {
2921
MaxOperationTokenLength: nexusoperations.MaxOperationTokenLength.Get(coll),
3022
}
3123
}
32-
33-
func RegisterHTTPHandler(options HandlerOptions, logger log.Logger, router *mux.Router) {
34-
h := nexusrpc.NewCompletionHTTPHandler(nexusrpc.CompletionHandlerOptions{
35-
Handler: &completionHandler{
36-
options,
37-
headers.NewDefaultVersionChecker(),
38-
options.MetricsHandler.Counter(metrics.NexusCompletionRequestPreProcessErrors.Name()),
39-
},
40-
Logger: log.NewSlogLogger(logger),
41-
Serializer: commonnexus.PayloadSerializer,
42-
})
43-
router.Path("/" + commonnexus.RouteCompletionCallback.Representation()).HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
44-
// Limit the request body to max allowed Payload size.
45-
// Content headers are transformed to Payload metadata and contribute to the Payload size as well. A separate
46-
// limit is enforced on top of this in the CompleteOperation method.
47-
r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes)
48-
h.ServeHTTP(w, r)
49-
})
50-
router.Path(commonnexus.PathCompletionCallbackNoIdentifier).HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51-
// Limit the request body to max allowed Payload size.
52-
// Content headers are transformed to Payload metadata and contribute to the Payload size as well. A separate
53-
// limit is enforced on top of this in the CompleteOperation method.
54-
r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes)
55-
h.ServeHTTP(w, r)
56-
})
57-
}

components/nexusoperations/frontend/handler.go

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,27 +74,30 @@ type HandlerOptions struct {
7474
HTTPTraceProvider commonnexus.HTTPClientTraceProvider
7575
}
7676

77-
type completionHandler struct {
77+
type CompletionHandler struct {
7878
HandlerOptions
7979
clientVersionChecker headers.VersionChecker
8080
preProcessErrorsCounter metrics.CounterIface
8181
}
8282

83+
func newCompletionHandler(options HandlerOptions) *CompletionHandler {
84+
return &CompletionHandler{
85+
HandlerOptions: options,
86+
clientVersionChecker: headers.NewDefaultVersionChecker(),
87+
preProcessErrorsCounter: options.MetricsHandler.Counter(metrics.NexusCompletionRequestPreProcessErrors.Name()),
88+
}
89+
}
90+
8391
// CompleteOperation implements nexus.CompletionHandler.
8492
// nolint:revive // (cyclomatic complexity) This function is long but the complexity is justified.
85-
func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.CompletionRequest) (retErr error) {
93+
func (h *CompletionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.CompletionRequest) (retErr error) {
8694
startTime := time.Now()
87-
token, err := commonnexus.DecodeCallbackToken(r.HTTPRequest.Header.Get(commonnexus.CallbackTokenHeader))
88-
if err != nil {
89-
h.Logger.Error("failed to decode callback token", tag.Error(err))
90-
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
91-
}
92-
93-
completion, err := h.CallbackTokenGenerator.DecodeCompletion(token)
94-
if err != nil {
95-
h.Logger.Error("failed to decode completion from token", tag.Error(err))
95+
completion := r.CompletionToken
96+
if completion == nil {
97+
h.Logger.Error("missing completion token")
9698
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
9799
}
100+
var err error
98101
ns, err := h.NamespaceRegistry.GetNamespaceByID(namespace.ID(completion.NamespaceId))
99102
if err != nil {
100103
h.Logger.Error("failed to get namespace for nexus completion request", tag.WorkflowNamespaceID(completion.NamespaceId), tag.Error(err))
@@ -112,7 +115,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
112115
tag.WorkflowRunID(completion.GetRunId()),
113116
)
114117
rCtx := &requestContext{
115-
completionHandler: h,
118+
CompletionHandler: h,
116119
namespace: ns,
117120
workflowID: completion.GetWorkflowId(),
118121
logger: log.With(h.Logger, tag.WorkflowNamespace(ns.Name().String())),
@@ -229,7 +232,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C
229232
return nil
230233
}
231234

232-
func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nexusrpc.CompletionRequest, rCtx *requestContext) error {
235+
func (h *CompletionHandler) forwardCompleteOperation(ctx context.Context, r *nexusrpc.CompletionRequest, rCtx *requestContext) error {
233236
client, err := h.ForwardingClients.Get(rCtx.namespace.ActiveClusterName(rCtx.workflowID))
234237
if err != nil {
235238
h.Logger.Error("unable to get HTTP client for forward request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err), tag.SourceCluster(h.ClusterMetadata.GetCurrentClusterName()), tag.TargetCluster(rCtx.namespace.ActiveClusterName(rCtx.workflowID)))
@@ -307,7 +310,7 @@ func (f *forwardingHTTPHeaderWrapper) Do(req *http.Request) (*http.Response, err
307310
}
308311

309312
type requestContext struct {
310-
*completionHandler
313+
*CompletionHandler
311314
logger log.Logger
312315
metricsHandler metrics.Handler
313316
metricsHandlerForInterceptors metrics.Handler

service/frontend/fx.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"go.temporal.io/server/api/adminservice/v1"
99
"go.temporal.io/server/chasm"
1010
"go.temporal.io/server/chasm/lib/activity"
11+
chasmnexus "go.temporal.io/server/chasm/lib/nexusoperation"
1112
"go.temporal.io/server/chasm/lib/scheduler/gen/schedulerpb/v1"
1213
"go.temporal.io/server/client"
1314
"go.temporal.io/server/common"
@@ -108,6 +109,7 @@ var Module = fx.Options(
108109
fx.Provide(NewVersionChecker),
109110
fx.Provide(ServiceResolverProvider),
110111
fx.Invoke(RegisterNexusHTTPHandler),
112+
fx.Invoke(RegisterNexusCompletionHandler),
111113
fx.Invoke(RegisterOpenAPIHTTPHandler),
112114
fx.Provide(HTTPAPIServerProvider),
113115
fx.Provide(NewServiceProvider),
@@ -116,6 +118,7 @@ var Module = fx.Options(
116118
fx.Invoke(ServiceLifetimeHooks),
117119
fx.Invoke(EndpointRegistryLifetimeHooks),
118120
fx.Provide(schedulerpb.NewSchedulerServiceLayeredClient),
121+
chasmnexus.Module,
119122
nexusfrontend.Module,
120123
activity.FrontendModule,
121124
fx.Provide(visibility.ChasmVisibilityManagerProvider),
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package frontend
2+
3+
import (
4+
"context"
5+
"net/http"
6+
7+
"github.com/gorilla/mux"
8+
"github.com/nexus-rpc/sdk-go/nexus"
9+
persistencespb "go.temporal.io/server/api/persistence/v1"
10+
tokenspb "go.temporal.io/server/api/token/v1"
11+
chasmnexus "go.temporal.io/server/chasm/lib/nexusoperation"
12+
"go.temporal.io/server/common/log"
13+
commonnexus "go.temporal.io/server/common/nexus"
14+
"go.temporal.io/server/common/nexus/nexusrpc"
15+
"go.temporal.io/server/common/rpc"
16+
nexusfrontend "go.temporal.io/server/components/nexusoperations/frontend"
17+
)
18+
19+
type completionRouter struct {
20+
hsm nexusrpc.CompletionHandler
21+
chasm nexusrpc.CompletionHandler
22+
}
23+
24+
func (h *completionRouter) CompleteOperation(ctx context.Context, r *nexusrpc.CompletionRequest) error {
25+
completion := r.CompletionToken
26+
if completion == nil {
27+
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
28+
}
29+
30+
// TODO: need to handle migration between HSM and CHASM
31+
32+
if len(completion.GetComponentRef()) > 0 {
33+
ref := &persistencespb.ChasmComponentRef{}
34+
if err := ref.Unmarshal(completion.GetComponentRef()); err != nil {
35+
return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token")
36+
}
37+
return h.chasm.CompleteOperation(ctx, r)
38+
}
39+
return h.hsm.CompleteOperation(ctx, r)
40+
}
41+
42+
func RegisterNexusCompletionHandler(
43+
hsmHandler *nexusfrontend.CompletionHandler,
44+
chasmHandler *chasmnexus.CompletionHandler,
45+
logger log.Logger,
46+
router *mux.Router,
47+
) {
48+
httpHandler := nexusrpc.NewCompletionHTTPHandler(nexusrpc.CompletionHandlerOptions{
49+
Handler: &completionRouter{
50+
hsm: hsmHandler,
51+
chasm: chasmHandler,
52+
},
53+
Logger: log.NewSlogLogger(logger),
54+
Serializer: commonnexus.PayloadSerializer,
55+
DecodeCompletionToken: func(encoded string) (*tokenspb.NexusOperationCompletion, error) {
56+
token, err := commonnexus.DecodeCallbackToken(encoded)
57+
if err != nil {
58+
return nil, err
59+
}
60+
return commonnexus.NewCallbackTokenGenerator().DecodeCompletion(token)
61+
},
62+
})
63+
64+
router.Path("/" + commonnexus.RouteCompletionCallback.Representation()).HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65+
// Limit the request body to max allowed Payload size.
66+
// Content headers are transformed to Payload metadata and contribute to the Payload size as well. A separate
67+
// limit is enforced on top of this in the CompleteOperation method.
68+
r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes)
69+
httpHandler.ServeHTTP(w, r)
70+
})
71+
router.Path(commonnexus.PathCompletionCallbackNoIdentifier).HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
72+
// Limit the request body to max allowed Payload size.
73+
// Content headers are transformed to Payload metadata and contribute to the Payload size as well. A separate
74+
// limit is enforced on top of this in the CompleteOperation method.
75+
r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes)
76+
httpHandler.ServeHTTP(w, r)
77+
})
78+
}

0 commit comments

Comments
 (0)