|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <chrono> |
| 4 | +#include <iostream> |
| 5 | +#include <stack> |
| 6 | +#include <cassert> |
| 7 | + |
| 8 | +#define STB_IMAGE_WRITE_IMPLEMENTATION |
| 9 | +#include "dx/dxassist.h" |
| 10 | +#include "dx/shader_compiler.h" |
| 11 | +#include "mesh_data.h" |
| 12 | +#include "radeonrays_dx.h" |
| 13 | +#include "stb_image_write.h" |
| 14 | + |
| 15 | +#define CHECK_RR_CALL(x) do {if((x) != RR_SUCCESS) { throw std::runtime_error("Incorrect radeonrays call");}} while(false) |
| 16 | + |
| 17 | +using namespace std::chrono; |
| 18 | + |
| 19 | +class Application |
| 20 | +{ |
| 21 | +public: |
| 22 | + void Run(); |
| 23 | +private: |
| 24 | + DxAssist dxassist_; |
| 25 | +}; |
| 26 | + |
| 27 | +void Application::Run() |
| 28 | +{ |
| 29 | + auto dev = dxassist_.device(); |
| 30 | + |
| 31 | + ComPtr<ID3D12QueryHeap> query_heap; |
| 32 | + D3D12_QUERY_HEAP_DESC query_heap_desc = {}; |
| 33 | + query_heap_desc.Count = 2; |
| 34 | + query_heap_desc.Type = D3D12_QUERY_HEAP_TYPE_TIMESTAMP; |
| 35 | + dev->CreateQueryHeap(&query_heap_desc, IID_PPV_ARGS(&query_heap)); |
| 36 | + |
| 37 | + RRContext context = nullptr; |
| 38 | + CHECK_RR_CALL(rrCreateContextDX(RR_API_VERSION, dxassist_.device(), dxassist_.command_queue(), &context)); |
| 39 | + MeshData mesh_data("scene.obj"); |
| 40 | + |
| 41 | + auto vertex_buffer = |
| 42 | + dxassist_.CreateUploadBuffer(sizeof(float) * mesh_data.positions.size(), mesh_data.positions.data()); |
| 43 | + auto index_buffer = |
| 44 | + dxassist_.CreateUploadBuffer(sizeof(std::uint32_t) * mesh_data.indices.size(), mesh_data.indices.data()); |
| 45 | + auto timestamp_buffer = dxassist_.CreateUAVBuffer(sizeof(std::uint64_t) * 8); |
| 46 | + auto timestamp_readback_buffer = dxassist_.CreateReadBackBuffer(sizeof(std::uint64_t) * 8); |
| 47 | + |
| 48 | + RRDevicePtr vertex_ptr = nullptr; |
| 49 | + RRDevicePtr index_ptr = nullptr; |
| 50 | + rrGetDevicePtrFromD3D12Resource(context, vertex_buffer.Get(), 0, &vertex_ptr); |
| 51 | + rrGetDevicePtrFromD3D12Resource(context, index_buffer.Get(), 0, &index_ptr); |
| 52 | + |
| 53 | + auto triangle_count = static_cast<UINT>(index_buffer->GetDesc().Width) / sizeof(UINT32) / 3; |
| 54 | + |
| 55 | + RRGeometryBuildInput geometry_build_input = {}; |
| 56 | + RRTriangleMeshPrimitive mesh = {}; |
| 57 | + geometry_build_input.triangle_mesh_primitives = &mesh; |
| 58 | + geometry_build_input.primitive_type = RR_PRIMITIVE_TYPE_TRIANGLE_MESH; |
| 59 | + geometry_build_input.triangle_mesh_primitives->vertices = vertex_ptr; |
| 60 | + geometry_build_input.triangle_mesh_primitives->vertex_count = |
| 61 | + static_cast<UINT>(vertex_buffer->GetDesc().Width) / (3 * sizeof(float)); |
| 62 | + |
| 63 | + geometry_build_input.triangle_mesh_primitives->vertex_stride = 3 * sizeof(float); |
| 64 | + geometry_build_input.triangle_mesh_primitives->triangle_indices = index_ptr; |
| 65 | + geometry_build_input.triangle_mesh_primitives->triangle_count = (UINT)triangle_count; |
| 66 | + geometry_build_input.triangle_mesh_primitives->index_type = RR_INDEX_TYPE_UINT32; |
| 67 | + geometry_build_input.primitive_count = 1u; |
| 68 | + |
| 69 | + std::cout << "Triangle count " << triangle_count << "\n"; |
| 70 | + |
| 71 | + RRBuildOptions options; |
| 72 | + options.build_flags = 0u; |
| 73 | + |
| 74 | + RRMemoryRequirements geometry_reqs; |
| 75 | + CHECK_RR_CALL(rrGetGeometryBuildMemoryRequirements(context, &geometry_build_input, &options, &geometry_reqs)); |
| 76 | + |
| 77 | + D3D12_RESOURCE_STATES initialResourceState = D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE; |
| 78 | + |
| 79 | + auto geometry = dxassist_.CreateUAVBuffer(geometry_reqs.result_buffer_size); |
| 80 | + std::cout << "Geometry buffer size: " << geometry_reqs.result_buffer_size / 1000000 << "Mb\n"; |
| 81 | + |
| 82 | + RRDevicePtr geometry_ptr; |
| 83 | + CHECK_RR_CALL(rrGetDevicePtrFromD3D12Resource(context, geometry.Get(), 0, &geometry_ptr)); |
| 84 | + |
| 85 | + auto scratch_buffer = dxassist_.CreateUAVBuffer( |
| 86 | + max(geometry_reqs.temporary_build_buffer_size, geometry_reqs.temporary_update_buffer_size)); |
| 87 | + std::cout << "Scratch buffer size: " |
| 88 | + << max(geometry_reqs.temporary_build_buffer_size, geometry_reqs.temporary_update_buffer_size) / 1000000 |
| 89 | + << "Mb\n"; |
| 90 | + |
| 91 | + RRDevicePtr scratch_ptr = nullptr; |
| 92 | + CHECK_RR_CALL(rrGetDevicePtrFromD3D12Resource(context, scratch_buffer.Get(), 0, &scratch_ptr)); |
| 93 | + |
| 94 | + RRCommandStream command_stream = nullptr; |
| 95 | + CHECK_RR_CALL(rrAllocateCommandStream(context, &command_stream)); |
| 96 | + |
| 97 | + CHECK_RR_CALL(rrCmdBuildGeometry( |
| 98 | + context, RR_BUILD_OPERATION_BUILD, &geometry_build_input, &options, scratch_ptr, geometry_ptr, command_stream)); |
| 99 | + |
| 100 | + RREvent wait_event = nullptr; |
| 101 | + CHECK_RR_CALL(rrSumbitCommandStream(context, command_stream, nullptr, &wait_event)); |
| 102 | + CHECK_RR_CALL(rrWaitEvent(context, wait_event)); |
| 103 | + CHECK_RR_CALL(rrReleaseEvent(context, wait_event)); |
| 104 | + CHECK_RR_CALL(rrReleaseCommandStream(context, command_stream)); |
| 105 | + |
| 106 | + // built-in intersection |
| 107 | + using Ray = RRRay; |
| 108 | + using Hit = RRHit; |
| 109 | + |
| 110 | + constexpr uint32_t kResolution = 2048; |
| 111 | + std::vector<Ray> rays(kResolution * kResolution); |
| 112 | + |
| 113 | + for (int x = 0; x < kResolution; ++x) |
| 114 | + { |
| 115 | + for (int y = 0; y < kResolution; ++y) |
| 116 | + { |
| 117 | + auto i = kResolution * y + x; |
| 118 | + |
| 119 | + rays[i].origin[0] = 0.f; |
| 120 | + rays[i].origin[1] = 15.f; |
| 121 | + rays[i].origin[2] = 0.f; |
| 122 | + |
| 123 | + rays[i].direction[0] = -1.f; |
| 124 | + rays[i].direction[1] = -1.f + (2.f / kResolution) * y; |
| 125 | + rays[i].direction[2] = -1.f + (2.f / kResolution) * x; |
| 126 | + |
| 127 | + rays[i].min_t = 0.001f; |
| 128 | + rays[i].max_t = 100000.f; |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + auto temp_ray_buffer = dxassist_.CreateUploadBuffer(kResolution * kResolution * sizeof(Ray), rays.data()); |
| 133 | + auto ray_buffer = |
| 134 | + dxassist_.CreateUAVBuffer(kResolution * kResolution * sizeof(Ray), D3D12_RESOURCE_STATE_COPY_DEST); |
| 135 | + |
| 136 | + RRDevicePtr rays_ptr = nullptr; |
| 137 | + rrGetDevicePtrFromD3D12Resource(context, ray_buffer.Get(), 0, &rays_ptr); |
| 138 | + |
| 139 | + auto temp_hit_buffer = dxassist_.CreateReadBackBuffer(kResolution * kResolution * sizeof(Hit)); |
| 140 | + auto hit_buffer = dxassist_.CreateUAVBuffer(kResolution * kResolution * sizeof(Hit)); |
| 141 | + |
| 142 | + RRDevicePtr hits_ptr = nullptr; |
| 143 | + rrGetDevicePtrFromD3D12Resource(context, hit_buffer.Get(), 0, &hits_ptr); |
| 144 | + |
| 145 | + // Copy ray data. |
| 146 | + auto command_allocator = dxassist_.CreateCommandAllocator(); |
| 147 | + auto copy_rays_command_list = dxassist_.CreateCommandList(command_allocator.Get()); |
| 148 | + |
| 149 | + D3D12_RESOURCE_BARRIER barrier; |
| 150 | + barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION; |
| 151 | + barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE; |
| 152 | + barrier.Transition.pResource = ray_buffer.Get(); |
| 153 | + barrier.Transition.Subresource = 0; |
| 154 | + barrier.Transition.StateBefore = D3D12_RESOURCE_STATE_COPY_DEST; |
| 155 | + barrier.Transition.StateAfter = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; |
| 156 | + copy_rays_command_list->CopyBufferRegion( |
| 157 | + ray_buffer.Get(), 0, temp_ray_buffer.Get(), 0, kResolution * kResolution * sizeof(Ray)); |
| 158 | + copy_rays_command_list->ResourceBarrier(1, &barrier); |
| 159 | + copy_rays_command_list->Close(); |
| 160 | + |
| 161 | + auto copy_hits_command_list = dxassist_.CreateCommandList(command_allocator.Get()); |
| 162 | + barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION; |
| 163 | + barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE; |
| 164 | + barrier.Transition.pResource = hit_buffer.Get(); |
| 165 | + barrier.Transition.Subresource = 0; |
| 166 | + barrier.Transition.StateBefore = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; |
| 167 | + barrier.Transition.StateAfter = D3D12_RESOURCE_STATE_COPY_SOURCE; |
| 168 | + |
| 169 | + copy_hits_command_list->ResourceBarrier(1, &barrier); |
| 170 | + copy_hits_command_list->CopyBufferRegion( |
| 171 | + temp_hit_buffer.Get(), 0, hit_buffer.Get(), 0, kResolution * kResolution * sizeof(Hit)); |
| 172 | + copy_hits_command_list->Close(); |
| 173 | + |
| 174 | + ID3D12CommandList* lists[] = {copy_rays_command_list.Get()}; |
| 175 | + dxassist_.command_queue()->ExecuteCommandLists(1, lists); |
| 176 | + auto ray_copy_fence = dxassist_.CreateFence(); |
| 177 | + dxassist_.command_queue()->Signal(ray_copy_fence.Get(), 2000); |
| 178 | + while (ray_copy_fence->GetCompletedValue() != 2000) Sleep(0); |
| 179 | + |
| 180 | + RRCommandStream trace_command_stream = nullptr; |
| 181 | + CHECK_RR_CALL(rrAllocateCommandStream(context, &trace_command_stream)); |
| 182 | + |
| 183 | + size_t scratch_size = 0; |
| 184 | + CHECK_RR_CALL(rrGetTraceMemoryRequirements(context, kResolution * kResolution, &scratch_size)); |
| 185 | + auto scratch_trace_buffer = dxassist_.CreateUAVBuffer(scratch_size); |
| 186 | + RRDevicePtr scratch_trace_ptr = nullptr; |
| 187 | + rrGetDevicePtrFromD3D12Resource(context, scratch_trace_buffer.Get(), 0, &scratch_trace_ptr); |
| 188 | + |
| 189 | + CHECK_RR_CALL(rrCmdIntersect(context, |
| 190 | + geometry_ptr, |
| 191 | + RR_INTERSECT_QUERY_CLOSEST, |
| 192 | + rays_ptr, |
| 193 | + kResolution * kResolution, |
| 194 | + nullptr, |
| 195 | + RR_INTERSECT_QUERY_OUTPUT_FULL_HIT, |
| 196 | + hits_ptr, |
| 197 | + scratch_trace_ptr, |
| 198 | + trace_command_stream)); |
| 199 | + |
| 200 | + CHECK_RR_CALL(rrSumbitCommandStream(context, trace_command_stream, nullptr, &wait_event)); |
| 201 | + CHECK_RR_CALL(rrWaitEvent(context, wait_event)); |
| 202 | + |
| 203 | + CHECK_RR_CALL(rrReleaseEvent(context, wait_event)); |
| 204 | + CHECK_RR_CALL(rrReleaseCommandStream(context, trace_command_stream)); |
| 205 | + |
| 206 | + ID3D12CommandList* lists1[] = {copy_hits_command_list.Get()}; |
| 207 | + dxassist_.command_queue()->ExecuteCommandLists(1, lists1); |
| 208 | + auto hit_copy_fence = dxassist_.CreateFence(); |
| 209 | + dxassist_.command_queue()->Signal(hit_copy_fence.Get(), 3000); |
| 210 | + while (hit_copy_fence->GetCompletedValue() != 3000) Sleep(0); |
| 211 | + |
| 212 | + { |
| 213 | + Hit* mapped_ptr; |
| 214 | + temp_hit_buffer->Map(0, nullptr, (void**)&mapped_ptr); |
| 215 | + |
| 216 | + std::vector<uint32_t> data(kResolution * kResolution); |
| 217 | + |
| 218 | + for (int y = 0; y < kResolution; ++y) |
| 219 | + { |
| 220 | + for (int x = 0; x < kResolution; ++x) |
| 221 | + { |
| 222 | + int wi = kResolution * (kResolution - 1 - y) + x; |
| 223 | + int i = kResolution * y + x; |
| 224 | + |
| 225 | + if (mapped_ptr[i].inst_id != ~0u) |
| 226 | + { |
| 227 | + data[wi] = 0xff000000 | (uint32_t(mapped_ptr[i].uv[0] * 255) << 8) | |
| 228 | + (uint32_t(mapped_ptr[i].uv[1] * 255) << 16); |
| 229 | + } else |
| 230 | + { |
| 231 | + data[wi] = 0xff101010; |
| 232 | + } |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + stbi_write_jpg("isect_result.jpg", kResolution, kResolution, 4, data.data(), 120); |
| 237 | + |
| 238 | + temp_hit_buffer->Unmap(0, nullptr); |
| 239 | + } |
| 240 | + |
| 241 | + CHECK_RR_CALL(rrDestroyContext(context)); |
| 242 | +} |
0 commit comments