Skip to content

Instantly share code, notes, and snippets.

@vanbukin
Created June 3, 2020 17:49
Show Gist options
  • Save vanbukin/35678a171cbf61831826a77ed1966b38 to your computer and use it in GitHub Desktop.
Save vanbukin/35678a171cbf61831826a77ed1966b38 to your computer and use it in GitHub Desktop.
Client credentials flow ASP.NET Core Background refresh
using System;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
using IdentityModel.Client;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
namespace Auth
{
public interface IAccessTokenFactory
{
public Task<string> CreateAsync(CancellationToken cancellationToken = default);
}
public interface IAccessTokenProvider
{
public string GetAccessToken();
public void SetAccessToken(string newAccessToken);
}
public class AccessTokenFactoryOptions
{
public AccessTokenFactoryOptions(string authority, string clientId, string clientSecret)
{
if (string.IsNullOrWhiteSpace(authority))
{
throw new ArgumentException("Value cannot be null or whitespace.", nameof(authority));
}
if (string.IsNullOrWhiteSpace(clientId))
{
throw new ArgumentException("Value cannot be null or whitespace.", nameof(clientId));
}
if (string.IsNullOrWhiteSpace(clientSecret))
{
throw new ArgumentException("Value cannot be null or whitespace.", nameof(clientSecret));
}
Authority = authority;
ClientId = clientId;
ClientSecret = clientSecret;
}
public string Authority { get; }
public string ClientId { get; }
public string ClientSecret { get; }
}
public class AccessTokenFactory : IAccessTokenFactory
{
private readonly AccessTokenFactoryOptions _options;
private readonly HttpClient _client;
public AccessTokenFactory(AccessTokenFactoryOptions options, HttpClient client)
{
_options = options ?? throw new ArgumentNullException(nameof(options));
_client = client ?? throw new ArgumentNullException(nameof(client));
}
public async Task<string> CreateAsync(CancellationToken cancellationToken = default)
{
var discoveryDoc = await _client.GetDiscoveryDocumentAsync(_options.Authority, cancellationToken);
if (discoveryDoc.IsError)
{
throw new Exception(discoveryDoc.Error);
}
var response = await _client.RequestClientCredentialsTokenAsync(
new ClientCredentialsTokenRequest
{
Address = discoveryDoc.TokenEndpoint,
ClientId = _options.ClientId,
ClientSecret = _options.ClientSecret
},
cancellationToken);
if (response.IsError)
{
throw new Exception(response.Error);
}
return response.AccessToken;
}
}
public class AccessTokenProvider : IAccessTokenProvider
{
private string _accessToken = string.Empty;
public string GetAccessToken()
{
return _accessToken;
}
public void SetAccessToken(string newAccessToken)
{
_accessToken = newAccessToken;
}
}
public class DevicesApiClientAccessTokenInjectionDelegatingHandler : DelegatingHandler
{
private readonly IAccessTokenProvider _accessTokenProvider;
public DevicesApiClientAccessTokenInjectionDelegatingHandler(
IAccessTokenProvider accessTokenProvider)
{
_accessTokenProvider = accessTokenProvider ?? throw new ArgumentNullException(nameof(accessTokenProvider));
}
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var accessToken = _accessTokenProvider.GetAccessToken();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
return base.SendAsync(request, cancellationToken);
}
}
public class BackgroundAccessTokenProviderHostedService : IHostedService
{
private readonly IAccessTokenFactory _accessTokenFactory;
private readonly IAccessTokenProvider _accessTokenProvider;
private readonly TimeSpan _accessTokenRefreshInterval;
private readonly TimeSpan _failedAttemptsDelay;
private readonly ILogger<BackgroundAccessTokenProviderHostedService> _logger;
private readonly CancellationTokenSource _stoppingCts = new CancellationTokenSource();
private Task _executingTask = null!;
public BackgroundAccessTokenProviderHostedService(
IAccessTokenProvider accessTokenProvider,
IAccessTokenFactory accessTokenFactory,
TimeSpan accessTokenRefreshInterval,
ILogger<BackgroundAccessTokenProviderHostedService> logger,
TimeSpan failedAttemptsDelay)
{
if (accessTokenRefreshInterval == default)
{
throw new ArgumentException("AccessTokenRefreshInterval can't be a zero or default");
}
_accessTokenProvider = accessTokenProvider ?? throw new ArgumentNullException(nameof(accessTokenProvider));
_accessTokenFactory = accessTokenFactory ?? throw new ArgumentNullException(nameof(accessTokenFactory));
_accessTokenRefreshInterval = accessTokenRefreshInterval;
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_failedAttemptsDelay = failedAttemptsDelay;
}
public Task StartAsync(CancellationToken cancellationToken)
{
_executingTask = StartRefreshingAccessTokensInBackgroundAsync(_stoppingCts.Token);
if (_executingTask.IsCompleted)
{
return _executingTask;
}
return Task.CompletedTask;
}
public async Task StopAsync(CancellationToken cancellationToken)
{
if (_executingTask == null)
{
return;
}
try
{
_stoppingCts.Cancel();
}
finally
{
await Task.WhenAny(
_executingTask,
Task.Delay(Timeout.Infinite, cancellationToken));
}
}
private async Task StartRefreshingAccessTokensInBackgroundAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var gotInitToken = false;
while (!cancellationToken.IsCancellationRequested)
{
if (!gotInitToken)
{
try
{
var initAccessToken = await _accessTokenFactory.CreateAsync(cancellationToken);
_accessTokenProvider.SetAccessToken(initAccessToken);
gotInitToken = true;
}
catch (Exception ex)
{
_logger.LogError(new EventId(0), ex, ex.Message);
}
await Task.Delay(_failedAttemptsDelay, cancellationToken);
}
else
{
cancellationToken.ThrowIfCancellationRequested();
await Task.Delay(_accessTokenRefreshInterval, cancellationToken);
var gotNewToken = false;
while (!gotNewToken)
{
try
{
var newAccessToken = await _accessTokenFactory.CreateAsync(cancellationToken);
_accessTokenProvider.SetAccessToken(newAccessToken);
gotNewToken = true;
}
catch (Exception ex)
{
_logger.LogError(new EventId(0), ex, ex.Message);
}
await Task.Delay(_failedAttemptsDelay, cancellationToken);
}
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment