Skip to content

Instantly share code, notes, and snippets.

@ldl19691031
Created August 16, 2024 12:13
Show Gist options
  • Save ldl19691031/c91af2f11f93649669ee6f3a2f6e49c4 to your computer and use it in GitHub Desktop.
Save ldl19691031/c91af2f11f93649669ee6f3a2f6e49c4 to your computer and use it in GitHub Desktop.
UE DispatchComputeShaderBundle example
#include "/Engine/Private/Common.ush"
struct FShaderBundleNodeRecord
{
uint DispatchGridSize : SV_DispatchGrid;
};
#include "/Engine/Shared/HLSLReservedSpaces.h"
struct FUERootConstants
{
uint RecordIndex;
uint3 PassData;
};
// RecordIndex = Raster Bin
// PassData .x = Unused, .y = Unused, .z = Unused
ConstantBuffer<FUERootConstants> UERootConstants : UE_HLSL_REGISTER(b, 0, UE_HLSL_SPACE_SHADER_ROOT_CONSTANTS);
RWStructuredBuffer<uint> UAV;
struct secondNodeInput
{
uint entryRecordIndex;
};
// --------------------------------------------------------------------------------------------------------------------------------
// firstNode is the entry node, a broadcasting launch node.
//
// For each entry record, a dispatch grid is spawned with grid size from inputData.gridSize.
//
// Grid size can also be fixed for the node instead of being part of the input record,
// using [NodeDispatchGrid(x,y,z)]
//
// Each thread group sends 2 records to secondNode asking it to do some work.
// --------------------------------------------------------------------------------------------------------------------------------
[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(16, 1, 1)] // Contrived value, input records from the app only top out at grid size of 4.
// This declaration should be as accurate as possible, but not too small (undefined behavior).
[NumThreads(1, 1, 1)]
[NodeMaxRecursionDepth(3)]
void firstNode(
DispatchNodeInputRecord< FShaderBundleNodeRecord > inputData,
uint threadIndex : SV_GroupIndex,
uint dispatchThreadID : SV_DispatchThreadID,
[MaxRecords(8)] NodeOutput< FShaderBundleNodeRecord > firstNode
)
{
InterlockedAdd(UAV[threadIndex], 1);
}
class FWorkGraphTestShader : public FGlobalShader
{
DECLARE_GLOBAL_SHADER(FWorkGraphTestShader);
SHADER_USE_PARAMETER_STRUCT(FWorkGraphTestShader, FGlobalShader);
BEGIN_SHADER_PARAMETER_STRUCT(FParameters, )
SHADER_PARAMETER(FUintVector4, PassData)
SHADER_PARAMETER_RDG_BUFFER_UAV(RWBuffer<uint>, UAV)
RDG_BUFFER_ACCESS(IndirectArgs, ERHIAccess::IndirectArgs)
END_SHADER_PARAMETER_STRUCT()
public:
static bool ShouldCompilePermutation(const FGlobalShaderPermutationParameters& Parameters)
{
return FGenericDataDrivenShaderPlatformInfo::GetSupportsWorkGraphs(Parameters.Platform);
}
};
IMPLEMENT_GLOBAL_SHADER(FWorkGraphTestShader, "/Plugin/Runtime/AddMeshPassPlugin/WorkGraphExampleShader.usf", "firstNode", SF_WorkGraphComputeNode);
void FWorkGraphTest::PostRenderBasePassDeferred_RenderThread(FRDGBuilder& GraphBuilder, FSceneView& View, const FRenderTargetBindingSlots& RenderTargets, TRDGUniformBufferRef<FSceneTextureUniformParameters> SceneTextures)
{
//PostRenderBasePassDeferred_RenderThread2(GraphBuilder, View, RenderTargets, SceneTextures);
//TODO
const uint32 ARG_COUNT = 8u;
const uint32 NUM_RECORDS = 1u;
FShaderBundleCreateInfo BundleCreateInfo;
BundleCreateInfo.ArgOffset = 0u;
BundleCreateInfo.ArgStride = ARG_COUNT * 4u;
BundleCreateInfo.NumRecords = NUM_RECORDS;
BundleCreateInfo.Mode = ERHIShaderBundleMode::CS;
FShaderBundleRHIRef ShaderBundle = RHICreateShaderBundle(BundleCreateInfo);
FWorkGraphTestShader::FParameters* passParameters = GraphBuilder.AllocParameters<FWorkGraphTestShader::FParameters>();
FRDGBufferRef Buffer = GraphBuilder.CreateBuffer(FRDGBufferDesc::CreateStructuredDesc(sizeof(uint32), 10), TEXT("TransientBuffer"), ERDGBufferFlags::None);
FRDGBufferUAVRef BufferUAV = GraphBuilder.CreateUAV(Buffer, ERDGUnorderedAccessViewFlags::SkipBarrier);
passParameters->UAV = BufferUAV;
passParameters->IndirectArgs = GraphBuilder.CreateBuffer(FRDGBufferDesc::CreateIndirectDesc<FRHIDispatchIndirectParameters>((int32)65535), TEXT("WorkGraphTest.IndirectArgs"));
GraphBuilder.AddPass(
RDG_EVENT_NAME("WorkGraphTest"),
passParameters,
ERDGPassFlags::Compute | ERDGPassFlags::NeverCull,
[ShaderBundle, passParameters](FRHIComputeCommandList& RHICmdList)
{
RHICmdList.DispatchComputeShaderBundle(
[ShaderBundle, passParameters, &RHICmdList](FRHICommandDispatchComputeShaderBundle& Command)
{
Command.ShaderBundle = ShaderBundle;
Command.bEmulated = false;
Command.RecordArgBuffer = passParameters->IndirectArgs->GetIndirectRHICallBuffer(); //Why?
Command.Dispatches.SetNum(ShaderBundle->NumRecords);
FRHIBatchedShaderParametersAllocator& ScratchAllocator = RHICmdList.GetScratchShaderParameters().Allocator;
int32 RecordIndex = 0;
for (FRHIShaderBundleComputeDispatch& Dispatch : Command.Dispatches)
{
Dispatch.RecordIndex = RecordIndex;
Dispatch.Constants = passParameters->PassData; //Constant buffer for each different dispatch?
TShaderMapRef<FWorkGraphTestShader> ComputeShader(GetGlobalShaderMap(GMaxRHIFeatureLevel));
Dispatch.Shader = nullptr;
Dispatch.WorkGraphShader = ComputeShader.GetWorkGraphShader();
Dispatch.Parameters.Emplace(ScratchAllocator);
SetShaderParameters
(
*Dispatch.Parameters,
ComputeShader,
*passParameters
);
Dispatch.Parameters->Finish();
RecordIndex++;
}
}
);
}
);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment