@@ -51,8 +51,7 @@ Context::~Context() {
51
51
}
52
52
}
53
53
54
- DescriptorSet Context::submit_compute_prologue (
55
- CommandBuffer& command_buffer,
54
+ DescriptorSet Context::get_descriptor_set (
56
55
const ShaderInfo& shader_descriptor,
57
56
const utils::uvec3& local_workgroup_size) {
58
57
VkDescriptorSetLayout shader_layout =
@@ -66,21 +65,34 @@ DescriptorSet Context::submit_compute_prologue(
66
65
shader_cache ().retrieve (shader_descriptor),
67
66
local_workgroup_size});
68
67
69
- command_buffer .bind_pipeline (pipeline, pipeline_layout, local_workgroup_size);
68
+ cmd_ .bind_pipeline (pipeline, pipeline_layout, local_workgroup_size);
70
69
71
70
return descriptor_pool ().get_descriptor_set (
72
71
shader_layout, shader_descriptor.kernel_layout );
73
72
}
74
73
75
- void Context::submit_compute_epilogue (
76
- CommandBuffer& command_buffer,
74
+ void Context::register_shader_dispatch (
77
75
const DescriptorSet& descriptors,
78
76
PipelineBarrier& pipeline_barrier,
77
+ const ShaderInfo& shader_descriptor,
79
78
const utils::uvec3& global_workgroup_size) {
80
- command_buffer.bind_descriptors (descriptors.get_bind_handle ());
81
- command_buffer.insert_barrier (pipeline_barrier);
82
-
83
- command_buffer.dispatch (global_workgroup_size);
79
+ // Adjust the global workgroup size based on the output tile size
80
+ const utils::uvec3 effective_global_wg = {
81
+ utils::div_up (
82
+ global_workgroup_size.data [0u ],
83
+ shader_descriptor.out_tile_size .data [0u ]),
84
+ utils::div_up (
85
+ global_workgroup_size.data [1u ],
86
+ shader_descriptor.out_tile_size .data [1u ]),
87
+ utils::div_up (
88
+ global_workgroup_size.data [2u ],
89
+ shader_descriptor.out_tile_size .data [2u ]),
90
+ };
91
+
92
+ cmd_.bind_descriptors (descriptors.get_bind_handle ());
93
+ cmd_.insert_barrier (pipeline_barrier);
94
+
95
+ cmd_.dispatch (effective_global_wg);
84
96
}
85
97
86
98
void Context::submit_cmd_to_gpu (VkFence fence_handle, const bool final_use) {
@@ -164,12 +176,13 @@ namespace {
164
176
void memcpy_to_buffer (const VulkanBuffer& src, VulkanBuffer& dst) {
165
177
MemoryMap dst_mapping (dst, MemoryAccessType::WRITE);
166
178
167
- MemoryMap src_mapping (src, api:: MemoryAccessType::READ);
179
+ MemoryMap src_mapping (src, MemoryAccessType::READ);
168
180
src_mapping.invalidate ();
169
181
170
182
void * dst_ptr = dst_mapping.template data <void >();
171
183
void * src_ptr = src_mapping.template data <void >();
172
184
185
+ // @lint-ignore CLANGTIDY facebook-security-vulnerable-memcpy
173
186
memcpy (dst_ptr, src_ptr, src.mem_size ());
174
187
}
175
188
0 commit comments