Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 19 additions & 35 deletions apps/sim/app/api/billing/update-cost/route.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import { db } from '@sim/db'
import { userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { eq, sql } from 'drizzle-orm'
import { sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { logModelUsage } from '@/lib/billing/core/usage-log'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { checkInternalApiKey } from '@/lib/copilot/utils'
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
Expand Down Expand Up @@ -87,53 +85,39 @@ export async function POST(req: NextRequest) {
source,
})

// Check if user stats record exists (same as ExecutionLogger)
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))

if (userStatsRecords.length === 0) {
logger.error(
`[${requestId}] User stats record not found - should be created during onboarding`,
{
userId,
}
)
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
}

const totalTokens = inputTokens + outputTokens

const updateFields: Record<string, unknown> = {
totalCost: sql`total_cost + ${cost}`,
currentPeriodCost: sql`current_period_cost + ${cost}`,
const additionalStats: Record<string, ReturnType<typeof sql>> = {
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
currentPeriodCopilotCost: sql`current_period_copilot_cost + ${cost}`,
totalCopilotCalls: sql`total_copilot_calls + 1`,
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
lastActive: new Date(),
}

if (isMcp) {
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
updateFields.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
additionalStats.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
additionalStats.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
additionalStats.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
}

await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))

logger.info(`[${requestId}] Updated user stats record`, {
await recordUsage({
userId,
addedCost: cost,
source,
entries: [
{
category: 'model',
source,
description: model,
cost,
metadata: { inputTokens, outputTokens },
},
],
additionalStats,
})

// Log usage for complete audit trail with the original source for visibility
await logModelUsage({
logger.info(`[${requestId}] Recorded usage`, {
userId,
addedCost: cost,
source,
model,
inputTokens,
outputTokens,
cost,
})

// Check if user has hit overage threshold and bill incrementally
Expand Down
41 changes: 19 additions & 22 deletions apps/sim/app/api/wand/route.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { db } from '@sim/db'
import { userStats, workflow } from '@sim/db/schema'
import { workflow } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getBYOKKey } from '@/lib/api-key/byok'
import { getSession } from '@/lib/auth'
import { logModelUsage } from '@/lib/billing/core/usage-log'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { env } from '@/lib/core/config/env'
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
Expand Down Expand Up @@ -134,23 +134,20 @@ async function updateUserStatsForWand(
costToStore = modelCost * costMultiplier
}

await db
.update(userStats)
.set({
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
totalCost: sql`total_cost + ${costToStore}`,
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
lastActive: new Date(),
})
.where(eq(userStats.userId, userId))

await logModelUsage({
await recordUsage({
userId,
source: 'wand',
model: modelName,
inputTokens: promptTokens,
outputTokens: completionTokens,
cost: costToStore,
entries: [
{
category: 'model',
source: 'wand',
description: modelName,
cost: costToStore,
metadata: { inputTokens: promptTokens, outputTokens: completionTokens },
},
],
additionalStats: {
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
},
})

await checkAndBillOverageThreshold(userId)
Expand Down Expand Up @@ -341,7 +338,7 @@ export async function POST(req: NextRequest) {
let finalUsage: any = null
let usageRecorded = false

const recordUsage = async () => {
const flushUsage = async () => {
if (usageRecorded || !finalUsage) {
return
}
Expand All @@ -360,7 +357,7 @@ export async function POST(req: NextRequest) {

if (done) {
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
await recordUsage()
await flushUsage()
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
controller.close()
break
Expand Down Expand Up @@ -390,7 +387,7 @@ export async function POST(req: NextRequest) {
if (data === '[DONE]') {
logger.info(`[${requestId}] Received [DONE] signal`)

await recordUsage()
await flushUsage()

controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
Expand Down Expand Up @@ -468,7 +465,7 @@ export async function POST(req: NextRequest) {
})

try {
await recordUsage()
await flushUsage()
} catch (usageError) {
logger.warn(`[${requestId}] Failed to record usage after stream error`, usageError)
}
Expand Down
Loading
Loading