Skip to content

Commit c984f23

Browse files
Mark Wolffmhoeger
authored andcommitted
feature: add before/after InvocationRequest hooks (Azure#339)
* feature: add before/after InvocationRequest hooks * rm return workerChannel * refactor to register cb pattern * remove double callback on promise resolve
1 parent ec942fc commit c984f23

File tree

2 files changed

+249
-7
lines changed

2 files changed

+249
-7
lines changed

src/WorkerChannel.ts

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@ import { toTypedData } from './converters';
77
import { augmentTriggerMetadata } from './augmenters';
88
import { systemError, systemWarn } from './utils/Logger';
99
import { InternalException } from './utils/InternalException';
10+
import { Context } from './public/Interfaces';
1011
import LogCategory = rpc.RpcLog.RpcLogCategory;
1112
import LogLevel = rpc.RpcLog.Level;
1213

14+
type InvocationRequestBefore = (context: Context, userFn: Function) => Function;
15+
type InvocationRequestAfter = (context: Context) => void;
16+
1317
/**
1418
* The worker channel should have a way to handle all incoming gRPC messages.
1519
* This includes all incoming StreamingMessage types (exclude *Response types and RpcLog type)
@@ -25,6 +29,8 @@ interface IWorkerChannel {
2529
invocationRequest(requestId: string, msg: rpc.InvocationRequest): void;
2630
invocationCancel(requestId: string, msg: rpc.InvocationCancel): void;
2731
functionEnvironmentReloadRequest(requestId: string, msg: rpc.IFunctionEnvironmentReloadRequest): void;
32+
registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void;
33+
registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void;
2834
}
2935

3036
/**
@@ -35,13 +41,17 @@ export class WorkerChannel implements IWorkerChannel {
3541
private _functionLoader: IFunctionLoader;
3642
private _workerId: string;
3743
private _v1WorkerBehavior: boolean;
44+
private _invocationRequestBefore: InvocationRequestBefore[];
45+
private _invocationRequestAfter: InvocationRequestAfter[];
3846

3947
constructor(workerId: string, eventStream: IEventStream, functionLoader: IFunctionLoader) {
4048
this._workerId = workerId;
4149
this._eventStream = eventStream;
4250
this._functionLoader = functionLoader;
4351
// default value
4452
this._v1WorkerBehavior = false;
53+
this._invocationRequestBefore = [];
54+
this._invocationRequestAfter = [];
4555

4656
// call the method with the matching 'event' name on this class, passing the requestId and event message
4757
eventStream.on('data', (msg) => {
@@ -85,6 +95,21 @@ export class WorkerChannel implements IWorkerChannel {
8595
});
8696
}
8797

98+
/**
99+
* Register a patching function to be run before User Function is executed.
100+
* Hook should return a patched version of User Function.
101+
*/
102+
public registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void {
103+
this._invocationRequestBefore.push(beforeCb);
104+
}
105+
106+
/**
107+
* Register a function to be run after User Function resolves.
108+
*/
109+
public registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void {
110+
this._invocationRequestAfter.push(afterCb);
111+
}
112+
88113
/**
89114
* Host sends capabilities/init data to worker and requests the worker to initialize itself
90115
* @param requestId gRPC message request id
@@ -252,19 +277,27 @@ export class WorkerChannel implements IWorkerChannel {
252277
requestId: requestId,
253278
invocationResponse: response
254279
});
280+
281+
this.runInvocationRequestAfter(context);
255282
}
256283

257284
let { context, inputs } = CreateContextAndInputs(info, msg, logCallback, resultCallback, this._v1WorkerBehavior);
258285
let userFunction = this._functionLoader.getFunc(<string>msg.functionId);
259286

287+
userFunction = this.runInvocationRequestBefore(context, userFunction);
288+
260289
// catch user errors from the same async context in the event loop and correlate with invocation
261290
// throws from asynchronous work (setTimeout, etc) are caught by 'unhandledException' and cannot be correlated with invocation
262291
try {
263-
let result = userFunction(context, ...inputs);
292+
let result = userFunction(context, ...inputs);
264293

265-
if (result && isFunction(result.then)) {
266-
result.then(result => (<any>context.done)(null, result, true))
267-
.catch(err => (<any>context.done)(err, null, true));
294+
if (result && isFunction(result.then)) {
295+
result.then(result => {
296+
(<any>context.done)(null, result, true)
297+
})
298+
.catch(err => {
299+
(<any>context.done)(err, null, true)
300+
});
268301
}
269302
} catch (err) {
270303
resultCallback(err);
@@ -277,7 +310,7 @@ export class WorkerChannel implements IWorkerChannel {
277310
public startStream(requestId: string, msg: rpc.StartStream): void {
278311
// Not yet implemented
279312
}
280-
313+
281314
/**
282315
* Message is empty by design - Will add more fields in future if needed
283316
*/
@@ -369,4 +402,18 @@ export class WorkerChannel implements IWorkerChannel {
369402

370403
return status;
371404
}
405+
406+
private runInvocationRequestBefore(context: Context, userFunction: Function): Function {
407+
let wrappedFunction = userFunction;
408+
for (let before of this._invocationRequestBefore) {
409+
wrappedFunction = before(context, wrappedFunction);
410+
}
411+
return wrappedFunction;
412+
}
413+
414+
private runInvocationRequestAfter(context: Context) {
415+
for (let after of this._invocationRequestAfter) {
416+
after(context);
417+
}
418+
}
372419
}

