Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions apps/api/src/auth/auth.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ import { PermissionGuard } from './permission.guard';
auth,
// Don't register global auth guard - we use HybridAuthGuard
disableGlobalAuthGuard: true,
// CORS is already configured in main.ts — prevent the module from
// overriding it with its own trustedOrigins-based CORS.
disableTrustedOriginsCors: true,
// Body parsing for non-auth routes is handled in main.ts with a
// custom middleware that skips /api/auth paths. Disable the module's
// own SkipBodyParsingMiddleware to avoid conflicts.
disableBodyParser: true,
}),
],
controllers: [AuthController],
Expand Down
34 changes: 19 additions & 15 deletions apps/api/src/auth/auth.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const MAGIC_LINK_EXPIRES_IN_SECONDS = 60 * 60; // 1 hour
*/
function getCookieDomain(): string | undefined {
const baseUrl =
process.env.AUTH_BASE_URL || process.env.BETTER_AUTH_URL || '';
process.env.BASE_URL || '';

if (baseUrl.includes('staging.trycomp.ai')) {
return '.staging.trycomp.ai';
Expand Down Expand Up @@ -109,10 +109,10 @@ function validateSecurityConfig(): void {
// Warn about development defaults in production
if (process.env.NODE_ENV === 'production') {
const baseUrl =
process.env.AUTH_BASE_URL || process.env.BETTER_AUTH_URL || '';
process.env.BASE_URL || '';
if (baseUrl.includes('localhost')) {
console.warn(
'SECURITY WARNING: AUTH_BASE_URL contains "localhost" in production. ' +
'SECURITY WARNING: BASE_URL contains "localhost" in production. ' +
'This may cause issues with OAuth callbacks and cookies.',
);
}
Expand All @@ -125,23 +125,21 @@ validateSecurityConfig();
/**
* The auth server instance - single source of truth for authentication.
*
* IMPORTANT: For OAuth to work correctly with the app's auth proxy:
* - Set AUTH_BASE_URL to the app's URL (e.g., http://localhost:3000 in dev)
* - This ensures OAuth callbacks point to the app, which proxies to this API
* - Cookies will be set for the app's domain, not the API's domain
*
* In production, use the app's public URL (e.g., https://app.trycomp.ai)
* BASE_URL must point to the API (e.g., https://api.trycomp.ai).
* OAuth callbacks go directly to the API. Clients send absolute callbackURLs
* so better-auth redirects to the correct app after processing.
* Cross-subdomain cookies (.trycomp.ai) ensure the session works on all apps.
*/
export const auth = betterAuth({
database: prismaAdapter(db, {
provider: 'postgresql',
}),
// Use AUTH_BASE_URL pointing to the app (client), not the API itself
// This is critical for OAuth callbacks and cookie domains to work correctly
baseURL:
process.env.AUTH_BASE_URL ||
process.env.BETTER_AUTH_URL ||
'http://localhost:3000',
// baseURL must point to the API (e.g., https://api.trycomp.ai) so that
// OAuth callbacks go directly to the API regardless of which frontend
// initiated the flow. Clients must send absolute callbackURLs so that
// after OAuth processing, better-auth redirects to the correct app.
// Cross-subdomain cookies (.trycomp.ai) ensure the session works everywhere.
baseURL: process.env.BASE_URL || 'http://localhost:3333',
trustedOrigins: getTrustedOrigins(),
emailAndPassword: {
enabled: true,
Expand Down Expand Up @@ -322,6 +320,12 @@ export const auth = betterAuth({
enabled: true,
trustedProviders: ['google', 'github', 'microsoft'],
},
// Skip the state cookie CSRF check for OAuth flows.
// In our cross-origin setup (app/portal → API), the state cookie may not
// survive the OAuth redirect flow. The OAuth state parameter stored in the
// database already provides CSRF protection (random 32-char string validated
// against the DB). This is the same approach better-auth's oAuthProxy plugin uses.
skipStateCookieCheck: true,
},
verification: {
modelName: 'Verification',
Expand Down
11 changes: 4 additions & 7 deletions apps/api/src/config/better-auth.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ import { registerAs } from '@nestjs/config';
import { z } from 'zod';

const betterAuthConfigSchema = z.object({
url: z.string().url('AUTH_BASE_URL must be a valid URL'),
url: z.string().url('BASE_URL must be a valid URL'),
});

export type BetterAuthConfig = z.infer<typeof betterAuthConfigSchema>;

/**
* Better Auth configuration for the API.
*
* Since the API now runs the auth server, AUTH_BASE_URL should point to the API itself.
* BASE_URL should point to the API itself since the API is the auth server.
* For example:
* - Production: https://api.trycomp.ai
* - Staging: https://api.staging.trycomp.ai
Expand All @@ -19,17 +19,14 @@ export type BetterAuthConfig = z.infer<typeof betterAuthConfigSchema>;
export const betterAuthConfig = registerAs(
'betterAuth',
(): BetterAuthConfig => {
// AUTH_BASE_URL is the URL of the auth server (which is now the API)
// Fall back to BETTER_AUTH_URL for backwards compatibility during migration
const url = process.env.AUTH_BASE_URL || process.env.BETTER_AUTH_URL;
const url = process.env.BASE_URL;

if (!url) {
throw new Error('AUTH_BASE_URL or BETTER_AUTH_URL environment variable is required');
throw new Error('BASE_URL environment variable is required');
}

const config = { url };

// Validate configuration at startup
const result = betterAuthConfigSchema.safeParse(config);

if (!result.success) {
Expand Down
14 changes: 12 additions & 2 deletions apps/api/src/integration-platform/controllers/checks.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import { ApiTags, ApiSecurity } from '@nestjs/swagger';
import { HybridAuthGuard } from '../../auth/hybrid-auth.guard';
import { PermissionGuard } from '../../auth/permission.guard';
import { RequirePermission } from '../../auth/require-permission.decorator';
import { OrganizationId } from '../../auth/auth-context.decorator';
import {
getManifest,
getAvailableChecks,
runAllChecks,
} from '@comp/integration-platform';
import { ConnectionRepository } from '../repositories/connection.repository';
import { ConnectionService } from '../services/connection.service';
import { CredentialVaultService } from '../services/credential-vault.service';
import { ProviderRepository } from '../repositories/provider.repository';
import { CheckRunRepository } from '../repositories/check-run.repository';
Expand All @@ -40,6 +42,7 @@ export class ChecksController {
private readonly providerRepository: ProviderRepository,
private readonly credentialVaultService: CredentialVaultService,
private readonly checkRunRepository: CheckRunRepository,
private readonly connectionService: ConnectionService,
) {}

/**
Expand Down Expand Up @@ -68,7 +71,11 @@ export class ChecksController {
*/
@Get('connections/:connectionId')
@RequirePermission('integration', 'read')
async listConnectionChecks(@Param('connectionId') connectionId: string) {
async listConnectionChecks(
@Param('connectionId') connectionId: string,
@OrganizationId() organizationId: string,
) {
await this.connectionService.getConnectionForOrg(connectionId, organizationId);
const connection = await this.connectionRepository.findById(connectionId);
if (!connection) {
throw new HttpException('Connection not found', HttpStatus.NOT_FOUND);
Expand Down Expand Up @@ -106,7 +113,9 @@ export class ChecksController {
async runConnectionChecks(
@Param('connectionId') connectionId: string,
@Body() body: RunChecksDto,
@OrganizationId() organizationId: string,
) {
await this.connectionService.getConnectionForOrg(connectionId, organizationId);
const connection = await this.connectionRepository.findById(connectionId);
if (!connection) {
throw new HttpException('Connection not found', HttpStatus.NOT_FOUND);
Expand Down Expand Up @@ -306,7 +315,8 @@ export class ChecksController {
async runSingleCheck(
@Param('connectionId') connectionId: string,
@Param('checkId') checkId: string,
@OrganizationId() organizationId: string,
) {
return this.runConnectionChecks(connectionId, { checkId });
return this.runConnectionChecks(connectionId, { checkId }, organizationId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,11 @@ export class ConnectionsController {
*/
@Get(':id')
@RequirePermission('integration', 'read')
async getConnection(@Param('id') id: string) {
const connection = await this.connectionService.getConnection(id);
async getConnection(
@Param('id') id: string,
@OrganizationId() organizationId: string,
) {
const connection = await this.connectionService.getConnectionForOrg(id, organizationId);
const providerSlug = (connection as { provider?: { slug: string } })
.provider?.slug;

Expand Down Expand Up @@ -654,8 +657,11 @@ export class ConnectionsController {
*/
@Post(':id/test')
@RequirePermission('integration', 'update')
async testConnection(@Param('id') id: string) {
const connection = await this.connectionService.getConnection(id);
async testConnection(
@Param('id') id: string,
@OrganizationId() organizationId: string,
) {
const connection = await this.connectionService.getConnectionForOrg(id, organizationId);
const providerSlug = (connection as any).provider?.slug;

if (!providerSlug) {
Expand Down Expand Up @@ -744,7 +750,11 @@ export class ConnectionsController {
*/
@Post(':id/pause')
@RequirePermission('integration', 'update')
async pauseConnection(@Param('id') id: string) {
async pauseConnection(
@Param('id') id: string,
@OrganizationId() organizationId: string,
) {
await this.connectionService.getConnectionForOrg(id, organizationId);
const connection = await this.connectionService.pauseConnection(id);
return { id: connection.id, status: connection.status };
}
Expand All @@ -754,7 +764,11 @@ export class ConnectionsController {
*/
@Post(':id/resume')
@RequirePermission('integration', 'update')
async resumeConnection(@Param('id') id: string) {
async resumeConnection(
@Param('id') id: string,
@OrganizationId() organizationId: string,
) {
await this.connectionService.getConnectionForOrg(id, organizationId);
const connection = await this.connectionService.activateConnection(id);
return { id: connection.id, status: connection.status };
}
Expand All @@ -764,7 +778,11 @@ export class ConnectionsController {
*/
@Post(':id/disconnect')
@RequirePermission('integration', 'delete')
async disconnectConnection(@Param('id') id: string) {
async disconnectConnection(
@Param('id') id: string,
@OrganizationId() organizationId: string,
) {
await this.connectionService.getConnectionForOrg(id, organizationId);
const connection = await this.connectionService.disconnectConnection(id);
return { id: connection.id, status: connection.status };
}
Expand All @@ -774,7 +792,11 @@ export class ConnectionsController {
*/
@Delete(':id')
@RequirePermission('integration', 'delete')
async deleteConnection(@Param('id') id: string) {
async deleteConnection(
@Param('id') id: string,
@OrganizationId() organizationId: string,
) {
await this.connectionService.getConnectionForOrg(id, organizationId);
await this.connectionService.deleteConnection(id);
return { success: true };
}
Expand All @@ -789,13 +811,7 @@ export class ConnectionsController {
@OrganizationId() organizationId: string,
@Body() body: { metadata?: Record<string, unknown> },
) {
const connection = await this.connectionService.getConnection(id);
if (connection.organizationId !== organizationId) {
throw new HttpException(
'Connection does not belong to this organization',
HttpStatus.FORBIDDEN,
);
}
const connection = await this.connectionService.getConnectionForOrg(id, organizationId);

if (body.metadata && Object.keys(body.metadata).length > 0) {
// Merge with existing metadata
Expand Down Expand Up @@ -824,11 +840,7 @@ export class ConnectionsController {
@Param('id') id: string,
@OrganizationId() organizationId: string,
) {
const connection = await this.connectionService.getConnection(id);

if (connection.organizationId !== organizationId) {
throw new HttpException('Connection not found', HttpStatus.NOT_FOUND);
}
const connection = await this.connectionService.getConnectionForOrg(id, organizationId);

if (connection.status !== 'active') {
throw new HttpException(
Expand Down Expand Up @@ -988,11 +1000,7 @@ export class ConnectionsController {
@OrganizationId() organizationId: string,
@Body() body: { credentials: Record<string, string | string[]> },
) {
const connection = await this.connectionService.getConnection(id);

if (connection.organizationId !== organizationId) {
throw new HttpException('Connection not found', HttpStatus.NOT_FOUND);
}
const connection = await this.connectionService.getConnectionForOrg(id, organizationId);

const providerSlug = (connection as { provider?: { slug: string } })
.provider?.slug;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ import { ApiTags, ApiSecurity } from '@nestjs/swagger';
import { HybridAuthGuard } from '../../auth/hybrid-auth.guard';
import { PermissionGuard } from '../../auth/permission.guard';
import { RequirePermission } from '../../auth/require-permission.decorator';
import { OrganizationId } from '../../auth/auth-context.decorator';
import { getManifest, type CheckVariable } from '@comp/integration-platform';
import { ConnectionRepository } from '../repositories/connection.repository';
import { ConnectionService } from '../services/connection.service';
import { ProviderRepository } from '../repositories/provider.repository';
import { CredentialVaultService } from '../services/credential-vault.service';
import { AutoCheckRunnerService } from '../services/auto-check-runner.service';
Expand Down Expand Up @@ -52,6 +54,7 @@ export class VariablesController {
private readonly providerRepository: ProviderRepository,
private readonly credentialVaultService: CredentialVaultService,
private readonly autoCheckRunnerService: AutoCheckRunnerService,
private readonly connectionService: ConnectionService,
) {}

/**
Expand Down Expand Up @@ -109,7 +112,12 @@ export class VariablesController {
*/
@Get('connections/:connectionId')
@RequirePermission('integration', 'read')
async getConnectionVariables(@Param('connectionId') connectionId: string) {
async getConnectionVariables(
@Param('connectionId') connectionId: string,
@OrganizationId() organizationId: string,
) {
await this.connectionService.getConnectionForOrg(connectionId, organizationId);

const connection = await this.connectionRepository.findById(connectionId);
if (!connection) {
throw new HttpException('Connection not found', HttpStatus.NOT_FOUND);
Expand Down Expand Up @@ -179,7 +187,10 @@ export class VariablesController {
async fetchVariableOptions(
@Param('connectionId') connectionId: string,
@Param('variableId') variableId: string,
@OrganizationId() organizationId: string,
): Promise<{ options: VariableOption[] }> {
await this.connectionService.getConnectionForOrg(connectionId, organizationId);

const connection = await this.connectionRepository.findById(connectionId);
if (!connection) {
throw new HttpException('Connection not found', HttpStatus.NOT_FOUND);
Expand Down Expand Up @@ -386,7 +397,10 @@ export class VariablesController {
async saveConnectionVariables(
@Param('connectionId') connectionId: string,
@Body() body: SaveVariablesDto,
@OrganizationId() organizationId: string,
) {
await this.connectionService.getConnectionForOrg(connectionId, organizationId);

const connection = await this.connectionRepository.findById(connectionId);
if (!connection) {
throw new HttpException('Connection not found', HttpStatus.NOT_FOUND);
Expand Down
11 changes: 11 additions & 0 deletions apps/api/src/integration-platform/services/connection.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ export class ConnectionService {
return connection;
}

async getConnectionForOrg(
connectionId: string,
organizationId: string,
): Promise<IntegrationConnection> {
const connection = await this.connectionRepository.findById(connectionId);
if (!connection || connection.organizationId !== organizationId) {
throw new NotFoundException(`Connection ${connectionId} not found`);
}
return connection;
}

async getConnectionByProviderSlug(
providerSlug: string,
organizationId: string,
Expand Down
17 changes: 15 additions & 2 deletions apps/api/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,21 @@ async function bootstrap(): Promise<void> {
// STEP 3: Configure body parser
// NOTE: Attachment uploads are sent as base64 in JSON, so request payloads are
// larger than the raw file size. Keep this above the user-facing max file size.
app.use(express.json({ limit: '150mb' }));
app.use(express.urlencoded({ limit: '150mb', extended: true }));
// IMPORTANT: Skip body parsing for /api/auth routes — better-auth needs the raw
// request stream to properly read the body (including OAuth callbackURL).
// Express-level middleware runs BEFORE NestJS module middleware, so without this
// skip, express.json() would consume the stream before better-auth's handler.
const jsonParser = express.json({ limit: '150mb' });
const urlencodedParser = express.urlencoded({ limit: '150mb', extended: true });
app.use((req: express.Request, res: express.Response, next: express.NextFunction) => {
if (req.path.startsWith('/api/auth')) {
return next();
}
jsonParser(req, res, (err?: unknown) => {
if (err) return next(err);
urlencodedParser(req, res, next);
});
});

// STEP 4: Enable global pipes and filters
app.useGlobalPipes(
Expand Down
Loading
Loading