8000 [VK-API][Op Redesign][3/n] Expose new Context and Resource APIs (#121… · pytorch/pytorch@eba28a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit eba28a6

Browse files
jorgep31415pytorchmergebot
authored andcommitted
[VK-API][Op Redesign][3/n] Expose new Context and Resource APIs (#121060)
Summary: For use in the next diff. Test Plan: sc Differential Revision: D54397862 Pull Request resolved: #121060 Approved by: https://github.com/SS-JIA
1 parent 70c23a5 commit eba28a6

File tree

2 files changed

+29
-28
lines changed

2 files changed

+29
-28
lines changed

aten/src/ATen/native/vulkan/api/Context.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ Context::~Context() {
5151
}
5252
}
5353

54-
DescriptorSet Context::submit_compute_prologue(
55-
CommandBuffer& command_buffer,
54+
DescriptorSet Context::get_descriptor_set(
5655
const ShaderInfo& shader_descriptor,
5756
const utils::uvec3& local_workgroup_size) {
5857
VkDescriptorSetLayout shader_layout =
@@ -66,21 +65,34 @@ DescriptorSet Context::submit_compute_prologue(
6665
shader_cache().retrieve(shader_descriptor),
6766
local_workgroup_size});
6867

69-
command_buffer.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
68+
cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
7069

7170
return descriptor_pool().get_descriptor_set(
7271
shader_layout, shader_descriptor.kernel_layout);
7372
}
7473

75-
void Context::submit_compute_epilogue(
76-
CommandBuffer& command_buffer,
74+
void Context::register_shader_dispatch(
7775
const DescriptorSet& descriptors,
7876
PipelineBarrier& pipeline_barrier,
77+
const ShaderInfo& shader_descriptor,
7978
const utils::uvec3& global_workgroup_size) {
80-
command_buffer.bind_descriptors(descriptors.get_bind_handle());
81-
command_buffer.insert_barrier(pipeline_barrier);
82-
83-
command_buffer.dispatch(global_workgroup_size);
79+
// Adjust the global workgroup size based on the output tile size
80+
const utils::uvec3 effective_global_wg = {
81+
utils::div_up(
82+
global_workgroup_size.data[0u],
83+
shader_descriptor.out_tile_size.data[0u]),
84+
utils::div_up(
85+
global_workgroup_size.data[1u],
86+
shader_descriptor.out_tile_size.data[1u]),
87+
utils::div_up(
88+
global_workgroup_size.data[2u],
89+
shader_descriptor.out_tile_size.data[2u]),
90+
};
91+
92+
cmd_.bind_descriptors(descriptors.get_bind_handle());
93+
cmd_.insert_barrier(pipeline_barrier);
94+
95+
cmd_.dispatch(effective_global_wg);
8496
}
8597

8698
void Context::submit_cmd_to_gpu(VkFence fence_handle, const bool final_use) {
@@ -164,12 +176,13 @@ namespace {
164176
void memcpy_to_buffer(const VulkanBuffer& src, VulkanBuffer& dst) {
165177
MemoryMap dst_mapping(dst, MemoryAccessType::WRITE);
166178

167-
MemoryMap src_mapping(src, api::MemoryAccessType::READ);
179+
MemoryMap src_mapping(src, MemoryAccessType::READ);
168180
src_mapping.invalidate();
169181

170182
void* dst_ptr = dst_mapping.template data<void>();
171183
void* src_ptr = src_mapping.template data<void>();
172184

185+
// @lint-ignore CLANGTIDY facebook-security-vulnerable-memcpy
173186
memcpy(dst_ptr, src_ptr, src.mem_size());
174187
}
175188

aten/src/ATen/native/vulkan/api/Context.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,19 +168,14 @@ class Context final {
168168
}
169169
}
170170

171-
private:
172-
DescriptorSet submit_compute_prologue(
173-
CommandBuffer&,
174-
const ShaderInfo&,
175-
const utils::uvec3&);
171+
DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&);
176172

177-
void submit_compute_epilogue(
178-
CommandBuffer&,
173+
void register_shader_dispatch(
179174
const DescriptorSet&,
180175
PipelineBarrier&,
176+
const ShaderInfo&,
181177
const utils::uvec3&);
182178

183-
public:
184179
template <class S, class D>
185180
bool submit_copy(
186181
PipelineBarrier&,
@@ -502,23 +497,16 @@ inline bool Context::submit_compute_job(
502497

503498
// Factor out template parameter independent code to minimize code bloat.
504499
DescriptorSet descriptor_set =
505-
submit_compute_prologue(cmd_, shader, local_work_group_size);
500+
get_descriptor_set(shader, local_work_group_size);
506501

507502
detail::bind(
508503
descriptor_set,
509504
std::index_sequence_for<Arguments...>{},
510505
std::forward<Arguments>(arguments)...);
511506

512-
// Adjust the global workgroup size based on the output tile size
513-
const utils::uvec3 effective_global_wg = {
514-
utils::div_up(global_work_group.data[0u], shader.out_tile_size.data[0u]),
515-
utils::div_up(global_work_group.data[1u], shader.out_tile_size.data[1u]),
516-
utils::div_up(global_work_group.data[2u], shader.out_tile_size.data[2u]),
517-
};
518-
519507
// Factor out template parameter independent code to minimize code bloat.
520-
submit_compute_epilogue(
521-
cmd_, descriptor_set, pipeline_barrier, effective_global_wg);
508+
register_shader_dispatch(
509+
descriptor_set, pipeline_barrier, shader, global_work_group);
522510< 431E code class="diff-text syntax-highlighted-line">

523511
#ifdef USE_VULKAN_GPU_DIAGNOSTICS
524512
if (enable_op_profiling_) {

0 commit comments

Comments
 (0)
0