Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion internal/db/migrations/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ func MigrateDatabase(db *gorm.DB) error {
}

// Run v1.1.0 migration
return V1_1_0_AddKeyHashColumn(db)
if err := V1_1_0_AddKeyHashColumn(db); err != nil {
return err
}

// Run v1.2.0 migration
return V1_2_0_AddTokenColumns(db)
}

// HandleLegacyIndexes removes old indexes from previous versions to prevent migration errors
Expand Down
37 changes: 37 additions & 0 deletions internal/db/migrations/v1_2_0_AddTokenColumns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package db

import (
"gpt-load/internal/models"

"github.com/sirupsen/logrus"
"gorm.io/gorm"
)

// V1_2_0_AddTokenColumns adds token usage columns to request_logs table
func V1_2_0_AddTokenColumns(db *gorm.DB) error {
if !db.Migrator().HasColumn(&models.RequestLog{}, "prompt_tokens") {
if err := db.Migrator().AddColumn(&models.RequestLog{}, "prompt_tokens"); err != nil {
return err
}
logrus.Info("Added column prompt_tokens to request_logs")
}
if !db.Migrator().HasColumn(&models.RequestLog{}, "completion_tokens") {
if err := db.Migrator().AddColumn(&models.RequestLog{}, "completion_tokens"); err != nil {
return err
}
logrus.Info("Added column completion_tokens to request_logs")
}
if !db.Migrator().HasColumn(&models.RequestLog{}, "total_tokens") {
if err := db.Migrator().AddColumn(&models.RequestLog{}, "total_tokens"); err != nil {
return err
}
logrus.Info("Added column total_tokens to request_logs")
}
if !db.Migrator().HasColumn(&models.RequestLog{}, "token_cost_usd") {
if err := db.Migrator().AddColumn(&models.RequestLog{}, "token_cost_usd"); err != nil {
return err
}
logrus.Info("Added column token_cost_usd to request_logs")
}
return nil
}
95 changes: 95 additions & 0 deletions internal/handler/metrics_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package handler

import (
"fmt"
"strings"

"gpt-load/internal/models"

"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)

// Metrics returns a Prometheus-text /metrics endpoint exposing token usage
// and request counts aggregated from request_logs.
//
// This is deliberately kept minimal — a full Prometheus client library is not
// introduced. The output format follows the Prometheus exposition format so
// operators can scrape it with any standard Prometheus server and build
// dashboards (e.g. Grafana) on top.
func (s *Server) Metrics(c *gin.Context) {
var results []struct {
GroupName string
Model string
TotalRequests int64
TotalTokens int64
TotalCost float64
TotalPrompt int64
TotalCompletion int64
}

// Aggregate token usage and request count from successful non-streaming
// requests (those are the ones where we can extract usage data).
if err := s.DB.Model(&models.RequestLog{}).
Select(`COALESCE(group_name, '') as group_name,
COALESCE(model, 'unknown') as model,
COUNT(*) as total_requests,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(token_cost_usd), 0) as total_cost,
COALESCE(SUM(prompt_tokens), 0) as total_prompt,
COALESCE(SUM(completion_tokens), 0) as total_completion`).
Where("is_success = ? AND is_stream = ?", true, false).
Group("group_name, model").
Scan(&results).Error; err != nil {
logrus.WithError(err).Error("Failed to query metrics")
c.String(500, "internal error\n")
return
}

var sb strings.Builder
sb.WriteString("# HELP gpt_load_requests_total Total number of successful proxy requests by group and model\n")
sb.WriteString("# TYPE gpt_load_requests_total counter\n")
for _, r := range results {
sb.WriteString(fmt.Sprintf(
`gpt_load_requests_total{group=%q,model=%q} %d`+"\n",
r.GroupName, r.Model, r.TotalRequests,
))
}

sb.WriteString("\n# HELP gpt_load_tokens_total Total token count by type, group, and model\n")
sb.WriteString("# TYPE gpt_load_tokens_total counter\n")
for _, r := range results {
if r.TotalPrompt > 0 {
sb.WriteString(fmt.Sprintf(
`gpt_load_tokens_total{type="prompt",group=%q,model=%q} %d`+"\n",
r.GroupName, r.Model, r.TotalPrompt,
))
}
if r.TotalCompletion > 0 {
sb.WriteString(fmt.Sprintf(
`gpt_load_tokens_total{type="completion",group=%q,model=%q} %d`+"\n",
r.GroupName, r.Model, r.TotalCompletion,
))
}
if r.TotalTokens > 0 {
sb.WriteString(fmt.Sprintf(
`gpt_load_tokens_total{type="total",group=%q,model=%q} %d`+"\n",
r.GroupName, r.Model, r.TotalTokens,
))
}
}

sb.WriteString("\n# HELP gpt_load_cost_total Total cost in USD by group and model\n")
sb.WriteString("# TYPE gpt_load_cost_total counter\n")
for _, r := range results {
if r.TotalCost > 0 {
sb.WriteString(fmt.Sprintf(
`gpt_load_cost_total{group=%q,model=%q} %.6f`+"\n",
r.GroupName, r.Model, r.TotalCost,
))
}
}

c.Header("Content-Type", "text/plain; charset=utf-8")
c.String(200, sb.String())
}
44 changes: 24 additions & 20 deletions internal/models/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,26 +135,30 @@ const (

// RequestLog 对应 request_logs 表
type RequestLog struct {
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
Timestamp time.Time `gorm:"not null;index" json:"timestamp"`
GroupID uint `gorm:"not null;index" json:"group_id"`
GroupName string `gorm:"type:varchar(255);index" json:"group_name"`
ParentGroupID uint `gorm:"index" json:"parent_group_id"`
ParentGroupName string `gorm:"type:varchar(255);index" json:"parent_group_name"`
KeyValue string `gorm:"type:text" json:"key_value"`
KeyHash string `gorm:"type:varchar(128);index" json:"key_hash"`
Model string `gorm:"type:varchar(255);index" json:"model"`
IsSuccess bool `gorm:"not null" json:"is_success"`
SourceIP string `gorm:"type:varchar(64)" json:"source_ip"`
StatusCode int `gorm:"not null" json:"status_code"`
RequestPath string `gorm:"type:varchar(500)" json:"request_path"`
Duration int64 `gorm:"not null" json:"duration_ms"`
ErrorMessage string `gorm:"type:text" json:"error_message"`
UserAgent string `gorm:"type:varchar(512)" json:"user_agent"`
RequestType string `gorm:"type:varchar(20);not null;default:'final';index" json:"request_type"`
UpstreamAddr string `gorm:"type:varchar(500)" json:"upstream_addr"`
IsStream bool `gorm:"not null" json:"is_stream"`
RequestBody string `gorm:"type:text" json:"request_body"`
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
Timestamp time.Time `gorm:"not null;index" json:"timestamp"`
GroupID uint `gorm:"not null;index" json:"group_id"`
GroupName string `gorm:"type:varchar(255);index" json:"group_name"`
ParentGroupID uint `gorm:"index" json:"parent_group_id"`
ParentGroupName string `gorm:"type:varchar(255);index" json:"parent_group_name"`
KeyValue string `gorm:"type:text" json:"key_value"`
KeyHash string `gorm:"type:varchar(128);index" json:"key_hash"`
Model string `gorm:"type:varchar(255);index" json:"model"`
IsSuccess bool `gorm:"not null" json:"is_success"`
SourceIP string `gorm:"type:varchar(64)" json:"source_ip"`
StatusCode int `gorm:"not null" json:"status_code"`
RequestPath string `gorm:"type:varchar(500)" json:"request_path"`
Duration int64 `gorm:"not null" json:"duration_ms"`
ErrorMessage string `gorm:"type:text" json:"error_message"`
UserAgent string `gorm:"type:varchar(512)" json:"user_agent"`
RequestType string `gorm:"type:varchar(20);not null;default:'final';index" json:"request_type"`
UpstreamAddr string `gorm:"type:varchar(500)" json:"upstream_addr"`
IsStream bool `gorm:"not null" json:"is_stream"`
RequestBody string `gorm:"type:text" json:"request_body"`
PromptTokens int64 `gorm:"not null;default:0" json:"prompt_tokens"`
CompletionTokens int64 `gorm:"not null;default:0" json:"completion_tokens"`
TotalTokens int64 `gorm:"not null;default:0" json:"total_tokens"`
TokenCostUSD float64 `gorm:"not null;default:0" json:"token_cost_usd"`
}

// StatCard 用于仪表盘的单个统计卡片数据
Expand Down
33 changes: 30 additions & 3 deletions internal/proxy/response_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,35 @@ func (ps *ProxyServer) handleStreamingResponse(c *gin.Context, resp *http.Respon
}
}

func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) {
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
logUpstreamError("copying response body", err)
// handleNormalResponse buffers the upstream response body only when token usage
// extraction is needed (successful chat-completion responses). For all other
// non-stream responses it streams directly to the client via io.Copy to avoid
// buffering large payloads into memory.
// Note: response headers and status code are already set by the caller (HandleProxy).
func (ps *ProxyServer) handleNormalResponse(c *gin.Context, resp *http.Response) *TokenUsage {
needsUsage := resp.StatusCode < 400 && isChatCompletionPath(c.Request.URL.Path)

if !needsUsage {
// Stream directly to client — no buffering.
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
logUpstreamError("streaming response to client", err)
}
return nil
}

// Buffer the body for usage extraction.
body, err := io.ReadAll(resp.Body)
if err != nil {
logUpstreamError("reading response body", err)
return nil
}

body = handleGzipCompression(resp, body)

if _, writeErr := c.Writer.Write(body); writeErr != nil {
logUpstreamError("writing buffered body to client", writeErr)
return nil
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

return extractTokenUsage(body)
}
24 changes: 18 additions & 6 deletions internal/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (ps *ProxyServer) executeRequestWithRetry(
if err != nil {
logrus.Errorf("Failed to select a key for group %s on attempt %d: %v", group.Name, retryCount+1, err)
response.Error(c, app_errors.NewAPIError(app_errors.ErrNoKeysAvailable, err.Error()))
ps.logRequest(c, originalGroup, group, nil, startTime, http.StatusServiceUnavailable, err, isStream, "", channelHandler, bodyBytes, models.RequestTypeFinal)
ps.logRequest(c, originalGroup, group, nil, startTime, http.StatusServiceUnavailable, err, isStream, "", channelHandler, bodyBytes, models.RequestTypeFinal, nil)
return
}

Expand Down Expand Up @@ -169,7 +169,7 @@ func (ps *ProxyServer) executeRequestWithRetry(
finalBodyBytes, err := channelHandler.ApplyModelRedirect(req, bodyBytes, group)
if err != nil {
response.Error(c, app_errors.NewAPIError(app_errors.ErrBadRequest, err.Error()))
ps.logRequest(c, originalGroup, group, apiKey, startTime, http.StatusBadRequest, err, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal)
ps.logRequest(c, originalGroup, group, apiKey, startTime, http.StatusBadRequest, err, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal, nil)
return
}

Expand Down Expand Up @@ -206,7 +206,7 @@ func (ps *ProxyServer) executeRequestWithRetry(
if err != nil || shouldRetryByStatus {
if err != nil && app_errors.IsIgnorableError(err) {
logrus.Debugf("Client-side ignorable error for key %s, aborting retries: %v", utils.MaskAPIKey(apiKey.KeyValue), err)
ps.logRequest(c, originalGroup, group, apiKey, startTime, 499, err, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal)
ps.logRequest(c, originalGroup, group, apiKey, startTime, 499, err, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal, nil)
return
}

Expand Down Expand Up @@ -244,7 +244,7 @@ func (ps *ProxyServer) executeRequestWithRetry(
requestType = models.RequestTypeFinal
}

ps.logRequest(c, originalGroup, group, apiKey, startTime, statusCode, errors.New(parsedError), isStream, upstreamURL, channelHandler, bodyBytes, requestType)
ps.logRequest(c, originalGroup, group, apiKey, startTime, statusCode, errors.New(parsedError), isStream, upstreamURL, channelHandler, bodyBytes, requestType, nil)

// 如果是最后一次尝试,直接返回错误,不再递归
if isLastAttempt {
Expand All @@ -264,6 +264,8 @@ func (ps *ProxyServer) executeRequestWithRetry(
// ps.keyProvider.UpdateStatus(apiKey, group, true) // 请求成功不再重置成功次数,减少IO消耗
logrus.Debugf("Request for group %s succeeded on attempt %d with key %s", group.Name, retryCount+1, utils.MaskAPIKey(apiKey.KeyValue))

var usage *TokenUsage

// Check if this is a model list request (needs special handling)
if shouldInterceptModelList(c.Request.URL.Path, c.Request.Method) {
ps.handleModelListResponse(c, resp, group, channelHandler)
Expand All @@ -278,11 +280,11 @@ func (ps *ProxyServer) executeRequestWithRetry(
if isStream {
ps.handleStreamingResponse(c, resp)
} else {
ps.handleNormalResponse(c, resp)
usage = ps.handleNormalResponse(c, resp)
}
}

ps.logRequest(c, originalGroup, group, apiKey, startTime, resp.StatusCode, nil, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal)
ps.logRequest(c, originalGroup, group, apiKey, startTime, resp.StatusCode, nil, isStream, upstreamURL, channelHandler, bodyBytes, models.RequestTypeFinal, usage)
}

func shouldFailoverOnStatusCode(statusCode int, group *models.Group) bool {
Expand All @@ -293,6 +295,7 @@ func shouldFailoverOnStatusCode(statusCode int, group *models.Group) bool {
}

// logRequest is a helper function to create and record a request log.
// usage may be nil for streaming requests or failed requests.
func (ps *ProxyServer) logRequest(
c *gin.Context,
originalGroup *models.Group,
Expand All @@ -306,6 +309,7 @@ func (ps *ProxyServer) logRequest(
channelHandler channel.ChannelProxy,
bodyBytes []byte,
requestType string,
usage *TokenUsage,
) {
if ps.requestLogService == nil {
return
Expand Down Expand Up @@ -335,6 +339,14 @@ func (ps *ProxyServer) logRequest(
RequestBody: requestBodyToLog,
}

// Set token usage data extracted from the response body.
if usage != nil {
logEntry.PromptTokens = usage.PromptTokens
logEntry.CompletionTokens = usage.CompletionTokens
logEntry.TotalTokens = usage.TotalTokens
logEntry.TokenCostUSD = usage.CostUSD
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// Set parent group
if originalGroup != nil && originalGroup.GroupType == "aggregate" && originalGroup.ID != group.ID {
logEntry.ParentGroupID = originalGroup.ID
Expand Down
Loading