Skip to content

[ET-VK] Add mechanism to trigger command buffer re-encode only when necessary #13184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,13 +583,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
}
}

// propagate_resize() will re-encode the command buffer so that push
// constants are updated and DynamicDispatchNode can update the compute
// shader, global workgroup size, and local workgroup size to perform the
// model inference.
if (should_propagate_resize ||
(compute_graph->graphconfig().expect_dynamic_shapes &&
compute_graph->execute_count() == 0u)) {
if (should_propagate_resize) {
compute_graph->propagate_resize();
}

Expand Down
57 changes: 52 additions & 5 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,29 @@ utils::StorageType ComputeGraph::suggested_storage_type() {
return utils::kTexture3D;
}

bool ComputeGraph::was_value_updated(const ValueRef idx) const noexcept {
if (!is_valid_value_idx(idx)) {
return false;
}

// Check if this ValueRef itself was updated
if (updated_values_.find(idx) != updated_values_.end()) {
return true;
}

// If this is a ValueList, check each ValueRef in the list
if (val_is_value_list(idx)) {
const auto& value_list = values_.at(idx).toConstValueList();
for (const auto& nested_idx : value_list) {
if (was_value_updated(nested_idx)) {
return true;
}
}
}

return false;
}

utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
const std::vector<int64_t>& sizes) {
if (config_.enable_memory_layout_override) {
Expand Down Expand Up @@ -236,6 +259,10 @@ void ComputeGraph::check_no_active_value_ptrs() {
"invalidated.");
}

bool ComputeGraph::is_valid_value_idx(const ValueRef idx) const noexcept {
return idx >= 0 && idx < static_cast<int>(values_.size());
}

std::vector<int64_t> ComputeGraph::sizes_of(const ValueRef idx) const {
const Value& val = values_.at(idx);
if (val.isTensor()) {
Expand Down Expand Up @@ -569,7 +596,12 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
}

void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
get_symint(idx)->set(val);
int32_t cur_val = read_symint(idx);
if (cur_val != val) {
get_symint(idx)->set(val);
// Track that this ValueRef was updated
updated_values_.insert(idx);
}
}

int32_t ComputeGraph::read_symint(const ValueRef idx) {
Expand Down Expand Up @@ -951,6 +983,12 @@ void ComputeGraph::execute() {
}

execute_count_++;

// Clear the set of updated values at the end of inference
updated_values_.clear();

// Reset the re-encoding flag at the end of inference
requires_reencode_ = false;
}

void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) {
Expand All @@ -968,21 +1006,30 @@ void ComputeGraph::resize_input(
const int64_t idx,
const std::vector<int64_t>& new_sizes) {
IOValueRef io_val = inputs_.at(idx);
get_tensor(io_val.value)->virtual_resize(new_sizes);
virtual_resize(io_val.value, new_sizes);
updated_values_.insert(io_val.staging);
}

void ComputeGraph::virtual_resize(
const ValueRef idx,
const std::vector<int64_t>& new_sizes) {
get_tensor(idx)->virtual_resize(new_sizes);
std::vector<int64_t> cur_sizes = sizes_of(idx);
if (cur_sizes != new_sizes) {
get_tensor(idx)->virtual_resize(new_sizes);
// Track that this ValueRef was updated
updated_values_.insert(idx);
}
}

void ComputeGraph::propagate_resize() {
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
node->trigger_resize(this);
}
// Only re-encode on resize if dynamic shapes are expected
if (config_.expect_dynamic_shapes) {
// A command buffer re-encode will be needed if:
// 1. Any push constant data (used for tensor metadata) was updated
// 2. Compute shader dispatch parameters (i.e. compute shader, global and
// local work group sizes) were updated
if (requires_reencode_) {
clear_deferred_cmds();
}
}
Expand Down
38 changes: 33 additions & 5 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ class ComputeGraph final {
// List of command buffers deferred for submission
std::vector<vkapi::CommandBuffer> deferred_cmd_list_;

// Set to track which ValueRefs were updated during inference
std::unordered_set<ValueRef> updated_values_;

// Flag to indicate if re-encoding is required
bool requires_reencode_ = false;

protected:
size_t values_in_use_ = 0;
size_t execute_count_ = 0;
Expand Down Expand Up @@ -244,6 +250,9 @@ class ComputeGraph final {
return config_;
}

// Check if the ComputeGraph has a value at the specified index
bool is_valid_value_idx(const ValueRef idx) const noexcept;

//
// Value Extraction
//
Expand Down Expand Up @@ -427,31 +436,41 @@ class ComputeGraph final {
}

inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes);
pc_data.set_value(idx);
return pc_data;
}

inline PushConstantDataInfo dim_order_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(),
api::kTensorDimOrder);
pc_data.set_value(idx);
return pc_data;
}

inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(),
api::kTensorStrides);
pc_data.set_value(idx);
return pc_data;
}

inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(),
api::kTensorLogicalLimits);
pc_data.set_value(idx);
return pc_data;
}

inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const {
return PushConstantDataInfo(
PushConstantDataInfo pc_data = PushConstantDataInfo(
values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel);
pc_data.set_value(idx);
return pc_data;
}

//
Expand Down Expand Up @@ -948,6 +967,15 @@ class ComputeGraph final {

void propagate_resize();

// Check if a specific ValueRef (or ValueList) was updated, with recursive
// handling
bool was_value_updated(const ValueRef idx) const noexcept;

// Set the flag to indicate that re-encoding is required
inline void set_requires_reencode() noexcept {
requires_reencode_ = true;
}

//
// Miscellaneous Utilities
//
Expand Down
17 changes: 17 additions & 0 deletions backends/vulkan/runtime/graph/containers/PushConstantData.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <executorch/backends/vulkan/runtime/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

namespace vkcompute {

class ComputeGraph;
Expand All @@ -33,6 +35,9 @@ class PushConstantDataInfo {
};

Payload payload_;
// The value in a compute graph that this push constant data is associated
// with, if any.
ValueRef value_ = kDummyValueRef;

public:
explicit PushConstantDataInfo(
Expand Down Expand Up @@ -60,6 +65,18 @@ class PushConstantDataInfo {
void* dst,
const uint32_t dst_offset,
const uint32_t max_dst_size) const;

inline bool is_tensor_metadata() const noexcept {
return tensorUniformData != nullptr;
}

inline void set_value(ValueRef value) noexcept {
value_ = value;
}

inline ValueRef value() const noexcept {
return value_;
}
};

} // namespace vkcompute
17 changes: 17 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,21 @@ void DispatchNode::write_push_constant_data() {
}
}

bool DispatchNode::trigger_resize(ComputeGraph* graph) {
const bool any_arg_updated = ExecuteNode::trigger_resize(graph);

if (any_arg_updated) {
// If this shader uses push constants, and the tensor metadata associated
// with the push constants has changed, then the command buffer needs to be
// re-encoded since push constants cannot be updated.
for (const auto& push_constant : push_constants_) {
if (push_constant.is_tensor_metadata() &&
graph->was_value_updated(push_constant.value())) {
graph->set_requires_reencode();
}
}
}
return any_arg_updated;
}

} // namespace vkcompute
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class DispatchNode : public ExecuteNode {

void encode(ComputeGraph* graph) override;

bool trigger_resize(ComputeGraph* graph) override;

protected:
vkapi::ShaderInfo shader_;
utils::uvec3 global_workgroup_size_;
Expand Down
69 changes: 64 additions & 5 deletions backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ DynamicDispatchNode::DynamicDispatchNode(
pick_global_wg_fn(&graph, shader_, args, resize_args);
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
&graph, shader_, global_workgroup_size_, args, resize_args));

// Calculate dispatch grid similar to Context.cpp register_shader_dispatch
wg_dispatch_grid_ = {
utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]),
utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]),
utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])};
}

DynamicDispatchNode::DynamicDispatchNode(
Expand Down Expand Up @@ -72,21 +78,74 @@ DynamicDispatchNode::DynamicDispatchNode(
pick_global_wg_fn(&graph, shader_, args, resize_args);
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
&graph, shader_, global_workgroup_size_, args, resize_args));
// Calculate the work group grid that will be dispatched
wg_dispatch_grid_ = {
utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]),
utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]),
utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])};
}

void DynamicDispatchNode::encode(ComputeGraph* graph) {
bool DynamicDispatchNode::trigger_resize(ComputeGraph* graph) {
// DispatchNode::trigger_resize() will return true if any of the values
// participating in this operation were updated.
const bool any_arg_updated = DispatchNode::trigger_resize(graph);
// Only re-compute the shader, global workgroup size, and local workgroup size
// if any of the values participating in this operation were updated.
// Otherwise, assume that these will not have changed.
if (!any_arg_updated) {
return false;
}

// Indicates if the shader dispatch should be changed since the last time the
// command buffer was encoded.
bool dispatch_params_changed = false;

if (pick_shader_fn_) {
shader_ = pick_shader_fn_(graph, args_, resize_args_);
vkapi::ShaderInfo new_shader = pick_shader_fn_(graph, args_, resize_args_);
// Compare shader kernel names as a proxy for shader equality
if (shader_.kernel_name != new_shader.kernel_name) {
shader_ = new_shader;
dispatch_params_changed = true;
}
}
if (pick_global_wg_fn_) {
// Note that if global workgroup size changes, then the dispatch params
// may not actually be different. The actual value to check is the
// work group grid size that will be dispatched, which is calculated
// below.
global_workgroup_size_ =
pick_global_wg_fn_(graph, shader_, args_, resize_args_);
}
if (pick_local_wg_fn_) {
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_(
graph, shader_, global_workgroup_size_, args_, resize_args_));
utils::uvec3 new_local_wg_uvec3 = pick_local_wg_fn_(
graph, shader_, global_workgroup_size_, args_, resize_args_);
utils::WorkgroupSize new_local_wg =
utils::WorkgroupSize(new_local_wg_uvec3);
if (local_workgroup_size_ != new_local_wg) {
local_workgroup_size_ = new_local_wg;
dispatch_params_changed = true;
}
}

// Always recompute the new dispatch grid and check if it's different
utils::uvec3 new_wg_dispatch_grid = {
utils::div_up(global_workgroup_size_[0], local_workgroup_size_[0]),
utils::div_up(global_workgroup_size_[1], local_workgroup_size_[1]),
utils::div_up(global_workgroup_size_[2], local_workgroup_size_[2])};

// Check if the new dispatch grid is different from the old one
if (wg_dispatch_grid_ != new_wg_dispatch_grid) {
dispatch_params_changed = true;
}
DispatchNode::encode(graph);
wg_dispatch_grid_ = new_wg_dispatch_grid;

// If any of the dispatch params have changed, then the command buffer must
// be re-encoded.
if (dispatch_params_changed) {
graph->set_requires_reencode();
}

return true;
}

} // namespace vkcompute
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ class DynamicDispatchNode final : public DispatchNode {

~DynamicDispatchNode() override = default;

void encode(ComputeGraph* graph) override;
bool trigger_resize(ComputeGraph* graph) override;

protected:
const PickShaderFn pick_shader_fn_;
const PickGlobalFn pick_global_wg_fn_;
const PickLocalFn pick_local_wg_fn_;

utils::uvec3 wg_dispatch_grid_{1u, 1u, 1u};

public:
operator bool() const {
return shader_;
Expand Down
Loading
Loading