Skip to content

Commit 9396e17

Browse files
authored
Add optimised 'Indirect BGEMM' binary convolution kernels. (#516)
To start, add portable 4x2 C++ kernels for float/int8/bitpacked output. Facilitate easy implementation of new indirect bgemm kernels, including architecture-specific variations.
1 parent 6744fc2 commit 9396e17

File tree

11 files changed

+717
-56
lines changed

11 files changed

+717
-56
lines changed

larq_compute_engine/core/bconv2d/BUILD

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ cc_library(
3131
)
3232

3333
cc_library(
34-
name = "optimized",
34+
name = "optimized_bgemm",
3535
hdrs = [
36-
"optimized.h",
36+
"optimized_bgemm.h",
3737
],
3838
deps = [
3939
":zero_padding_correction",
@@ -45,3 +45,20 @@ cc_library(
4545
"@ruy//ruy/profiler:instrumentation",
4646
],
4747
)
48+
49+
cc_library(
50+
name = "optimized_indirect_bgemm",
51+
hdrs = [
52+
"optimized_indirect_bgemm.h",
53+
],
54+
deps = [
55+
":zero_padding_correction",
56+
"//larq_compute_engine/core/indirect_bgemm:kernels",
57+
"//larq_compute_engine/core/indirect_bgemm:prepare",
58+
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context",
59+
"@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm",
60+
"@org_tensorflow//tensorflow/lite/kernels:padding",
61+
"@org_tensorflow//tensorflow/lite/kernels/internal:optimized_base",
62+
"@ruy//ruy/profiler:instrumentation",
63+
],
64+
)

larq_compute_engine/core/bconv2d/optimized.h renamed to larq_compute_engine/core/bconv2d/optimized_bgemm.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_H_
2-
#define COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_H_
1+
#ifndef COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_BGEMM_H_
2+
#define COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_BGEMM_H_
33

44
#include "larq_compute_engine/core/bconv2d/zero_padding_correction.h"
55
#include "larq_compute_engine/core/bgemm/bgemm.h"
@@ -61,7 +61,7 @@ inline void im2col(const ConvParams& params, const RuntimeShape& input_shape,
6161
}
6262

6363
template <typename AccumScalar, typename DstScalar>
64-
inline void BConv2DOptimized(
64+
inline void BConv2DOptimizedBGEMM(
6565
const ConvParams& params, const RuntimeShape& input_shape,
6666
const TBitpacked* input_data, const RuntimeShape& filter_shape,
6767
const TBitpacked* packed_filter_data,
@@ -152,6 +152,8 @@ inline void BConv2DOptimized(
152152

153153
if (std::is_same<DstScalar, float>::value &&
154154
params.padding_type == PaddingType::kSame && pad_value == 0) {
155+
ruy::profiler::ScopeLabel label("Zero padding correction");
156+
155157
const int stride_width = params.stride_width;
156158
const int stride_height = params.stride_height;
157159
const int dilation_width_factor = params.dilation_width_factor;
@@ -166,20 +168,17 @@ inline void BConv2DOptimized(
166168
const int output_width = output_shape.Dims(2);
167169
const int output_height = output_shape.Dims(1);
168170

169-
{
170-
ruy::profiler::ScopeLabel label("Zero padding correction");
171-
zero_padding_correction::ApplyCorrection(
172-
batches, input_height, input_width, input_depth, filter_height,
173-
filter_width, output_depth, stride_height, stride_width,
174-
dilation_height_factor, dilation_width_factor,
175-
reinterpret_cast<float*>(output_data), output_height, output_width,
176-
padding_buffer);
177-
}
171+
zero_padding_correction::ApplyCorrection(
172+
batches, input_height, input_width, input_depth, filter_height,
173+
filter_width, output_depth, stride_height, stride_width,
174+
dilation_height_factor, dilation_width_factor,
175+
reinterpret_cast<float*>(output_data), output_height, output_width,
176+
padding_buffer);
178177
}
179178
}
180179

181180
} // namespace bconv2d
182181
} // namespace core
183182
} // namespace compute_engine
184183

185-
#endif // COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_H_
184+
#endif // COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_BGEMM_H_
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#ifndef COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_
2+
#define COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_
3+
4+
#include "larq_compute_engine/core/bconv2d/zero_padding_correction.h"
5+
#include "larq_compute_engine/core/indirect_bgemm/kernel.h"
6+
#include "ruy/profiler/instrumentation.h"
7+
#include "tensorflow/lite/kernels/internal/types.h"
8+
9+
namespace compute_engine {
10+
namespace core {
11+
namespace bconv2d {
12+
13+
template <typename AccumScalar, typename DstScalar>
14+
inline void BConv2DOptimizedIndirectBGEMM(
15+
const indirect_bgemm::IndirectBGEMMKernel<DstScalar> kernel,
16+
const compute_engine::tflite::bconv2d::TfLiteBConv2DParams* conv_params,
17+
const RuntimeShape& bitpacked_input_shape, const RuntimeShape& output_shape,
18+
const OutputTransform<DstScalar>& output_transform,
19+
const TBitpacked* packed_weights, const TBitpacked** indirection_buffer,
20+
DstScalar* output_data, const float* padding_buffer, const int pad_value) {
21+
TF_LITE_ASSERT_EQ(bitpacked_input_shape.DimensionsCount(), 4);
22+
TF_LITE_ASSERT_EQ(output_shape.DimensionsCount(), 4);
23+
24+
ruy::profiler::ScopeLabel label("BConv2D (optimized, indirect BGEMM)");
25+
26+
const std::int32_t conv_kernel_size =
27+
conv_params->filter_height * conv_params->filter_width;
28+
const std::int32_t bitpacked_input_channels = bitpacked_input_shape.Dims(3);
29+
const std::int32_t output_size = output_shape.Dims(1) * output_shape.Dims(2);
30+
const std::int32_t output_channels = conv_params->channels_out;
31+
32+
indirect_bgemm::RunKernel(kernel, conv_kernel_size, bitpacked_input_channels,
33+
output_size, output_channels, output_transform,
34+
packed_weights, indirection_buffer, output_data);
35+
36+
if (std::is_same<DstScalar, float>::value &&
37+
conv_params->padding_type == TfLitePadding::kTfLitePaddingSame &&
38+
pad_value == 0) {
39+
ruy::profiler::ScopeLabel label("Zero padding correction");
40+
41+
const int stride_width = conv_params->stride_width;
42+
const int stride_height = conv_params->stride_height;
43+
const int dilation_width_factor = conv_params->dilation_width_factor;
44+
const int dilation_height_factor = conv_params->dilation_height_factor;
45+
const int batches = MatchingDim(bitpacked_input_shape, 0, output_shape, 0);
46+
const int input_depth = conv_params->channels_in;
47+
const int input_width = bitpacked_input_shape.Dims(2);
48+
const int input_height = bitpacked_input_shape.Dims(1);
49+
const int filter_height = conv_params->filter_height;
50+
const int filter_width = conv_params->filter_width;
51+
const int output_depth = output_shape.Dims(3);
52+
const int output_width = output_shape.Dims(2);
53+
const int output_height = output_shape.Dims(1);
54+
55+
zero_padding_correction::ApplyCorrection(
56+
batches, input_height, input_width, input_depth, filter_height,
57+
filter_width, output_depth, stride_height, stride_width,
58+
dilation_height_factor, dilation_width_factor,
59+
reinterpret_cast<float*>(output_data), output_height, output_width,
60+
padding_buffer);
61+
}
62+
}
63+
64+
} // namespace bconv2d
65+
} // namespace core
66+
} // namespace compute_engine
67+
68+
#endif // COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
cc_library(
6+
name = "prepare",
7+
hdrs = [
8+
"prepare.h",
9+
],
10+
deps = [
11+
"//larq_compute_engine/core:types",
12+
"//larq_compute_engine/tflite/kernels:bconv2d_params",
13+
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
14+
],
15+
)
16+
17+
cc_library(
18+
name = "kernels",
19+
hdrs = [
20+
"kernel.h",
21+
"kernel_4x2_portable.h",
22+
],
23+
deps = [
24+
"//larq_compute_engine/core:types",
25+
"//larq_compute_engine/core/bconv2d:output_transform",
26+
"//larq_compute_engine/tflite/kernels:bconv2d_params",
27+
"@org_tensorflow//tensorflow/lite/kernels/internal:types",
28+
"@ruy//ruy/profiler:instrumentation",
29+
],
30+
)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
2+
#ifndef COMPUTE_ENGINE_INDIRECT_BGEMM_KERNEL_H_
3+
#define COMPUTE_ENGINE_INDIRECT_BGEMM_KERNEL_H_
4+
5+
#include <cstdint>
6+
#include <type_traits>
7+
8+
#include "larq_compute_engine/core/indirect_bgemm/kernel_4x2_portable.h"
9+
#include "larq_compute_engine/core/types.h"
10+
#include "larq_compute_engine/tflite/kernels/bconv2d_params.h"
11+
#include "tensorflow/lite/c/builtin_op_data.h"
12+
#include "tensorflow/lite/kernels/internal/types.h"
13+
14+
using namespace tflite;
15+
16+
namespace compute_engine {
17+
namespace core {
18+
namespace indirect_bgemm {
19+
20+
using compute_engine::tflite::bconv2d::TfLiteBConv2DParams;
21+
22+
template <typename DstScalar>
23+
struct IndirectBGEMMKernel {
24+
using MicroKernelFunction = void(const std::int32_t, const std::int32_t,
25+
const std::int32_t, const std::int32_t,
26+
const bconv2d::OutputTransform<DstScalar>&,
27+
const TBitpacked*, const TBitpacked**,
28+
DstScalar*);
29+
MicroKernelFunction* micro_kernel_function;
30+
const std::int32_t block_size_output_channels;
31+
const std::int32_t block_size_pixels;
32+
};
33+
34+
// This function allows us to select which kernel to use at runtime based on any
35+
// parameter we choose: destination scalar; conv params; input/output shapes;
36+
// even detected CPU features.
37+
// It is very important that this function is deterministic, as we rely on
38+
// the fact that the same kernel is selected for each call to `Eval` (as long as
39+
// the input shape doesn't change).
40+
template <typename DstScalar>
41+
inline IndirectBGEMMKernel<DstScalar> SelectRuntimeKernel(
42+
const TfLiteBConv2DParams* conv_params,
43+
const RuntimeShape& bitpacked_input_shape,
44+
const RuntimeShape& output_shape) {
45+
// For now there is only one kernel available.
46+
return IndirectBGEMMKernel<DstScalar>{
47+
&kernel_4x2_portable::RunKernel<DstScalar>, 4, 2};
48+
}
49+
50+
template <typename DstScalar>
51+
void RunKernel(const IndirectBGEMMKernel<DstScalar>& kernel,
52+
const std::int32_t conv_kernel_size,
53+
const std::int32_t bitpacked_input_channels,
54+
const std::int32_t output_size,
55+
const std::int32_t output_channels,
56+
const bconv2d::OutputTransform<DstScalar>& output_transform,
57+
const TBitpacked* packed_weights_ptr,
58+
const TBitpacked** indirection_buffer, DstScalar* output_ptr) {
59+
// TODO: implement multithreading here.
60+
for (std::int32_t pixel_start = 0; pixel_start < output_size;
61+
pixel_start += kernel.block_size_pixels) {
62+
const std::int32_t output_stride =
63+
std::is_same<DstScalar, TBitpacked>::value
64+
? bitpacking::GetBitpackedSize(output_channels)
65+
: output_channels;
66+
kernel.micro_kernel_function(
67+
std::min(output_size - pixel_start, kernel.block_size_pixels),
68+
conv_kernel_size, bitpacked_input_channels, output_channels,
69+
output_transform, packed_weights_ptr,
70+
indirection_buffer + pixel_start * conv_kernel_size,
71+
output_ptr + pixel_start * output_stride);
72+
}
73+
}
74+
75+
} // namespace indirect_bgemm
76+
} // namespace core
77+
} // namespace compute_engine
78+
79+
#endif // COMPUTE_ENGINE_INDIRECT_BGEMM_KERNEL_H_

0 commit comments

Comments
 (0)