Skip to content

Instantly share code, notes, and snippets.

@mgray88
Last active June 6, 2025 21:02
Show Gist options
  • Save mgray88/6924116b821a58a6d34b4502cc0713f2 to your computer and use it in GitHub Desktop.
Save mgray88/6924116b821a58a6d34b4502cc0713f2 to your computer and use it in GitHub Desktop.
Expo API Routes Middleware handler
/* eslint-disable @typescript-eslint/no-empty-object-type */
import log from "@/lib/logger"; // This is using adze
// Define a base context type
type BaseContext = Record<string, any>;
// Define a middleware function type
export type Middleware<T extends BaseContext = BaseContext> = (
request: Request,
) => T | Promise<T> | void;
// Helper type to extract the context type from a middleware
type ExtractMiddlewareContext<M> = M extends Middleware<infer T> ? T : {};
// Combine multiple middleware context types using intersection
type CombineMiddlewareContexts<M extends Middleware<any>[]> = M extends [
infer First,
...infer Rest,
]
? First extends Middleware<any>
? Rest extends Middleware<any>[]
? ExtractMiddlewareContext<First> & CombineMiddlewareContexts<Rest>
: ExtractMiddlewareContext<First>
: {}
: {};
// Handler type with combined context
export type Handler<C extends BaseContext> = (
request: Request,
pathParams: Record<string, string>,
context: C,
) => Promise<Response>;
/**
* Creates an API handler that executes a series of middlewares
* and then calls a handler function with the combined context.
*
* @param middlewares - An array of middleware functions that will be executed in order.
* @param handler - The final handler function that will be called with the combined context.
*/
export function apiHandler<M extends Middleware<any>[]>(
middlewares: [...M],
handler: Handler<CombineMiddlewareContexts<M>>,
): (request: Request, pathParams: Record<string, string>) => Promise<Response> {
return async (
request: Request,
pathParams: Record<string, string>,
): Promise<Response> => {
const url = new URL(request.url);
const logLabel = `${request.method} ${url.pathname}`;
log.label(logLabel).time.info();
try {
// Execute all middlewares and combine their results
const context = await middlewares.reduce(
async (acc, middleware) => {
const current = await middleware(request);
return { ...(await acc), ...current };
},
Promise.resolve({} as CombineMiddlewareContexts<M>),
);
// Call the handler with the combined context
return await handler(request, pathParams, context);
} catch (error) {
if (error instanceof ApiError) {
return Response.json(
{
error: {
message: error.message,
code: error.code,
},
} as ApiErrorResponse,
{ status: error.status },
);
} else {
log.error("Unexpected error:", error);
return Response.json(
{ error: "Internal Server Error" },
{ status: 500 },
);
}
} finally {
log.label(logLabel).timeEnd.info();
}
};
}
/**
* Custom error class for API errors.
*
* @example
* ```
* new ApiError("Some error message", 404);
* ```
* or
* ```
* new ApiError({
* message: "Some error message",
* code: "NOT_FOUND",
* status: 404,
* });
* ```
*
* `status` is optional and defaults to 500 if not provided.
*/
export class ApiError extends Error {
public code?: string;
public status: number;
constructor(
params:
| {
message: string;
code?: string;
status?: number;
}
| string,
status?: number,
) {
if (typeof params === "string") {
super(params);
this.status = status ?? 500; // Default to 500 if status is not provided
} else {
const { message, code, status } = params;
super(message);
this.code = code;
this.status = status ?? 500; // Default to 500 if status is not provided
}
this.name = "ApiError";
}
}
// Define the structure of the API error response
export type ApiErrorResponse = {
error: {
message: string;
code?: string;
};
};
import { db } from "@/db/server";
import { companyTable, UserRole } from "@/db/server/schema";
import { ApiError, apiHandler } from "@/server/api-handler";
import { getSubscriptionById } from "@/server/stripe";
import type { StripeSubscription } from "@/server/stripe";
import { withAuthMiddleware } from "@/server/with-auth-middleware";
import { eq } from "drizzle-orm";
export type CompanyResponse = {
id: string;
name: string;
stripeCustomerId: string | null;
stripeSubscriptionId: string | null;
subscription?: StripeSubscription | null;
};
export const GET = apiHandler(
[withAuthMiddleware({ allowedRoles: [UserRole.OWNER] })],
async (_req: Request, _params, { user }) => {
const companies = await db
.select()
.from(companyTable)
.where(eq(companyTable.id, user.companyId));
if (!companies || companies.length === 0) {
throw new ApiError("Company not found", 404);
}
const company = companies[0] as CompanyResponse;
if (company.stripeSubscriptionId) {
const subscription = await getSubscriptionById(
company.stripeSubscriptionId,
);
company.subscription = subscription;
}
return Response.json(company, { status: 200 });
},
);
import type { UserApi } from "@/app/api/user/index+api";
import { db } from "@/db/server";
import { companyTable, UserRole, userTable } from "@/db/server/schema";
import log from "@/lib/logger";
import { eq, getTableColumns } from "drizzle-orm";
import { FirebaseAuthError } from "firebase-admin/auth";
import { firebaseAuth } from "./firebase";
import { ApiError } from "./api-handler";
interface Opts {
allowedRoles: UserRole[];
}
/**
* Middleware to authenticate a user based on Firebase ID token.
* If allowedRoles are provided, it checks if the user's role is included in the allowed roles.
*
* Adds the user to the request context for further processing.
*/
export function withAuthMiddleware(
{ allowedRoles }: Opts = { allowedRoles: [] },
): (request: Request) => Promise<{ user: UserApi }> {
return async (request: Request): Promise<{ user: UserApi }> => {
// Check for Authorization header
const authHeader = request.headers.get("Authorization");
if (!authHeader || !authHeader.startsWith("Bearer ")) {
throw new ApiError("Unauthorized", 401);
}
const token = authHeader.split(" ")[1];
let firebaseId: string;
try {
// Verify the Firebase ID token
const decodedToken = await firebaseAuth.verifyIdToken(token);
firebaseId = decodedToken.uid;
} catch (error) {
if (error instanceof FirebaseAuthError) {
log.debug("Error verifying token:", error);
throw new ApiError({
message: error.message,
code: error.code,
status: 401,
});
} else {
log.error("Unknown error verifying token:", error);
throw new ApiError("Unauthorized", 401);
}
}
const userRows = await db
.select({
...getTableColumns(userTable),
companyName: companyTable.name,
})
.from(userTable)
.innerJoin(companyTable, eq(userTable.companyId, companyTable.id))
.where(eq(userTable.firebaseId, firebaseId));
if (!userRows || userRows.length !== 1) {
throw new ApiError("User not found", 403);
}
const user = userRows[0];
if (allowedRoles.length > 0 && !allowedRoles.includes(user.role)) {
log.debug("User role not allowed:", user.role);
throw new ApiError("Forbidden", 403);
}
return { user };
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment