Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ RESEND_FROM_DEFAULT= # e.g., hello@mail.trycomp.ai
# Background checks
BACKGROUND_CHECK_API_BASE_URL=https://glad-sturgeon-729.convex.site
BACKGROUND_CHECK_API_KEY=
MACED_API_KEY=mc_dev_dummy_api_key
BACKGROUND_CHECK_WEBHOOK_SECRET=
BACKGROUND_WH_ENDPOINT=
STRIPE_BACKGROUND_CHECK_PRICE_ID=price_1TRWckCkFWhKYvHIA1GLv1sO
Expand Down
54 changes: 26 additions & 28 deletions apps/api/src/background-checks/background-check-billing.service.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import {
BadRequestException,
Injectable,
NotFoundException,
} from '@nestjs/common';
import { BadRequestException, Injectable, NotFoundException } from '@nestjs/common';
import { db } from '@db';
import { StripeService } from '../stripe/stripe.service';

Expand All @@ -14,20 +10,32 @@ export class BackgroundCheckBillingService {
hasBilling: boolean;
hasPaymentMethod: boolean;
setupAt: Date | null;
usage: {
backgroundChecks: number;
penetrationTests: number;
};
}> {
const billing = await db.organizationBilling.findUnique({
where: { organizationId },
select: {
stripeCustomerId: true,
stripeBackgroundCheckPaymentMethodId: true,
backgroundCheckPaymentMethodSetupAt: true,
},
});
const [billing, backgroundChecks, penetrationTests] = await Promise.all([
db.organizationBilling.findUnique({
where: { organizationId },
select: {
stripeCustomerId: true,
stripeBackgroundCheckPaymentMethodId: true,
backgroundCheckPaymentMethodSetupAt: true,
},
}),
db.backgroundCheckRequest.count({ where: { organizationId } }),
db.securityPenetrationTestRun.count({ where: { organizationId } }),
]);

return {
hasBilling: !!billing,
hasPaymentMethod: !!billing?.stripeBackgroundCheckPaymentMethodId,
setupAt: billing?.backgroundCheckPaymentMethodSetupAt ?? null,
usage: {
backgroundChecks,
penetrationTests,
},
};
}

Expand Down Expand Up @@ -60,9 +68,7 @@ export class BackgroundCheckBillingService {
});

if (!session.url) {
throw new BadRequestException(
'Failed to create Stripe Checkout session.',
);
throw new BadRequestException('Failed to create Stripe Checkout session.');
}

return { url: session.url };
Expand Down Expand Up @@ -112,9 +118,7 @@ export class BackgroundCheckBillingService {

const paymentMethodId = this.extractStripeId(setupIntent.payment_method);
if (!paymentMethodId) {
throw new BadRequestException(
'Setup intent is missing a payment method.',
);
throw new BadRequestException('Setup intent is missing a payment method.');
}

await stripe.customers.update(stripeCustomerId, {
Expand Down Expand Up @@ -157,9 +161,7 @@ export class BackgroundCheckBillingService {
});

if (!billing) {
throw new NotFoundException(
'No billing record found for this organization.',
);
throw new NotFoundException('No billing record found for this organization.');
}

const portalSession = await stripe.billingPortal.sessions.create({
Expand Down Expand Up @@ -249,15 +251,11 @@ export class BackgroundCheckBillingService {
}

if (parsed.origin !== new URL(appUrl).origin) {
throw new BadRequestException(
'Redirect URL must belong to the application origin.',
);
throw new BadRequestException('Redirect URL must belong to the application origin.');
}
}

private extractStripeId(
value: string | { id?: string } | null,
): string | null {
private extractStripeId(value: string | { id?: string } | null): string | null {
if (!value) return null;
if (typeof value === 'string') return value;
return value.id ?? null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ describe('BackgroundCheckPaymentService', () => {
expect.objectContaining({
amount: 1250,
customer: 'cus_1',
description: 'Comp AI - Background Check x1',
payment_method: 'pm_1',
}),
{ idempotencyKey: 'background-check:org_1:mem_1:price_bg:pm_1' },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import { BackgroundCheckBillingService } from './background-check-billing.servic

@Injectable()
export class BackgroundCheckPaymentService {
private static readonly receiptDescription =
'Comp AI - Background Check x1';

private readonly logger = new Logger(BackgroundCheckPaymentService.name);

constructor(
Expand Down Expand Up @@ -43,6 +46,7 @@ export class BackgroundCheckPaymentService {
customer: billing.stripeCustomerId,
amount: price.unitAmount,
currency: price.currency,
description: BackgroundCheckPaymentService.receiptDescription,
payment_method: billing.stripeBackgroundCheckPaymentMethodId,
off_session: true,
confirm: true,
Expand Down
37 changes: 37 additions & 0 deletions apps/api/src/background-checks/background-checks.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jest.mock('@db', () => {
backgroundCheckRequest: {
findUnique: jest.fn(),
findFirst: jest.fn(),
count: jest.fn(),
create: jest.fn(),
upsert: jest.fn(),
update: jest.fn(),
Expand All @@ -42,6 +43,9 @@ jest.mock('@db', () => {
create: jest.fn(),
upsert: jest.fn(),
},
securityPenetrationTestRun: {
count: jest.fn(),
},
organization: {
findUnique: jest.fn(),
},
Expand Down Expand Up @@ -475,4 +479,37 @@ describe('background checks', () => {
}),
).resolves.toEqual({ url: 'https://checkout.stripe.com/c/session_1' });
});

it('includes background check and penetration test usage in billing status', async () => {
mockAsync<Awaited<ReturnType<typeof db.organizationBilling.findUnique>>>(
mockedDb.organizationBilling.findUnique,
).mockResolvedValueOnce({
stripeCustomerId: 'cus_1',
stripeBackgroundCheckPaymentMethodId: 'pm_1',
backgroundCheckPaymentMethodSetupAt: new Date('2026-04-29T12:00:00.000Z'),
} as Awaited<ReturnType<typeof db.organizationBilling.findUnique>>);
mockAsync<number>(mockedDb.backgroundCheckRequest.count).mockResolvedValueOnce(4);
mockAsync<number>(
mockedDb.securityPenetrationTestRun.count,
).mockResolvedValueOnce(2);

const service = new BackgroundCheckBillingService({
getClient: jest.fn(),
} as unknown as StripeService);

await expect(service.getStatus('org_1')).resolves.toMatchObject({
hasBilling: true,
hasPaymentMethod: true,
usage: {
backgroundChecks: 4,
penetrationTests: 2,
},
});
expect(mockedDb.backgroundCheckRequest.count).toHaveBeenCalledWith({
where: { organizationId: 'org_1' },
});
expect(mockedDb.securityPenetrationTestRun.count).toHaveBeenCalledWith({
where: { organizationId: 'org_1' },
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,58 @@ describe('applySync', () => {
}));
});

it('creates missing control row before reconciling RequirementMap drift when v1 and v2 both claim it', async () => {
const tx = mockTx();
tx.control.create.mockResolvedValue({
id: 'ctl_created',
controlTemplateId: 'ct_1',
organizationId: 'org_1',
name: 'C',
description: 'D',
archivedAt: null,
});
tx.control.findMany.mockResolvedValue([]);
tx.requirementMap.findMany.mockResolvedValue([]);

const sameControl = { id: 'ct_1', name: 'C', description: 'D', requirementIds: ['rq_1'], policyIds: [], taskIds: [] };
await applySync(tx, {
instance: baseInstance as any,
currentVersion: {
id: 'fvr_v1',
frameworkId: 'frk_soc2',
manifest: manifest({
requirements: [{ id: 'rq_1', identifier: 'CC1', name: 'X', description: null }],
controls: [sameControl],
}),
} as any,
targetVersion: {
id: 'fvr_v2',
frameworkId: 'frk_soc2',
manifest: manifest({
requirements: [{ id: 'rq_1', identifier: 'CC1', name: 'X', description: null }],
controls: [sameControl],
}),
} as any,
memberId: 'mem_1',
});

expect(tx.control.create).toHaveBeenCalledWith(expect.objectContaining({
data: expect.objectContaining({
organizationId: 'org_1',
controlTemplateId: 'ct_1',
name: 'C',
description: 'D',
}),
}));
expect(tx.requirementMap.create).toHaveBeenCalledWith(expect.objectContaining({
data: expect.objectContaining({
controlId: 'ctl_created',
requirementId: 'rq_1',
frameworkInstanceId: 'frm_1',
}),
}));
});

it('unarchives an existing archived RequirementMap row instead of creating a duplicate', async () => {
const tx = mockTx();
tx.control.findMany.mockResolvedValue([
Expand Down Expand Up @@ -414,6 +466,54 @@ describe('applySync', () => {
expect(cpInsertCalled).toBe(true);
});

it('creates missing task row before reconciling _ControlToTask drift when v1 and v2 both claim it', async () => {
const tx = mockTx();
tx.control.findMany.mockResolvedValue([
{ id: 'ctl_1', controlTemplateId: 'ct_1', organizationId: 'org_1', name: 'C', description: 'D', archivedAt: null },
]);
tx.task.findMany.mockResolvedValue([]);
tx.task.create.mockResolvedValue({
id: 'tsk_created',
taskTemplateId: 'tt_1',
organizationId: 'org_1',
title: 'T',
description: 'D',
frequency: null,
department: null,
archivedAt: null,
});
tx.$queryRaw.mockResolvedValue([]);

const sameControl = { id: 'ct_1', name: 'C', description: 'D', requirementIds: [], policyIds: [], taskIds: ['tt_1'] };
const sameTask = { id: 'tt_1', name: 'T', description: 'D', frequency: null, department: null };
await applySync(tx, {
instance: baseInstance as any,
currentVersion: {
id: 'fvr_v1',
frameworkId: 'frk_soc2',
manifest: manifest({ controls: [sameControl], tasks: [sameTask] }),
} as any,
targetVersion: {
id: 'fvr_v2',
frameworkId: 'frk_soc2',
manifest: manifest({ controls: [sameControl], tasks: [sameTask] }),
} as any,
memberId: 'mem_1',
});

expect(tx.task.create).toHaveBeenCalledWith(expect.objectContaining({
data: expect.objectContaining({
organizationId: 'org_1',
taskTemplateId: 'tt_1',
title: 'T',
description: 'D',
}),
}));
const calls = tx.$executeRaw.mock.calls.map((c: unknown[]) => String(c[0]?.[0] ?? ''));
const ctInsertCalled = calls.some((s: string) => s.includes('INSERT INTO "_ControlToTask"'));
expect(ctInsertCalled).toBe(true);
});

it('creates missing ControlDocumentType row when v1 and v2 both claim it but customer has no row (drift)', async () => {
const tx = mockTx();
tx.control.findMany.mockResolvedValue([
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,17 @@ export async function applySync(
};

// --- Controls ---
for (const added of diff.controls.added) {
if (ctlByTemplate.has(added.id)) continue;
for (const targetControl of to.controls) {
if (ctlByTemplate.has(targetControl.id)) continue;
const created = await tx.control.create({
data: {
organizationId: ctx.instance.organizationId,
controlTemplateId: added.id,
name: added.name,
description: added.description,
controlTemplateId: targetControl.id,
name: targetControl.name,
description: targetControl.description,
},
});
ctlByTemplate.set(added.id, created);
ctlByTemplate.set(targetControl.id, created);
undo.controls.created.push(created.id);
summary.controlsAdded += 1;
}
Expand All @@ -103,19 +103,19 @@ export async function applySync(
}

// --- Tasks ---
for (const added of diff.tasks.added) {
if (taskByTemplate.has(added.id)) continue;
for (const targetTask of to.tasks) {
if (taskByTemplate.has(targetTask.id)) continue;
const created = await tx.task.create({
data: {
organizationId: ctx.instance.organizationId,
taskTemplateId: added.id,
title: added.name,
description: added.description,
frequency: added.frequency as Frequency | null,
department: added.department as Departments | null,
taskTemplateId: targetTask.id,
title: targetTask.name,
description: targetTask.description,
frequency: targetTask.frequency as Frequency | null,
department: targetTask.department as Departments | null,
},
});
taskByTemplate.set(added.id, created);
taskByTemplate.set(targetTask.id, created);
undo.tasks.created.push(created.id);
summary.tasksAdded += 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
Textarea,
} from '@trycompai/design-system';
import type { UseFormReturn } from 'react-hook-form';
import { BillingCallout } from './BackgroundCheckWizardParts';
import { BackgroundCheckSummary, BillingCallout } from './BackgroundCheckWizardParts';
import type { BackgroundCheckFormValues } from './backgroundCheckForm';

export function BackgroundCheckDetailsForm({
Expand All @@ -37,6 +37,8 @@ export function BackgroundCheckDetailsForm({
return (
<form noValidate onSubmit={form.handleSubmit(onSubmit)}>
<Stack gap="lg">
<BackgroundCheckSummary />
<div className="border-t" />
{billingSetupComplete && (
<BillingCallout
title="Payment method saved"
Expand Down
Loading
Loading