Skip to content

Instantly share code, notes, and snippets.

@phosphoer
Created April 9, 2026 00:22
Show Gist options
  • Select an option

  • Save phosphoer/c6641e6b8e5351368abfaddc2d023bfc to your computer and use it in GitHub Desktop.

Select an option

Save phosphoer/c6641e6b8e5351368abfaddc2d023bfc to your computer and use it in GitHub Desktop.
Compute Shader Grass Tutorial
#pragma kernel Scatter
struct InstanceData
{
float3 Position;
float3 Normal;
float Scale;
float Rotation;
};
AppendStructuredBuffer<InstanceData> _Instances;
StructuredBuffer<float3> _TerrainVertices;
StructuredBuffer<float3> _TerrainNormals;
StructuredBuffer<uint> _TerrainIndices;
uint _TriangleCount;
float4x4 _TerrainMatrix;
uint Hash(uint v)
{
v ^= 2747636419u;
v *= 2654435769u;
v ^= v >> 16;
v *= 2654435769u;
v ^= v >> 16;
v *= 2654435769u;
return v;
}
float Random(uint seed)
{
return float(Hash(seed)) / 4294967295.0; // 2^32-1
}
float RandomRange(uint seed, float min, float max)
{
return Random(seed) * (max - min) + min;
}
[numthreads(64,1,1)]
void Scatter(uint id : SV_DispatchThreadID)
{
uint triIndex = id % _TriangleCount;
uint i0 = _TerrainIndices[triIndex * 3 + 0];
uint i1 = _TerrainIndices[triIndex * 3 + 1];
uint i2 = _TerrainIndices[triIndex * 3 + 2];
float3 n0 = _TerrainNormals[i0];
float3 n1 = _TerrainNormals[i1];
float3 n2 = _TerrainNormals[i2];
float3 v0 = _TerrainVertices[i0];
float3 v1 = _TerrainVertices[i1];
float3 v2 = _TerrainVertices[i2];
uint seed = triIndex ^ id;
float r1 = Random(seed * 13);
float r2 = Random(seed * 47);
float sqrtR1 = sqrt(r1);
float randA = 1 - sqrtR1;
float randB = sqrtR1 * (1 - r2);
float randC = sqrtR1 * r2;
float3 pos = v0 * randA + v1 * randB + v2 * randC;
float3 normal = normalize(n0 * randA + n1 * randB + n2 * randC);
float3 worldPos = mul(_TerrainMatrix, float4(pos, 1)).xyz;
float3 worldNormal = normalize(mul((float3x3)_TerrainMatrix, normal));
InstanceData data;
data.Position = worldPos;
data.Normal = worldNormal;
data.Scale = RandomRange(seed, 0.25, 0.5);
data.Rotation = Random(seed) * 6.283;
_Instances.Append(data);
}
using UnityEngine;
public class ComputeGrass : MonoBehaviour
{
[SerializeField] private ComputeShader _computeShader = null;
[SerializeField] private MeshFilter _terrainMesh = null;
[SerializeField] private Mesh _grassMesh = null;
[SerializeField] private Material _grassMaterial = null;
private ComputeBuffer _drawArgsBuffer;
private ComputeBuffer _instanceBuffer;
private ComputeBuffer _terrainVertexBuffer;
private ComputeBuffer _terrainIndexBuffer;
private ComputeBuffer _terrainNormalBuffer;
private MaterialPropertyBlock _materialProps;
private uint[] _drawArgs;
private int _triangleCount;
private int _kernelScatterId;
private uint _threadCountScatterX;
private const int kMaxInstanceCount = 1_000_000;
private const int kInstancesPerTriangle = 100;
private void Start()
{
_drawArgs = new uint[5]
{
_grassMesh.GetIndexCount(0),
0,
_grassMesh.GetIndexStart(0),
_grassMesh.GetBaseVertex(0),
0
};
_drawArgsBuffer = new(5, sizeof(uint), ComputeBufferType.IndirectArguments);
_drawArgsBuffer.SetData(_drawArgs);
Mesh mesh = _terrainMesh.sharedMesh;
Vector3[] verts = mesh.vertices;
Vector3[] normals = mesh.normals;
int[] indices = mesh.triangles;
_triangleCount = indices.Length / 3;
_instanceBuffer = new ComputeBuffer(kMaxInstanceCount, sizeof(float) * 8, ComputeBufferType.Append);
_terrainVertexBuffer = new ComputeBuffer(verts.Length, sizeof(float) * 3);
_terrainNormalBuffer = new ComputeBuffer(normals.Length, sizeof(float) * 3);
_terrainIndexBuffer = new ComputeBuffer(indices.Length, sizeof(int));
_terrainVertexBuffer.SetData(verts);
_terrainIndexBuffer.SetData(indices);
_terrainNormalBuffer.SetData(normals);
_kernelScatterId = _computeShader.FindKernel("Scatter");
_computeShader.GetKernelThreadGroupSizes(_kernelScatterId, out _threadCountScatterX, out uint _, out uint _);
_materialProps = new();
}
private void Update()
{
_instanceBuffer.SetCounterValue(0);
Bounds drawBounds = new Bounds(Vector3.zero, Camera.main.transform.position * 2);
// Set base params
_computeShader.SetBuffer(_kernelScatterId, "_Instances", _instanceBuffer);
_computeShader.SetBuffer(_kernelScatterId, "_TerrainVertices", _terrainVertexBuffer);
_computeShader.SetBuffer(_kernelScatterId, "_TerrainIndices", _terrainIndexBuffer);
_computeShader.SetBuffer(_kernelScatterId, "_TerrainNormals", _terrainNormalBuffer);
_computeShader.SetMatrix("_TerrainMatrix", _terrainMesh.transform.localToWorldMatrix);
_computeShader.SetInt("_TriangleCount", _triangleCount);
_computeShader.SetInt("_InstancesPerTriangle", kInstancesPerTriangle);
int candidateCount = _triangleCount * kInstancesPerTriangle;
int groupsScatter = Mathf.CeilToInt(candidateCount / (float)_threadCountScatterX);
_computeShader.Dispatch(_kernelScatterId, groupsScatter, 1, 1);
ComputeBuffer.CopyCount(_instanceBuffer, _drawArgsBuffer, sizeof(uint));
_materialProps.SetBuffer("_Instances", _instanceBuffer);
Graphics.DrawMeshInstancedIndirect(_grassMesh, 0, _grassMaterial, drawBounds, _drawArgsBuffer, 0, _materialProps);
}
}
struct InstanceData
{
float3 Position;
float3 Normal;
float Scale;
float Rotation;
};
StructuredBuffer<InstanceData> _Instances;
void ApplyInstanceData_float(uint instanceId, float3 positionOS, float normalOS, out float3 instancePosition, out float3 instanceNormal)
{
InstanceData data = _Instances[instanceId];
// Offset the quad by half a unit so its base is on the ground
positionOS.y += 0.5;
// Create a rotation matrix to rotate the grass based on the normal
float3 up = data.Normal;
float3 forward = float3(sin(data.Rotation), 0, cos(data.Rotation));
float3 right = normalize(cross(up, forward));
forward = cross(right, up);
float3x3 rotMatrix = float3x3(right, up, forward);
float3 rotatedVert = mul(positionOS * data.Scale, rotMatrix);
instancePosition = rotatedVert + data.Position;
instanceNormal = data.Normal;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment