Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions app/(authenticated)/discounts/actions.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"use server";

import { z } from "zod";
import { requireAdmin } from "@/lib/auth/require-admin";
import {
applyProductDiscountToCustomer,
removeProductDiscountFromCustomer,
} from "@/lib/stripe";

export type DiscountActionState = {
errors?: Record<string, string[]>;
message?: string;
data?: Record<string, string>;
} | null;

const applyDiscountSchema = z
.object({
product_id: z.string().min(1, "Product is required"),
customer_id: z.string().min(1, "Customer is required"),
usage_limit: z.coerce
.number()
.int("Usage limit must be an integer")
.min(1, "Usage limit must be at least 1"),
discount_type: z.enum(["percent", "amount"]),
discount_value: z.coerce
.number()
.int("Discount value must be an integer")
.positive("Discount value must be positive"),
currency: z.string().trim().length(3).optional(),
})
.superRefine((value, ctx) => {
if (value.discount_type === "percent" && value.discount_value > 100) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ["discount_value"],
message: "Percent discount cannot exceed 100",
});
}

if (value.discount_type === "amount" && !value.currency) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
path: ["currency"],
message: "Currency is required for amount discounts",
});
}
});

const removeDiscountSchema = z.object({
product_id: z.string().min(1, "Product is required"),
customer_id: z.string().min(1, "Customer is required"),
});

export async function applyDiscountToCustomerProduct(
_prev: DiscountActionState,
formData: FormData,
): Promise<DiscountActionState> {
try {
await requireAdmin();
} catch {
return { errors: { _form: ["Unauthorized"] } };
}

const parsed = applyDiscountSchema.safeParse({
product_id: formData.get("product_id"),
customer_id: formData.get("customer_id"),
usage_limit: formData.get("usage_limit"),
discount_type: formData.get("discount_type"),
discount_value: formData.get("discount_value"),
currency: formData.get("currency") ?? undefined,
});

if (!parsed.success) {
return { errors: parsed.error.flatten().fieldErrors };
}

const discount =
parsed.data.discount_type === "percent"
? { percentOff: parsed.data.discount_value }
: {
amountOffCents: parsed.data.discount_value,
currency: (parsed.data.currency ?? "cad").toLowerCase(),
};

try {
const result = await applyProductDiscountToCustomer({
productId: parsed.data.product_id,
customerId: parsed.data.customer_id,
usageLimit: parsed.data.usage_limit,
discount,
});

return {
message: "Discount applied.",
data: {
coupon_id: result.couponId,
},
};
} catch (error) {
return {
errors: {
_form: [
error instanceof Error ? error.message : "Could not apply discount",
],
},
};
}
}

export async function removeDiscountFromCustomerProduct(
_prev: DiscountActionState,
formData: FormData,
): Promise<DiscountActionState> {
try {
await requireAdmin();
} catch {
return { errors: { _form: ["Unauthorized"] } };
}

const parsed = removeDiscountSchema.safeParse({
product_id: formData.get("product_id"),
customer_id: formData.get("customer_id"),
});

if (!parsed.success) {
return { errors: parsed.error.flatten().fieldErrors };
}

try {
const { removed } = await removeProductDiscountFromCustomer({
productId: parsed.data.product_id,
customerId: parsed.data.customer_id,
});

return {
message:
removed > 0
? "Discount removed."
: "No active discount found for this customer and product.",
};
} catch (error) {
return {
errors: {
_form: [
error instanceof Error ? error.message : "Could not remove discount",
],
},
};
}
}
17 changes: 16 additions & 1 deletion app/api/checkout/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { NextRequest, NextResponse } from "next/server";
import { stripe, getOrCreateStripeCustomer } from "@/lib/stripe";
import {
getActiveCouponForCustomerProduct,
getOrCreateStripeCustomer,
stripe,
} from "@/lib/stripe";
import { createClient } from "@/utils/supabase/server";

export async function POST(request: NextRequest) {
Expand All @@ -26,11 +30,22 @@ export async function POST(request: NextRequest) {
user.email!,
);

const price = await stripe.prices.retrieve(priceId);
const stripeProductId = price.product as string;

const couponId = await getActiveCouponForCustomerProduct({
customerId: stripeCustomerId,
productId: stripeProductId,
});

const session = await stripe.checkout.sessions.create({
customer: stripeCustomerId,
mode: mode as "subscription" | "payment",
payment_method_types: ["card"],
line_items: [{ price: priceId, quantity: 1 }],
...(couponId
? { discounts: [{ coupon: couponId }] }
: {}),
success_url: `${request.nextUrl.origin}/checkout/success`,
cancel_url: `${request.nextUrl.origin}/checkout/cancel`,
});
Expand Down
10 changes: 9 additions & 1 deletion app/api/webhooks/stripe/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { NextRequest, NextResponse } from "next/server";
import { stripe, syncStripeData } from "@/lib/stripe";
import { deleteCouponIfExhausted, stripe, syncStripeData } from "@/lib/stripe";
import { db } from "@/lib/db";
import { profiles, purchases } from "@/lib/db/schema";
import { eq } from "drizzle-orm";
Expand Down Expand Up @@ -44,6 +44,14 @@ export async function POST(request: NextRequest) {
if (event.type === "checkout.session.completed") {
const session = event.data.object as Stripe.Checkout.Session;

for (const d of session.discounts ?? []) {
const couponId =
typeof d.coupon === "string" ? d.coupon : d.coupon?.id;
if (couponId) {
await deleteCouponIfExhausted(couponId);
}
}

if (session.mode === "payment" && session.customer) {
const customerId = session.customer as string;

Expand Down
120 changes: 120 additions & 0 deletions lib/stripe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ export type StripeServiceData = {
priceCurrency: string | null;
};

type ManagedProductDiscount = {
couponId: string;
};

export async function getStripeServiceData(
productId: string,
): Promise<StripeServiceData | null> {
Expand All @@ -212,3 +216,119 @@ export async function getStripeServiceData(
priceCurrency: latestPrice?.currency ?? null,
};
}

async function getManagedDiscountForCustomerProduct(
customerId: string,
productId: string,
): Promise<ManagedProductDiscount | null> {
const coupons = await stripe.coupons.list({
limit: 100,
});

for (const coupon of coupons.data) {
const metadata = coupon.metadata ?? {};
if (metadata.customerId !== customerId) continue;
if (!coupon.valid) continue;

const appliesToProducts = coupon.applies_to?.products ?? [];
if (!appliesToProducts.includes(productId)) continue;

return {
couponId: coupon.id,
};
}

return null;
}

export type ProductDiscountConfig =
| {
percentOff: number;
amountOffCents?: never;
currency?: string;
}
| {
percentOff?: never;
amountOffCents: number;
currency: string;
};

export async function applyProductDiscountToCustomer(input: {
productId: string;
customerId: string;
usageLimit: number;
discount: ProductDiscountConfig;
}): Promise<{
couponId: string;
}> {
const existing = await getManagedDiscountForCustomerProduct(
input.customerId,
input.productId,
);
if (existing) {
throw new Error(
"A discount is already active for this customer and product. Remove it before applying a new one.",
);
}

const couponParams: Stripe.CouponCreateParams = {
duration: "forever",
applies_to: { products: [input.productId] },
max_redemptions: input.usageLimit,
metadata: {
productId: input.productId,
customerId: input.customerId,
},
};

if (input.discount.percentOff !== undefined) {
couponParams.percent_off = input.discount.percentOff;
} else {
couponParams.amount_off = input.discount.amountOffCents;
couponParams.currency = input.discount.currency;
}

const coupon = await stripe.coupons.create(couponParams);

return {
couponId: coupon.id,
};
}

export async function getActiveCouponForCustomerProduct(input: {
customerId: string;
productId: string;
}): Promise<string | null> {
const managed = await getManagedDiscountForCustomerProduct(
input.customerId,
input.productId,
);
return managed?.couponId ?? null;
}

export async function removeProductDiscountFromCustomer(input: {
productId: string;
customerId: string;
}): Promise<{ removed: number }> {
const managed = await getManagedDiscountForCustomerProduct(
input.customerId,
input.productId,
);
if (!managed) {
return { removed: 0 };
}

await stripe.coupons.del(managed.couponId);
return { removed: 1 };
}

export async function deleteCouponIfExhausted(couponId: string): Promise<void> {
try {
const coupon = await stripe.coupons.retrieve(couponId);
if (!coupon.valid) {
await stripe.coupons.del(couponId);
}
} catch (err) {
console.error(`[STRIPE] Coupon cleanup failed for ${couponId}:`, err);
}
}
Loading