Skip to content

Commit f94be08

Browse files
authored
fix(billing): atomize usage_log and userStats writes via central recordUsage (#3767)
* fix(billing): atomize usage_log and userStats writes via central recordUsage() * fix(billing): address PR review — re-throw errors, guard reserved keys, handle zero-cost counters * chore(lint): fix formatting in hubspot list_lists.ts from staging * fix(billing): tighten early-return guard to handle empty additionalStats object * lint * chore(billing): remove implementation-decision comments
1 parent 54a862d commit f94be08

File tree

12 files changed

+201
-372
lines changed

12 files changed

+201
-372
lines changed

apps/sim/app/api/billing/update-cost/route.ts

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import { db } from '@sim/db'
2-
import { userStats } from '@sim/db/schema'
31
import { createLogger } from '@sim/logger'
4-
import { eq, sql } from 'drizzle-orm'
2+
import { sql } from 'drizzle-orm'
53
import { type NextRequest, NextResponse } from 'next/server'
64
import { z } from 'zod'
7-
import { logModelUsage } from '@/lib/billing/core/usage-log'
5+
import { recordUsage } from '@/lib/billing/core/usage-log'
86
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
97
import { checkInternalApiKey } from '@/lib/copilot/utils'
108
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
@@ -87,53 +85,39 @@ export async function POST(req: NextRequest) {
8785
source,
8886
})
8987

90-
// Check if user stats record exists (same as ExecutionLogger)
91-
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))
92-
93-
if (userStatsRecords.length === 0) {
94-
logger.error(
95-
`[${requestId}] User stats record not found - should be created during onboarding`,
96-
{
97-
userId,
98-
}
99-
)
100-
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
101-
}
102-
10388
const totalTokens = inputTokens + outputTokens
10489

105-
const updateFields: Record<string, unknown> = {
106-
totalCost: sql`total_cost + ${cost}`,
107-
currentPeriodCost: sql`current_period_cost + ${cost}`,
90+
const additionalStats: Record<string, ReturnType<typeof sql>> = {
10891
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
10992
currentPeriodCopilotCost: sql`current_period_copilot_cost + ${cost}`,
11093
totalCopilotCalls: sql`total_copilot_calls + 1`,
11194
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
112-
lastActive: new Date(),
11395
}
11496

11597
if (isMcp) {
116-
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
117-
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
118-
updateFields.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
98+
additionalStats.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
99+
additionalStats.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
100+
additionalStats.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
119101
}
120102

121-
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
122-
123-
logger.info(`[${requestId}] Updated user stats record`, {
103+
await recordUsage({
124104
userId,
125-
addedCost: cost,
126-
source,
105+
entries: [
106+
{
107+
category: 'model',
108+
source,
109+
description: model,
110+
cost,
111+
metadata: { inputTokens, outputTokens },
112+
},
113+
],
114+
additionalStats,
127115
})
128116

129-
// Log usage for complete audit trail with the original source for visibility
130-
await logModelUsage({
117+
logger.info(`[${requestId}] Recorded usage`, {
131118
userId,
119+
addedCost: cost,
132120
source,
133-
model,
134-
inputTokens,
135-
outputTokens,
136-
cost,
137121
})
138122

139123
// Check if user has hit overage threshold and bill incrementally

apps/sim/app/api/wand/route.ts

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import { db } from '@sim/db'
2-
import { userStats, workflow } from '@sim/db/schema'
2+
import { workflow } from '@sim/db/schema'
33
import { createLogger } from '@sim/logger'
44
import { eq, sql } from 'drizzle-orm'
55
import { type NextRequest, NextResponse } from 'next/server'
66
import { getBYOKKey } from '@/lib/api-key/byok'
77
import { getSession } from '@/lib/auth'
8-
import { logModelUsage } from '@/lib/billing/core/usage-log'
8+
import { recordUsage } from '@/lib/billing/core/usage-log'
99
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
1010
import { env } from '@/lib/core/config/env'
1111
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
@@ -134,23 +134,20 @@ async function updateUserStatsForWand(
134134
costToStore = modelCost * costMultiplier
135135
}
136136

137-
await db
138-
.update(userStats)
139-
.set({
140-
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
141-
totalCost: sql`total_cost + ${costToStore}`,
142-
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
143-
lastActive: new Date(),
144-
})
145-
.where(eq(userStats.userId, userId))
146-
147-
await logModelUsage({
137+
await recordUsage({
148138
userId,
149-
source: 'wand',
150-
model: modelName,
151-
inputTokens: promptTokens,
152-
outputTokens: completionTokens,
153-
cost: costToStore,
139+
entries: [
140+
{
141+
category: 'model',
142+
source: 'wand',
143+
description: modelName,
144+
cost: costToStore,
145+
metadata: { inputTokens: promptTokens, outputTokens: completionTokens },
146+
},
147+
],
148+
additionalStats: {
149+
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
150+
},
154151
})
155152

156153
await checkAndBillOverageThreshold(userId)
@@ -341,7 +338,7 @@ export async function POST(req: NextRequest) {
341338
let finalUsage: any = null
342339
let usageRecorded = false
343340

344-
const recordUsage = async () => {
341+
const flushUsage = async () => {
345342
if (usageRecorded || !finalUsage) {
346343
return
347344
}
@@ -360,7 +357,7 @@ export async function POST(req: NextRequest) {
360357

361358
if (done) {
362359
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
363-
await recordUsage()
360+
await flushUsage()
364361
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
365362
controller.close()
366363
break
@@ -390,7 +387,7 @@ export async function POST(req: NextRequest) {
390387
if (data === '[DONE]') {
391388
logger.info(`[${requestId}] Received [DONE] signal`)
392389

393-
await recordUsage()
390+
await flushUsage()
394391

395392
controller.enqueue(
396393
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
@@ -468,7 +465,7 @@ export async function POST(req: NextRequest) {
468465
})
469466

470467
try {
471-
await recordUsage()
468+
await flushUsage()
472469
} catch (usageError) {
473470
logger.warn(`[${requestId}] Failed to record usage after stream error`, usageError)
474471
}

0 commit comments

Comments
 (0)