Created
March 8, 2023 02:13
-
-
Save mttchpmn/6a5d923afa1cdb2d403fc9655398d48a to your computer and use it in GitHub Desktop.
Unit Test Src Generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
using GenerationAssembly; | |
using Microsoft.CodeAnalysis; | |
using Microsoft.CodeAnalysis.CSharp.Syntax; | |
using Microsoft.CodeAnalysis.Text; | |
namespace CDD.Unit.Tests.SourceGenerators | |
{ | |
[Generator] | |
public class PartialTestClassGenerator : ISourceGenerator | |
{ | |
public void Initialize(GeneratorInitializationContext context) | |
{ | |
context.RegisterForSyntaxNotifications(() => new TestSubjectSyntaxReceiver()); | |
} | |
public void Execute(GeneratorExecutionContext context) | |
{ | |
var receiver = context.SyntaxReceiver as TestSubjectSyntaxReceiver; | |
var fields = receiver?.FieldDeclarations; | |
if (fields is null) | |
return; | |
foreach (var field in fields) | |
{ | |
GeneratePartialTestClass(context, field); | |
} | |
} | |
/// <summary> | |
/// Generates a partial test class, instantiating the test subject, | |
/// and generating Mocks for the required constructor parameters. | |
/// <br /><br /> | |
/// Each Mock that is instantiated will also have `SetupXXX()` and | |
/// `VerifyXXX()` methods generated for every method defined on the interface. | |
/// <br /><br /> | |
/// For any test related setup, please define a method called 'Initialize' with the following signature: | |
/// <code>partial void Initialize()</code> | |
/// The generated test class will create a parameterless constructor which | |
/// then calls this Initialize method. Note - creating an Initialize method is optional. | |
/// </summary> | |
private void GeneratePartialTestClass(GeneratorExecutionContext context, FieldDeclarationSyntax field) | |
{ | |
var semanticModel = GetSemanticModel(context, field); | |
var testSubjectTypeSymbol = GetTestSubjectTypeSymbol(semanticModel, field); | |
var testSubjectVariableDeclaration = GetTestSubjectVariableDeclaration(semanticModel, field); | |
var testSubjectConstructorParameters = GetTestSubjectConstructorParameters(testSubjectTypeSymbol).ToList(); | |
// Generate text components required for partial class | |
var namespaceName = GetNamespaceName(testSubjectVariableDeclaration); | |
var usingStatements = GetUsingStatements(testSubjectTypeSymbol, testSubjectConstructorParameters); | |
var className = GetClassName(testSubjectVariableDeclaration); | |
var fieldDeclarations = GetFieldDeclarations(testSubjectConstructorParameters); | |
var constructorInstantiation = GetConstructorInstantiation(testSubjectVariableDeclaration, testSubjectTypeSymbol, testSubjectConstructorParameters); | |
var helperMethods = GetHelperMethods(testSubjectConstructorParameters); | |
// Generate partial class | |
var sourceText = GenerateSourceText( | |
usingStatements, | |
namespaceName, | |
className, | |
fieldDeclarations, | |
constructorInstantiation, | |
helperMethods); | |
context.AddSource($"{className}.generated.cs", sourceText); | |
} | |
private IEnumerable<IParameterSymbol> GetTestSubjectConstructorParameters(INamedTypeSymbol testSubjectVariableDeclaration) | |
{ | |
var constructors = testSubjectVariableDeclaration.Constructors; | |
if (!constructors.Any()) | |
throw new InvalidOperationException("Unable to obtain constructor for test subject. Ensure you are using a concrete type and not an interface"); | |
if (constructors.Length > 1) | |
throw new InvalidOperationException("Encountered more than one constructor for test subject"); | |
return constructors.Single().Parameters.ToList(); | |
} | |
private SemanticModel GetSemanticModel(GeneratorExecutionContext context, FieldDeclarationSyntax field) | |
=> context.Compilation.GetSemanticModel(field.Declaration.Type.SyntaxTree); | |
private INamedTypeSymbol GetTestSubjectTypeSymbol(SemanticModel semanticModel, FieldDeclarationSyntax field) | |
=> semanticModel.GetTypeInfo(field.Declaration.Type).Type as INamedTypeSymbol ?? throw new InvalidOperationException("Unable to obtain type symbol for test subject"); | |
private ISymbol GetTestSubjectVariableDeclaration(SemanticModel semanticModel, FieldDeclarationSyntax field) | |
{ | |
if (field.Declaration.Variables.Count > 1) | |
throw new InvalidOperationException("Encountered more than one variable for field declaration"); | |
var result = semanticModel.GetDeclaredSymbol(field.Declaration.Variables.First()); | |
if (result is null) | |
throw new InvalidOperationException("Unable to obtain test subject variable declaration"); | |
return result; | |
} | |
private string GetUsingStatements(ISymbol testSubjectSymbol, List<IParameterSymbol> testSubjectConstructorParameters) | |
{ | |
var namespaceForType = $"using {testSubjectSymbol.ContainingNamespace.ToDisplayString()};"; | |
var namespaces = testSubjectConstructorParameters.Select(GetAssemblyForParameter).ToList(); | |
namespaces.Add(namespaceForType); | |
return string.Join("\n", namespaces.Distinct()); | |
} | |
private string GetAssemblyForParameter(IParameterSymbol parameter) | |
{ | |
var namespc = parameter.Type.ContainingNamespace; | |
return $"using {namespc.ToDisplayString()};"; | |
} | |
private string GetNamespaceName(ISymbol declaration) | |
=> declaration.ContainingNamespace.ToDisplayString(); | |
private string GetClassName(ISymbol testSubjectVariableDeclaration) | |
=> testSubjectVariableDeclaration.ContainingType.Name; | |
private string GetFieldDeclarations(IEnumerable<IParameterSymbol> testSubjectConstructorParameters) | |
{ | |
var fields = testSubjectConstructorParameters.Select( | |
x => | |
{ | |
var mockType = x.Type as INamedTypeSymbol; | |
var genericType = mockType?.TypeArguments.FirstOrDefault(); | |
var genericTypeText = genericType != null | |
? $"<{genericType}>" | |
: ""; | |
return $"private Mock<{x.Type.Name}{genericTypeText}> {GetFieldName(x.Type.Name)} = new();"; | |
}); | |
return string.Join("\n\t", fields); | |
} | |
private string GetFieldName(string parameterName) | |
=> "_" + parameterName[1].ToString().ToLower() + parameterName.Substring(2); | |
private string GetConstructorInstantiation(ISymbol testSubjectVariableDeclaration, INamedTypeSymbol testSubjectTypeSymbol, IEnumerable<IParameterSymbol> testSubjectConstructorParameters) | |
{ | |
var parameters = testSubjectConstructorParameters.Select(x => $"{GetFieldName(x.Type.Name)}.Object"); | |
var parameterList = string.Join(", ", parameters); | |
return $"{testSubjectVariableDeclaration.Name} = new {testSubjectTypeSymbol.Name}({parameterList});"; | |
} | |
private string GetHelperMethods(IEnumerable<IParameterSymbol> testSubjectConstructorParameters) | |
{ | |
var setupMethods = testSubjectConstructorParameters.Select(GenerateHelperMethodsForParameter); | |
return string.Join("\n\n\t", setupMethods); | |
} | |
private string GenerateHelperMethodsForParameter(IParameterSymbol parameter) | |
{ | |
var param = parameter.Type as INamedTypeSymbol; | |
var availableMethods = param?.GetMembers().Select(x => x as IMethodSymbol).Where(x => x != null).ToList(); | |
if (availableMethods is null) | |
throw new InvalidOperationException($"Unable to obtain methods for parameter: {parameter}"); | |
var setupMethods = availableMethods.Where(x => x != null && !x.ReturnType.Name.Equals("Void")).Select(x => GenerateSetupMethod(parameter, x)); | |
var verifyMethods = availableMethods.Select(x => GenerateVerifyMethod(parameter, x)); | |
var setupText = string.Join("\n\n\t", setupMethods); | |
var verifyText = string.Join("\n\n\t", verifyMethods); | |
if (string.IsNullOrWhiteSpace(setupText) && string.IsNullOrWhiteSpace(verifyText)) | |
return ""; | |
return $"#region {parameter.Type.Name} helper methods:\n\t" + setupText + "\n\n\t" + verifyText + "\n\t#endregion"; | |
} | |
private string GenerateSetupMethod(IParameterSymbol parameter, IMethodSymbol method) | |
{ | |
if (method is null) | |
return ""; | |
var methodParameters = method.Parameters.Select((x, y) => $"It.Is<{x.ToDisplayString()}>(y => param{y + 1} == null || y == param{y + 1})"); | |
var parametersText = string.Join(", ", methodParameters); | |
var nullableParametersText = GetNullableParametersText(method); | |
var fieldName = GetFieldName(parameter.Type.Name); | |
var returnType = method.ReturnType.ToDisplayString(); | |
return $@"private void Setup{method.Name}({returnType} returnValue{nullableParametersText}) | |
{{ | |
{fieldName}.Setup(x => x.{method.Name}({parametersText})).Returns(returnValue); | |
}}"; | |
} | |
private string GenerateVerifyMethod(IParameterSymbol parameter, IMethodSymbol method) | |
{ | |
if (method is null) | |
return ""; | |
var methodParameters = method.Parameters.Select((x, y) => $"param{y + 1} ?? It.IsAny<{x.ToDisplayString()}>()"); | |
var parametersText = string.Join(", ", methodParameters); | |
var nullableParametersText = GetNullableParametersText(method); | |
var fieldName = GetFieldName(parameter.Type.Name); | |
return $@"private void Verify{method.Name}(Times? timesCalled = null{nullableParametersText}) | |
{{ | |
{fieldName}.Verify(x => x.{method.Name}({parametersText}), timesCalled ?? Times.AtLeastOnce()); | |
}}"; | |
} | |
private static string GetNullableParametersText(IMethodSymbol method) | |
{ | |
var nullableParameters = method.Parameters | |
.Select((x, y) => $"{x.ToDisplayString()}{((x.IsOptional || x.NullableAnnotation == NullableAnnotation.Annotated) ? "" : "?")} param{y + 1} = null"); | |
return method.Parameters.Length > 0 | |
? ", " + string.Join(", ", nullableParameters) | |
: string.Empty; | |
} | |
private SourceText GenerateSourceText( | |
string usingStatements, | |
string namespaceName, | |
string className, | |
string fieldDeclarations, | |
string constructorInstantiation, | |
string helperMethods) | |
{ | |
return SourceText.From( | |
$@"// <auto-generated> | |
#pragma warning disable CS8073 | |
#nullable enable | |
using System; | |
using Moq; | |
{usingStatements} | |
namespace {namespaceName}; | |
public partial class {className} | |
{{ | |
{fieldDeclarations} | |
public {className}() | |
{{ | |
{constructorInstantiation} | |
Initialize(); | |
}} | |
partial void Initialize(); | |
{helperMethods} | |
}} | |
#pragma warning restore CS8073 | |
", | |
Encoding.UTF8); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment