Skip to content

Commit 58dd753

Browse files
committed
fix(billing): atomize usage_log and userStats writes via central recordUsage()
1 parent 8caaf01 commit 58dd753

File tree

11 files changed

+199
-366
lines changed

11 files changed

+199
-366
lines changed

apps/sim/app/api/billing/update-cost/route.ts

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import { db } from '@sim/db'
2-
import { userStats } from '@sim/db/schema'
31
import { createLogger } from '@sim/logger'
4-
import { eq, sql } from 'drizzle-orm'
2+
import { sql } from 'drizzle-orm'
53
import { type NextRequest, NextResponse } from 'next/server'
64
import { z } from 'zod'
7-
import { logModelUsage } from '@/lib/billing/core/usage-log'
5+
import { recordUsage } from '@/lib/billing/core/usage-log'
86
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
97
import { checkInternalApiKey } from '@/lib/copilot/utils'
108
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
@@ -87,53 +85,40 @@ export async function POST(req: NextRequest) {
8785
source,
8886
})
8987

90-
// Check if user stats record exists (same as ExecutionLogger)
91-
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))
92-
93-
if (userStatsRecords.length === 0) {
94-
logger.error(
95-
`[${requestId}] User stats record not found - should be created during onboarding`,
96-
{
97-
userId,
98-
}
99-
)
100-
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
101-
}
102-
10388
const totalTokens = inputTokens + outputTokens
10489

105-
const updateFields: Record<string, unknown> = {
106-
totalCost: sql`total_cost + ${cost}`,
107-
currentPeriodCost: sql`current_period_cost + ${cost}`,
90+
const additionalStats: Record<string, ReturnType<typeof sql>> = {
10891
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
10992
currentPeriodCopilotCost: sql`current_period_copilot_cost + ${cost}`,
11093
totalCopilotCalls: sql`total_copilot_calls + 1`,
11194
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
112-
lastActive: new Date(),
11395
}
11496

11597
if (isMcp) {
116-
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
117-
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
118-
updateFields.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
98+
additionalStats.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
99+
additionalStats.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
100+
additionalStats.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
119101
}
120102

121-
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
122-
123-
logger.info(`[${requestId}] Updated user stats record`, {
103+
// Atomic write: usage_log INSERT + userStats UPDATE in one transaction
104+
await recordUsage({
124105
userId,
125-
addedCost: cost,
126-
source,
106+
entries: [
107+
{
108+
category: 'model',
109+
source,
110+
description: model,
111+
cost,
112+
metadata: { inputTokens, outputTokens },
113+
},
114+
],
115+
additionalStats,
127116
})
128117

129-
// Log usage for complete audit trail with the original source for visibility
130-
await logModelUsage({
118+
logger.info(`[${requestId}] Recorded usage`, {
131119
userId,
120+
addedCost: cost,
132121
source,
133-
model,
134-
inputTokens,
135-
outputTokens,
136-
cost,
137122
})
138123

139124
// Check if user has hit overage threshold and bill incrementally

apps/sim/app/api/wand/route.ts

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import { db } from '@sim/db'
2-
import { userStats, workflow } from '@sim/db/schema'
2+
import { workflow } from '@sim/db/schema'
33
import { createLogger } from '@sim/logger'
44
import { eq, sql } from 'drizzle-orm'
55
import { type NextRequest, NextResponse } from 'next/server'
66
import { getBYOKKey } from '@/lib/api-key/byok'
77
import { getSession } from '@/lib/auth'
8-
import { logModelUsage } from '@/lib/billing/core/usage-log'
8+
import { recordUsage } from '@/lib/billing/core/usage-log'
99
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
1010
import { env } from '@/lib/core/config/env'
1111
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
@@ -134,23 +134,21 @@ async function updateUserStatsForWand(
134134
costToStore = modelCost * costMultiplier
135135
}
136136

137-
await db
138-
.update(userStats)
139-
.set({
140-
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
141-
totalCost: sql`total_cost + ${costToStore}`,
142-
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
143-
lastActive: new Date(),
144-
})
145-
.where(eq(userStats.userId, userId))
146-
147-
await logModelUsage({
137+
// Atomic write: usage_log INSERT + userStats UPDATE in one transaction
138+
await recordUsage({
148139
userId,
149-
source: 'wand',
150-
model: modelName,
151-
inputTokens: promptTokens,
152-
outputTokens: completionTokens,
153-
cost: costToStore,
140+
entries: [
141+
{
142+
category: 'model',
143+
source: 'wand',
144+
description: modelName,
145+
cost: costToStore,
146+
metadata: { inputTokens: promptTokens, outputTokens: completionTokens },
147+
},
148+
],
149+
additionalStats: {
150+
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
151+
},
154152
})
155153

156154
await checkAndBillOverageThreshold(userId)
@@ -341,7 +339,7 @@ export async function POST(req: NextRequest) {
341339
let finalUsage: any = null
342340
let usageRecorded = false
343341

344-
const recordUsage = async () => {
342+
const flushUsage = async () => {
345343
if (usageRecorded || !finalUsage) {
346344
return
347345
}
@@ -360,7 +358,7 @@ export async function POST(req: NextRequest) {
360358

361359
if (done) {
362360
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
363-
await recordUsage()
361+
await flushUsage()
364362
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
365363
controller.close()
366364
break
@@ -390,7 +388,7 @@ export async function POST(req: NextRequest) {
390388
if (data === '[DONE]') {
391389
logger.info(`[${requestId}] Received [DONE] signal`)
392390

393-
await recordUsage()
391+
await flushUsage()
394392

395393
controller.enqueue(
396394
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
@@ -468,7 +466,7 @@ export async function POST(req: NextRequest) {
468466
})
469467

470468
try {
471-
await recordUsage()
469+
await flushUsage()
472470
} catch (usageError) {
473471
logger.warn(`[${requestId}] Failed to record usage after stream error`, usageError)
474472
}

0 commit comments

Comments
 (0)