Last active
December 28, 2021 15:33
-
-
Save newdigate/117c9443767570d7c9b8952374b937d3 to your computer and use it in GitHub Desktop.
c sharp - c# - remove type inference using roslyn and/or msbuild
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Reflection; | |
using Microsoft.Build.Locator; | |
using Microsoft.CodeAnalysis; | |
using Microsoft.CodeAnalysis.CSharp; | |
using Microsoft.CodeAnalysis.CSharp.Syntax; | |
using Microsoft.CodeAnalysis.MSBuild; | |
using Xunit; | |
namespace type_deinference; | |
public interface ITypeDeInference { | |
string? RemoveTypeInference(string source, IEnumerable<MetadataReference> references); | |
IDictionary<string, string> RemoveTypeInference( | |
IEnumerable<string> sourceIdentifier, | |
IEnumerable<MetadataReference> references, | |
Func<string, string> getSourceFromIdentifier); | |
IDictionary<string, string> Process(IDictionary<SyntaxTree, CompilationUnitSyntax> trees, Compilation compilation, Func<SyntaxTree, string> fnGetKeyForSyntaxTree); | |
} | |
public class TypeDeInference : ITypeDeInference | |
{ | |
public string? RemoveTypeInference(string source, IEnumerable<MetadataReference> references) | |
{ | |
IDictionary<string, string> result = RemoveTypeInference(new [] { string.Empty }, references, s => source); | |
if (result.ContainsKey(string.Empty)) | |
return result[string.Empty]; | |
return null; | |
} | |
public IDictionary<string, string> RemoveTypeInference( | |
IEnumerable<string> sourceIdentifiers, | |
IEnumerable<MetadataReference> references, | |
Func<string, string> getSourceFromIdentifier) | |
{ | |
var sourceCodeByIdentifier = new Dictionary<string, string>(); | |
foreach (string sourceIdentifier in sourceIdentifiers) { | |
sourceCodeByIdentifier.Add( sourceIdentifier, getSourceFromIdentifier(sourceIdentifier)); | |
} | |
CSharpCompilation compilation = | |
CompileCSharp( | |
sourceCodeByIdentifier.Values, | |
references, | |
out IDictionary<SyntaxTree, CompilationUnitSyntax> trees); | |
Func<SyntaxTree, string> fnKeyForSyntaxTree = | |
tree => | |
sourceCodeByIdentifier | |
.FirstOrDefault(kvp => kvp.Value == tree.ToString()) | |
.Key; | |
return Process(trees, compilation, fnKeyForSyntaxTree); | |
} | |
public IDictionary<string, string> Process(IDictionary<SyntaxTree, CompilationUnitSyntax> trees, Compilation compilation, Func<SyntaxTree, string> fnGetKeyForSyntaxTree){ | |
var result = new Dictionary<string, string>(); | |
foreach (KeyValuePair<SyntaxTree, CompilationUnitSyntax> tree in trees) { | |
IEnumerable<VariableDeclarationSyntax> typeInferenceVariableDeclarations = GetVariableDeclarationSyntaxUsingTypeInference(tree.Key); | |
IEnumerable<ForEachStatementSyntax> typeInferencesForEachStatements = GetForEachSyntaxUsingTypeInference(tree.Key); | |
IEnumerable<CSharpSyntaxNode> typeInferenceSyntaxes = | |
typeInferenceVariableDeclarations | |
.Cast<CSharpSyntaxNode>() | |
.Union(typeInferencesForEachStatements); | |
if (typeInferenceSyntaxes.Count() == 0) | |
continue; | |
SemanticModel model = compilation.GetSemanticModel(tree.Key); | |
CompilationUnitSyntax newSyntax = | |
tree.Value | |
.ReplaceNodes( | |
typeInferenceSyntaxes | |
, (a,b) => { | |
SyntaxNode node = b; | |
if (a is VariableDeclarationSyntax variableDeclarationSyntax) | |
node = ReplaceVariableDeclarationSyntaxWithDeinferedType(model, variableDeclarationSyntax, variableDeclarationSyntax)?? b; | |
if (a is ForEachStatementSyntax forEachStatementSyntax) | |
node = ReplaceForeachSyntaxWithDeinferedType(model, forEachStatementSyntax, forEachStatementSyntax)?? b; | |
return node; | |
} | |
); | |
string newSyntaxString = newSyntax.ToString(); | |
if (tree.Value.ToString() != newSyntaxString) { | |
string? key = fnGetKeyForSyntaxTree(tree.Key); | |
result.Add(key, newSyntaxString); | |
} | |
} | |
return result; | |
} | |
private CSharpCompilation CompileCSharp( | |
IEnumerable<string> sourceCodes, | |
IEnumerable<MetadataReference> references, | |
out IDictionary<SyntaxTree, CompilationUnitSyntax> roots) | |
{ | |
var syntaxTrees = new Dictionary<SyntaxTree, CompilationUnitSyntax>(); | |
roots = syntaxTrees; | |
foreach (string sourceCode in sourceCodes) { | |
SyntaxTree tree = CSharpSyntaxTree.ParseText(sourceCode); | |
CompilationUnitSyntax root = tree.GetCompilationUnitRoot(); | |
syntaxTrees[tree] = root; | |
} | |
CSharpCompilationOptions cSharpCompilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary); | |
CSharpCompilation compilation = | |
CSharpCompilation | |
.Create( | |
"assemblyName", | |
syntaxTrees.Keys, | |
references, | |
/* new[] { | |
MetadataReference.CreateFromFile(typeof(object).Assembly.Location), | |
MetadataReference.CreateFromFile(typeof(System.Linq.Enumerable).Assembly.Location), | |
MetadataReference.CreateFromFile(typeof(System.Console).Assembly.Location), | |
MetadataReference.CreateFromFile(coreDir.FullName + Path.DirectorySeparatorChar + "System.Runtime.dll"), | |
},*/ | |
cSharpCompilationOptions | |
); | |
foreach (var d in compilation.GetDiagnostics()) | |
{ | |
Console.WriteLine(CSharpDiagnosticFormatter.Instance.Format(d)); | |
} | |
return compilation; | |
} | |
private SyntaxNode? ReplaceVariableDeclarationSyntaxWithDeinferedType(SemanticModel model, VariableDeclarationSyntax a, VariableDeclarationSyntax b) { | |
SyntaxNode? expression = a.Variables.FirstOrDefault()?.Initializer?.Value; | |
if (expression == null) | |
return null; | |
Microsoft.CodeAnalysis.TypeInfo typeInfo = model.GetTypeInfo(expression); | |
if (typeInfo.ConvertedType == null) | |
return null; | |
TypeSyntax replacement; | |
if (typeInfo.ConvertedType is IArrayTypeSymbol arrayTypeSymbol) { | |
replacement = | |
SyntaxFactory | |
.IdentifierName( | |
SyntaxFactory.Identifier(arrayTypeSymbol.ToDisplayString())) | |
.WithTriviaFrom(a.Type); | |
} else { | |
replacement = | |
SyntaxFactory | |
.IdentifierName( | |
SyntaxFactory.Identifier(typeInfo.ConvertedType.Name)) | |
.WithTriviaFrom(a.Type); | |
} | |
return a.WithType(replacement); | |
} | |
private SyntaxNode? ReplaceForeachSyntaxWithDeinferedType(SemanticModel model, ForEachStatementSyntax a, ForEachStatementSyntax b) { | |
Microsoft.CodeAnalysis.TypeInfo typeInfo = model.GetTypeInfo(a.Expression); | |
INamedTypeSymbol namedTypeSymbol = typeInfo.Type as INamedTypeSymbol; | |
if (namedTypeSymbol == null) | |
return a; | |
string? convertedType = typeInfo.ConvertedType?.ToDisplayString(); | |
if (convertedType == null) | |
return a; | |
string? nameSpace = namedTypeSymbol.ContainingNamespace?.ToString(); | |
Assembly? ass = Assembly.LoadWithPartialName(typeInfo.ConvertedType.ContainingAssembly.Name); | |
string typeArgsString = String.Join(",", namedTypeSymbol.TypeArguments.OfType<INamedTypeSymbol>().Cast<INamedTypeSymbol>().Select( tp => tp.ContainingNamespace + "." +tp.Name)); | |
string translatedTypeName = $"{nameSpace}.{typeInfo.ConvertedType.OriginalDefinition.Name}`{namedTypeSymbol.TypeArguments.Count()}[{typeArgsString}]"; | |
Type? t = Type.GetType(translatedTypeName); | |
MethodInfo? getEnumerator = t.GetMethod("GetEnumerator"); | |
Type? typeOfEnumerable = getEnumerator?.ReturnType.GetGenericArguments()?.FirstOrDefault(); | |
if (typeOfEnumerable == null) | |
return a; | |
TypeSyntax replacement = | |
SyntaxFactory | |
.IdentifierName( | |
SyntaxFactory.Identifier( typeOfEnumerable.Name)) | |
.WithTriviaFrom(a.Type); | |
return a.WithType(replacement); | |
} | |
private IEnumerable<VariableDeclarationSyntax> GetVariableDeclarationSyntaxUsingTypeInference(SyntaxTree tree) { | |
return | |
tree | |
.GetRoot() | |
.DescendantNodes() | |
.OfType<VariableDeclarationSyntax>() | |
.Cast<VariableDeclarationSyntax>() | |
.Where( variableDeclarationSyntax => variableDeclarationSyntax.Type.IsVar ); | |
} | |
private IEnumerable<ForEachStatementSyntax> GetForEachSyntaxUsingTypeInference(SyntaxTree tree) { | |
return | |
tree | |
.GetRoot() | |
.DescendantNodes() | |
.OfType<ForEachStatementSyntax>() | |
.Cast<ForEachStatementSyntax>() | |
.Where( variableDeclarationSyntax => variableDeclarationSyntax.Type.IsVar ); | |
} | |
} | |
public class SolutionDeInference { | |
private readonly ITypeDeInference _typeDeInference; | |
public SolutionDeInference(ITypeDeInference typeDeInference) | |
{ | |
_typeDeInference = typeDeInference; | |
} | |
public async Task RemoveTypeInferenceFromSolution(string solutionPath) { | |
EnsureMsBuildRegistration(); | |
using (var workspace = MSBuildWorkspace.Create()) | |
{ | |
workspace.WorkspaceFailed += (sender, workspaceFailedArgs) => Console.WriteLine(workspaceFailedArgs.Diagnostic.Message); | |
var solution = await workspace.OpenSolutionAsync(solutionPath); | |
foreach (ProjectId projectId in workspace.CurrentSolution.GetProjectDependencyGraph().GetTopologicallySortedProjects()) { | |
Project? project = workspace.CurrentSolution.GetProject(projectId); | |
if (project == null) continue; | |
await RemoveTypeInferenceFromProject(project); | |
} | |
} | |
} | |
public async Task<IDictionary<string, string>> RemoveTypeInferenceFromProject(string projectPath) { | |
EnsureMsBuildRegistration(); | |
using(MSBuildWorkspace workspace = MSBuildWorkspace.Create()) | |
{ | |
workspace.WorkspaceFailed += (sender, workspaceFailedArgs) => Console.WriteLine(workspaceFailedArgs.Diagnostic.Message); | |
Project project = await workspace.OpenProjectAsync(projectPath); | |
return await RemoveTypeInferenceFromProject(project); | |
} | |
} | |
public async Task<IDictionary<string, string>> RemoveTypeInferenceFromProject(Project project) { | |
Dictionary<string, string> empty = new Dictionary<string, string>(); | |
Compilation? compilation = await project.GetCompilationAsync(); | |
if (compilation == null) return empty; | |
foreach (var d in compilation.GetDiagnostics()) | |
{ | |
Console.WriteLine(CSharpDiagnosticFormatter.Instance.Format(d)); | |
} | |
IDictionary<SyntaxTree, CompilationUnitSyntax> roots = new Dictionary<SyntaxTree, CompilationUnitSyntax>(); | |
foreach (SyntaxTree tree in compilation?.SyntaxTrees) { | |
CompilationUnitSyntax root = tree.GetCompilationUnitRoot(); | |
roots[tree] = root; | |
} | |
return | |
_typeDeInference | |
.Process(roots, compilation, tree => tree.FilePath); | |
} | |
private void EnsureMsBuildRegistration() { | |
if (!MSBuildLocator.IsRegistered) | |
MSBuildLocator.RegisterDefaults(); | |
} | |
} | |
public class TestTypeDeInference { | |
private readonly ITypeDeInference typeDeInference = new TypeDeInference(); | |
private readonly IEnumerable<MetadataReference> defaultReferences; | |
public TestTypeDeInference() { | |
var dd = typeof(Enumerable).GetTypeInfo().Assembly.Location; | |
DirectoryInfo? coreDir = Directory.GetParent(dd); | |
defaultReferences = | |
new[] { | |
MetadataReference.CreateFromFile(typeof(object).Assembly.Location), | |
MetadataReference.CreateFromFile(typeof(System.Linq.Enumerable).Assembly.Location), | |
MetadataReference.CreateFromFile(typeof(System.Console).Assembly.Location), | |
MetadataReference.CreateFromFile(coreDir.FullName + Path.DirectorySeparatorChar + "System.Runtime.dll"), | |
}; | |
} | |
[Fact] | |
public void TestDeInferVariableDeclarationStatement() { | |
const string source = @"public class NumberWang { | |
public void Wang() { | |
var x = 1; | |
var z = (x == 1)? 10 : 100; | |
var t = GetType(); | |
var t2 = new int[] {0, 1, 2, 3, 4, 5}; | |
} | |
}"; | |
const string expectedResult = @"public class NumberWang { | |
public void Wang() { | |
Int32 x = 1; | |
Int32 z = (x == 1)? 10 : 100; | |
Type t = GetType(); | |
int[] t2 = new int[] {0, 1, 2, 3, 4, 5}; | |
} | |
}"; | |
string? deinferedSource = typeDeInference.RemoveTypeInference(source, defaultReferences); | |
Assert.Equal(expectedResult, deinferedSource); | |
} | |
[Fact] | |
public void TestDeInferForEachStatement() { | |
const string source = @"using System.Linq; | |
public class T { | |
public void M() { | |
foreach(var s in Enumerable.Range(1, 10)) { | |
System.Console.WriteLine(s); | |
} | |
} | |
}"; | |
const string expectedResult = @"using System.Linq; | |
public class T { | |
public void M() { | |
foreach(Int32 s in Enumerable.Range(1, 10)) { | |
System.Console.WriteLine(s); | |
} | |
} | |
}"; | |
string? deinferedSource = typeDeInference.RemoveTypeInference(source, defaultReferences); | |
Assert.Equal(expectedResult, deinferedSource); | |
} | |
[Fact] | |
public void TestDeinferMultipleCodeFiles() { | |
const string source1 = @"public class NumberWang { | |
public void Wang() { | |
var x = 1; | |
var z = (x == 1)? 10 : 100; | |
var t = GetType(); | |
var t2 = new int[] {0, 1, 2, 3, 4, 5}; | |
} | |
}"; | |
const string source2 = @"public class NumberWong { | |
public void Wong() { | |
var x = 1; | |
var z = (x == 1)? 10 : 100; | |
var t = GetType(); | |
var t2 = new int[] {0, 1, 2, 3, 4, 5}; | |
} | |
}"; | |
const string expectedSource1Result = @"public class NumberWang { | |
public void Wang() { | |
Int32 x = 1; | |
Int32 z = (x == 1)? 10 : 100; | |
Type t = GetType(); | |
int[] t2 = new int[] {0, 1, 2, 3, 4, 5}; | |
} | |
}"; | |
const string expectedSource2Result = @"public class NumberWong { | |
public void Wong() { | |
Int32 x = 1; | |
Int32 z = (x == 1)? 10 : 100; | |
Type t = GetType(); | |
int[] t2 = new int[] {0, 1, 2, 3, 4, 5}; | |
} | |
}"; | |
Dictionary<string, string> sourceByIdentifier = new Dictionary<string, string>() {{"a",source1}, {"b", source2}}; | |
IDictionary<string, string> results = typeDeInference.RemoveTypeInference(sourceByIdentifier.Keys, defaultReferences, ident => sourceByIdentifier[ident]); | |
Assert.Equal(results["a"], expectedSource1Result); | |
Assert.Equal(results["b"], expectedSource2Result); | |
} | |
[Fact] | |
public async Task TestDeinferProject() { | |
const string projectPath = "/Users/nicholasnewdigate/Development/github/DelegateDecompiler/src/DelegateDecompiler/DelegateDecompiler.csproj"; | |
SolutionDeInference solutionDeInference = new SolutionDeInference(new TypeDeInference()); | |
try { | |
await solutionDeInference.RemoveTypeInferenceFromProject(projectPath); | |
} | |
catch (Exception exc) { | |
Console.Write(exc.Message); | |
} | |
} | |
[Fact] | |
public async Task TestDeinferSolution() { | |
const string projectPath = "/Users/nicholasnewdigate/Development/github/newdigate/entity-lang/lib1/lib1.csproj"; | |
SolutionDeInference solutionDeInference = new SolutionDeInference(new TypeDeInference()); | |
try { | |
IDictionary<string, string> results = await solutionDeInference.RemoveTypeInferenceFromProject(projectPath); | |
} | |
catch (Exception exc) { | |
Console.WriteLine(exc.Message); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<Project Sdk="Microsoft.NET.Sdk"> | |
<PropertyGroup> | |
<TargetFramework>net6.0</TargetFramework> | |
<RootNamespace>type_deinference</RootNamespace> | |
<ImplicitUsings>enable</ImplicitUsings> | |
<Nullable>enable</Nullable> | |
</PropertyGroup> | |
<ItemGroup> | |
<PackageReference Include="Microsoft.Build" Version="17.0.0" ExcludeAssets="runtime" /> | |
<PackageReference Include="Microsoft.Build.Framework" Version="17.0.0" ExcludeAssets="runtime" /> | |
<PackageReference Include="Microsoft.Build.Utilities.Core" Version="17.0.0" ExcludeAssets="runtime" /> | |
<PackageReference Include="Microsoft.Build.Locator" Version="1.4.1" /> | |
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.0.1" /> | |
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.0.1" /> | |
<PackageReference Include="Microsoft.CodeAnalysis.Workspaces.MSBuild" Version="4.0.1" /> | |
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0" /> | |
<PackageReference Include="NuGet.Frameworks" Version="6.0.0" /> | |
<PackageReference Include="xunit" Version="2.4.1" /> | |
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3"> | |
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | |
<PrivateAssets>all</PrivateAssets> | |
</PackageReference> | |
<PackageReference Include="coverlet.collector" Version="3.1.0"> | |
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | |
<PrivateAssets>all</PrivateAssets> | |
</PackageReference> | |
</ItemGroup> | |
</Project> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment