Skip to content

Commit 0ff6fb4

Browse files
author
Mark Wolff
authored
feature: add before/after InvocationRequest hooks (#339)
* feature: add before/after InvocationRequest hooks * rm return workerChannel * refactor to register cb pattern * remove double callback on promise resolve
1 parent ac130db commit 0ff6fb4

File tree

2 files changed

+230
-58
lines changed

2 files changed

+230
-58
lines changed

src/WorkerChannel.ts

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@ import { toTypedData } from './converters';
77
import { augmentTriggerMetadata } from './augmenters';
88
import { systemError } from './utils/Logger';
99
import { InternalException } from './utils/InternalException';
10+
import { Context } from './public/Interfaces';
1011
import LogCategory = rpc.RpcLog.RpcLogCategory;
1112
import LogLevel = rpc.RpcLog.Level;
1213

14+
type InvocationRequestBefore = (context: Context, userFn: Function) => Function;
15+
type InvocationRequestAfter = (context: Context) => void;
16+
1317
/**
1418
* The worker channel should have a way to handle all incoming gRPC messages.
1519
* This includes all incoming StreamingMessage types (exclude *Response types and RpcLog type)
@@ -25,6 +29,8 @@ interface IWorkerChannel {
2529
invocationRequest(requestId: string, msg: rpc.InvocationRequest): void;
2630
invocationCancel(requestId: string, msg: rpc.InvocationCancel): void;
2731
functionEnvironmentReloadRequest(requestId: string, msg: rpc.IFunctionEnvironmentReloadRequest): void;
32+
registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void;
33+
registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void;
2834
}
2935

3036
/**
@@ -34,11 +40,15 @@ export class WorkerChannel implements IWorkerChannel {
3440
private _eventStream: IEventStream;
3541
private _functionLoader: IFunctionLoader;
3642
private _workerId: string;
43+
private _invocationRequestBefore: InvocationRequestBefore[];
44+
private _invocationRequestAfter: InvocationRequestAfter[];
3745

3846
constructor(workerId: string, eventStream: IEventStream, functionLoader: IFunctionLoader) {
3947
this._workerId = workerId;
4048
this._eventStream = eventStream;
4149
this._functionLoader = functionLoader;
50+
this._invocationRequestBefore = [];
51+
this._invocationRequestAfter = [];
4252

4353
// call the method with the matching 'event' name on this class, passing the requestId and event message
4454
eventStream.on('data', (msg) => {
@@ -82,6 +92,21 @@ export class WorkerChannel implements IWorkerChannel {
8292
});
8393
}
8494

95+
/**
96+
* Register a patching function to be run before User Function is executed.
97+
* Hook should return a patched version of User Function.
98+
*/
99+
public registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void {
100+
this._invocationRequestBefore.push(beforeCb);
101+
}
102+
103+
/**
104+
* Register a function to be run after User Function resolves.
105+
*/
106+
public registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void {
107+
this._invocationRequestAfter.push(afterCb);
108+
}
109+
85110
/**
86111
* Host sends capabilities/init data to worker and requests the worker to initialize itself
87112
* @param requestId gRPC message request id
@@ -160,7 +185,7 @@ export class WorkerChannel implements IWorkerChannel {
160185
invocationId: msg.invocationId,
161186
result: this.getStatus(err)
162187
}
163-
188+
164189
try {
165190
if (result) {
166191
if (result.return) {
@@ -183,19 +208,27 @@ export class WorkerChannel implements IWorkerChannel {
183208
requestId: requestId,
184209
invocationResponse: response
185210
});
211+
212+
this.runInvocationRequestAfter(context);
186213
}
187214

188215
let { context, inputs } = CreateContextAndInputs(info, msg, logCallback, resultCallback);
189216
let userFunction = this._functionLoader.getFunc(<string>msg.functionId);
190217

218+
userFunction = this.runInvocationRequestBefore(context, userFunction);
219+
191220
// catch user errors from the same async context in the event loop and correlate with invocation
192221
// throws from asynchronous work (setTimeout, etc) are caught by 'unhandledException' and cannot be correlated with invocation
193222
try {
194-
let result = userFunction(context, ...inputs);
223+
let result = userFunction(context, ...inputs);
195224

196-
if (result && isFunction(result.then)) {
197-
result.then(result => (<any>context.done)(null, result, true))
198-
.catch(err => (<any>context.done)(err, null, true));
225+
if (result && isFunction(result.then)) {
226+
result.then(result => {
227+
(<any>context.done)(null, result, true)
228+
})
229+
.catch(err => {
230+
(<any>context.done)(err, null, true)
231+
});
199232
}
200233
} catch (err) {
201234
resultCallback(err);
@@ -208,7 +241,7 @@ export class WorkerChannel implements IWorkerChannel {
208241
public startStream(requestId: string, msg: rpc.StartStream): void {
209242
// Not yet implemented
210243
}
211-
244+
212245
/**
213246
* Message is empty by design - Will add more fields in future if needed
214247
*/
@@ -304,4 +337,18 @@ export class WorkerChannel implements IWorkerChannel {
304337

305338
return status;
306339
}
340+
341+
private runInvocationRequestBefore(context: Context, userFunction: Function): Function {
342+
let wrappedFunction = userFunction;
343+
for (let before of this._invocationRequestBefore) {
344+
wrappedFunction = before(context, wrappedFunction);
345+
}
346+
return wrappedFunction;
347+
}
348+
349+
private runInvocationRequestAfter(context: Context) {
350+
for (let after of this._invocationRequestAfter) {
351+
after(context);
352+
}
353+
}
307354
}

test/WorkerChannelTests.ts

Lines changed: 177 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,73 @@ import * as sinon from 'sinon';
66
import { AzureFunctionsRpcMessages as rpc } from '../azure-functions-language-worker-protobuf/src/rpc';
77
import 'mocha';
88
import { load } from 'grpc';
9+
import { worker } from 'cluster';
910

1011
describe('WorkerChannel', () => {
1112
var channel: WorkerChannel;
1213
var stream: TestEventStream;
1314
var loader: sinon.SinonStubbedInstance<FunctionLoader>;
1415
var functions;
16+
17+
const runInvokedFunction = () => {
18+
const triggerDataMock: { [k: string]: rpc.ITypedData } = {
19+
"Headers": {
20+
json: JSON.stringify({Connection: 'Keep-Alive'})
21+
},
22+
"Sys": {
23+
json: JSON.stringify({MethodName: 'test-js', UtcNow: '2018', RandGuid: '3212'})
24+
}
25+
};
26+
27+
const inputDataValue = {
28+
name: "req",
29+
data: {
30+
data: "http",
31+
http:
32+
{
33+
body:
34+
{
35+
data: "string",
36+
body: "blahh"
37+
},
38+
rawBody:
39+
{
40+
data: "string",
41+
body: "blahh"
42+
}
43+
}
44+
}
45+
};
46+
47+
const actualInvocationRequest: rpc.IInvocationRequest = <rpc.IInvocationRequest> {
48+
functionId: 'id',
49+
invocationId: '1',
50+
inputData: [inputDataValue],
51+
triggerMetadata: triggerDataMock,
52+
};
53+
54+
stream.addTestMessage({
55+
invocationRequest: actualInvocationRequest
56+
});
57+
58+
return [inputDataValue, actualInvocationRequest];
59+
}
60+
61+
const assertInvokedFunction = (inputDataValue, actualInvocationRequest) => {
62+
sinon.assert.calledWithMatch(stream.written, <rpc.IStreamingMessage> {
63+
invocationResponse: {
64+
invocationId: '1',
65+
result: {
66+
status: rpc.StatusResult.Status.Success
67+
},
68+
outputData: []
69+
}
70+
});
71+
72+
// triggerMedata will be augmented with inpuDataValue since "RpcHttpTriggerMetadataRemoved" capability is set to true and therefore not populated by the host.
73+
expect(JSON.stringify(actualInvocationRequest.triggerMetadata!.$request)).to.equal(JSON.stringify(inputDataValue.data));
74+
expect(JSON.stringify(actualInvocationRequest.triggerMetadata!.req)).to.equal(JSON.stringify(inputDataValue.data));
75+
}
1576

1677
beforeEach(() => {
1778
stream = new TestEventStream();
@@ -226,67 +287,131 @@ describe('WorkerChannel', () => {
226287
workerStatusResponse: {}
227288
});
228289
});
290+
291+
describe('#invocationRequestBefore, #invocationRequestAfter', () => {
292+
afterEach(() => {
293+
channel['_invocationRequestAfter'] = [];
294+
channel['_invocationRequestBefore'] = [];
295+
});
229296

230-
it ('invokes function', () => {
231-
loader.getFunc.returns((context) => context.done());
232-
loader.getInfo.returns({
233-
name: 'test',
234-
outputBindings: {}
235-
})
297+
it("should apply hook before user function is executed", () => {
298+
channel.registerBeforeInvocationRequest((context, userFunction) => {
299+
context['magic_flag'] = 'magic value';
300+
return userFunction.bind({ __wrapped: true });
301+
});
302+
303+
channel.registerBeforeInvocationRequest((context, userFunction) => {
304+
context["secondary_flag"] = 'magic value';
305+
return userFunction;
306+
});
307+
308+
loader.getFunc.returns(function (this: any, context) {
309+
expect(context['magic_flag']).to.equal('magic value');
310+
expect(context['secondary_flag']).to.equal('magic value');
311+
expect(this.__wrapped).to.equal(true);
312+
expect(channel['_invocationRequestBefore'].length).to.equal(2);
313+
expect(channel['_invocationRequestAfter'].length).to.equal(0);
314+
context.done();
315+
});
316+
loader.getInfo.returns({
317+
name: 'test',
318+
outputBindings: {}
319+
});
320+
321+
const [inputDataValue, actualInvocationRequest] = runInvokedFunction();
322+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
323+
});
236324

237-
var triggerDataMock: { [k: string]: rpc.ITypedData } = {
238-
"Headers": {
239-
json: JSON.stringify({Connection: 'Keep-Alive'})
240-
},
241-
"Sys": {
242-
json: JSON.stringify({MethodName: 'test-js', UtcNow: '2018', RandGuid: '3212'})
243-
}
244-
};
325+
it('should apply hook after user function is executed (callback)', (done) => {
326+
let finished = false;
327+
let count = 0;
328+
channel.registerAfterInvocationRequest((context) => {
329+
expect(finished).to.equal(true);
330+
count += 1;
331+
});
245332

246-
var inputDataValue = {
247-
name: "req",
248-
data: {
249-
data: "http",
250-
http:
251-
{
252-
body:
253-
{
254-
data: "string",
255-
body: "blahh"
256-
},
257-
rawBody:
258-
{
259-
data: "string",
260-
body: "blahh"
261-
}
262-
}
263-
}
264-
};
333+
loader.getFunc.returns(function (this: any, context) {
334+
finished = true;
335+
expect(channel['_invocationRequestBefore'].length).to.equal(0);
336+
expect(channel['_invocationRequestAfter'].length).to.equal(1);
337+
expect(count).to.equal(0);
338+
context.done();
339+
expect(count).to.equal(1);
340+
done();
341+
});
342+
loader.getInfo.returns({
343+
name: 'test',
344+
outputBindings: {}
345+
});
265346

266-
var actualInvocationRequest: rpc.IInvocationRequest = <rpc.IInvocationRequest> {
267-
functionId: 'id',
268-
invocationId: '1',
269-
inputData: [inputDataValue],
270-
triggerMetadata: triggerDataMock,
271-
};
347+
const [inputDataValue, actualInvocationRequest] = runInvokedFunction();
348+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
349+
});
350+
351+
it('should apply hook after user function resolves (promise)', (done) => {
352+
let finished = false;
353+
let count = 0;
354+
let inputDataValue, actualInvocationRequest;
355+
channel.registerAfterInvocationRequest((context) => {
356+
expect(finished).to.equal(true);
357+
count += 1;
358+
expect(count).to.equal(1);
359+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
360+
done();
361+
});
272362

273-
stream.addTestMessage({
274-
invocationRequest: actualInvocationRequest
363+
loader.getFunc.returns(() => new Promise((resolve) => {
364+
finished = true;
365+
expect(channel['_invocationRequestBefore'].length).to.equal(0);
366+
expect(channel['_invocationRequestAfter'].length).to.equal(1);
367+
expect(count).to.equal(0);
368+
resolve();
369+
}));
370+
loader.getInfo.returns({
371+
name: 'test',
372+
outputBindings: {}
373+
});
374+
375+
[inputDataValue, actualInvocationRequest] = runInvokedFunction();
275376
});
377+
378+
379+
it('should apply hook after user function rejects (promise)', (done) => {
380+
let finished = false;
381+
let count = 0;
382+
channel.registerAfterInvocationRequest((context) => {
383+
expect(finished).to.equal(true);
384+
count += 1;
385+
expect(count).to.equal(1);
386+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
387+
done();
388+
});
276389

277-
sinon.assert.calledWithMatch(stream.written, <rpc.IStreamingMessage> {
278-
invocationResponse: {
279-
invocationId: '1',
280-
result: {
281-
status: rpc.StatusResult.Status.Success
282-
},
283-
outputData: []
284-
}
390+
loader.getFunc.returns((context) => new Promise((_, reject) => {
391+
finished = true;
392+
expect(channel['_invocationRequestBefore'].length).to.equal(0);
393+
expect(channel['_invocationRequestAfter'].length).to.equal(1);
394+
expect(count).to.equal(0);
395+
reject();
396+
}));
397+
loader.getInfo.returns({
398+
name: 'test',
399+
outputBindings: {}
400+
});
401+
402+
const [inputDataValue, actualInvocationRequest] = runInvokedFunction();
285403
});
404+
});
286405

287-
// triggerMedata will be augmented with inpuDataValue since "RpcHttpTriggerMetadataRemoved" capability is set to true and therefore not populated by the host.
288-
expect(JSON.stringify(actualInvocationRequest.triggerMetadata!.$request)).to.equal(JSON.stringify(inputDataValue.data));
289-
expect(JSON.stringify(actualInvocationRequest.triggerMetadata!.req)).to.equal(JSON.stringify(inputDataValue.data));
406+
it ('invokes function', () => {
407+
loader.getFunc.returns((context) => context.done());
408+
loader.getInfo.returns({
409+
name: 'test',
410+
outputBindings: {}
411+
})
412+
413+
const [inputDataValue, actualInvocationRequest] = runInvokedFunction();
414+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
290415
});
291416

292417
it ('throws for malformed messages', () => {

0 commit comments

Comments
 (0)