@@ -7,9 +7,13 @@ import { toTypedData } from './converters';
7
7
import { augmentTriggerMetadata } from './augmenters' ;
8
8
import { systemError , systemWarn } from './utils/Logger' ;
9
9
import { InternalException } from './utils/InternalException' ;
10
+ import { Context } from './public/Interfaces' ;
10
11
import LogCategory = rpc . RpcLog . RpcLogCategory ;
11
12
import LogLevel = rpc . RpcLog . Level ;
12
13
14
+ type InvocationRequestBefore = ( context : Context , userFn : Function ) => Function ;
15
+ type InvocationRequestAfter = ( context : Context ) => void ;
16
+
13
17
/**
14
18
* The worker channel should have a way to handle all incoming gRPC messages.
15
19
* This includes all incoming StreamingMessage types (exclude *Response types and RpcLog type)
@@ -25,6 +29,8 @@ interface IWorkerChannel {
25
29
invocationRequest ( requestId : string , msg : rpc . InvocationRequest ) : void ;
26
30
invocationCancel ( requestId : string , msg : rpc . InvocationCancel ) : void ;
27
31
functionEnvironmentReloadRequest ( requestId : string , msg : rpc . IFunctionEnvironmentReloadRequest ) : void ;
32
+ registerBeforeInvocationRequest ( beforeCb : InvocationRequestBefore ) : void ;
33
+ registerAfterInvocationRequest ( afterCb : InvocationRequestAfter ) : void ;
28
34
}
29
35
30
36
/**
@@ -35,13 +41,17 @@ export class WorkerChannel implements IWorkerChannel {
35
41
private _functionLoader : IFunctionLoader ;
36
42
private _workerId : string ;
37
43
private _v1WorkerBehavior : boolean ;
44
+ private _invocationRequestBefore : InvocationRequestBefore [ ] ;
45
+ private _invocationRequestAfter : InvocationRequestAfter [ ] ;
38
46
39
47
constructor ( workerId : string , eventStream : IEventStream , functionLoader : IFunctionLoader ) {
40
48
this . _workerId = workerId ;
41
49
this . _eventStream = eventStream ;
42
50
this . _functionLoader = functionLoader ;
43
51
// default value
44
52
this . _v1WorkerBehavior = false ;
53
+ this . _invocationRequestBefore = [ ] ;
54
+ this . _invocationRequestAfter = [ ] ;
45
55
46
56
// call the method with the matching 'event' name on this class, passing the requestId and event message
47
57
eventStream . on ( 'data' , ( msg ) => {
@@ -85,6 +95,21 @@ export class WorkerChannel implements IWorkerChannel {
85
95
} ) ;
86
96
}
87
97
98
+ /**
99
+ * Register a patching function to be run before User Function is executed.
100
+ * Hook should return a patched version of User Function.
101
+ */
102
+ public registerBeforeInvocationRequest ( beforeCb : InvocationRequestBefore ) : void {
103
+ this . _invocationRequestBefore . push ( beforeCb ) ;
104
+ }
105
+
106
+ /**
107
+ * Register a function to be run after User Function resolves.
108
+ */
109
+ public registerAfterInvocationRequest ( afterCb : InvocationRequestAfter ) : void {
110
+ this . _invocationRequestAfter . push ( afterCb ) ;
111
+ }
112
+
88
113
/**
89
114
* Host sends capabilities/init data to worker and requests the worker to initialize itself
90
115
* @param requestId gRPC message request id
@@ -252,19 +277,27 @@ export class WorkerChannel implements IWorkerChannel {
252
277
requestId : requestId ,
253
278
invocationResponse : response
254
279
} ) ;
280
+
281
+ this . runInvocationRequestAfter ( context ) ;
255
282
}
256
283
257
284
let { context, inputs } = CreateContextAndInputs ( info , msg , logCallback , resultCallback , this . _v1WorkerBehavior ) ;
258
285
let userFunction = this . _functionLoader . getFunc ( < string > msg . functionId ) ;
259
286
287
+ userFunction = this . runInvocationRequestBefore ( context , userFunction ) ;
288
+
260
289
// catch user errors from the same async context in the event loop and correlate with invocation
261
290
// throws from asynchronous work (setTimeout, etc) are caught by 'unhandledException' and cannot be correlated with invocation
262
291
try {
263
- let result = userFunction ( context , ...inputs ) ;
292
+ let result = userFunction ( context , ...inputs ) ;
264
293
265
- if ( result && isFunction ( result . then ) ) {
266
- result . then ( result => ( < any > context . done ) ( null , result , true ) )
267
- . catch ( err => ( < any > context . done ) ( err , null , true ) ) ;
294
+ if ( result && isFunction ( result . then ) ) {
295
+ result . then ( result => {
296
+ ( < any > context . done ) ( null , result , true )
297
+ } )
298
+ . catch ( err => {
299
+ ( < any > context . done ) ( err , null , true )
300
+ } ) ;
268
301
}
269
302
} catch ( err ) {
270
303
resultCallback ( err ) ;
@@ -277,7 +310,7 @@ export class WorkerChannel implements IWorkerChannel {
277
310
public startStream ( requestId : string , msg : rpc . StartStream ) : void {
278
311
// Not yet implemented
279
312
}
280
-
313
+
281
314
/**
282
315
* Message is empty by design - Will add more fields in future if needed
283
316
*/
@@ -369,4 +402,18 @@ export class WorkerChannel implements IWorkerChannel {
369
402
370
403
return status ;
371
404
}
405
+
406
+ private runInvocationRequestBefore ( context : Context , userFunction : Function ) : Function {
407
+ let wrappedFunction = userFunction ;
408
+ for ( let before of this . _invocationRequestBefore ) {
409
+ wrappedFunction = before ( context , wrappedFunction ) ;
410
+ }
411
+ return wrappedFunction ;
412
+ }
413
+
414
+ private runInvocationRequestAfter ( context : Context ) {
415
+ for ( let after of this . _invocationRequestAfter ) {
416
+ after ( context ) ;
417
+ }
418
+ }
372
419
}
0 commit comments