Created
March 17, 2024 11:53
-
-
Save chimame/f8ab9ae3172ded0e97f64010cad3d578 to your computer and use it in GitHub Desktop.
Automatic rollback of vitest using drizzle
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Client } from "pg"; | |
import { drizzle } from "drizzle-orm/node-postgres"; | |
import * as schema from "../drizzle/schema"; | |
import { Logger } from "drizzle-orm/logger"; | |
export async function createContext() { | |
const client = new Client({ | |
connectionString: "your database connection string", | |
}); | |
await client.connect(); | |
const db = drizzle(client, { | |
schema, | |
logger: process.env.NODE_ENV === "development", | |
}); | |
return { | |
db, | |
}; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { createContext } from "../context"; | |
import { sql } from "drizzle-orm"; | |
import * as crypto from "crypto"; | |
import { NodePgTransaction } from "drizzle-orm/node-postgres"; | |
type Context = Awaited<ReturnType<typeof createContext>>; | |
declare module "vitest" { | |
export interface TestContext { | |
ctx: Context; | |
} | |
} | |
let ctx: Context; | |
let savePoints: string[] = []; | |
beforeAll(async () => { | |
ctx = await createContext(); | |
await ctx.db.execute(sql`BEGIN`); | |
}); | |
beforeEach(async (context) => { | |
context.ctx = ctx; | |
const uuid = crypto.randomUUID(); | |
savePoints.push(uuid); | |
// savepoint name must begin with a letter | |
await ctx.db.execute(sql.raw(`SAVEPOINT A${uuid.replace(/\-/g, "")}`)); | |
const nodePgSessionSpy = vi.spyOn(NodePgTransaction.prototype, "execute"); | |
nodePgSessionSpy.mockImplementation(async (query) => { | |
return { | |
rows: [], | |
rowCount: 0, | |
command: "SELECT", | |
oid: 0, | |
fields: [], | |
}; | |
}); | |
}); | |
afterEach(async (context) => { | |
const uuid = savePoints.pop(); | |
await context.ctx.db.execute( | |
sql.raw(`ROLLBACK TO SAVEPOINT A${uuid!.replace(/\-/g, "")}`), | |
); | |
}); | |
afterAll(async () => { | |
await ctx.db.execute(sql`ROLLBACK`); | |
}); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment