Skip to content

Instantly share code, notes, and snippets.

@skarllot
Last active July 16, 2025 17:46
Show Gist options
  • Save skarllot/be66922ae18bffd576570891ca7b846c to your computer and use it in GitHub Desktop.
Save skarllot/be66922ae18bffd576570891ca7b846c to your computer and use it in GitHub Desktop.
Test EF Core (Entity Framework) migrations
using DotNet.Testcontainers.Builders;
using DotNet.Testcontainers.Configurations;
using DotNet.Testcontainers.Containers;
using Microsoft.EntityFrameworkCore;
using Testcontainers.MsSql;
using Xunit;
namespace Tests;
public abstract class MsSqlTestBase : IAsyncLifetime
{
private static bool IsCiSystem => !string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("TF_BUILD"));
protected MsSqlContainer MsSqlContainer { get; } =
new MsSqlBuilder()
.WithAutoRemove(true)
.WithCleanUp(!IsCiSystem)
.WithWaitStrategy(
Wait.ForUnixContainer()
.UntilMessageIsLogged("SQL Server is now ready for client connections")
.AddCustomWaitStrategy(new WaitSqlCmdSuccess())
)
.Build();
public Task InitializeAsync() => MsSqlContainer.StartAsync();
public virtual async Task DisposeAsync()
{
try
{
await MsSqlContainer.StopAsync();
await MsSqlContainer.DisposeAsync();
}
catch
{
// Ignore dispose errors
}
}
public DbContextOptions<TContext> CreateDbContextOptions<TContext>()
where TContext : DbContext =>
new DbContextOptionsBuilder<TContext>().UseSqlServer(MsSqlContainer.GetConnectionString()).Options;
/// <remarks>
/// Uses the <c>sqlcmd</c> utility scripting variables to detect readiness of the MsSql container:
/// <see href="https://learn.microsoft.com/en-us/sql/tools/sqlcmd/sqlcmd-utility?view=sql-server-linux-ver15#sqlcmd-scripting-variables"/>.
/// </remarks>
private sealed class WaitSqlCmdSuccess : IWaitUntil
{
/// <inheritdoc />
public Task<bool> UntilAsync(IContainer container)
{
return UntilAsync((MsSqlContainer)container);
}
/// <inheritdoc cref="IWaitUntil.UntilAsync" />
private static async Task<bool> UntilAsync(MsSqlContainer container)
{
var sqlCmdFilePath = await container.GetSqlCmdFilePathAsync().ConfigureAwait(false);
var execResult = await container
.ExecAsync(new[] { sqlCmdFilePath, "-C", "-Q", "SELECT 1;" })
.ConfigureAwait(false);
return 0L.Equals(execResult.ExitCode);
}
}
}
using System.Diagnostics.CodeAnalysis;
using System.Text.RegularExpressions;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Design.Internal;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Migrations;
using Microsoft.EntityFrameworkCore.Migrations.Design;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
namespace Tests;
[Collection(SqlServerTestGroup.Name)]
public class MyDbContextTest : MsSqlTestBase
{
private static readonly string[] _tablesToCheck = { "Users", "Contacts" };
[Fact]
public async Task ShouldRunMigrations()
{
await using var dbContext = new MyDbContext(CreateDbContextOptions<MyDbContext>());
var totalMigrations = dbContext.Database.GetMigrations().Count();
var initialPendingMigrations = (await dbContext.Database.GetPendingMigrationsAsync()).Count();
var initialAppliedMigrations = (await dbContext.Database.GetAppliedMigrationsAsync()).Count();
await dbContext.Database.MigrateAsync();
var finalPendingMigrations = (await dbContext.Database.GetPendingMigrationsAsync()).Count();
var finalAppliedMigrations = (await dbContext.Database.GetAppliedMigrationsAsync()).Count();
Assert.True(totalMigrations > 0);
Assert.True(initialPendingMigrations > 0);
Assert.True(initialAppliedMigrations == 0);
Assert.True(finalPendingMigrations == 0);
Assert.True(finalAppliedMigrations > 0);
foreach (var tableName in _tablesToCheck)
{
Assert.True(await TableExists(dbContext, tableName), $"Table {tableName} does not exist");
}
}
[Fact]
public async Task ShouldRollbackMigrations()
{
await using var serviceProvider = new ServiceCollection()
.AddDbContext<MyDbContext>(builder => builder.UseSqlServer(MsSqlContainer.GetConnectionString()))
.BuildServiceProvider();
await using var scope = serviceProvider.CreateAsyncScope();
await using var dbContext = scope.ServiceProvider.GetRequiredService<MyDbContext>();
var migrator = dbContext.GetInfrastructure().GetRequiredService<IMigrator>();
await migrator.MigrateAsync();
var pendingMigrationsAfterMigrate = (await dbContext.Database.GetPendingMigrationsAsync()).Count();
var appliedMigrationsAfterMigrate = (await dbContext.Database.GetAppliedMigrationsAsync()).Count();
await migrator.MigrateAsync("0");
var pendingMigrationsAfterRollback = (await dbContext.Database.GetPendingMigrationsAsync()).Count();
var appliedMigrationsAfterRollback = (await dbContext.Database.GetAppliedMigrationsAsync()).Count();
Assert.True(pendingMigrationsAfterMigrate == 0);
Assert.True(appliedMigrationsAfterMigrate > 0);
Assert.True(pendingMigrationsAfterRollback > 0);
Assert.True(appliedMigrationsAfterRollback == 0);
Assert.Equal(appliedMigrationsAfterMigrate, pendingMigrationsAfterRollback);
foreach (var tableName in _tablesToCheck)
{
Assert.False(await TableExists(dbContext, tableName), $"Table {tableName} does exist");
}
}
[Fact]
[SuppressMessage("Usage", "EF1001:Internal EF Core API usage.")]
public async Task ShouldCreateABlankMigration()
{
await using var dbContext = new MyDbContext(CreateDbContextOptions<MyDbContext>());
await using var serviceProvider = new DesignTimeServicesBuilder(
typeof(MyDbContext).Assembly,
typeof(WebApp).Assembly,
new OperationReporter(null),
Array.Empty<string>()
)
.CreateServiceCollection(dbContext)
.BuildServiceProvider();
var migrationsScaffolder = serviceProvider.GetRequiredService<IMigrationsScaffolder>();
string migrationCode = migrationsScaffolder
.ScaffoldMigration("test", typeof(MyDbContext).FullName)
.MigrationCode;
var migrationWithoutWhitespace = Regex.Replace(migrationCode, @"\s+", string.Empty);
Assert.True(
migrationWithoutWhitespace.Contains("Up(MigrationBuildermigrationBuilder){}"),
$"Model changes were found which are not yet reflected in a migration. Proposed migration to add them: {migrationCode}"
);
}
private static async Task<bool> TableExists(DbContext dbContext, string tableName)
{
var result = await dbContext
.Database.SqlQuery<int>(
$"IF OBJECT_ID ({tableName}, N'U') IS NOT NULL SELECT 1 AS Value ELSE SELECT 0 AS Value"
)
.ToListAsync();
return result.Single() != 0;
}
}
namespace Tests;
public static class SqlServerTestGroup
{
public const string Name = "SqlServerTest";
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment