-
-
Save ChristopherHaws/b1c54b95838f1513bfb74fa1c8e408f3 to your computer and use it in GitHub Desktop.
using System.Collections.Generic; | |
namespace System.Threading.Tasks | |
{ | |
public static class AsyncUtilities | |
{ | |
/// <summary> | |
/// Execute's an async Task{T} method which has a void return value synchronously | |
/// </summary> | |
/// <param name="task">Task{T} method to execute</param> | |
public static void RunSync(Func<Task> task) | |
{ | |
var oldContext = SynchronizationContext.Current; | |
var sync = new ExclusiveSynchronizationContext(); | |
SynchronizationContext.SetSynchronizationContext(sync); | |
sync.Post(async _ => | |
{ | |
try | |
{ | |
await task(); | |
} | |
catch (Exception e) | |
{ | |
sync.InnerException = e; | |
throw; | |
} | |
finally | |
{ | |
sync.EndMessageLoop(); | |
} | |
}, null); | |
sync.BeginMessageLoop(); | |
SynchronizationContext.SetSynchronizationContext(oldContext); | |
} | |
/// <summary> | |
/// Execute's an async Task{T} method which has a T return type synchronously | |
/// </summary> | |
/// <typeparam name="T">Return Type</typeparam> | |
/// <param name="task">Task{T} method to execute</param> | |
/// <returns></returns> | |
public static T RunSync<T>(Func<Task<T>> task) | |
{ | |
var oldContext = SynchronizationContext.Current; | |
var sync = new ExclusiveSynchronizationContext(); | |
SynchronizationContext.SetSynchronizationContext(sync); | |
T ret = default; | |
sync.Post(async _ => | |
{ | |
try | |
{ | |
ret = await task(); | |
} | |
catch (Exception e) | |
{ | |
sync.InnerException = e; | |
throw; | |
} | |
finally | |
{ | |
sync.EndMessageLoop(); | |
} | |
}, null); | |
sync.BeginMessageLoop(); | |
SynchronizationContext.SetSynchronizationContext(oldContext); | |
return ret; | |
} | |
private class ExclusiveSynchronizationContext : SynchronizationContext, IDisposable | |
{ | |
private readonly AutoResetEvent workItemsWaiting = new AutoResetEvent(false); | |
private readonly Queue<Tuple<SendOrPostCallback, Object>> items = new Queue<Tuple<SendOrPostCallback, Object>>(); | |
private bool done; | |
public Exception InnerException { get; set; } | |
public void Dispose() | |
{ | |
this.workItemsWaiting?.Dispose(); | |
} | |
public override void Send(SendOrPostCallback d, Object state) | |
{ | |
throw new NotSupportedException("We cannot send to our same thread"); | |
} | |
public override void Post(SendOrPostCallback d, Object state) | |
{ | |
lock (this.items) | |
{ | |
this.items.Enqueue(Tuple.Create(d, state)); | |
} | |
this.workItemsWaiting.Set(); | |
} | |
public void EndMessageLoop() | |
{ | |
this.Post(_ => this.done = true, null); | |
} | |
public void BeginMessageLoop() | |
{ | |
while (!this.done) | |
{ | |
Tuple<SendOrPostCallback, object> task = null; | |
lock (this.items) | |
{ | |
if (this.items.Count > 0) | |
{ | |
task = this.items.Dequeue(); | |
} | |
} | |
if (task != null) | |
{ | |
task.Item1(task.Item2); | |
if (this.InnerException != null) // the method threw an exeption | |
{ | |
throw new AggregateException("AsyncHelpers.Run method threw an exception.", this.InnerException); | |
} | |
} | |
else | |
{ | |
this.workItemsWaiting.WaitOne(); | |
} | |
} | |
} | |
public override SynchronizationContext CreateCopy() | |
{ | |
return this; | |
} | |
} | |
} | |
} |
using System.Data.Common; | |
using System.Data.SqlClient; | |
using System.Threading.Tasks; | |
using Microsoft.Azure.Services.AppAuthentication; | |
using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal; | |
using Microsoft.EntityFrameworkCore.Storage; | |
namespace Microsoft.EntityFrameworkCore | |
{ | |
public static class AzureSqlServerConnectionExtensions | |
{ | |
public static void UseAzureAccessToken(this DbContextOptionsBuilder options) | |
{ | |
options.ReplaceService<ISqlServerConnection, AzureSqlServerConnection>(); | |
} | |
} | |
public class AzureSqlServerConnection : SqlServerConnection | |
{ | |
// Compensate for slow SQL Server database creation | |
private const int DefaultMasterConnectionCommandTimeout = 60; | |
private static readonly AzureServiceTokenProvider TokenProvider = new AzureServiceTokenProvider(); | |
public AzureSqlServerConnection(RelationalConnectionDependencies dependencies) | |
: base(dependencies) | |
{ | |
} | |
protected override DbConnection CreateDbConnection() => new SqlConnection(this.ConnectionString) | |
{ | |
// AzureServiceTokenProvider handles caching the token and refreshing it before it expires | |
AccessToken = AsyncUtilities.RunSync(() => TokenProvider.GetAccessTokenAsync("https://database.windows.net/")) | |
}; | |
public override ISqlServerConnection CreateMasterConnection() | |
{ | |
var connectionStringBuilder = new SqlConnectionStringBuilder(this.ConnectionString) | |
{ | |
InitialCatalog = "master" | |
}; | |
connectionStringBuilder.Remove("AttachDBFilename"); | |
var contextOptions = new DbContextOptionsBuilder() | |
.UseSqlServer( | |
connectionStringBuilder.ConnectionString, | |
b => b.CommandTimeout(this.CommandTimeout ?? DefaultMasterConnectionCommandTimeout)) | |
.Options; | |
return new AzureSqlServerConnection(this.Dependencies.With(contextOptions)); | |
} | |
} | |
} |
public class Startup | |
{ | |
private readonly IConfiguration configuration; | |
private readonly IHostingEnvironment env; | |
public Startup(IConfiguration configuration, IHostingEnvironment env) | |
{ | |
this.configuration = configuration; | |
this.env = env; | |
} | |
// This method gets called by the runtime. Use this method to add services to the container. | |
public void ConfigureServices(IServiceCollection services) | |
{ | |
services.AddDbContextPool<ApplicationContext>(options => | |
{ | |
options.UseSqlServer(this.configuration.GetConnectionString("DefaultConnection")); | |
if (!this.env.IsDevelopment()) | |
{ | |
options.UseAzureAccessToken(); | |
} | |
}); | |
// Removed unrelated code... | |
} | |
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline. | |
public void Configure(IApplicationBuilder app) | |
{ | |
// Removed unrelated code... | |
} | |
} |
Why is it needed to override CreateMasterConnection()
?
@OskarKlintrot Because it returns AzureSqlServerConnection
instead of SqlServerConnection
.
I missed that one, thanks for the clarification!
FYI, to anyone interested, I moved TokenProvider
to be a static readonly field so that the caching of tokens works properly.
Did it not work properly before? The token is already cached in a static field.
@OskarKlintrot I was not aware of that, thanks for the info. I suppose all that my update does then is remove a small allocation. ;)
That saves ~28µs on my machine if I remember correctly when I used benchmarkdotnet to see how long time it took to create a new instance :) I ended up using an extension (IDbContextOptionsExtension
) to be able to use EF's DI instead and be able to mock it for unit testing purposes. It's probably a lot slower, though.
for the sake of completeness 😊 the nuget package Microsoft.Azure.Services.AppAuthentication is needed