Created
December 23, 2024 16:33
-
-
Save 01Vladimir10/7a0ca0bedf21f8e1a98699e141d040f7 to your computer and use it in GitHub Desktop.
Ef core extensions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using System.Collections; | |
| using Dapper; | |
| using Microsoft.EntityFrameworkCore; | |
| using Microsoft.EntityFrameworkCore.ChangeTracking; | |
| using MySqlConnector; | |
| using System.Data; | |
| using System.Diagnostics.CodeAnalysis; | |
| using System.Linq.Expressions; | |
| using System.Runtime.CompilerServices; | |
| using Domain.Entities; | |
| using Domain.Modules.Common.Data; | |
| using Domain.Modules.Core.Query; | |
| using JetBrains.Annotations; | |
| using MR.EntityFrameworkCore.KeysetPagination; | |
| namespace Infrastructure.Extensions; | |
| public static class EntityFrameworkExtensions | |
| { | |
| private static T? FindAttachedEntry<T>(this DbSet<T> set, T entity) where T : class | |
| { | |
| var primaryKeys = set.EntityType.FindPrimaryKey()!.Properties | |
| .Select(x => new { getter = x.GetGetter(), comparer = x.GetKeyValueComparer() }) | |
| .ToArray(); | |
| return set.Local | |
| .FirstOrDefault(local => primaryKeys.All(comparer => | |
| comparer.comparer.Equals(comparer.getter.GetClrValue(local), comparer.getter.GetClrValue(entity)) | |
| ) | |
| ); | |
| } | |
| private static Func<T, T, bool> GetKeyComparer<T>(this DbSet<T> set) where T : class | |
| { | |
| var keys = set.EntityType.FindPrimaryKey()?.Properties | |
| .Select(x => new { getter = x.GetGetter(), comparer = x.GetKeyValueComparer() }) | |
| .ToArray() ?? []; | |
| return (a, b) => keys.All(x => x.comparer.Equals( | |
| x.getter.GetClrValue(a), | |
| x.getter.GetClrValue(b) | |
| ) | |
| ); | |
| } | |
| public static void UpdateDisconnected<T>(this DbSet<T> context, T entity) where T : class | |
| { | |
| var attachedEntry = context.FindAttachedEntry(entity); | |
| if (attachedEntry is not null) | |
| { | |
| context.Entry(attachedEntry).CurrentValues.SetValues(entity); | |
| return; | |
| } | |
| context.Entry(entity).CurrentValues.SetValues(entity); | |
| } | |
| public static void UpdateDisconnectedNavigationCollection<T>( | |
| this DbSet<T> context, | |
| Func<T, bool> parentEntityPredicate, | |
| ICollection<T> updatedList | |
| ) where T : class | |
| { | |
| var localList = context.Local.Where(parentEntityPredicate).ToArray(); | |
| var equals = context.GetKeyComparer(); | |
| foreach (var item in updatedList) | |
| { | |
| var local = localList.FirstOrDefault(x => equals(item, x)); | |
| if (local is null) | |
| { | |
| context.Add(item); | |
| } | |
| else | |
| { | |
| context.Entry(local).CurrentValues.SetValues(item); | |
| } | |
| } | |
| // remove items that are not in the list. | |
| context.RemoveRange(localList.Where(x => updatedList.All(y => !equals(x, y)))); | |
| } | |
| public static void UpdatedDisconnectedRange<T>(this DbSet<T> context, IEnumerable<T> entities) where T : class | |
| { | |
| var primaryKeyGetters = context.EntityType.FindPrimaryKey()!.Properties | |
| .Select(x => x.GetGetter()) | |
| .ToArray(); | |
| var trackedEntriesDictionary = context.Local.ToDictionary( | |
| x => primaryKeyGetters.Select(getter => getter.GetClrValue(x)).ToArray(), | |
| comparer: ArraySequenceComparer<object?>.Default | |
| ); | |
| foreach (var entry in entities) | |
| { | |
| var key = primaryKeyGetters.Select(getter => getter.GetClrValue(entry)).ToArray(); | |
| if (trackedEntriesDictionary.TryGetValue(key, out var cachedEntry)) | |
| // the entity has not been loaded or tracked by this context | |
| context.Entry(cachedEntry).CurrentValues.SetValues(entry); | |
| else | |
| // the entity is tracked, set the new values so that EF can track what properties changed. | |
| context.Entry(entry).State = EntityState.Modified; | |
| } | |
| } | |
| /// <summary> | |
| /// Detaches all tracked entries that match <paramref name="predicate"/> and then executes | |
| /// the deletion on the database. | |
| /// </summary> | |
| /// <param name="set"></param> | |
| /// <param name="predicate"></param> | |
| /// <typeparam name="T"></typeparam> | |
| /// <returns></returns> | |
| public static Task BulkDeleteAsync<T>(this DbSet<T> set, Expression<Func<T, bool>> predicate) where T : class | |
| { | |
| var filter = predicate.Compile(); | |
| foreach (var item in set.Local.Where(filter)) | |
| { | |
| set.Entry(item).State = EntityState.Detached; | |
| } | |
| return set.Where(predicate).ExecuteDeleteAsync(); | |
| } | |
| public static Task<Page<T>> PaginateAsync<T>(this IQueryable<T> source, PaginationQuery query) | |
| => source.PaginateAsync(query.PageIndex, query.PageSize); | |
| public static Task<Page<TFinal>> PaginateAsync<T, TFinal>( | |
| this IQueryable<T> source, | |
| PaginationQuery query, | |
| Func<T, TFinal> mapper) | |
| => source.PaginateAsync(query.PageIndex, query.PageSize, mapper); | |
| public static Task<Page<T>> PaginateAsync<T>( | |
| this IQueryable<T> source, | |
| int pageIndex, | |
| int pageSize) | |
| => PaginateAsync(source, pageIndex, pageSize, x => x); | |
| public static async Task<Page<TFinal>> PaginateAsync<T, TFinal>( | |
| this IQueryable<T> source, | |
| int pageIndex, | |
| int pageSize, | |
| Func<T, TFinal> mapper) | |
| { | |
| var totalItems = await source.CountAsync(); | |
| var totalPages = (int)Math.Ceiling(totalItems / (decimal)pageSize); | |
| if (pageIndex > totalPages) | |
| { | |
| return Page<TFinal>.Empty; | |
| } | |
| var data = await source | |
| .Skip((pageIndex - 1) * pageSize) | |
| .Take(pageSize) | |
| .ToListAsync() | |
| .ConfigureAwait(false); | |
| return new Page<TFinal> | |
| { | |
| Result = data.Select(mapper).ToList(), | |
| PageSize = pageSize, | |
| TotalItems = totalItems, | |
| PageIndex = pageIndex | |
| }; | |
| } | |
| public static T? GetOrDefault<TKey, T>(this IDictionary<TKey, T> source, TKey key) | |
| where TKey : notnull | |
| => | |
| source.TryGetValue(key, out var value) ? value : default; | |
| public static T GetOrDefault<TKey, T>(this IDictionary<TKey, T> source, TKey key, T defaultValue) | |
| where TKey : notnull | |
| => | |
| source.TryGetValue(key, out var value) ? value : defaultValue; | |
| public static async Task<IEnumerable<T>> ExecuteStoredProcedure<T>(this DbContext context, string storedProcedure, | |
| object? parameters = null) | |
| { | |
| var connection = context.Database.GetDbConnection(); | |
| var useNewConnection = connection.State != ConnectionState.Open; | |
| if (useNewConnection) | |
| { | |
| connection = new MySqlConnection(context.Database.GetConnectionString()); | |
| } | |
| try | |
| { | |
| return await connection.QueryAsync<T>( | |
| storedProcedure, | |
| parameters, | |
| commandType: CommandType.StoredProcedure, commandTimeout: context.Database.GetCommandTimeout()); | |
| } | |
| finally | |
| { | |
| if (useNewConnection) | |
| { | |
| await connection.CloseAsync(); | |
| connection.Dispose(); | |
| } | |
| } | |
| } | |
| public static async Task ExecuteStoredProcedure(this DbContext context, string storedProcedure, | |
| object? parameters = null, int? commandTimeout = null) | |
| { | |
| var connection = context.Database.GetDbConnection(); | |
| var useNewConnection = connection.State != ConnectionState.Open; | |
| if (useNewConnection) | |
| { | |
| connection = new MySqlConnection(context.Database.GetConnectionString()); | |
| } | |
| try | |
| { | |
| await connection.ExecuteScalarAsync(storedProcedure, | |
| parameters, | |
| commandType: CommandType.StoredProcedure, | |
| commandTimeout: commandTimeout ?? context.Database.GetCommandTimeout()); | |
| } | |
| finally | |
| { | |
| if (useNewConnection) | |
| { | |
| await connection.CloseAsync(); | |
| connection.Dispose(); | |
| } | |
| } | |
| } | |
| private static async Task<T> OpenConnectionAsync<T>(this DbContext context, Func<IDbConnection, Task<T>> client) | |
| { | |
| var connection = context.Database.GetDbConnection(); | |
| var useNewConnection = connection.State != ConnectionState.Open; | |
| if (useNewConnection) | |
| { | |
| connection = new MySqlConnection(context.Database.GetConnectionString()); | |
| } | |
| try | |
| { | |
| return await client(connection); | |
| } | |
| finally | |
| { | |
| if (useNewConnection) | |
| { | |
| await connection.CloseAsync(); | |
| await connection.DisposeAsync(); | |
| } | |
| } | |
| } | |
| public static Task<IEnumerable<T>> ExecuteDapperQueryAsync< | |
| [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | | |
| DynamicallyAccessedMemberTypes.PublicProperties)] | |
| T>(this DbContext context, | |
| [LanguageInjection("MySQL")] string query, | |
| object? parameters = null, | |
| CancellationToken cancellationToken = default) => | |
| context.OpenConnectionAsync(connection => | |
| connection.QueryAsync<T>(new CommandDefinition(query, parameters, cancellationToken: cancellationToken))); | |
| public static Task<IEnumerable<T>> ExecuteDapperQueryAsync<T>(this DbContext context, Func<T> typeBuilder, | |
| string query, | |
| object? parameters = null) => | |
| context.OpenConnectionAsync(connection => connection.QueryAsync<T>(query, parameters)); | |
| public static IAsyncEnumerable<T> ExecuteDapperQueryAsAsyncEnumerable<T>( | |
| this DbContext context, | |
| // ReSharper disable once UnusedParameter.Global | |
| Func<T> builder, | |
| string query, | |
| object? parameters = null, | |
| CancellationToken cancellationToken = default) => | |
| ExecuteDapperQueryAsAsyncEnumerable<T>(context, query, parameters, cancellationToken); | |
| public static async IAsyncEnumerable<T> ExecuteDapperQueryAsAsyncEnumerable<T>(this DbContext context, string query, | |
| object? parameters = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | |
| { | |
| var connection = context.Database.GetDbConnection(); | |
| var useNewConnection = connection.State != ConnectionState.Open; | |
| if (useNewConnection) | |
| { | |
| connection = new MySqlConnection(context.Database.GetConnectionString()); | |
| } | |
| try | |
| { | |
| var reader = await connection.ExecuteReaderAsync(query, parameters); | |
| var parser = reader.GetRowParser<T>(); | |
| while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) | |
| { | |
| yield return parser(reader); | |
| } | |
| while (await reader.NextResultAsync(cancellationToken).ConfigureAwait(false)) | |
| { | |
| } | |
| } | |
| finally | |
| { | |
| if (useNewConnection) | |
| { | |
| await connection.CloseAsync(); | |
| await connection.DisposeAsync(); | |
| } | |
| } | |
| } | |
| public static Task<T[]> PaginateByKey<T>( | |
| this IQueryable<T> source, | |
| string? startKey, | |
| int pageSize, | |
| bool fetchForward = true, | |
| params OrderQueryDefinition[] orderByDefinitions | |
| ) where T : class, IAggregate => | |
| PaginateByKey(source, startKey, pageSize, fetchForward, | |
| (query, id) => query.FirstAsync(x => x.Id == id), | |
| orderByDefinitions, | |
| x => x.Id); | |
| private static async Task<T[]> PaginateByKey<T>( | |
| this IQueryable<T> source, | |
| string? startKey, | |
| int pageSize, | |
| bool fetchForward, | |
| Func<IQueryable<T>, string, Task<T>> startKeyFinder, | |
| IEnumerable<OrderQueryDefinition> orderByDefinitions, | |
| params Expression<Func<T, object>>[] identityKeys) where T : class | |
| { | |
| var orderBy = BuildKeySetQuery<T>( | |
| orderByDefinitions.Concat( | |
| identityKeys.Select(x => new OrderQueryDefinition(x.GetPropertyName())) | |
| ) | |
| .GroupBy(x => x.Property, StringComparer.OrdinalIgnoreCase) | |
| .Select(x => x.First()) | |
| ); | |
| var page = await source | |
| .KeysetPaginateQuery( | |
| keysetQueryDefinition: orderBy, | |
| direction: fetchForward ? KeysetPaginationDirection.Forward : KeysetPaginationDirection.Backward, | |
| reference: startKey == null ? null : await startKeyFinder(source, startKey) | |
| ) | |
| .Take(pageSize) | |
| .ToArrayAsync(); | |
| if (page.Length == 0) | |
| return Array.Empty<T>(); | |
| return fetchForward ? page : page.Reverse().ToArray(); | |
| } | |
| public static async Task<T[]> PaginateByKey<T>( | |
| this IQueryable<T> source, | |
| KeySetPaginationQuery paginationQuery, | |
| Func<IQueryable<T>, string, Task<T>> startKeyFinder, | |
| params Expression<Func<T, object>>[] identityKeys) | |
| where T : class | |
| { | |
| var orderBy = BuildKeySetQuery<T>( | |
| paginationQuery.OrderBy | |
| .Concat(identityKeys.Select(x => new OrderQueryDefinition(x.GetPropertyName()))) | |
| .GroupBy(x => x.Property, StringComparer.OrdinalIgnoreCase) | |
| .Select(x => x.First()) | |
| ); | |
| var startKey = paginationQuery.StartKey == null ? null : await startKeyFinder(source, paginationQuery.StartKey); | |
| var page = await source | |
| .KeysetPaginateQuery( | |
| keysetQueryDefinition: orderBy, | |
| direction: paginationQuery.FetchForward | |
| ? KeysetPaginationDirection.Forward | |
| : KeysetPaginationDirection.Backward, | |
| reference: startKey | |
| ) | |
| .Take(paginationQuery.PageSize) | |
| .ToArrayAsync(); | |
| if (!page.Any()) | |
| return Array.Empty<T>(); | |
| return paginationQuery.FetchForward ? page : page.Reverse().ToArray(); | |
| } | |
| public static IQueryable<T> OrderBy<T>(this IQueryable<T> source, IEnumerable<OrderQueryDefinition> definitions) | |
| { | |
| var expression = source.Expression; | |
| var count = 0; | |
| var properties = typeof(T).GetProperties(); | |
| foreach (var item in definitions) | |
| { | |
| var property = | |
| properties.FirstOrDefault(x => x.Name.Equals(item.Property, StringComparison.OrdinalIgnoreCase)); | |
| if (property == null) | |
| continue; | |
| var parameter = Expression.Parameter(typeof(T), "x"); | |
| var selector = Expression.PropertyOrField(parameter, property.Name); | |
| var method = item.OrderDescending ? count == 0 ? "OrderByDescending" : "ThenByDescending" : | |
| count == 0 ? "OrderBy" : "ThenBy"; | |
| expression = Expression.Call(typeof(Queryable), method, | |
| new[] { source.ElementType, selector.Type }, | |
| expression, Expression.Quote(Expression.Lambda(selector, parameter))); | |
| count++; | |
| } | |
| return count > 0 ? source.Provider.CreateQuery<T>(expression) : source; | |
| } | |
| private static KeysetQueryDefinition<T> BuildKeySetQuery<T>(IEnumerable<OrderQueryDefinition> properties) | |
| { | |
| var parameter = Expression.Parameter(typeof(T), "x"); | |
| var allProperties = typeof(T).GetProperties() | |
| .ToDictionary(x => x.Name, x => x, StringComparer.OrdinalIgnoreCase); | |
| var orderDefinitions = properties | |
| .Select(x => new | |
| { | |
| property = allProperties[x.Property], | |
| descending = x.OrderDescending | |
| }); | |
| var builder = KeysetQuery.Build<T>( | |
| b => | |
| { | |
| foreach (var orderDefinition in orderDefinitions) | |
| { | |
| var propertyType = orderDefinition.property.PropertyType; | |
| var memberExpression = Expression.MakeMemberAccess(parameter, orderDefinition.property); | |
| var expression = Expression.Lambda(memberExpression, parameter); | |
| if (propertyType == typeof(string)) | |
| b.By(CastExpression<string>(expression), orderDefinition.descending); | |
| else if (propertyType == typeof(int)) | |
| b.By(CastExpression<int>(expression), orderDefinition.descending); | |
| else if (propertyType == typeof(double)) | |
| b.By(CastExpression<double>(expression), orderDefinition.descending); | |
| else if (propertyType == typeof(decimal)) | |
| b.By(CastExpression<decimal>(expression), orderDefinition.descending); | |
| else if (propertyType == typeof(bool)) | |
| b.By(CastExpression<bool>(expression), orderDefinition.descending); | |
| else if (propertyType == typeof(DateTime)) | |
| b.By(CastExpression<DateTime>(expression), orderDefinition.descending); | |
| } | |
| }); | |
| return builder; | |
| static Expression<Func<T, TProp>> CastExpression<TProp>(LambdaExpression lambdaExpression) | |
| { | |
| return (Expression<Func<T, TProp>>)lambdaExpression; | |
| } | |
| } | |
| public static KeysetPaginationBuilder<T> By<T, TProp>(this KeysetPaginationBuilder<T> builder, | |
| Expression<Func<T, TProp>> propertyExpression, bool desc = false) | |
| { | |
| return desc ? builder.Descending(propertyExpression) : builder.Ascending(propertyExpression); | |
| } | |
| public static string GetPropertyName<T>(this Expression<Func<T, object?>> expression) | |
| { | |
| return expression.Body switch | |
| { | |
| MemberExpression memberExpression => memberExpression.Member.Name, | |
| UnaryExpression unaryExpression => ((MemberExpression)unaryExpression.Operand).Member.Name, | |
| _ => throw new NotSupportedException() | |
| }; | |
| } | |
| public static Task<Page<TResult>> PaginateAndSort<T, TResult>(this IQueryable<T> source, PaginationQuery query, | |
| Func<T, TResult> mapper) => source | |
| .OrderBy(query.OrderBy) | |
| .PaginateAsync(query.PageIndex, query.PageSize, mapper); | |
| public static Task<Page<T>> PaginateAndSort<T>(this IQueryable<T> source, PaginationQuery query) => source | |
| .OrderBy(query.OrderBy) | |
| .Paginate(query.PageIndex, query.PageSize); | |
| public static async Task<Page<T>> Paginate<T>( | |
| this IQueryable<T> source, | |
| int page, | |
| int pageSize | |
| ) | |
| { | |
| var totalItems = await source.CountAsync().ConfigureAwait(false); | |
| var totalPages = (int)Math.Ceiling(totalItems / (decimal)pageSize); | |
| if (page > totalPages || page <= 0) | |
| { | |
| return Page<T>.Empty; | |
| } | |
| var result = await source.Skip((page - 1) * pageSize).Take(pageSize).ToListAsync().ConfigureAwait(false); | |
| return new Page<T> | |
| { | |
| Result = result, | |
| PageSize = pageSize, | |
| TotalItems = totalItems, | |
| PageIndex = page | |
| }; | |
| } | |
| private static void ForEachItem<TCollection, TItem>(this TCollection source, Action<TItem> action) | |
| where TCollection : IEnumerable<TItem> | |
| { | |
| foreach (var x in source) action.Invoke(x); | |
| } | |
| public static void ForEachItem<T>(this IEnumerable<T> source, Action<T> action) | |
| { | |
| source.ForEachItem<IEnumerable<T>, T>(action); | |
| } | |
| public static bool IsEmpty<T>(this IEnumerable<T> source) => !source.Any(); | |
| #region BetterUpdates | |
| /// <summary> | |
| /// Get list of objects that represents the Primary Key of an entity | |
| /// </summary> | |
| /// <param name="entry"></param> | |
| /// <returns></returns> | |
| private static object?[] GetPrimaryKeyValues(this EntityEntry entry) | |
| { | |
| return entry.Metadata.FindPrimaryKey()? | |
| .Properties | |
| .Select(p => entry.Property(p.Name).CurrentValue) | |
| .ToArray() ?? throw new InvalidOperationException("No Primary Key Found"); | |
| } | |
| /// <summary> | |
| /// simple update method that will help you to do a full update to an aggregate graph with all related entities in it. | |
| /// the update method will take the loaded aggregate entity from the DB and the passed one that may come from the API layer. | |
| /// the method will update just the eager loaded entities in the aggregate "The included entities" | |
| /// </summary> | |
| /// <typeparam name="T"></typeparam> | |
| /// <param name="context"></param> | |
| /// <param name="newEntity">The De-Attached Entity</param> | |
| /// <param name="excludedProperties"></param> | |
| /// <param name="maxDepth">The depth of the search</param> | |
| public static void UpdateGraph<T>(this DbContext context, T newEntity, | |
| int maxDepth = 2, | |
| params Expression<Func<T, object?>>[] excludedProperties) where T : class | |
| { | |
| var existingEntity = context.Set<T>().FindAttachedEntry(newEntity); | |
| var excludedMembersHashSet = new HashSet<string>(excludedProperties.Select(x => x!.GetPropertyName())); | |
| UpdateGraph(context, newEntity, existingEntity, excludedMembersHashSet, null, maxDepth); | |
| } | |
| public static void UpdateGraphLimited<T>(this DbContext context, | |
| T newEntity, | |
| IEnumerable<Expression<Func<T, object?>>> properties, | |
| int maxDepth = 2 | |
| ) where T : class | |
| { | |
| var existingEntity = context.Set<T>().FindAttachedEntry(newEntity); | |
| var propertyNames = properties.Select(GetPropertyName).ToArray(); | |
| var excludedProperties = context.Set<T>().EntityType | |
| .GetNavigations() | |
| .Where(n => propertyNames.All(p => p != n.Name)) | |
| .Select(x => x.Name) | |
| .ToHashSet(StringComparer.OrdinalIgnoreCase); | |
| UpdateGraph(context, newEntity, existingEntity, excludedProperties, null, maxDepth); | |
| } | |
| public static void UpdateGraphs<T>( | |
| this DbContext context, | |
| ICollection<T> newEntities, | |
| int maxDepth = 2, | |
| int depth = 1, | |
| params Expression<Func<T, object>>[] excludedProperties) where T : class | |
| { | |
| var excludedMembersHashSet = new HashSet<string>(excludedProperties.Select(x => x.GetPropertyName())); | |
| if (newEntities.Count <= 5) | |
| { | |
| foreach (var newEntity in newEntities) | |
| { | |
| UpdateGraph(context, newEntity); | |
| } | |
| return; | |
| } | |
| var set = context.Set<T>(); | |
| var primaryKeyGetters = set.EntityType.FindPrimaryKey()!.Properties | |
| .Select(x => x.GetGetter()) | |
| .ToArray(); | |
| var trackedEntriesDictionary = set.Local.ToDictionary( | |
| x => primaryKeyGetters.Select(getter => getter.GetClrValue(x)).ToArray(), | |
| comparer: ArraySequenceComparer<object?>.Default | |
| ); | |
| foreach (var entry in newEntities) | |
| { | |
| var key = primaryKeyGetters.Select(getter => getter.GetClrValue(entry)).ToArray(); | |
| var existingEntity = trackedEntriesDictionary.TryGetValue(key, out var cachedEntry) ? cachedEntry : null; | |
| UpdateGraph(context, entry, existingEntity, excludedMembersHashSet, null); | |
| } | |
| } | |
| private static void UpdateGraph<T>( | |
| this DbContext context, | |
| T? newEntity, | |
| T? existingEntity, | |
| IReadOnlySet<string> excludedProperties, | |
| string? parentAggregateTypeName = null, | |
| int maxDepth = 2, | |
| int depth = 1) | |
| where T : class | |
| { | |
| if (depth > maxDepth) | |
| { | |
| return; | |
| } | |
| if (existingEntity == null && newEntity == null) | |
| { | |
| return; | |
| } | |
| if (existingEntity == null && newEntity != null) | |
| { | |
| context.Entry(newEntity).State = EntityState.Added; | |
| return; | |
| } | |
| if (newEntity == null && existingEntity != null) | |
| { | |
| context.Entry(existingEntity).State = EntityState.Deleted; | |
| return; | |
| } | |
| var existingEntry = context.Entry(existingEntity!); | |
| existingEntry.CurrentValues.SetValues(newEntity!); | |
| foreach (var navigationEntry in existingEntry.Navigations.Where(n => | |
| n.IsLoaded && n.Metadata.ClrType.FullName != parentAggregateTypeName && | |
| !excludedProperties.Contains(n.Metadata.Name))) | |
| { | |
| var entityTypeName = existingEntry.Metadata.ClrType.FullName; | |
| var newValue = existingEntry.Entity.GetType().GetProperty(navigationEntry.Metadata.Name) | |
| ?.GetValue(newEntity); | |
| var existingValue = navigationEntry.CurrentValue; | |
| if (navigationEntry is CollectionEntry) | |
| { | |
| var newItems = newValue as IEnumerable<object> ?? []; | |
| var existingItems = (existingValue as IEnumerable<object>)?.ToList() ?? []; | |
| // get new and updated items | |
| foreach (var newItem in newItems) | |
| { | |
| var key = context.Entry(newItem).GetPrimaryKeyValues(); | |
| var existingItem = | |
| existingItems.FirstOrDefault(x => context.Entry(x).GetPrimaryKeyValues().SequenceEqual(key)); | |
| if (existingItem is not null) | |
| { | |
| existingItems.Remove(existingItem); | |
| } | |
| UpdateGraph(context, newItem, existingItem, excludedProperties, entityTypeName, maxDepth, | |
| depth + 1); | |
| } | |
| foreach (var existingItem in existingItems) | |
| { | |
| UpdateGraph(context, null, existingItem, excludedProperties, entityTypeName, maxDepth, depth + 1); | |
| } | |
| } | |
| else | |
| { | |
| // the navigation is not a list | |
| UpdateGraph(context, newValue, existingValue, excludedProperties, entityTypeName, maxDepth, depth + 1); | |
| } | |
| } | |
| } | |
| #endregion | |
| } | |
| file class SequenceComparer<T, TCollection> : IEqualityComparer<TCollection> where TCollection : IEnumerable<T> | |
| { | |
| public static SequenceComparer<T, TCollection> Default { get; } = new(); | |
| public bool Equals(TCollection? x, TCollection? y) => StructuralComparisons.StructuralEqualityComparer.Equals(x, y); | |
| public int GetHashCode(TCollection obj) => StructuralComparisons.StructuralEqualityComparer.GetHashCode(obj); | |
| } | |
| file abstract class ArraySequenceComparer<T> : SequenceComparer<T, T[]>; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment