Skip to content

[ET-VK] Better work group sizes for matmul #13185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,27 @@ utils::uvec3 default_pick_local_wg_size(
return graph->create_local_wg_size(global_workgroup_size);
}

utils::uvec3 pick_hw_square_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)graph;
(void)shader;
(void)args;
(void)resize_args;
// Some inactive invocations are okay; set 6 as the threshold to use the
// a square wg size.
if (global_workgroup_size[0u] >= 6 && global_workgroup_size[1u] >= 6) {
return {8u, 8u, 1u};
}
// If width dim is sufficiently small, then bias towards height dim to reduce
// the number of inactive invocations.
if (global_workgroup_size[0u] < 6u) {
return {4u, 16u, 1u};
}
return {16u, 4u, 1u};
}

} // namespace vkcompute
18 changes: 18 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,22 @@ utils::uvec3 default_pick_local_wg_size(
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args);

/**
* Constructs a local work group size with the shape {W, H, 1}. The function
* will try to set W == H == sqrt(num_invocations), where num_invocations is
* typically 64. This configuration is good for ops like matrix multiplication
* as it reduces the total volume of unique data that the entire work group
* will need to read from input tensors in order to produce the output data.
* To compute an output tile of {W, H, 1}, the work group will need to read
* H unique rows = H * K unique elements from the input tensor and W unique cols
* = W * K elements from the weight tensor, resulting in (W + H) * K unique
* elements in total.
*/
utils::uvec3 pick_hw_square_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args);

} // namespace vkcompute
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ void add_addmm_naive_texture_node(
graph,
VK_KERNEL_FROM_STR(kernel_name),
addmm_naive_texture_global_wg_size,
default_pick_local_wg_size,
pick_hw_square_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}},
// Shader params buffers
Expand Down Expand Up @@ -245,7 +245,7 @@ void add_addmm_naive_buffer_node(
graph,
VK_KERNEL_FROM_STR(kernel_name),
addmm_naive_buffer_global_wg_size,
default_pick_local_wg_size,
pick_hw_square_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}},
// Shader params buffers
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void add_matmul_naive_buffer_node(
graph,
VK_KERNEL_FROM_STR(kernel_name),
matmul_naive_buffer_global_wg_size,
default_pick_local_wg_size,
pick_hw_square_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}},
// Shader params buffers
Expand Down Expand Up @@ -158,7 +158,7 @@ void add_matmul_naive_texture3d_node(
graph,
pick_matmul_naive_texture3d_shader,
default_pick_global_wg_size,
default_pick_local_wg_size,
pick_hw_square_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}},
// Shader params buffers
Expand Down Expand Up @@ -273,7 +273,7 @@ void add_matmul_optimized_node(
graph,
pick_matmul_optimized_shader,
matmul_optimized_global_wg_size,
default_pick_local_wg_size,
pick_hw_square_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}},
// Shader params buffers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ utils::uvec3 linear_qga4w_local_wg_size(
if (use_coop_algorithm) {
return {64, 1, 1};
} else {
return graph->create_local_wg_size(global_workgroup_size);
return pick_hw_square_wg_size(
graph, shader, global_workgroup_size, args, resize_args);
}
}

Expand Down
Loading