Skip to content

Instantly share code, notes, and snippets.

@newdigate
Last active December 28, 2021 15:33
Show Gist options
  • Save newdigate/117c9443767570d7c9b8952374b937d3 to your computer and use it in GitHub Desktop.
Save newdigate/117c9443767570d7c9b8952374b937d3 to your computer and use it in GitHub Desktop.
c sharp - c# - remove type inference using roslyn and/or msbuild
using System.Reflection;
using Microsoft.Build.Locator;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.MSBuild;
using Xunit;
namespace type_deinference;
public interface ITypeDeInference {
string? RemoveTypeInference(string source, IEnumerable<MetadataReference> references);
IDictionary<string, string> RemoveTypeInference(
IEnumerable<string> sourceIdentifier,
IEnumerable<MetadataReference> references,
Func<string, string> getSourceFromIdentifier);
IDictionary<string, string> Process(IDictionary<SyntaxTree, CompilationUnitSyntax> trees, Compilation compilation, Func<SyntaxTree, string> fnGetKeyForSyntaxTree);
}
public class TypeDeInference : ITypeDeInference
{
public string? RemoveTypeInference(string source, IEnumerable<MetadataReference> references)
{
IDictionary<string, string> result = RemoveTypeInference(new [] { string.Empty }, references, s => source);
if (result.ContainsKey(string.Empty))
return result[string.Empty];
return null;
}
public IDictionary<string, string> RemoveTypeInference(
IEnumerable<string> sourceIdentifiers,
IEnumerable<MetadataReference> references,
Func<string, string> getSourceFromIdentifier)
{
var sourceCodeByIdentifier = new Dictionary<string, string>();
foreach (string sourceIdentifier in sourceIdentifiers) {
sourceCodeByIdentifier.Add( sourceIdentifier, getSourceFromIdentifier(sourceIdentifier));
}
CSharpCompilation compilation =
CompileCSharp(
sourceCodeByIdentifier.Values,
references,
out IDictionary<SyntaxTree, CompilationUnitSyntax> trees);
Func<SyntaxTree, string> fnKeyForSyntaxTree =
tree =>
sourceCodeByIdentifier
.FirstOrDefault(kvp => kvp.Value == tree.ToString())
.Key;
return Process(trees, compilation, fnKeyForSyntaxTree);
}
public IDictionary<string, string> Process(IDictionary<SyntaxTree, CompilationUnitSyntax> trees, Compilation compilation, Func<SyntaxTree, string> fnGetKeyForSyntaxTree){
var result = new Dictionary<string, string>();
foreach (KeyValuePair<SyntaxTree, CompilationUnitSyntax> tree in trees) {
IEnumerable<VariableDeclarationSyntax> typeInferenceVariableDeclarations = GetVariableDeclarationSyntaxUsingTypeInference(tree.Key);
IEnumerable<ForEachStatementSyntax> typeInferencesForEachStatements = GetForEachSyntaxUsingTypeInference(tree.Key);
IEnumerable<CSharpSyntaxNode> typeInferenceSyntaxes =
typeInferenceVariableDeclarations
.Cast<CSharpSyntaxNode>()
.Union(typeInferencesForEachStatements);
if (typeInferenceSyntaxes.Count() == 0)
continue;
SemanticModel model = compilation.GetSemanticModel(tree.Key);
CompilationUnitSyntax newSyntax =
tree.Value
.ReplaceNodes(
typeInferenceSyntaxes
, (a,b) => {
SyntaxNode node = b;
if (a is VariableDeclarationSyntax variableDeclarationSyntax)
node = ReplaceVariableDeclarationSyntaxWithDeinferedType(model, variableDeclarationSyntax, variableDeclarationSyntax)?? b;
if (a is ForEachStatementSyntax forEachStatementSyntax)
node = ReplaceForeachSyntaxWithDeinferedType(model, forEachStatementSyntax, forEachStatementSyntax)?? b;
return node;
}
);
string newSyntaxString = newSyntax.ToString();
if (tree.Value.ToString() != newSyntaxString) {
string? key = fnGetKeyForSyntaxTree(tree.Key);
result.Add(key, newSyntaxString);
}
}
return result;
}
private CSharpCompilation CompileCSharp(
IEnumerable<string> sourceCodes,
IEnumerable<MetadataReference> references,
out IDictionary<SyntaxTree, CompilationUnitSyntax> roots)
{
var syntaxTrees = new Dictionary<SyntaxTree, CompilationUnitSyntax>();
roots = syntaxTrees;
foreach (string sourceCode in sourceCodes) {
SyntaxTree tree = CSharpSyntaxTree.ParseText(sourceCode);
CompilationUnitSyntax root = tree.GetCompilationUnitRoot();
syntaxTrees[tree] = root;
}
CSharpCompilationOptions cSharpCompilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary);
CSharpCompilation compilation =
CSharpCompilation
.Create(
"assemblyName",
syntaxTrees.Keys,
references,
/* new[] {
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
MetadataReference.CreateFromFile(typeof(System.Linq.Enumerable).Assembly.Location),
MetadataReference.CreateFromFile(typeof(System.Console).Assembly.Location),
MetadataReference.CreateFromFile(coreDir.FullName + Path.DirectorySeparatorChar + "System.Runtime.dll"),
},*/
cSharpCompilationOptions
);
foreach (var d in compilation.GetDiagnostics())
{
Console.WriteLine(CSharpDiagnosticFormatter.Instance.Format(d));
}
return compilation;
}
private SyntaxNode? ReplaceVariableDeclarationSyntaxWithDeinferedType(SemanticModel model, VariableDeclarationSyntax a, VariableDeclarationSyntax b) {
SyntaxNode? expression = a.Variables.FirstOrDefault()?.Initializer?.Value;
if (expression == null)
return null;
Microsoft.CodeAnalysis.TypeInfo typeInfo = model.GetTypeInfo(expression);
if (typeInfo.ConvertedType == null)
return null;
TypeSyntax replacement;
if (typeInfo.ConvertedType is IArrayTypeSymbol arrayTypeSymbol) {
replacement =
SyntaxFactory
.IdentifierName(
SyntaxFactory.Identifier(arrayTypeSymbol.ToDisplayString()))
.WithTriviaFrom(a.Type);
} else {
replacement =
SyntaxFactory
.IdentifierName(
SyntaxFactory.Identifier(typeInfo.ConvertedType.Name))
.WithTriviaFrom(a.Type);
}
return a.WithType(replacement);
}
private SyntaxNode? ReplaceForeachSyntaxWithDeinferedType(SemanticModel model, ForEachStatementSyntax a, ForEachStatementSyntax b) {
Microsoft.CodeAnalysis.TypeInfo typeInfo = model.GetTypeInfo(a.Expression);
INamedTypeSymbol namedTypeSymbol = typeInfo.Type as INamedTypeSymbol;
if (namedTypeSymbol == null)
return a;
string? convertedType = typeInfo.ConvertedType?.ToDisplayString();
if (convertedType == null)
return a;
string? nameSpace = namedTypeSymbol.ContainingNamespace?.ToString();
Assembly? ass = Assembly.LoadWithPartialName(typeInfo.ConvertedType.ContainingAssembly.Name);
string typeArgsString = String.Join(",", namedTypeSymbol.TypeArguments.OfType<INamedTypeSymbol>().Cast<INamedTypeSymbol>().Select( tp => tp.ContainingNamespace + "." +tp.Name));
string translatedTypeName = $"{nameSpace}.{typeInfo.ConvertedType.OriginalDefinition.Name}`{namedTypeSymbol.TypeArguments.Count()}[{typeArgsString}]";
Type? t = Type.GetType(translatedTypeName);
MethodInfo? getEnumerator = t.GetMethod("GetEnumerator");
Type? typeOfEnumerable = getEnumerator?.ReturnType.GetGenericArguments()?.FirstOrDefault();
if (typeOfEnumerable == null)
return a;
TypeSyntax replacement =
SyntaxFactory
.IdentifierName(
SyntaxFactory.Identifier( typeOfEnumerable.Name))
.WithTriviaFrom(a.Type);
return a.WithType(replacement);
}
private IEnumerable<VariableDeclarationSyntax> GetVariableDeclarationSyntaxUsingTypeInference(SyntaxTree tree) {
return
tree
.GetRoot()
.DescendantNodes()
.OfType<VariableDeclarationSyntax>()
.Cast<VariableDeclarationSyntax>()
.Where( variableDeclarationSyntax => variableDeclarationSyntax.Type.IsVar );
}
private IEnumerable<ForEachStatementSyntax> GetForEachSyntaxUsingTypeInference(SyntaxTree tree) {
return
tree
.GetRoot()
.DescendantNodes()
.OfType<ForEachStatementSyntax>()
.Cast<ForEachStatementSyntax>()
.Where( variableDeclarationSyntax => variableDeclarationSyntax.Type.IsVar );
}
}
public class SolutionDeInference {
private readonly ITypeDeInference _typeDeInference;
public SolutionDeInference(ITypeDeInference typeDeInference)
{
_typeDeInference = typeDeInference;
}
public async Task RemoveTypeInferenceFromSolution(string solutionPath) {
EnsureMsBuildRegistration();
using (var workspace = MSBuildWorkspace.Create())
{
workspace.WorkspaceFailed += (sender, workspaceFailedArgs) => Console.WriteLine(workspaceFailedArgs.Diagnostic.Message);
var solution = await workspace.OpenSolutionAsync(solutionPath);
foreach (ProjectId projectId in workspace.CurrentSolution.GetProjectDependencyGraph().GetTopologicallySortedProjects()) {
Project? project = workspace.CurrentSolution.GetProject(projectId);
if (project == null) continue;
await RemoveTypeInferenceFromProject(project);
}
}
}
public async Task<IDictionary<string, string>> RemoveTypeInferenceFromProject(string projectPath) {
EnsureMsBuildRegistration();
using(MSBuildWorkspace workspace = MSBuildWorkspace.Create())
{
workspace.WorkspaceFailed += (sender, workspaceFailedArgs) => Console.WriteLine(workspaceFailedArgs.Diagnostic.Message);
Project project = await workspace.OpenProjectAsync(projectPath);
return await RemoveTypeInferenceFromProject(project);
}
}
public async Task<IDictionary<string, string>> RemoveTypeInferenceFromProject(Project project) {
Dictionary<string, string> empty = new Dictionary<string, string>();
Compilation? compilation = await project.GetCompilationAsync();
if (compilation == null) return empty;
foreach (var d in compilation.GetDiagnostics())
{
Console.WriteLine(CSharpDiagnosticFormatter.Instance.Format(d));
}
IDictionary<SyntaxTree, CompilationUnitSyntax> roots = new Dictionary<SyntaxTree, CompilationUnitSyntax>();
foreach (SyntaxTree tree in compilation?.SyntaxTrees) {
CompilationUnitSyntax root = tree.GetCompilationUnitRoot();
roots[tree] = root;
}
return
_typeDeInference
.Process(roots, compilation, tree => tree.FilePath);
}
private void EnsureMsBuildRegistration() {
if (!MSBuildLocator.IsRegistered)
MSBuildLocator.RegisterDefaults();
}
}
public class TestTypeDeInference {
private readonly ITypeDeInference typeDeInference = new TypeDeInference();
private readonly IEnumerable<MetadataReference> defaultReferences;
public TestTypeDeInference() {
var dd = typeof(Enumerable).GetTypeInfo().Assembly.Location;
DirectoryInfo? coreDir = Directory.GetParent(dd);
defaultReferences =
new[] {
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
MetadataReference.CreateFromFile(typeof(System.Linq.Enumerable).Assembly.Location),
MetadataReference.CreateFromFile(typeof(System.Console).Assembly.Location),
MetadataReference.CreateFromFile(coreDir.FullName + Path.DirectorySeparatorChar + "System.Runtime.dll"),
};
}
[Fact]
public void TestDeInferVariableDeclarationStatement() {
const string source = @"public class NumberWang {
public void Wang() {
var x = 1;
var z = (x == 1)? 10 : 100;
var t = GetType();
var t2 = new int[] {0, 1, 2, 3, 4, 5};
}
}";
const string expectedResult = @"public class NumberWang {
public void Wang() {
Int32 x = 1;
Int32 z = (x == 1)? 10 : 100;
Type t = GetType();
int[] t2 = new int[] {0, 1, 2, 3, 4, 5};
}
}";
string? deinferedSource = typeDeInference.RemoveTypeInference(source, defaultReferences);
Assert.Equal(expectedResult, deinferedSource);
}
[Fact]
public void TestDeInferForEachStatement() {
const string source = @"using System.Linq;
public class T {
public void M() {
foreach(var s in Enumerable.Range(1, 10)) {
System.Console.WriteLine(s);
}
}
}";
const string expectedResult = @"using System.Linq;
public class T {
public void M() {
foreach(Int32 s in Enumerable.Range(1, 10)) {
System.Console.WriteLine(s);
}
}
}";
string? deinferedSource = typeDeInference.RemoveTypeInference(source, defaultReferences);
Assert.Equal(expectedResult, deinferedSource);
}
[Fact]
public void TestDeinferMultipleCodeFiles() {
const string source1 = @"public class NumberWang {
public void Wang() {
var x = 1;
var z = (x == 1)? 10 : 100;
var t = GetType();
var t2 = new int[] {0, 1, 2, 3, 4, 5};
}
}";
const string source2 = @"public class NumberWong {
public void Wong() {
var x = 1;
var z = (x == 1)? 10 : 100;
var t = GetType();
var t2 = new int[] {0, 1, 2, 3, 4, 5};
}
}";
const string expectedSource1Result = @"public class NumberWang {
public void Wang() {
Int32 x = 1;
Int32 z = (x == 1)? 10 : 100;
Type t = GetType();
int[] t2 = new int[] {0, 1, 2, 3, 4, 5};
}
}";
const string expectedSource2Result = @"public class NumberWong {
public void Wong() {
Int32 x = 1;
Int32 z = (x == 1)? 10 : 100;
Type t = GetType();
int[] t2 = new int[] {0, 1, 2, 3, 4, 5};
}
}";
Dictionary<string, string> sourceByIdentifier = new Dictionary<string, string>() {{"a",source1}, {"b", source2}};
IDictionary<string, string> results = typeDeInference.RemoveTypeInference(sourceByIdentifier.Keys, defaultReferences, ident => sourceByIdentifier[ident]);
Assert.Equal(results["a"], expectedSource1Result);
Assert.Equal(results["b"], expectedSource2Result);
}
[Fact]
public async Task TestDeinferProject() {
const string projectPath = "/Users/nicholasnewdigate/Development/github/DelegateDecompiler/src/DelegateDecompiler/DelegateDecompiler.csproj";
SolutionDeInference solutionDeInference = new SolutionDeInference(new TypeDeInference());
try {
await solutionDeInference.RemoveTypeInferenceFromProject(projectPath);
}
catch (Exception exc) {
Console.Write(exc.Message);
}
}
[Fact]
public async Task TestDeinferSolution() {
const string projectPath = "/Users/nicholasnewdigate/Development/github/newdigate/entity-lang/lib1/lib1.csproj";
SolutionDeInference solutionDeInference = new SolutionDeInference(new TypeDeInference());
try {
IDictionary<string, string> results = await solutionDeInference.RemoveTypeInferenceFromProject(projectPath);
}
catch (Exception exc) {
Console.WriteLine(exc.Message);
}
}
}
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<RootNamespace>type_deinference</RootNamespace>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Build" Version="17.0.0" ExcludeAssets="runtime" />
<PackageReference Include="Microsoft.Build.Framework" Version="17.0.0" ExcludeAssets="runtime" />
<PackageReference Include="Microsoft.Build.Utilities.Core" Version="17.0.0" ExcludeAssets="runtime" />
<PackageReference Include="Microsoft.Build.Locator" Version="1.4.1" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.0.1" />
<PackageReference Include="Microsoft.CodeAnalysis.Workspaces.MSBuild" Version="4.0.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0" />
<PackageReference Include="NuGet.Frameworks" Version="6.0.0" />
<PackageReference Include="xunit" Version="2.4.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="3.1.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>
</Project>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment