Last active
August 6, 2023 15:31
-
-
Save ChristopherHaws/b1c54b95838f1513bfb74fa1c8e408f3 to your computer and use it in GitHub Desktop.
EFCore Azure AccessToken
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Collections.Generic; | |
namespace System.Threading.Tasks | |
{ | |
public static class AsyncUtilities | |
{ | |
/// <summary> | |
/// Execute's an async Task{T} method which has a void return value synchronously | |
/// </summary> | |
/// <param name="task">Task{T} method to execute</param> | |
public static void RunSync(Func<Task> task) | |
{ | |
var oldContext = SynchronizationContext.Current; | |
var sync = new ExclusiveSynchronizationContext(); | |
SynchronizationContext.SetSynchronizationContext(sync); | |
sync.Post(async _ => | |
{ | |
try | |
{ | |
await task(); | |
} | |
catch (Exception e) | |
{ | |
sync.InnerException = e; | |
throw; | |
} | |
finally | |
{ | |
sync.EndMessageLoop(); | |
} | |
}, null); | |
sync.BeginMessageLoop(); | |
SynchronizationContext.SetSynchronizationContext(oldContext); | |
} | |
/// <summary> | |
/// Execute's an async Task{T} method which has a T return type synchronously | |
/// </summary> | |
/// <typeparam name="T">Return Type</typeparam> | |
/// <param name="task">Task{T} method to execute</param> | |
/// <returns></returns> | |
public static T RunSync<T>(Func<Task<T>> task) | |
{ | |
var oldContext = SynchronizationContext.Current; | |
var sync = new ExclusiveSynchronizationContext(); | |
SynchronizationContext.SetSynchronizationContext(sync); | |
T ret = default; | |
sync.Post(async _ => | |
{ | |
try | |
{ | |
ret = await task(); | |
} | |
catch (Exception e) | |
{ | |
sync.InnerException = e; | |
throw; | |
} | |
finally | |
{ | |
sync.EndMessageLoop(); | |
} | |
}, null); | |
sync.BeginMessageLoop(); | |
SynchronizationContext.SetSynchronizationContext(oldContext); | |
return ret; | |
} | |
private class ExclusiveSynchronizationContext : SynchronizationContext, IDisposable | |
{ | |
private readonly AutoResetEvent workItemsWaiting = new AutoResetEvent(false); | |
private readonly Queue<Tuple<SendOrPostCallback, Object>> items = new Queue<Tuple<SendOrPostCallback, Object>>(); | |
private bool done; | |
public Exception InnerException { get; set; } | |
public void Dispose() | |
{ | |
this.workItemsWaiting?.Dispose(); | |
} | |
public override void Send(SendOrPostCallback d, Object state) | |
{ | |
throw new NotSupportedException("We cannot send to our same thread"); | |
} | |
public override void Post(SendOrPostCallback d, Object state) | |
{ | |
lock (this.items) | |
{ | |
this.items.Enqueue(Tuple.Create(d, state)); | |
} | |
this.workItemsWaiting.Set(); | |
} | |
public void EndMessageLoop() | |
{ | |
this.Post(_ => this.done = true, null); | |
} | |
public void BeginMessageLoop() | |
{ | |
while (!this.done) | |
{ | |
Tuple<SendOrPostCallback, object> task = null; | |
lock (this.items) | |
{ | |
if (this.items.Count > 0) | |
{ | |
task = this.items.Dequeue(); | |
} | |
} | |
if (task != null) | |
{ | |
task.Item1(task.Item2); | |
if (this.InnerException != null) // the method threw an exeption | |
{ | |
throw new AggregateException("AsyncHelpers.Run method threw an exception.", this.InnerException); | |
} | |
} | |
else | |
{ | |
this.workItemsWaiting.WaitOne(); | |
} | |
} | |
} | |
public override SynchronizationContext CreateCopy() | |
{ | |
return this; | |
} | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Data.Common; | |
using System.Data.SqlClient; | |
using System.Threading.Tasks; | |
using Microsoft.Azure.Services.AppAuthentication; | |
using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal; | |
using Microsoft.EntityFrameworkCore.Storage; | |
namespace Microsoft.EntityFrameworkCore | |
{ | |
public static class AzureSqlServerConnectionExtensions | |
{ | |
public static void UseAzureAccessToken(this DbContextOptionsBuilder options) | |
{ | |
options.ReplaceService<ISqlServerConnection, AzureSqlServerConnection>(); | |
} | |
} | |
public class AzureSqlServerConnection : SqlServerConnection | |
{ | |
// Compensate for slow SQL Server database creation | |
private const int DefaultMasterConnectionCommandTimeout = 60; | |
private static readonly AzureServiceTokenProvider TokenProvider = new AzureServiceTokenProvider(); | |
public AzureSqlServerConnection(RelationalConnectionDependencies dependencies) | |
: base(dependencies) | |
{ | |
} | |
protected override DbConnection CreateDbConnection() => new SqlConnection(this.ConnectionString) | |
{ | |
// AzureServiceTokenProvider handles caching the token and refreshing it before it expires | |
AccessToken = AsyncUtilities.RunSync(() => TokenProvider.GetAccessTokenAsync("https://database.windows.net/")) | |
}; | |
public override ISqlServerConnection CreateMasterConnection() | |
{ | |
var connectionStringBuilder = new SqlConnectionStringBuilder(this.ConnectionString) | |
{ | |
InitialCatalog = "master" | |
}; | |
connectionStringBuilder.Remove("AttachDBFilename"); | |
var contextOptions = new DbContextOptionsBuilder() | |
.UseSqlServer( | |
connectionStringBuilder.ConnectionString, | |
b => b.CommandTimeout(this.CommandTimeout ?? DefaultMasterConnectionCommandTimeout)) | |
.Options; | |
return new AzureSqlServerConnection(this.Dependencies.With(contextOptions)); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class Startup | |
{ | |
private readonly IConfiguration configuration; | |
private readonly IHostingEnvironment env; | |
public Startup(IConfiguration configuration, IHostingEnvironment env) | |
{ | |
this.configuration = configuration; | |
this.env = env; | |
} | |
// This method gets called by the runtime. Use this method to add services to the container. | |
public void ConfigureServices(IServiceCollection services) | |
{ | |
services.AddDbContextPool<ApplicationContext>(options => | |
{ | |
options.UseSqlServer(this.configuration.GetConnectionString("DefaultConnection")); | |
if (!this.env.IsDevelopment()) | |
{ | |
options.UseAzureAccessToken(); | |
} | |
}); | |
// Removed unrelated code... | |
} | |
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline. | |
public void Configure(IApplicationBuilder app) | |
{ | |
// Removed unrelated code... | |
} | |
} |
Did it not work properly before? The token is already cached in a static field.
@OskarKlintrot I was not aware of that, thanks for the info. I suppose all that my update does then is remove a small allocation. ;)
That saves ~28µs on my machine if I remember correctly when I used benchmarkdotnet to see how long time it took to create a new instance :) I ended up using an extension (IDbContextOptionsExtension
) to be able to use EF's DI instead and be able to mock it for unit testing purposes. It's probably a lot slower, though.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
FYI, to anyone interested, I moved
TokenProvider
to be a static readonly field so that the caching of tokens works properly.