Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions apps/api/src/cloud-security/azure-remediation.service.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { Injectable, Logger } from '@nestjs/common';
import { db, Prisma } from '@db';
import { getManifest } from '@trycompai/integration-platform';
import { CredentialVaultService } from '../integration-platform/services/credential-vault.service';
import { OAuthCredentialsService } from '../integration-platform/services/oauth-credentials.service';
import { AiRemediationService } from './ai-remediation.service';
import { AzureSecurityService } from './providers/azure-security.service';
import { parseAzurePermissionError } from './remediation-error.utils';
Expand Down Expand Up @@ -36,6 +38,7 @@ export class AzureRemediationService {

constructor(
private readonly credentialVaultService: CredentialVaultService,
private readonly oauthCredentialsService: OAuthCredentialsService,
private readonly aiRemediationService: AiRemediationService,
private readonly azureSecurityService: AzureSecurityService,
) {}
Expand Down Expand Up @@ -455,31 +458,13 @@ export class AzureRemediationService {
throw new Error('No rollback steps available for this action.');
}

// Get fresh access token
const credentials = await this.resolveCredentials(
// Get fresh access token (auto-refreshes if expired)
const accessToken = await this.getValidAzureToken(
action.connectionId,
action.organizationId,
);
if (!credentials) {
throw new Error('Cannot retrieve Azure credentials for rollback.');
}

// OAuth flow: token from vault; legacy: SP client credentials
let accessToken = credentials.access_token as string | undefined;
if (
!accessToken &&
credentials.tenantId &&
credentials.clientId &&
credentials.clientSecret
) {
accessToken = await this.azureSecurityService.getAccessToken(
credentials.tenantId as string,
credentials.clientId as string,
credentials.clientSecret as string,
);
}
if (!accessToken) {
throw new Error('Cannot obtain Azure access token for rollback.');
throw new Error('Cannot obtain Azure access token for rollback. Please reconnect the integration.');
}

this.logger.log(
Expand Down Expand Up @@ -638,6 +623,56 @@ export class AzureRemediationService {

// --- Private helpers ---

/**
* Get a valid Azure access token, refreshing if expired.
*/
private async getValidAzureToken(
connectionId: string,
organizationId: string,
): Promise<string | null> {
const manifest = getManifest('azure');
const oauthConfig = manifest?.auth?.type === 'oauth2' ? manifest.auth.config : null;

if (oauthConfig) {
const oauthCreds = await this.oauthCredentialsService.getCredentials(
'azure',
organizationId,
);
if (oauthCreds) {
const token = await this.credentialVaultService.getValidAccessToken(
connectionId,
{
tokenUrl: oauthConfig.tokenUrl,
clientId: oauthCreds.clientId,
clientSecret: oauthCreds.clientSecret,
clientAuthMethod: oauthConfig.clientAuthMethod,
},
);
if (token) return token;
}
}

// Fallback: try raw credentials (legacy SP or expired token)
const credentials =
await this.credentialVaultService.getDecryptedCredentials(connectionId);
if (!credentials) return null;

if (credentials.access_token) {
return credentials.access_token as string;
}

// Legacy service principal flow
if (credentials.tenantId && credentials.clientId && credentials.clientSecret) {
return this.azureSecurityService.getAccessToken(
credentials.tenantId as string,
credentials.clientId as string,
credentials.clientSecret as string,
);
}

return null;
}

private async resolveCredentials(
connectionId: string,
organizationId: string,
Expand All @@ -655,30 +690,16 @@ export class AzureRemediationService {
organizationId: string,
checkResultId: string,
) {
const credentials = await this.resolveCredentials(
connectionId,
organizationId,
);

let accessToken: string | null = null;
// OAuth flow: token from vault
if (credentials?.access_token) {
accessToken = credentials.access_token as string;
}
// Legacy SP flow fallback
if (
!accessToken &&
credentials?.tenantId &&
credentials?.clientId &&
credentials?.clientSecret
) {
accessToken = await this.azureSecurityService.getAccessToken(
credentials.tenantId as string,
credentials.clientId as string,
credentials.clientSecret as string,
);
const connection = await db.integrationConnection.findFirst({
where: { id: connectionId, organizationId, status: 'active' },
include: { provider: true },
});
if (!connection || connection.provider.slug !== 'azure') {
throw new Error('Azure connection not found or not active');
}

const accessToken = await this.getValidAzureToken(connectionId, organizationId);

const checkResult = await db.integrationCheckResult.findFirst({
where: {
id: checkResultId,
Expand Down
193 changes: 121 additions & 72 deletions apps/api/src/integration-platform/services/credential-vault.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,28 @@ export class CredentialVaultService {
}

/**
* Get decrypted credentials for a connection
* Get decrypted credentials for a connection.
* Prefers the explicitly marked active version, falls back to latest by version number.
*/
async getDecryptedCredentials(
connectionId: string,
): Promise<Record<string, string | string[]> | null> {
const latestVersion =
await this.credentialRepository.findLatestByConnection(connectionId);
if (!latestVersion) return null;
// Prefer the active credential version set during token storage/refresh
const connection = await this.connectionRepository.findById(connectionId);
let version = connection?.activeCredentialVersionId
? await this.credentialRepository.findById(
connection.activeCredentialVersionId,
)
: null;

// Fall back to latest version by version number
if (!version) {
version =
await this.credentialRepository.findLatestByConnection(connectionId);
}
if (!version) return null;

const latestVersion = version;

const encryptedPayload = latestVersion.encryptedPayload as Record<
string,
Expand Down Expand Up @@ -297,8 +311,66 @@ export class CredentialVaultService {
}

/**
* Refresh OAuth tokens using the refresh token
* Returns the new access token, or null if refresh failed
* Attempt a single token refresh request to the OAuth provider.
* Returns the new access token on success, or null on failure.
*/
private async attemptTokenRefresh(
connectionId: string,
refreshToken: string,
config: TokenRefreshConfig,
): Promise<{ token?: string; status?: number; errorBody?: string }> {
const body = new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: refreshToken,
});

const headers: Record<string, string> = {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
};

// Per OAuth 2.0 RFC 6749 Section 2.3.1, when using HTTP Basic auth (header),
// client credentials should NOT be included in the request body
if (config.clientAuthMethod === 'header') {
const credentials = Buffer.from(
`${config.clientId}:${config.clientSecret}`,
).toString('base64');
headers['Authorization'] = `Basic ${credentials}`;
} else {
body.set('client_id', config.clientId);
body.set('client_secret', config.clientSecret);
}

const refreshEndpoint = config.refreshUrl || config.tokenUrl;
const response = await fetch(refreshEndpoint, {
method: 'POST',
headers,
body: body.toString(),
});

if (!response.ok) {
const errorBody = await response.text();
return { status: response.status, errorBody };
}

const tokens: OAuthTokens = await response.json();

const tokensToStore: OAuthTokens = {
access_token: tokens.access_token,
refresh_token: tokens.refresh_token || refreshToken,
token_type: tokens.token_type,
expires_in: tokens.expires_in,
scope: tokens.scope,
};

await this.storeOAuthTokens(connectionId, tokensToStore);
return { token: tokens.access_token };
}

/**
* Refresh OAuth tokens using the refresh token.
* Retries once after a short delay before marking the connection as error.
* Returns the new access token, or null if refresh failed.
*/
async refreshOAuthTokens(
connectionId: string,
Expand All @@ -315,76 +387,55 @@ export class CredentialVaultService {
try {
this.logger.log(`Refreshing OAuth tokens for connection ${connectionId}`);

// Build the token request
const body = new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: refreshToken,
});

const headers: Record<string, string> = {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
};

// Add client credentials based on auth method
// Per OAuth 2.0 RFC 6749 Section 2.3.1, when using HTTP Basic auth (header),
// client credentials should NOT be included in the request body
if (config.clientAuthMethod === 'header') {
const credentials = Buffer.from(
`${config.clientId}:${config.clientSecret}`,
).toString('base64');
headers['Authorization'] = `Basic ${credentials}`;
} else {
// Default: send in body
body.set('client_id', config.clientId);
body.set('client_secret', config.clientSecret);
// First attempt
const first = await this.attemptTokenRefresh(
connectionId,
refreshToken,
config,
);
if (first.token) {
this.logger.log(
`Successfully refreshed OAuth tokens for connection ${connectionId}`,
);
return first.token;
}

// Use refreshUrl if provided, otherwise fall back to tokenUrl
const refreshEndpoint = config.refreshUrl || config.tokenUrl;

const response = await fetch(refreshEndpoint, {
method: 'POST',
headers,
body: body.toString(),
});
// Retry once after 2 seconds for transient failures (rate limits, network blips)
this.logger.warn(
`Token refresh attempt 1 failed for connection ${connectionId}: HTTP ${first.status} — ${first.errorBody ?? '(no body)'}. Retrying in 2s...`,
);
await new Promise((r) => setTimeout(r, 2000));

if (!response.ok) {
await response.text(); // consume body
this.logger.error(
`Token refresh failed for connection ${connectionId}: ${response.status}`,
const second = await this.attemptTokenRefresh(
connectionId,
refreshToken,
config,
);
if (second.token) {
this.logger.log(
`Successfully refreshed OAuth tokens for connection ${connectionId} on retry`,
);

// If refresh token is invalid/expired, mark connection as error
if (response.status === 400 || response.status === 401) {
await this.connectionRepository.update(connectionId, {
status: 'error',
errorMessage:
'OAuth token expired. Please reconnect the integration.',
});
}

return null;
return second.token;
}

const tokens: OAuthTokens = await response.json();

// Store the new tokens
// Note: Some providers return a new refresh token, some don't
const tokensToStore: OAuthTokens = {
access_token: tokens.access_token,
refresh_token: tokens.refresh_token || refreshToken, // Keep old refresh token if not provided
token_type: tokens.token_type,
expires_in: tokens.expires_in,
scope: tokens.scope,
};
// Both attempts failed — log the full error and mark connection
this.logger.error(
`Token refresh failed for connection ${connectionId} after 2 attempts: HTTP ${second.status} — ${second.errorBody ?? '(no body)'}`,
);

await this.storeOAuthTokens(connectionId, tokensToStore);
if (
second.status === 400 ||
second.status === 401 ||
second.status === 403
) {
await this.connectionRepository.update(connectionId, {
status: 'error',
errorMessage:
'OAuth token expired. Please reconnect the integration.',
});
}

this.logger.log(
`Successfully refreshed OAuth tokens for connection ${connectionId}`,
);
return tokens.access_token;
return null;
} catch (error) {
this.logger.error(
`Error refreshing tokens for connection ${connectionId}:`,
Expand All @@ -402,10 +453,9 @@ export class CredentialVaultService {
connectionId: string,
refreshConfig?: TokenRefreshConfig,
): Promise<string | null> {
// Check if we need to refresh
const needsRefresh = await this.needsRefresh(connectionId);
const shouldRefresh = await this.needsRefresh(connectionId);

if (needsRefresh && refreshConfig) {
if (shouldRefresh && refreshConfig) {
const newToken = await this.refreshOAuthTokens(
connectionId,
refreshConfig,
Expand All @@ -416,7 +466,6 @@ export class CredentialVaultService {
// If refresh failed, try to use existing token (might still work briefly)
}

// Get current credentials
const credentials = await this.getDecryptedCredentials(connectionId);
return typeof credentials?.access_token === 'string'
? credentials.access_token
Expand Down
Loading
Loading