diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index 4c3c16417b5..6c701224f7f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -33,4 +33,27 @@ utils::uvec3 default_pick_local_wg_size( return graph->create_local_wg_size(global_workgroup_size); } +utils::uvec3 pick_hw_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)args; + (void)resize_args; + // Some inactive invocations are okay; set 6 as the threshold to use the + // a square wg size. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[1u] >= 6) { + return {8u, 8u, 1u}; + } + // If width dim is sufficiently small, then bias towards height dim to reduce + // the number of inactive invocations. + if (global_workgroup_size[0u] < 6u) { + return {4u, 16u, 1u}; + } + return {16u, 4u, 1u}; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.h b/backends/vulkan/runtime/graph/ops/impl/Common.h index 662fb07095a..1831ab2a845 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.h +++ b/backends/vulkan/runtime/graph/ops/impl/Common.h @@ -36,4 +36,22 @@ utils::uvec3 default_pick_local_wg_size( const std::vector& args, const std::vector& resize_args); +/** + * Constructs a local work group size with the shape {W, H, 1}. The function + * will try to set W == H == sqrt(num_invocations), where num_invocations is + * typically 64. This configuration is good for ops like matrix multiplication + * as it reduces the total volume of unique data that the entire work group + * will need to read from input tensors in order to produce the output data. + * To compute an output tile of {W, H, 1}, the work group will need to read + * H unique rows = H * K unique elements from the input tensor and W unique cols + * = W * K elements from the weight tensor, resulting in (W + H) * K unique + * elements in total. + */ +utils::uvec3 pick_hw_square_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 7ca31599cdf..38d70271f4f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -178,7 +178,7 @@ void add_addmm_naive_texture_node( graph, VK_KERNEL_FROM_STR(kernel_name), addmm_naive_texture_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers @@ -245,7 +245,7 @@ void add_addmm_naive_buffer_node( graph, VK_KERNEL_FROM_STR(kernel_name), addmm_naive_buffer_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 0f5556060a2..47ecf5f18d2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -102,7 +102,7 @@ void add_matmul_naive_buffer_node( graph, VK_KERNEL_FROM_STR(kernel_name), matmul_naive_buffer_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -158,7 +158,7 @@ void add_matmul_naive_texture3d_node( graph, pick_matmul_naive_texture3d_shader, default_pick_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -273,7 +273,7 @@ void add_matmul_optimized_node( graph, pick_matmul_optimized_shader, matmul_optimized_global_wg_size, - default_pick_local_wg_size, + pick_hw_square_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index 8c7c6b0cdf9..52cf75e28b5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -158,7 +158,8 @@ utils::uvec3 linear_qga4w_local_wg_size( if (use_coop_algorithm) { return {64, 1, 1}; } else { - return graph->create_local_wg_size(global_workgroup_size); + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } }