Created
August 16, 2024 12:13
-
-
Save ldl19691031/c91af2f11f93649669ee6f3a2f6e49c4 to your computer and use it in GitHub Desktop.
UE DispatchComputeShaderBundle example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "/Engine/Private/Common.ush" | |
struct FShaderBundleNodeRecord | |
{ | |
uint DispatchGridSize : SV_DispatchGrid; | |
}; | |
#include "/Engine/Shared/HLSLReservedSpaces.h" | |
struct FUERootConstants | |
{ | |
uint RecordIndex; | |
uint3 PassData; | |
}; | |
// RecordIndex = Raster Bin | |
// PassData .x = Unused, .y = Unused, .z = Unused | |
ConstantBuffer<FUERootConstants> UERootConstants : UE_HLSL_REGISTER(b, 0, UE_HLSL_SPACE_SHADER_ROOT_CONSTANTS); | |
RWStructuredBuffer<uint> UAV; | |
struct secondNodeInput | |
{ | |
uint entryRecordIndex; | |
}; | |
// -------------------------------------------------------------------------------------------------------------------------------- | |
// firstNode is the entry node, a broadcasting launch node. | |
// | |
// For each entry record, a dispatch grid is spawned with grid size from inputData.gridSize. | |
// | |
// Grid size can also be fixed for the node instead of being part of the input record, | |
// using [NodeDispatchGrid(x,y,z)] | |
// | |
// Each thread group sends 2 records to secondNode asking it to do some work. | |
// -------------------------------------------------------------------------------------------------------------------------------- | |
[Shader("node")] | |
[NodeLaunch("broadcasting")] | |
[NodeMaxDispatchGrid(16, 1, 1)] // Contrived value, input records from the app only top out at grid size of 4. | |
// This declaration should be as accurate as possible, but not too small (undefined behavior). | |
[NumThreads(1, 1, 1)] | |
[NodeMaxRecursionDepth(3)] | |
void firstNode( | |
DispatchNodeInputRecord< FShaderBundleNodeRecord > inputData, | |
uint threadIndex : SV_GroupIndex, | |
uint dispatchThreadID : SV_DispatchThreadID, | |
[MaxRecords(8)] NodeOutput< FShaderBundleNodeRecord > firstNode | |
) | |
{ | |
InterlockedAdd(UAV[threadIndex], 1); | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FWorkGraphTestShader : public FGlobalShader | |
{ | |
DECLARE_GLOBAL_SHADER(FWorkGraphTestShader); | |
SHADER_USE_PARAMETER_STRUCT(FWorkGraphTestShader, FGlobalShader); | |
BEGIN_SHADER_PARAMETER_STRUCT(FParameters, ) | |
SHADER_PARAMETER(FUintVector4, PassData) | |
SHADER_PARAMETER_RDG_BUFFER_UAV(RWBuffer<uint>, UAV) | |
RDG_BUFFER_ACCESS(IndirectArgs, ERHIAccess::IndirectArgs) | |
END_SHADER_PARAMETER_STRUCT() | |
public: | |
static bool ShouldCompilePermutation(const FGlobalShaderPermutationParameters& Parameters) | |
{ | |
return FGenericDataDrivenShaderPlatformInfo::GetSupportsWorkGraphs(Parameters.Platform); | |
} | |
}; | |
IMPLEMENT_GLOBAL_SHADER(FWorkGraphTestShader, "/Plugin/Runtime/AddMeshPassPlugin/WorkGraphExampleShader.usf", "firstNode", SF_WorkGraphComputeNode); | |
void FWorkGraphTest::PostRenderBasePassDeferred_RenderThread(FRDGBuilder& GraphBuilder, FSceneView& View, const FRenderTargetBindingSlots& RenderTargets, TRDGUniformBufferRef<FSceneTextureUniformParameters> SceneTextures) | |
{ | |
//PostRenderBasePassDeferred_RenderThread2(GraphBuilder, View, RenderTargets, SceneTextures); | |
//TODO | |
const uint32 ARG_COUNT = 8u; | |
const uint32 NUM_RECORDS = 1u; | |
FShaderBundleCreateInfo BundleCreateInfo; | |
BundleCreateInfo.ArgOffset = 0u; | |
BundleCreateInfo.ArgStride = ARG_COUNT * 4u; | |
BundleCreateInfo.NumRecords = NUM_RECORDS; | |
BundleCreateInfo.Mode = ERHIShaderBundleMode::CS; | |
FShaderBundleRHIRef ShaderBundle = RHICreateShaderBundle(BundleCreateInfo); | |
FWorkGraphTestShader::FParameters* passParameters = GraphBuilder.AllocParameters<FWorkGraphTestShader::FParameters>(); | |
FRDGBufferRef Buffer = GraphBuilder.CreateBuffer(FRDGBufferDesc::CreateStructuredDesc(sizeof(uint32), 10), TEXT("TransientBuffer"), ERDGBufferFlags::None); | |
FRDGBufferUAVRef BufferUAV = GraphBuilder.CreateUAV(Buffer, ERDGUnorderedAccessViewFlags::SkipBarrier); | |
passParameters->UAV = BufferUAV; | |
passParameters->IndirectArgs = GraphBuilder.CreateBuffer(FRDGBufferDesc::CreateIndirectDesc<FRHIDispatchIndirectParameters>((int32)65535), TEXT("WorkGraphTest.IndirectArgs")); | |
GraphBuilder.AddPass( | |
RDG_EVENT_NAME("WorkGraphTest"), | |
passParameters, | |
ERDGPassFlags::Compute | ERDGPassFlags::NeverCull, | |
[ShaderBundle, passParameters](FRHIComputeCommandList& RHICmdList) | |
{ | |
RHICmdList.DispatchComputeShaderBundle( | |
[ShaderBundle, passParameters, &RHICmdList](FRHICommandDispatchComputeShaderBundle& Command) | |
{ | |
Command.ShaderBundle = ShaderBundle; | |
Command.bEmulated = false; | |
Command.RecordArgBuffer = passParameters->IndirectArgs->GetIndirectRHICallBuffer(); //Why? | |
Command.Dispatches.SetNum(ShaderBundle->NumRecords); | |
FRHIBatchedShaderParametersAllocator& ScratchAllocator = RHICmdList.GetScratchShaderParameters().Allocator; | |
int32 RecordIndex = 0; | |
for (FRHIShaderBundleComputeDispatch& Dispatch : Command.Dispatches) | |
{ | |
Dispatch.RecordIndex = RecordIndex; | |
Dispatch.Constants = passParameters->PassData; //Constant buffer for each different dispatch? | |
TShaderMapRef<FWorkGraphTestShader> ComputeShader(GetGlobalShaderMap(GMaxRHIFeatureLevel)); | |
Dispatch.Shader = nullptr; | |
Dispatch.WorkGraphShader = ComputeShader.GetWorkGraphShader(); | |
Dispatch.Parameters.Emplace(ScratchAllocator); | |
SetShaderParameters | |
( | |
*Dispatch.Parameters, | |
ComputeShader, | |
*passParameters | |
); | |
Dispatch.Parameters->Finish(); | |
RecordIndex++; | |
} | |
} | |
); | |
} | |
); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment