Skip to content

Instantly share code, notes, and snippets.

@01Vladimir10
Created December 23, 2024 16:33
Show Gist options
  • Select an option

  • Save 01Vladimir10/7a0ca0bedf21f8e1a98699e141d040f7 to your computer and use it in GitHub Desktop.

Select an option

Save 01Vladimir10/7a0ca0bedf21f8e1a98699e141d040f7 to your computer and use it in GitHub Desktop.
Ef core extensions
using System.Collections;
using Dapper;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using MySqlConnector;
using System.Data;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using Domain.Entities;
using Domain.Modules.Common.Data;
using Domain.Modules.Core.Query;
using JetBrains.Annotations;
using MR.EntityFrameworkCore.KeysetPagination;
namespace Infrastructure.Extensions;
public static class EntityFrameworkExtensions
{
private static T? FindAttachedEntry<T>(this DbSet<T> set, T entity) where T : class
{
var primaryKeys = set.EntityType.FindPrimaryKey()!.Properties
.Select(x => new { getter = x.GetGetter(), comparer = x.GetKeyValueComparer() })
.ToArray();
return set.Local
.FirstOrDefault(local => primaryKeys.All(comparer =>
comparer.comparer.Equals(comparer.getter.GetClrValue(local), comparer.getter.GetClrValue(entity))
)
);
}
private static Func<T, T, bool> GetKeyComparer<T>(this DbSet<T> set) where T : class
{
var keys = set.EntityType.FindPrimaryKey()?.Properties
.Select(x => new { getter = x.GetGetter(), comparer = x.GetKeyValueComparer() })
.ToArray() ?? [];
return (a, b) => keys.All(x => x.comparer.Equals(
x.getter.GetClrValue(a),
x.getter.GetClrValue(b)
)
);
}
public static void UpdateDisconnected<T>(this DbSet<T> context, T entity) where T : class
{
var attachedEntry = context.FindAttachedEntry(entity);
if (attachedEntry is not null)
{
context.Entry(attachedEntry).CurrentValues.SetValues(entity);
return;
}
context.Entry(entity).CurrentValues.SetValues(entity);
}
public static void UpdateDisconnectedNavigationCollection<T>(
this DbSet<T> context,
Func<T, bool> parentEntityPredicate,
ICollection<T> updatedList
) where T : class
{
var localList = context.Local.Where(parentEntityPredicate).ToArray();
var equals = context.GetKeyComparer();
foreach (var item in updatedList)
{
var local = localList.FirstOrDefault(x => equals(item, x));
if (local is null)
{
context.Add(item);
}
else
{
context.Entry(local).CurrentValues.SetValues(item);
}
}
// remove items that are not in the list.
context.RemoveRange(localList.Where(x => updatedList.All(y => !equals(x, y))));
}
public static void UpdatedDisconnectedRange<T>(this DbSet<T> context, IEnumerable<T> entities) where T : class
{
var primaryKeyGetters = context.EntityType.FindPrimaryKey()!.Properties
.Select(x => x.GetGetter())
.ToArray();
var trackedEntriesDictionary = context.Local.ToDictionary(
x => primaryKeyGetters.Select(getter => getter.GetClrValue(x)).ToArray(),
comparer: ArraySequenceComparer<object?>.Default
);
foreach (var entry in entities)
{
var key = primaryKeyGetters.Select(getter => getter.GetClrValue(entry)).ToArray();
if (trackedEntriesDictionary.TryGetValue(key, out var cachedEntry))
// the entity has not been loaded or tracked by this context
context.Entry(cachedEntry).CurrentValues.SetValues(entry);
else
// the entity is tracked, set the new values so that EF can track what properties changed.
context.Entry(entry).State = EntityState.Modified;
}
}
/// <summary>
/// Detaches all tracked entries that match <paramref name="predicate"/> and then executes
/// the deletion on the database.
/// </summary>
/// <param name="set"></param>
/// <param name="predicate"></param>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public static Task BulkDeleteAsync<T>(this DbSet<T> set, Expression<Func<T, bool>> predicate) where T : class
{
var filter = predicate.Compile();
foreach (var item in set.Local.Where(filter))
{
set.Entry(item).State = EntityState.Detached;
}
return set.Where(predicate).ExecuteDeleteAsync();
}
public static Task<Page<T>> PaginateAsync<T>(this IQueryable<T> source, PaginationQuery query)
=> source.PaginateAsync(query.PageIndex, query.PageSize);
public static Task<Page<TFinal>> PaginateAsync<T, TFinal>(
this IQueryable<T> source,
PaginationQuery query,
Func<T, TFinal> mapper)
=> source.PaginateAsync(query.PageIndex, query.PageSize, mapper);
public static Task<Page<T>> PaginateAsync<T>(
this IQueryable<T> source,
int pageIndex,
int pageSize)
=> PaginateAsync(source, pageIndex, pageSize, x => x);
public static async Task<Page<TFinal>> PaginateAsync<T, TFinal>(
this IQueryable<T> source,
int pageIndex,
int pageSize,
Func<T, TFinal> mapper)
{
var totalItems = await source.CountAsync();
var totalPages = (int)Math.Ceiling(totalItems / (decimal)pageSize);
if (pageIndex > totalPages)
{
return Page<TFinal>.Empty;
}
var data = await source
.Skip((pageIndex - 1) * pageSize)
.Take(pageSize)
.ToListAsync()
.ConfigureAwait(false);
return new Page<TFinal>
{
Result = data.Select(mapper).ToList(),
PageSize = pageSize,
TotalItems = totalItems,
PageIndex = pageIndex
};
}
public static T? GetOrDefault<TKey, T>(this IDictionary<TKey, T> source, TKey key)
where TKey : notnull
=>
source.TryGetValue(key, out var value) ? value : default;
public static T GetOrDefault<TKey, T>(this IDictionary<TKey, T> source, TKey key, T defaultValue)
where TKey : notnull
=>
source.TryGetValue(key, out var value) ? value : defaultValue;
public static async Task<IEnumerable<T>> ExecuteStoredProcedure<T>(this DbContext context, string storedProcedure,
object? parameters = null)
{
var connection = context.Database.GetDbConnection();
var useNewConnection = connection.State != ConnectionState.Open;
if (useNewConnection)
{
connection = new MySqlConnection(context.Database.GetConnectionString());
}
try
{
return await connection.QueryAsync<T>(
storedProcedure,
parameters,
commandType: CommandType.StoredProcedure, commandTimeout: context.Database.GetCommandTimeout());
}
finally
{
if (useNewConnection)
{
await connection.CloseAsync();
connection.Dispose();
}
}
}
public static async Task ExecuteStoredProcedure(this DbContext context, string storedProcedure,
object? parameters = null, int? commandTimeout = null)
{
var connection = context.Database.GetDbConnection();
var useNewConnection = connection.State != ConnectionState.Open;
if (useNewConnection)
{
connection = new MySqlConnection(context.Database.GetConnectionString());
}
try
{
await connection.ExecuteScalarAsync(storedProcedure,
parameters,
commandType: CommandType.StoredProcedure,
commandTimeout: commandTimeout ?? context.Database.GetCommandTimeout());
}
finally
{
if (useNewConnection)
{
await connection.CloseAsync();
connection.Dispose();
}
}
}
private static async Task<T> OpenConnectionAsync<T>(this DbContext context, Func<IDbConnection, Task<T>> client)
{
var connection = context.Database.GetDbConnection();
var useNewConnection = connection.State != ConnectionState.Open;
if (useNewConnection)
{
connection = new MySqlConnection(context.Database.GetConnectionString());
}
try
{
return await client(connection);
}
finally
{
if (useNewConnection)
{
await connection.CloseAsync();
await connection.DisposeAsync();
}
}
}
public static Task<IEnumerable<T>> ExecuteDapperQueryAsync<
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors |
DynamicallyAccessedMemberTypes.PublicProperties)]
T>(this DbContext context,
[LanguageInjection("MySQL")] string query,
object? parameters = null,
CancellationToken cancellationToken = default) =>
context.OpenConnectionAsync(connection =>
connection.QueryAsync<T>(new CommandDefinition(query, parameters, cancellationToken: cancellationToken)));
public static Task<IEnumerable<T>> ExecuteDapperQueryAsync<T>(this DbContext context, Func<T> typeBuilder,
string query,
object? parameters = null) =>
context.OpenConnectionAsync(connection => connection.QueryAsync<T>(query, parameters));
public static IAsyncEnumerable<T> ExecuteDapperQueryAsAsyncEnumerable<T>(
this DbContext context,
// ReSharper disable once UnusedParameter.Global
Func<T> builder,
string query,
object? parameters = null,
CancellationToken cancellationToken = default) =>
ExecuteDapperQueryAsAsyncEnumerable<T>(context, query, parameters, cancellationToken);
public static async IAsyncEnumerable<T> ExecuteDapperQueryAsAsyncEnumerable<T>(this DbContext context, string query,
object? parameters = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var connection = context.Database.GetDbConnection();
var useNewConnection = connection.State != ConnectionState.Open;
if (useNewConnection)
{
connection = new MySqlConnection(context.Database.GetConnectionString());
}
try
{
var reader = await connection.ExecuteReaderAsync(query, parameters);
var parser = reader.GetRowParser<T>();
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
yield return parser(reader);
}
while (await reader.NextResultAsync(cancellationToken).ConfigureAwait(false))
{
}
}
finally
{
if (useNewConnection)
{
await connection.CloseAsync();
await connection.DisposeAsync();
}
}
}
public static Task<T[]> PaginateByKey<T>(
this IQueryable<T> source,
string? startKey,
int pageSize,
bool fetchForward = true,
params OrderQueryDefinition[] orderByDefinitions
) where T : class, IAggregate =>
PaginateByKey(source, startKey, pageSize, fetchForward,
(query, id) => query.FirstAsync(x => x.Id == id),
orderByDefinitions,
x => x.Id);
private static async Task<T[]> PaginateByKey<T>(
this IQueryable<T> source,
string? startKey,
int pageSize,
bool fetchForward,
Func<IQueryable<T>, string, Task<T>> startKeyFinder,
IEnumerable<OrderQueryDefinition> orderByDefinitions,
params Expression<Func<T, object>>[] identityKeys) where T : class
{
var orderBy = BuildKeySetQuery<T>(
orderByDefinitions.Concat(
identityKeys.Select(x => new OrderQueryDefinition(x.GetPropertyName()))
)
.GroupBy(x => x.Property, StringComparer.OrdinalIgnoreCase)
.Select(x => x.First())
);
var page = await source
.KeysetPaginateQuery(
keysetQueryDefinition: orderBy,
direction: fetchForward ? KeysetPaginationDirection.Forward : KeysetPaginationDirection.Backward,
reference: startKey == null ? null : await startKeyFinder(source, startKey)
)
.Take(pageSize)
.ToArrayAsync();
if (page.Length == 0)
return Array.Empty<T>();
return fetchForward ? page : page.Reverse().ToArray();
}
public static async Task<T[]> PaginateByKey<T>(
this IQueryable<T> source,
KeySetPaginationQuery paginationQuery,
Func<IQueryable<T>, string, Task<T>> startKeyFinder,
params Expression<Func<T, object>>[] identityKeys)
where T : class
{
var orderBy = BuildKeySetQuery<T>(
paginationQuery.OrderBy
.Concat(identityKeys.Select(x => new OrderQueryDefinition(x.GetPropertyName())))
.GroupBy(x => x.Property, StringComparer.OrdinalIgnoreCase)
.Select(x => x.First())
);
var startKey = paginationQuery.StartKey == null ? null : await startKeyFinder(source, paginationQuery.StartKey);
var page = await source
.KeysetPaginateQuery(
keysetQueryDefinition: orderBy,
direction: paginationQuery.FetchForward
? KeysetPaginationDirection.Forward
: KeysetPaginationDirection.Backward,
reference: startKey
)
.Take(paginationQuery.PageSize)
.ToArrayAsync();
if (!page.Any())
return Array.Empty<T>();
return paginationQuery.FetchForward ? page : page.Reverse().ToArray();
}
public static IQueryable<T> OrderBy<T>(this IQueryable<T> source, IEnumerable<OrderQueryDefinition> definitions)
{
var expression = source.Expression;
var count = 0;
var properties = typeof(T).GetProperties();
foreach (var item in definitions)
{
var property =
properties.FirstOrDefault(x => x.Name.Equals(item.Property, StringComparison.OrdinalIgnoreCase));
if (property == null)
continue;
var parameter = Expression.Parameter(typeof(T), "x");
var selector = Expression.PropertyOrField(parameter, property.Name);
var method = item.OrderDescending ? count == 0 ? "OrderByDescending" : "ThenByDescending" :
count == 0 ? "OrderBy" : "ThenBy";
expression = Expression.Call(typeof(Queryable), method,
new[] { source.ElementType, selector.Type },
expression, Expression.Quote(Expression.Lambda(selector, parameter)));
count++;
}
return count > 0 ? source.Provider.CreateQuery<T>(expression) : source;
}
private static KeysetQueryDefinition<T> BuildKeySetQuery<T>(IEnumerable<OrderQueryDefinition> properties)
{
var parameter = Expression.Parameter(typeof(T), "x");
var allProperties = typeof(T).GetProperties()
.ToDictionary(x => x.Name, x => x, StringComparer.OrdinalIgnoreCase);
var orderDefinitions = properties
.Select(x => new
{
property = allProperties[x.Property],
descending = x.OrderDescending
});
var builder = KeysetQuery.Build<T>(
b =>
{
foreach (var orderDefinition in orderDefinitions)
{
var propertyType = orderDefinition.property.PropertyType;
var memberExpression = Expression.MakeMemberAccess(parameter, orderDefinition.property);
var expression = Expression.Lambda(memberExpression, parameter);
if (propertyType == typeof(string))
b.By(CastExpression<string>(expression), orderDefinition.descending);
else if (propertyType == typeof(int))
b.By(CastExpression<int>(expression), orderDefinition.descending);
else if (propertyType == typeof(double))
b.By(CastExpression<double>(expression), orderDefinition.descending);
else if (propertyType == typeof(decimal))
b.By(CastExpression<decimal>(expression), orderDefinition.descending);
else if (propertyType == typeof(bool))
b.By(CastExpression<bool>(expression), orderDefinition.descending);
else if (propertyType == typeof(DateTime))
b.By(CastExpression<DateTime>(expression), orderDefinition.descending);
}
});
return builder;
static Expression<Func<T, TProp>> CastExpression<TProp>(LambdaExpression lambdaExpression)
{
return (Expression<Func<T, TProp>>)lambdaExpression;
}
}
public static KeysetPaginationBuilder<T> By<T, TProp>(this KeysetPaginationBuilder<T> builder,
Expression<Func<T, TProp>> propertyExpression, bool desc = false)
{
return desc ? builder.Descending(propertyExpression) : builder.Ascending(propertyExpression);
}
public static string GetPropertyName<T>(this Expression<Func<T, object?>> expression)
{
return expression.Body switch
{
MemberExpression memberExpression => memberExpression.Member.Name,
UnaryExpression unaryExpression => ((MemberExpression)unaryExpression.Operand).Member.Name,
_ => throw new NotSupportedException()
};
}
public static Task<Page<TResult>> PaginateAndSort<T, TResult>(this IQueryable<T> source, PaginationQuery query,
Func<T, TResult> mapper) => source
.OrderBy(query.OrderBy)
.PaginateAsync(query.PageIndex, query.PageSize, mapper);
public static Task<Page<T>> PaginateAndSort<T>(this IQueryable<T> source, PaginationQuery query) => source
.OrderBy(query.OrderBy)
.Paginate(query.PageIndex, query.PageSize);
public static async Task<Page<T>> Paginate<T>(
this IQueryable<T> source,
int page,
int pageSize
)
{
var totalItems = await source.CountAsync().ConfigureAwait(false);
var totalPages = (int)Math.Ceiling(totalItems / (decimal)pageSize);
if (page > totalPages || page <= 0)
{
return Page<T>.Empty;
}
var result = await source.Skip((page - 1) * pageSize).Take(pageSize).ToListAsync().ConfigureAwait(false);
return new Page<T>
{
Result = result,
PageSize = pageSize,
TotalItems = totalItems,
PageIndex = page
};
}
private static void ForEachItem<TCollection, TItem>(this TCollection source, Action<TItem> action)
where TCollection : IEnumerable<TItem>
{
foreach (var x in source) action.Invoke(x);
}
public static void ForEachItem<T>(this IEnumerable<T> source, Action<T> action)
{
source.ForEachItem<IEnumerable<T>, T>(action);
}
public static bool IsEmpty<T>(this IEnumerable<T> source) => !source.Any();
#region BetterUpdates
/// <summary>
/// Get list of objects that represents the Primary Key of an entity
/// </summary>
/// <param name="entry"></param>
/// <returns></returns>
private static object?[] GetPrimaryKeyValues(this EntityEntry entry)
{
return entry.Metadata.FindPrimaryKey()?
.Properties
.Select(p => entry.Property(p.Name).CurrentValue)
.ToArray() ?? throw new InvalidOperationException("No Primary Key Found");
}
/// <summary>
/// simple update method that will help you to do a full update to an aggregate graph with all related entities in it.
/// the update method will take the loaded aggregate entity from the DB and the passed one that may come from the API layer.
/// the method will update just the eager loaded entities in the aggregate "The included entities"
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="context"></param>
/// <param name="newEntity">The De-Attached Entity</param>
/// <param name="excludedProperties"></param>
/// <param name="maxDepth">The depth of the search</param>
public static void UpdateGraph<T>(this DbContext context, T newEntity,
int maxDepth = 2,
params Expression<Func<T, object?>>[] excludedProperties) where T : class
{
var existingEntity = context.Set<T>().FindAttachedEntry(newEntity);
var excludedMembersHashSet = new HashSet<string>(excludedProperties.Select(x => x!.GetPropertyName()));
UpdateGraph(context, newEntity, existingEntity, excludedMembersHashSet, null, maxDepth);
}
public static void UpdateGraphLimited<T>(this DbContext context,
T newEntity,
IEnumerable<Expression<Func<T, object?>>> properties,
int maxDepth = 2
) where T : class
{
var existingEntity = context.Set<T>().FindAttachedEntry(newEntity);
var propertyNames = properties.Select(GetPropertyName).ToArray();
var excludedProperties = context.Set<T>().EntityType
.GetNavigations()
.Where(n => propertyNames.All(p => p != n.Name))
.Select(x => x.Name)
.ToHashSet(StringComparer.OrdinalIgnoreCase);
UpdateGraph(context, newEntity, existingEntity, excludedProperties, null, maxDepth);
}
public static void UpdateGraphs<T>(
this DbContext context,
ICollection<T> newEntities,
int maxDepth = 2,
int depth = 1,
params Expression<Func<T, object>>[] excludedProperties) where T : class
{
var excludedMembersHashSet = new HashSet<string>(excludedProperties.Select(x => x.GetPropertyName()));
if (newEntities.Count <= 5)
{
foreach (var newEntity in newEntities)
{
UpdateGraph(context, newEntity);
}
return;
}
var set = context.Set<T>();
var primaryKeyGetters = set.EntityType.FindPrimaryKey()!.Properties
.Select(x => x.GetGetter())
.ToArray();
var trackedEntriesDictionary = set.Local.ToDictionary(
x => primaryKeyGetters.Select(getter => getter.GetClrValue(x)).ToArray(),
comparer: ArraySequenceComparer<object?>.Default
);
foreach (var entry in newEntities)
{
var key = primaryKeyGetters.Select(getter => getter.GetClrValue(entry)).ToArray();
var existingEntity = trackedEntriesDictionary.TryGetValue(key, out var cachedEntry) ? cachedEntry : null;
UpdateGraph(context, entry, existingEntity, excludedMembersHashSet, null);
}
}
private static void UpdateGraph<T>(
this DbContext context,
T? newEntity,
T? existingEntity,
IReadOnlySet<string> excludedProperties,
string? parentAggregateTypeName = null,
int maxDepth = 2,
int depth = 1)
where T : class
{
if (depth > maxDepth)
{
return;
}
if (existingEntity == null && newEntity == null)
{
return;
}
if (existingEntity == null && newEntity != null)
{
context.Entry(newEntity).State = EntityState.Added;
return;
}
if (newEntity == null && existingEntity != null)
{
context.Entry(existingEntity).State = EntityState.Deleted;
return;
}
var existingEntry = context.Entry(existingEntity!);
existingEntry.CurrentValues.SetValues(newEntity!);
foreach (var navigationEntry in existingEntry.Navigations.Where(n =>
n.IsLoaded && n.Metadata.ClrType.FullName != parentAggregateTypeName &&
!excludedProperties.Contains(n.Metadata.Name)))
{
var entityTypeName = existingEntry.Metadata.ClrType.FullName;
var newValue = existingEntry.Entity.GetType().GetProperty(navigationEntry.Metadata.Name)
?.GetValue(newEntity);
var existingValue = navigationEntry.CurrentValue;
if (navigationEntry is CollectionEntry)
{
var newItems = newValue as IEnumerable<object> ?? [];
var existingItems = (existingValue as IEnumerable<object>)?.ToList() ?? [];
// get new and updated items
foreach (var newItem in newItems)
{
var key = context.Entry(newItem).GetPrimaryKeyValues();
var existingItem =
existingItems.FirstOrDefault(x => context.Entry(x).GetPrimaryKeyValues().SequenceEqual(key));
if (existingItem is not null)
{
existingItems.Remove(existingItem);
}
UpdateGraph(context, newItem, existingItem, excludedProperties, entityTypeName, maxDepth,
depth + 1);
}
foreach (var existingItem in existingItems)
{
UpdateGraph(context, null, existingItem, excludedProperties, entityTypeName, maxDepth, depth + 1);
}
}
else
{
// the navigation is not a list
UpdateGraph(context, newValue, existingValue, excludedProperties, entityTypeName, maxDepth, depth + 1);
}
}
}
#endregion
}
file class SequenceComparer<T, TCollection> : IEqualityComparer<TCollection> where TCollection : IEnumerable<T>
{
public static SequenceComparer<T, TCollection> Default { get; } = new();
public bool Equals(TCollection? x, TCollection? y) => StructuralComparisons.StructuralEqualityComparer.Equals(x, y);
public int GetHashCode(TCollection obj) => StructuralComparisons.StructuralEqualityComparer.GetHashCode(obj);
}
file abstract class ArraySequenceComparer<T> : SequenceComparer<T, T[]>;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment