Last active
February 7, 2023 22:34
-
-
Save SabinT/deed2a263c3fe4f99fdaa39bdb19bcc2 to your computer and use it in GitHub Desktop.
Unity3D: Generic template to pass options to a compute shader and dispatch; avoids common boilerplate of Shader.PropertyToID and shader.SetInt etc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using EasyButtons; | |
namespace Lumic.Compute | |
{ | |
using System; | |
using System.Linq; | |
using System.Reflection; | |
using System.Collections.Generic; | |
using UnityEngine; | |
[Serializable] | |
public class ExampleComputeShaderParams | |
{ | |
[PassToShader(Name = "_Number")] public float Num; | |
[PassToShader] public Vector4 Vec; | |
} | |
public class ExampleComputeShaderDriver : GenericComputeShaderDriver<ExampleComputeShaderParams> | |
{ | |
public void Update() | |
{ | |
this.ApplyPropertiesAndDispatch("TestKernel"); | |
} | |
} | |
public class GenericComputeShaderDriver<T> : MonoBehaviour where T : new() | |
{ | |
public ComputeShader ComputeShader; | |
[Tooltip( | |
"Set this to match the resolution of the buffers divided by 'numthreads' in the compute shader kernels.")] | |
public Vector3Int ThreadGroupSize = Vector3Int.one; | |
public T Options = new T(); | |
/// <summary> | |
/// A dictionary of decorated fields from <see cref="Options"/> that need to be | |
/// passed to the compute shader. The value is the "ShaderId". | |
/// </summary> | |
private Dictionary<FieldInfo, int> fieldsWithShaderIds = null; | |
// Start is called before the first frame update | |
protected virtual void Start() | |
{ | |
if (this.ComputeShader == null) | |
{ | |
Debug.Log("No compute shader assigned!"); | |
return; | |
} | |
this.BuildShaderAttributesMapIfNeeded(); | |
} | |
private void BuildShaderAttributesMapIfNeeded() | |
{ | |
if (this.fieldsWithShaderIds != null) | |
{ | |
return; | |
} | |
// Find the fields inside "Options" that are decorated | |
var fields = this.Options.GetType() | |
.GetFields( | |
BindingFlags.Public | | |
BindingFlags.NonPublic | | |
BindingFlags.Instance) | |
.Where( | |
field => Attribute.IsDefined(field, typeof(PassToShader))); | |
this.fieldsWithShaderIds = new Dictionary<FieldInfo, int>(); | |
foreach (FieldInfo field in fields) | |
{ | |
PassToShader attr = field.GetCustomAttribute<PassToShader>(); | |
string shaderVariableName = !string.IsNullOrWhiteSpace(attr.Name) | |
? attr.Name | |
: field.Name; | |
fieldsWithShaderIds.Add(field, Shader.PropertyToID(shaderVariableName)); | |
} | |
} | |
[Button] | |
public void ApplyProperties(string kernel) | |
{ | |
this.BuildShaderAttributesMapIfNeeded(); | |
if (this.ComputeShader != null) | |
{ | |
int kernelIndex = this.ComputeShader.FindKernel(kernel); | |
foreach (KeyValuePair<FieldInfo, int> pair in fieldsWithShaderIds) | |
{ | |
int nameId = pair.Value; | |
object value = pair.Key.GetValue(this.Options); | |
if (value == null) | |
{ | |
continue; | |
} | |
switch (value) | |
{ | |
case bool b: | |
this.ComputeShader.SetBool(nameId, b); | |
break; | |
case int i: | |
this.ComputeShader.SetInt(nameId, i); | |
break; | |
case float f: | |
this.ComputeShader.SetFloat(nameId, f); | |
break; | |
case Vector2 v: | |
this.ComputeShader.SetVector(nameId, v); | |
break; | |
case Vector2Int v: | |
this.ComputeShader.SetInts(nameId, v.x, v.y); | |
break; | |
case Vector3 v: | |
this.ComputeShader.SetVector(nameId, v); | |
break; | |
case Vector4 v: | |
this.ComputeShader.SetVector(nameId, v); | |
break; | |
case Vector4[] vectors: | |
this.ComputeShader.SetVectorArray(nameId, vectors); | |
break; | |
case Color c: | |
this.ComputeShader.SetVector(nameId, new Vector4(c.r, c.g, c.b, c.a)); | |
break; | |
case Color[] colors: | |
this.ComputeShader.SetVectorArray( | |
nameId, | |
colors.Select(c => new Vector4(c.r, c.g, c.b, c.a)).ToArray()); | |
break; | |
case RenderTexture t: | |
this.ComputeShader.SetTexture(kernelIndex, nameId, t); | |
break; | |
case ComputeBuffer b: | |
this.ComputeShader.SetBuffer(kernelIndex, nameId, b); | |
break; | |
case Texture t: | |
this.ComputeShader.SetTexture(kernelIndex, nameId, t); | |
break; | |
// Add more cases here to support more types | |
default: | |
// Unsupported type | |
Debug.LogError( | |
$"Not passing unsupported type to Compute Shader: " + | |
$"{pair.Key.Name} of type {pair.Key.FieldType}"); | |
break; | |
} | |
} | |
} | |
} | |
[Button] | |
public void AutoCalculateThreadGroups2D( | |
string kernel, | |
int x = 1, | |
int y = 1, | |
int z = 1) | |
{ | |
int ki = this.ComputeShader.FindKernel(kernel); | |
this.ComputeShader.GetKernelThreadGroupSizes(ki, out uint kx, out uint ky, out uint kz); | |
this.ThreadGroupSize = | |
new Vector3Int( | |
x / (int) kx, | |
y / (int) ky, | |
z / (int) kz | |
); | |
} | |
[Button] | |
public void Dispatch(string kernel) | |
{ | |
this.ApplyPropertiesAndDispatch(kernel); | |
} | |
protected void ApplyPropertiesAndDispatch(string kernel, Vector3Int? threadGroupSizeOverride = null) | |
{ | |
this.ApplyProperties(kernel); | |
this.Dispatch(kernel, threadGroupSizeOverride); | |
} | |
protected void Dispatch(string kernel, Vector3Int? threadGroupSizeOverride = null) | |
{ | |
if (this.ComputeShader != null) | |
{ | |
int kernelIndex = this.ComputeShader.FindKernel(kernel); | |
Vector3Int tgSize = threadGroupSizeOverride ?? this.ThreadGroupSize; | |
this.ComputeShader.Dispatch( | |
kernelIndex, | |
tgSize.x, | |
tgSize.y, | |
tgSize.z); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment