Skip to content

Commit 91eca89

Browse files
authored
[ET-VK] 5/n Split dispatches between multiple command buffers. Track previously submitted command buffers in context and add function to execute all previous command buffers. (#12650)
The diff adds changes to store command buffers submitted with final_use set to false. Storing these buffers is necessary for `execute()` function. Since, `encode_execute()` function is typically called once but `execute()` can be called multiple times, `submit_all_non_final_cmds` function is added so all recorded command buffers with `final_use = False` can be called multiple times in `execute()`. #### Key Changes * Added a flag `execute_pending_first_submission` to the `ComputeGraph` class to track whether execute nodes have been freshly encoded and need to be submitted first. * Added a new function `submit_all_non_final_cmds` to the `Context` class, which submits all non-final command buffers to the GPU. * Modified the `submit_cmd_to_gpu` function to add the submitted command buffer to the `non_final_cmds_` list if it's not marked as final use. * Updated the `execute` function in `ComputeGraph` to submit all non-final command buffers before executing the graph. Differential Revision: [D78360038](https://our.internmc.facebook.com/intern/diff/D78360038/)
1 parent c9df2aa commit 91eca89

File tree

4 files changed

+81
-7
lines changed

4 files changed

+81
-7
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ void Context::submit_cmd_to_gpu(VkFence fence_handle, const bool final_use) {
217217
}
218218

219219
void Context::flush() {
220-
VK_CHECK(vkQueueWaitIdle(queue()));
220+
VK_CHECK(vkQueueWaitIdle(queue().handle));
221221

222222
command_pool_.flush();
223223
descriptor_pool_.flush();

backends/vulkan/runtime/api/Context.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ class Context final {
9090
return device_;
9191
}
9292

93-
inline VkQueue queue() {
94-
return queue_.handle;
93+
inline vkapi::Adapter::Queue& queue() {
94+
return queue_;
9595
}
9696

9797
// Device Caches
@@ -230,6 +230,10 @@ class Context final {
230230
VkFence fence_handle = VK_NULL_HANDLE,
231231
const bool final_use = false);
232232

233+
vkapi::CommandBuffer& extract_cmd() {
234+
return cmd_;
235+
}
236+
233237
void flush();
234238

235239
#ifdef VULKAN_DEBUG

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ ComputeGraph::~ComputeGraph() {
158158

159159
prepack_nodes_.clear();
160160
execute_nodes_.clear();
161+
clear_deferred_cmds();
161162

162163
context_->flush();
163164
}
@@ -775,6 +776,53 @@ void ComputeGraph::submit_current_cmd_and_wait(const bool final_use) {
775776
context_->fences().return_fence(fence);
776777
}
777778

779+
void ComputeGraph::submit_cmd(
780+
vkapi::CommandBuffer& cmd_buf,
781+
VkSemaphore wait_semaphore,
782+
VkSemaphore signal_semaphore,
783+
VkFence fence) {
784+
if (cmd_buf) {
785+
cmd_buf.end();
786+
context_->adapter_ptr()->submit_cmd(
787+
context_->queue(),
788+
cmd_buf.get_submit_handle(false),
789+
fence,
790+
wait_semaphore,
791+
signal_semaphore);
792+
}
793+
}
794+
795+
void ComputeGraph::submit_deferred_cmds_and_wait() {
796+
VkSemaphore prev_semaphore = VK_NULL_HANDLE;
797+
vkapi::VulkanFence fence = context_->fences().get_fence();
798+
799+
for (uint32_t i = 0; i < deferred_cmd_list_.size(); i++) {
800+
auto& cmd = deferred_cmd_list_[i];
801+
VkSemaphore wait_semaphore = prev_semaphore;
802+
VkSemaphore signal_semaphore = cmd.get_signal_semaphore();
803+
prev_semaphore = signal_semaphore;
804+
805+
submit_cmd(
806+
cmd,
807+
wait_semaphore,
808+
signal_semaphore,
809+
i == (deferred_cmd_list_.size() - 1) ? fence.get_submit_handle()
810+
: VK_NULL_HANDLE);
811+
}
812+
fence.wait();
813+
context_->fences().return_fence(fence);
814+
}
815+
816+
void ComputeGraph::clear_deferred_cmds() {
817+
for (auto& cmd : deferred_cmd_list_) {
818+
if (cmd) {
819+
cmd.end();
820+
cmd.invalidate();
821+
}
822+
}
823+
deferred_cmd_list_.clear();
824+
}
825+
778826
void ComputeGraph::prepack() {
779827
int i = 0;
780828
bool submitted = false;
@@ -813,6 +861,7 @@ void ComputeGraph::prepack() {
813861
}
814862

815863
void ComputeGraph::encode_execute() {
864+
clear_deferred_cmds();
816865
context_->flush();
817866
context_->set_cmd(/*reusable = */ true);
818867

@@ -821,13 +870,12 @@ void ComputeGraph::encode_execute() {
821870
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
822871
node->encode(this);
823872
}
873+
874+
deferred_cmd_list_.emplace_back(std::move(context_->extract_cmd()));
824875
}
825876

826877
void ComputeGraph::execute() {
827-
vkapi::VulkanFence fence = context_->fences().get_fence();
828-
context_->submit_cmd_to_gpu(fence.get_submit_handle());
829-
fence.wait();
830-
context_->fences().return_fence(fence);
878+
submit_deferred_cmds_and_wait();
831879
execute_count_++;
832880
}
833881

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ class ComputeGraph final {
193193
// Utility constexpr to express byte quantities
194194
constexpr static size_t MB = 1024 * 1024;
195195

196+
// List of command buffers deferred for submission
197+
std::vector<vkapi::CommandBuffer> deferred_cmd_list_;
198+
196199
protected:
197200
size_t values_in_use_ = 0;
198201
size_t execute_count_ = 0;
@@ -851,6 +854,25 @@ class ComputeGraph final {
851854
*/
852855
void submit_current_cmd_and_wait(const bool final_use = false);
853856

857+
/*
858+
* Submit one command buffer to the GPU.
859+
*/
860+
void submit_cmd(
861+
vkapi::CommandBuffer& cmd_buf,
862+
VkSemaphore wait_semaphore,
863+
VkSemaphore signal_semaphore,
864+
VkFence fence);
865+
866+
/*
867+
* Submits all the commands gathered in deferred_cmd_bufs_ to the GPU.
868+
*/
869+
void submit_deferred_cmds_and_wait();
870+
871+
/*
872+
* Ends and invalidates all deferred commands.
873+
*/
874+
void clear_deferred_cmds();
875+
854876
public:
855877
//
856878
// Graph Prepacking

0 commit comments

Comments
 (0)