test/WorkerChannelTests.ts

Lines changed: 197 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import * as sinon from 'sinon';
66
import { AzureFunctionsRpcMessages as rpc } from '../azure-functions-language-worker-protobuf/src/rpc';
77
import 'mocha';
88
import { load } from 'grpc';
9+
<<<<<<< HEAD
910
import { FunctionInfo } from '../src/FunctionInfo';
1011

1112
describe('WorkerChannel', () => {
@@ -85,6 +86,75 @@ describe('WorkerChannel', () => {
8586
test: orchestrationTriggerBinding
8687
}
8788
};
89+
=======
90+
import { worker } from 'cluster';
91+
92+
describe('WorkerChannel', () => {
93+
var channel: WorkerChannel;
94+
var stream: TestEventStream;
95+
var loader: sinon.SinonStubbedInstance<FunctionLoader>;
96+
var functions;
97+
98+
const runInvokedFunction = () => {
99+
const triggerDataMock: { [k: string]: rpc.ITypedData } = {
100+
"Headers": {
101+
json: JSON.stringify({Connection: 'Keep-Alive'})
102+
},
103+
"Sys": {
104+
json: JSON.stringify({MethodName: 'test-js', UtcNow: '2018', RandGuid: '3212'})
105+
}
106+
};
107+
108+
const inputDataValue = {
109+
name: "req",
110+
data: {
111+
data: "http",
112+
http:
113+
{
114+
body:
115+
{
116+
data: "string",
117+
body: "blahh"
118+
},
119+
rawBody:
120+
{
121+
data: "string",
122+
body: "blahh"
123+
}
124+
}
125+
}
126+
};
127+
128+
const actualInvocationRequest: rpc.IInvocationRequest = <rpc.IInvocationRequest> {
129+
functionId: 'id',
130+
invocationId: '1',
131+
inputData: [inputDataValue],
132+
triggerMetadata: triggerDataMock,
133+
};
134+
135+
stream.addTestMessage({
136+
invocationRequest: actualInvocationRequest
137+
});
138+
139+
return [inputDataValue, actualInvocationRequest];
140+
}
141+
142+
const assertInvokedFunction = (inputDataValue, actualInvocationRequest) => {
143+
sinon.assert.calledWithMatch(stream.written, <rpc.IStreamingMessage> {
144+
invocationResponse: {
145+
invocationId: '1',
146+
result: {
147+
status: rpc.StatusResult.Status.Success
148+
},
149+
outputData: []
150+
}
151+
});
152+
153+
// triggerMedata will be augmented with inpuDataValue since "RpcHttpTriggerMetadataRemoved" capability is set to true and therefore not populated by the host.
154+
expect(JSON.stringify(actualInvocationRequest.triggerMetadata!.$request)).to.equal(JSON.stringify(inputDataValue.data));
155+
expect(JSON.stringify(actualInvocationRequest.triggerMetadata!.req)).to.equal(JSON.stringify(inputDataValue.data));
156+
}
157+
>>>>>>> 0ff6fb4... feature: add before/after InvocationRequest hooks (#339)
88158

89159
beforeEach(() => {
90160
stream = new TestEventStream();
@@ -378,7 +448,14 @@ describe('WorkerChannel', () => {
378448
expect(JSON.stringify(actualInvocationRequest.triggerMetadata!.$request)).to.equal(JSON.stringify(httpInputData.data));
379449
expect(JSON.stringify(actualInvocationRequest.triggerMetadata!.req)).to.equal(JSON.stringify(httpInputData.data));
380450
});
451+
452+
describe('#invocationRequestBefore, #invocationRequestAfter', () => {
453+
afterEach(() => {
454+
channel['_invocationRequestAfter'] = [];
455+
channel['_invocationRequestBefore'] = [];
456+
});
381457

458+
<<<<<<< HEAD
382459
it ('invokes function', () => {
383460
loader.getFunc.returns((context) => context.done());
384461
loader.getInfo.returns(new FunctionInfo(orchestratorBinding));
@@ -529,11 +606,102 @@ describe('WorkerChannel', () => {
529606
inputData: [httpInputData],
530607
triggerMetadata: getTriggerDataMock(),
531608
};
609+
=======
610+
it("should apply hook before user function is executed", () => {
611+
channel.registerBeforeInvocationRequest((context, userFunction) => {
612+
context['magic_flag'] = 'magic value';
613+
return userFunction.bind({ __wrapped: true });
614+
});
615+
616+
channel.registerBeforeInvocationRequest((context, userFunction) => {
617+
context["secondary_flag"] = 'magic value';
618+
return userFunction;
619+
});
532620

533-
stream.addTestMessage({
534-
invocationRequest: actualInvocationRequest
621+
loader.getFunc.returns(function (this: any, context) {
622+
expect(context['magic_flag']).to.equal('magic value');
623+
expect(context['secondary_flag']).to.equal('magic value');
624+
expect(this.__wrapped).to.equal(true);
625+
expect(channel['_invocationRequestBefore'].length).to.equal(2);
626+
expect(channel['_invocationRequestAfter'].length).to.equal(0);
627+
context.done();
628+
});
629+
loader.getInfo.returns({
630+
name: 'test',
631+
outputBindings: {}
632+
});
633+
634+
const [inputDataValue, actualInvocationRequest] = runInvokedFunction();
635+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
535636
});
637+
638+
it('should apply hook after user function is executed (callback)', (done) => {
639+
let finished = false;
640+
let count = 0;
641+
channel.registerAfterInvocationRequest((context) => {
642+
expect(finished).to.equal(true);
643+
count += 1;
644+
});
645+
646+
loader.getFunc.returns(function (this: any, context) {
647+
finished = true;
648+
expect(channel['_invocationRequestBefore'].length).to.equal(0);
649+
expect(channel['_invocationRequestAfter'].length).to.equal(1);
650+
expect(count).to.equal(0);
651+
context.done();
652+
expect(count).to.equal(1);
653+
done();
654+
});
655+
loader.getInfo.returns({
656+
name: 'test',
657+
outputBindings: {}
658+
});
536659

660+
const [inputDataValue, actualInvocationRequest] = runInvokedFunction();
661+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
662+
});
663+
664+
it('should apply hook after user function resolves (promise)', (done) => {
665+
let finished = false;
666+
let count = 0;
667+
let inputDataValue, actualInvocationRequest;
668+
channel.registerAfterInvocationRequest((context) => {
669+
expect(finished).to.equal(true);
670+
count += 1;
671+
expect(count).to.equal(1);
672+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
673+
done();
674+
});
675+
>>>>>>> 0ff6fb4... feature: add before/after InvocationRequest hooks (#339)
676+
677+
loader.getFunc.returns(() => new Promise((resolve) => {
678+
finished = true;
679+
expect(channel['_invocationRequestBefore'].length).to.equal(0);
680+
expect(channel['_invocationRequestAfter'].length).to.equal(1);
681+
expect(count).to.equal(0);
682+
resolve();
683+
}));
684+
loader.getInfo.returns({
685+
name: 'test',
686+
outputBindings: {}
687+
});
688+
689+
[inputDataValue, actualInvocationRequest] = runInvokedFunction();
690+
});
691+
692+
693+
it('should apply hook after user function rejects (promise)', (done) => {
694+
let finished = false;
695+
let count = 0;
696+
channel.registerAfterInvocationRequest((context) => {
697+
expect(finished).to.equal(true);
698+
count += 1;
699+
expect(count).to.equal(1);
700+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
701+
done();
702+
});
703+
704+
<<<<<<< HEAD
537705
sinon.assert.calledWithMatch(stream.written, <rpc.IStreamingMessage> {
538706
invocationResponse: {
539707
invocationId: '1',
@@ -682,6 +850,33 @@ describe('WorkerChannel', () => {
682850
}
683851
}
684852
});
853+
=======
854+
loader.getFunc.returns((context) => new Promise((_, reject) => {
855+
finished = true;
856+
expect(channel['_invocationRequestBefore'].length).to.equal(0);
857+
expect(channel['_invocationRequestAfter'].length).to.equal(1);
858+
expect(count).to.equal(0);
859+
reject();
860+
}));
861+
loader.getInfo.returns({
862+
name: 'test',
863+
outputBindings: {}
864+
});
865+
866+
const [inputDataValue, actualInvocationRequest] = runInvokedFunction();
867+
});
868+
});
869+
870+
it ('invokes function', () => {
871+
loader.getFunc.returns((context) => context.done());
872+
loader.getInfo.returns({
873+
name: 'test',
874+
outputBindings: {}
875+
})
876+
877+
const [inputDataValue, actualInvocationRequest] = runInvokedFunction();
878+
assertInvokedFunction(inputDataValue, actualInvocationRequest);
879+
>>>>>>> 0ff6fb4... feature: add before/after InvocationRequest hooks (#339)
685880
});
686881

687882
it ('throws for malformed messages', () => {

0 commit comments

Comments
 (0)