Skip to content

[ET-VK] Add mechanism to trigger command buffer re-encode only when necessary #13379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 13, 2025

Conversation

pytorchbot
Copy link
Collaborator

This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #13184 by @SS-JIA
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/271/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/271/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/272/orig
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/271/orig
@diff-train-skip-merge

ssjia added 2 commits August 13, 2025 06:53
Pull Request resolved: #13185

## Context

Currently `default_pick_local_wg_size()` (which internally calls `ComputeGraph::create_local_wg_size`) is used to select the local work group size for matrix multiplication ops. However, these functions currently bias the size of the local work group towards the largest dim of the global work group producing local wg sizes like

```
shader                                                                          globalwg size            localwg size
===========                                                                     =====================    ====================                 =============
linear_qga4w_tiled_texture3d_texture3d_texture2d_float                          {256, 29, 1}             {32, 2, 1}                                    1487
matmul_naive_texture3d_float                                                    {29, 115, 32}            {4, 2, 8}                                      712
```

for matrix multiplication shaders. This behaviour was introduced in D64418632 / #6409.

However, through experimental testing a "square" work group size of `{8, 8, 1}` works a lot better for matrix multiplication shaders. The theoretical analysis for this behaviour is that the local work group size determines the memory locations that need to be loaded to compute the overall work group. For a work group with size `{W, H, 1}` the data required to compute the output would be `W * OUTPUT_TILE_W` columns of the weight tensor and `H * OUTPUT_TILE_H` rows of the input tensor. Note that all work group items in the same W index will be requesting the same columns from the weight tensor, and all work group items in the same H index will be requesting the same rows from the input tensor.

If `H==W`, then that "balances" the amount of data needed to loaded from each input tensor and may result in better data sharing behaviour among all work group items. Assuming `OUTPUT_TILE_W == OUTPUT_TILE_H == 1`, a local work group of size `{64, 1, 1}` would require 1 unique row from the input tensor an 64 unique columns to be loaded from the weight tensor, resulting in `(1 + 64) * K = 65K` elements to be loaded in total, where K is the size of the shared reduction dim. Conversely, a local work group of size `{8, 8, 1}` would require 8 unique rows / 8 unique columns resulting in only `(8 + 8) * K = 16K` unique elements to be loaded.

This highlights the need to use dedicated logic to compute work group sizes for matrix multiplication shaders.

## Changes

* Introduce `pick_hw_square_wg_size`
* Use the new local work group size determination function for Quantized Linear, Matmul, and Linear
ghstack-source-id: 302703877

Differential Revision: [D79813236](https://our.internmc.facebook.com/intern/diff/D79813236/)
…ecessary

Pull Request resolved: #13184

## Context

Dynamic shape models currently will require the command buffer to be re-encoded every inference. However, this introduces a significant overhead when running models that require dynamic shapes.

The reality is that a command buffer  re-encode may not be needed every frame. A command buffer re-encode will only be needed when:

1. Shader dispatch parameters change; i.e. new tensor sizes require a completely different compute shader, require new local work group sizing, or require new work group grid size (i.e. global work group size / local work group size)
2. Push constants containing tensor metadata need to be updated

This diff aims to reduce the overhead of triggering tensor shape change by detecting when a command buffer re-encode is actually needed.

## Changes

`ComputeGraph`:
* Introduce `requires_reencode` flag to `ComputeGraph` to indicate when a command buffer re-encode is needed.
* Introduce a `std::set<ValueRef>` tracking which values were updated when propagating tensor sizes
  * "update" can be one of two things: 1) tensor sizes changed 2) symint value changed

`DispatchNode`:
* When propagating new tensor sizes, only execute the resize function if any of the values participating in the computation have been updated
* Mark `requries_reencode` if any push constants associated with tensor metadata need to be udpated

`DynamicDispatchNode`:
* Only recompute compute shader dispatch params if any of the values participating in the computation have been updated
* Mark `requires_reencode` if 1) a new compute shader is required, 2) local work group size changed, 3) work group grid size changed
ghstack-source-id: 302703876
@exported-using-ghexport

Differential Revision: [D79813237](https://our.internmc.facebook.com/intern/diff/D79813237/)
@pytorchbot pytorchbot requested a review from SS-JIA as a code owner August 13, 2025 17:52
Copy link

pytorch-bot bot commented Aug 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/13379

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 85 Pending

As of commit 1fd109c with merge base b36d6b6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 13, 2025
Base automatically changed from gh/SS-JIA/272/orig to main August 13, 2025 18:13
@SS-JIA SS-JIA merged commit 3254ddf into main Aug 13, 2025
96 of 102 checks passed
@SS-JIA SS-JIA deleted the gh/SS-JIA/271/orig branch August 13, 2025 18:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants