Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 4fc1c6b

Browse files
author
ematejska
authored
Merge pull request #387 from kulinseth/rfc_kernel_extension_variable_ops
RFC: Kernel C API extension for Variable ops.
2 parents 5461ac7 + 91ca407 commit 4fc1c6b

File tree

1 file changed

+227
-0
lines changed

1 file changed

+227
-0
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Kernel Extension for Variable Operations API
2+
3+
| Status | Accepted |
4+
:-------------- |:---------------------------------------------------- |
5+
| **RFC #** | [20210504-kernel-extension-variable-ops](https://github.com/tensorflow/community/pull/20210504-kernel-extension-variable-ops) |
6+
| **Author(s)** | Kulin Seth (Apple), Charles Brissart (Apple) |
7+
| **Sponsor** | Saurabh Saxena ([email protected]) |
8+
| **Updated** | 2021-05-04 |
9+
10+
## Objective
11+
12+
The proposal extends the current [Kernel C API](https://github.com/tensorflow/community/blob/master/rfcs/20190814-kernel-and-op-registration.md) to enable plugin writers to add support for :
13+
14+
* Optimizer operations like SGD, Adam etc.
15+
* Variable updates like Assign, AssignUpdate
16+
17+
## Motivation
18+
19+
Tensorflow has proposed the [Modular Tensorflow design](https://github.com/tensorflow/community/blob/master/rfcs/20190305-modular-tensorflow.md). This provides plugin (e.g GPU) writers to register the device in a [pluggable way](https://github.com/tensorflow/community/blob/master/rfcs/20200624-pluggable-device-for-tensorflow.md).
20+
To register the OpKernels Tensorflow provides [C++ API](https://github.com/tensorflow/community/blob/master/rfcs/20190814-kernel-and-op-registration.md) for implementing kernels and ops. It creates a modular/plugin-based TF implementation with API and ABI surfaces.
21+
In order to support operations like Optimizers and Resource Variable updates
22+
used in all Training networks, we would need to extend the Kernel C++ API as
23+
currently their support is missing. This Proposal provides high-level API to add support for
24+
Variable ops and Optimizer operations.
25+
26+
## User Benefit
27+
28+
Training support for pluggable vendors.
29+
30+
## Design Proposal
31+
32+
The proposal extends the Kernel C API to add support for variable operations
33+
used in optimizer and resource variable operations such as “GradientDescent”.
34+
These operations show up in Training graphs for instance:
35+
36+
```
37+
node {
38+
name: "SGD/SGD/AssignAddVariableOp"
39+
op: "AssignAddVariableOp"
40+
input: "sgd_sgd_assignaddvariableop_resource"
41+
input: "SGD/SGD/Const"
42+
device: "/job:localhost/replica:0/task:0/device:GPU:0"
43+
attr {
44+
key: "dtype"
45+
value {
46+
type: DT_INT64
47+
}
48+
}
49+
}
50+
```
51+
52+
The above GraphDef snippet shows the resource variable update operation being
53+
used as part of the optimizer Op “SGD (Stochastic Gradient Descent), while
54+
training a simple classification network. To perform this operation in Plugin
55+
there would need to be support in Core TensorFlow to expose updating the Input
56+
tensors through variable updates in a thread-safe manner. The below API proposes
57+
how to extend the current API to support that.
58+
59+
### Interface Changes
60+
61+
*Optimizers Operations*
62+
63+
Following are the interface changes we are proposing for optimizer operations.
64+
These operations follow the pattern of:
65+
66+
1. Locking the input variables to perform the updates in thread-safe way
67+
2. Get the corresponding input Tensor from Variable.
68+
3. Performing the optimizer update (implemented by the plugin)
69+
4. Forwarding the reference from input to output
70+
5. Releasing the lock.
71+
72+
Below APIs provide the functionality to implement the above list of operations
73+
(1), (2), (4) and (5) in core. This provides a higher level interface to
74+
implement all the optimizer ops like GradientDescent , Adam, Momentum etc.
75+
76+
```
77+
78+
// This is a helper function which acquires mutexes in-order to provide thread-safe
79+
// way of performing weights update during the optimizer op. It returns an opaque
80+
// LockHolder handle back to plugin. This handle is passed to the Release API for
81+
// releasing the locks when the weight update is done.
82+
TF_CAPI_EXPORT extern void TF_MaybeLockVariableInputMutexesInOrder(
83+
TF_OpKernelContext* ctx, bool do_lock, bool sparse,
84+
const int* const input_ids,
85+
size_t len,
86+
TF_VariableInputLockHolder** lockHolder,
87+
TF_Status* status);
88+
89+
// This interface returns out tensor which is updated corresponding to the
90+
// variable passed with input index.
91+
TF_CAPI_EXPORT extern void TF_GetInputTensorFromVariable(
92+
TF_OpKernelContext* ctx,
93+
int input,
94+
bool lock_held,
95+
bool isVariantType,
96+
bool sparse,
97+
void (*copyFunc)(
98+
TF_OpKernelContext * ctx,
99+
TF_Tensor *source,
100+
TF_Tensor *dest),
101+
TF_Tensor** out,
102+
TF_Status* status);
103+
104+
// This interface forwards the reference from input to the output tensors
105+
// corresponding to the indices provided with input_index and output_index
106+
TF_CAPI_EXPORT extern void TF_OpKernelContext_ForwardRefInputToRefOutput(
107+
TF_OpKernelContext* ctx,
108+
int32_t input_index,
109+
int32_t output_index);
110+
111+
// The API releases the opaque lock handle returned with
112+
// TF_MaybeLockVariableInputMutexesInOrder API
113+
TF_CAPI_EXPORT extern void TF_ReleaseVariableInputLockHolder(
114+
TF_VariableInputLockHolder* lockHolder);
115+
116+
117+
```
118+
119+
*Resource Variables*
120+
121+
Below mentioned APIs expose functionality in Core TensorFlow to allow plugins
122+
to do Assign and Update operations to implement different ResourceVariable
123+
primitives. These are higher level operations which can be easily integrated
124+
into the plugin by directly calling these APIs from the Compute function of
125+
these ops in the Plugin.
126+
127+
```
128+
// Expose higher level Assignment operation for Pluggable vendors to implement
129+
// in the plugin for Training. The API takes in the context with indices for
130+
// the input and value tensors. It also accepts the copy functor provided by
131+
// pluggable vendor to do the copying of the tensors.
132+
TF_CAPI_EXPORT extern void TF_AssignVariable(TF_OpKernelContext* ctx,
133+
int input_index,
134+
int value_index,
135+
void (*copyFunc)(TF_OpKernelContext * ctx,
136+
TF_Tensor *source,
137+
TF_Tensor *dest),
138+
TF_Status* status);
139+
140+
// Expose higher level AssignUpdate operation for Pluggable vendors to implement
141+
// in the plugin for Training. The API takes in the context with indices for
142+
// the input and value tensors. It also accepts the update functor provided by
143+
// pluggable vendor to perform these operations respectively.
144+
TF_CAPI_EXPORT extern void TF_AssignUpdateVariable(
145+
TF_OpKernelContext* ctx,
146+
int input_index,
147+
int value_index,
148+
int Op,
149+
bool isVariantType,
150+
void (*updateFunc)(TF_OpKernelContext *ctx,
151+
TF_Tensor *tensor,
152+
TF_Tensor *value, int Op),
153+
TF_Status* status);
154+
155+
```
156+
157+
*Helper Function*
158+
159+
We are proposing to add simple helper function which allows plugins to get the Tensor by providing an input_name.
160+
161+
```
162+
// Allows plugin to get TF_Tensor when passed its input_name
163+
TF_CAPI_EXPORT extern void TF_GetInputByName(TF_OpKernelContext* ctx, const char *input_name,
164+
TF_Tensor** tensor, TF_Status* status);
165+
166+
```
167+
168+
### Alternatives Considered
169+
170+
We considered two different ways to add support for resource variables and
171+
optimizers operations. Option #1 is to expose the required lower level
172+
structures in Tensorflow core like TF_Mutex, TF_ResourceHandle, TF_RefCountPtr
173+
to the plugin vendors. This will allow plugin writers the flexibility to
174+
construct higher level optimizer operations using these lower level primitives
175+
and will be scalable for newer operations. Option #2, is to expose all the
176+
necessary higher level helper methods to implement the Resource variables and
177+
the optimizer ops. This reduces complexity of the interface with keeping lower
178+
level structures intact in the TensorFlow Core. In current proposal we are
179+
discussing the Option #2 to simplify the API design as the first step to add
180+
support. As needed this interface can be built upon in the future to expose lower-level
181+
primitives.
182+
183+
### Performance Implications
184+
185+
We don't expect performance impact due to this RFC. This enables functionality
186+
to update variables used in Training graphs which wasn't supported earlier.
187+
188+
### Dependencies
189+
190+
* This RFC doesn't add new dependencies to external libraries.
191+
* It depends on following modular Tensorflow related RFC:
192+
* [Modular TensorFlow RFC](https://github.com/tensorflow/community/pull/77)
193+
* [StreamExecutor C interface RFC](https://github.com/tensorflow/community/pull/257)
194+
* [Kernel and op registration and implementation API](https://github.com/tensorflow/community/blob/master/rfcs/20190814-kernel-and-op-registration.md)
195+
* [Pluggable device](https://github.com/tensorflow/community/pull/262)
196+
197+
### Engineering Impact
198+
* The impact to binary size / startup time / build time / test times are minimum
199+
* The TensorFlow team will maintain this code.
200+
201+
### Platforms and Environments
202+
203+
This is an extension to the Kernel C API so the change would work on all the
204+
platforms supported by current implementation. The enhancements are platform
205+
independent.
206+
207+
### Best Practices
208+
209+
This works with Modular TensorFlow which is the direction for integrating new third-party vendors to the current Tensorflow stack.
210+
211+
### Tutorials and Examples
212+
213+
We will work with the Tensorflow core team to provide examples as how to
214+
use these API for plugin vendors.
215+
216+
### Compatibility
217+
* The RFC is an extension to the Kernel C API, it follows the same
218+
backwards/forwards compatibility requirements
219+
* This proposal will allow plugin vendors to train models in Tensorflow
220+
ecosystem. Since Modular API is the path forward for newer devices to
221+
integrate to Tensorflow stack it will enable these devices to train models.
222+
- Current API doesn't support TFLite
223+
- It should not impede distribution strategies or serialization to SavedModel
224+
225+
### User Impact
226+
227+
This is an extension of the current Kernel C API as part of Modular design.

0 commit comments

Comments
 (0)