Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions backend/internal/app/pluginadmin/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,24 +176,31 @@ func (s *Service) ListMarketplace(ctx context.Context) ([]MarketplacePlugin, err
return nil, err
}

installedVersions := make(map[string]string)
type installedPlugin struct {
version string
isDev bool
}
installed := make(map[string]installedPlugin)
for _, meta := range s.manager.GetAllPluginMeta() {
installedVersions[meta.Name] = meta.Version
installed[meta.Name] = installedPlugin{
version: meta.Version,
isDev: meta.IsDev,
}
}

result := make([]MarketplacePlugin, 0, len(items))
for _, item := range items {
installedVer, installed := installedVersions[item.Name]
meta, ok := installed[item.Name]
result = append(result, MarketplacePlugin{
Name: item.Name,
Version: item.Version,
Description: item.Description,
Author: item.Author,
Type: item.Type,
GithubRepo: item.GithubRepo,
Installed: installed,
InstalledVersion: installedVer,
HasUpdate: installed && isNewerVersion(item.Version, installedVer),
Installed: ok,
InstalledVersion: meta.version,
HasUpdate: ok && !meta.isDev && isNewerVersion(item.Version, meta.version),
})
}
return result, nil
Expand Down
18 changes: 18 additions & 0 deletions backend/internal/app/pluginadmin/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ func TestListMarketplaceMarksInstalled(t *testing.T) {
}
}

func TestListMarketplaceDoesNotOfferUpdatesForDevPlugin(t *testing.T) {
service := NewService(pluginAdminManagerStub{
allMeta: []plugin.PluginMeta{{Name: "airgate-playground", Version: "0.1.0", IsDev: true}},
}, pluginMarketplaceStub{
listAvailable: func(context.Context) ([]plugin.MarketplacePlugin, error) {
return []plugin.MarketplacePlugin{{Name: "airgate-playground", Version: "0.1.10"}}, nil
},
})

items, err := service.ListMarketplace(t.Context())
if err != nil {
t.Fatalf("ListMarketplace() error = %v", err)
}
if len(items) != 1 || !items[0].Installed || items[0].HasUpdate {
t.Fatalf("unexpected marketplace items: %+v", items)
}
}

type pluginAdminManagerStub struct {
allMeta []plugin.PluginMeta
}
Expand Down
23 changes: 10 additions & 13 deletions backend/internal/auth/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,20 @@ func (i *APIKeyInfo) UserGroupRate() (float64, bool) {
// GenerateAPIKey 生成 API Key 和对应的哈希值
// 返回明文密钥(仅展示一次)和用于存储的哈希
func GenerateAPIKey() (key string, hash string, err error) {
// 生成 32 字节随机数据
return generatePrefixedAPIKey(apiKeyPrefix)
}

// GenerateAdminAPIKey 生成管理员 API Key,返回明文密钥和哈希。
func GenerateAdminAPIKey() (key string, hash string, err error) {
return generatePrefixedAPIKey(adminKeyPrefix)
}

func generatePrefixedAPIKey(prefix string) (key string, hash string, err error) {
b := make([]byte, 32)
if _, err = rand.Read(b); err != nil {
return "", "", err
}
key = apiKeyPrefix + hex.EncodeToString(b)
key = prefix + hex.EncodeToString(b)
hash = HashAPIKey(key)
return key, hash, nil
}
Expand All @@ -115,17 +123,6 @@ func HashAPIKey(key string) string {
return hex.EncodeToString(h[:])
}

// GenerateAdminAPIKey 生成管理员 API Key,返回明文密钥和哈希。
func GenerateAdminAPIKey() (key string, hash string, err error) {
b := make([]byte, 32)
if _, err = rand.Read(b); err != nil {
return "", "", err
}
key = adminKeyPrefix + hex.EncodeToString(b)
hash = HashAPIKey(key)
return key, hash, nil
}

// AdminKeyHint 生成管理员 API Key 的显示提示(前缀 + 前4位...后4位)。
func AdminKeyHint(key string) string {
if len(key) <= 12 {
Expand Down
32 changes: 32 additions & 0 deletions backend/internal/auth/apikey_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package auth

import (
"strings"
"testing"
)

func TestGenerateAPIKeyPrefixesAndHashes(t *testing.T) {
t.Parallel()

key, hash, err := GenerateAPIKey()
if err != nil {
t.Fatalf("GenerateAPIKey error: %v", err)
}
if !strings.HasPrefix(key, apiKeyPrefix) {
t.Fatalf("API key prefix = %q, want %q", key[:len(apiKeyPrefix)], apiKeyPrefix)
}
if hash != HashAPIKey(key) {
t.Fatalf("hash = %q, want HashAPIKey(key)", hash)
}

adminKey, adminHash, err := GenerateAdminAPIKey()
if err != nil {
t.Fatalf("GenerateAdminAPIKey error: %v", err)
}
if !strings.HasPrefix(adminKey, adminKeyPrefix) {
t.Fatalf("admin key prefix = %q, want %q", adminKey[:len(adminKeyPrefix)], adminKeyPrefix)
}
if adminHash != HashAPIKey(adminKey) {
t.Fatalf("admin hash = %q, want HashAPIKey(adminKey)", adminHash)
}
}
6 changes: 3 additions & 3 deletions backend/internal/plugin/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (f *Forwarder) Forward(c *gin.Context) {
return
}

var hardExclude []int
hardExclude := make([]int, 0, maxFailoverAttempts*len(routes))
var mwBag map[string]string
beginCalled := false
ctx := c.Request.Context()
Expand All @@ -134,7 +134,7 @@ func (f *Forwarder) Forward(c *gin.Context) {
state.selectedRoute = route
state.keyInfo = keyInfoForRoute(state.keyInfo, route)

softExclude := []int(nil)
softExclude := make([]int, 0, maxFailoverAttempts)
attempt := 0
queueDeadline := time.Now().Add(queueWaitTimeout)

Expand All @@ -146,7 +146,7 @@ func (f *Forwarder) Forward(c *gin.Context) {
if err := f.pickAccount(c, state, exclude...); err != nil {
failureSummary.recordPickAccountError(err)
if len(softExclude) > 0 && time.Now().Before(queueDeadline) {
softExclude = nil
softExclude = softExclude[:0]
select {
case <-ctx.Done():
return
Expand Down
33 changes: 33 additions & 0 deletions backend/internal/plugin/forwarder_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package plugin

import (
"bytes"
"context"
"mime/multipart"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -49,6 +51,37 @@ func TestParseBody_StreamTrue(t *testing.T) {
}
}

func TestParseBody_MultipartIgnoresFileParts(t *testing.T) {
t.Parallel()

var body bytes.Buffer
writer := multipart.NewWriter(&body)
file, err := writer.CreateFormFile("image", "input.png")
if err != nil {
t.Fatalf("CreateFormFile error: %v", err)
}
if _, err := file.Write(bytes.Repeat([]byte("x"), 1024)); err != nil {
t.Fatalf("file.Write error: %v", err)
}
if err := writer.WriteField("model", " gpt-image-1 "); err != nil {
t.Fatalf("WriteField(model) error: %v", err)
}
if err := writer.WriteField("stream", "true"); err != nil {
t.Fatalf("WriteField(stream) error: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("writer.Close error: %v", err)
}

parsed := parseBody(body.Bytes(), writer.FormDataContentType())
if parsed.Model != "gpt-image-1" {
t.Fatalf("Model = %q, want %q", parsed.Model, "gpt-image-1")
}
if !parsed.Stream {
t.Fatalf("Stream = false, want true")
}
}

func TestParseBody_ReasoningEffort(t *testing.T) {
t.Parallel()

Expand Down
15 changes: 9 additions & 6 deletions backend/internal/plugin/host_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ func (h *HostService) forward(ctx context.Context, req *pb.HostForwardRequest) (
fwdCtx, cancel := context.WithTimeout(ctx, hostForwardTimeout(req))
defer cancel()

var hardExclude []int
hardExclude := make([]int, 0, maxHostForwardAttempts*len(routes))
for _, route := range routes {
model := h.resolveHostModel(route.Platform, req.Model)
if model == "" {
Expand Down Expand Up @@ -592,7 +592,7 @@ func (h *HostService) forwardStream(ctx context.Context, req *pb.HostForwardRequ
defer cancel()

sw := &hostStreamWriter{stream: stream}
var hardExclude []int
hardExclude := make([]int, 0, maxHostForwardAttempts*len(routes))

for _, route := range routes {
model := h.resolveHostModel(route.Platform, req.Model)
Expand Down Expand Up @@ -911,13 +911,16 @@ func (h *HostService) recordHostForwardUsage(
// listPlatforms 列出已加载的网关平台。
func (h *HostService) listPlatforms(_ context.Context, _ *pb.HostListPlatformsRequest) (*pb.HostListPlatformsResponse, error) {
metas := h.manager.GetAllPluginMeta()
seen := make(map[string]bool)
var platforms []*pb.HostPlatform
seen := make(map[string]struct{}, len(metas))
platforms := make([]*pb.HostPlatform, 0, len(metas))
for _, m := range metas {
if m.Type != "gateway" || m.Platform == "" || seen[m.Platform] {
if m.Type != "gateway" || m.Platform == "" {
continue
}
seen[m.Platform] = true
if _, ok := seen[m.Platform]; ok {
continue
}
seen[m.Platform] = struct{}{}
platforms = append(platforms, &pb.HostPlatform{
Name: m.Platform,
DisplayName: m.DisplayName,
Expand Down
2 changes: 1 addition & 1 deletion backend/internal/plugin/marketplace.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ var officialPlugins = []MarketplacePlugin{
},
{
Name: "gateway-claude",
Version: "1.0.0",
Version: "0.2.0",
Description: "Claude Messages API 网关插件:OAuth 授权、TLS 指纹、用量监控",
Author: "AirGate",
Type: "gateway",
Expand Down
7 changes: 6 additions & 1 deletion backend/internal/plugin/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,14 @@ func parseMultipartFields(body []byte, contentType string) parsedRequest {
if err != nil {
break
}
name := part.FormName()
if name != "model" && name != "stream" {
_ = part.Close()
continue
}
data, _ := io.ReadAll(part)
_ = part.Close()
switch part.FormName() {
switch name {
case "model":
pr.Model = strings.TrimSpace(string(data))
case "stream":
Expand Down
11 changes: 6 additions & 5 deletions backend/internal/scheduler/selection.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ func (s *Scheduler) SelectAccount(ctx context.Context, platform, model string, u
}

now := time.Now()
var normalCandidates, stickyCandidates []*ent.Account
normalCandidates := make([]*ent.Account, 0, len(candidates))
stickyCandidates := make([]*ent.Account, 0, len(candidates))
for _, acc := range candidates {
switch s.checkSchedulability(ctx, acc, model, now) {
case Normal:
Expand Down Expand Up @@ -91,13 +92,13 @@ func excludeAccounts(candidates []*ent.Account, excludeIDs []int) []*ent.Account
if len(excludeIDs) == 0 {
return candidates
}
excludeSet := make(map[int]bool, len(excludeIDs))
excludeSet := make(map[int]struct{}, len(excludeIDs))
for _, id := range excludeIDs {
excludeSet[id] = true
excludeSet[id] = struct{}{}
}
filtered := candidates[:0]
filtered := make([]*ent.Account, 0, len(candidates))
for _, acc := range candidates {
if !excludeSet[acc.ID] {
if _, excluded := excludeSet[acc.ID]; !excluded {
filtered = append(filtered, acc)
}
}
Expand Down
21 changes: 21 additions & 0 deletions backend/internal/scheduler/selection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package scheduler

import (
"testing"

"github.com/DouDOU-start/airgate-core/ent"
)

func TestExcludeAccountsDoesNotMutateCandidates(t *testing.T) {
t.Parallel()

candidates := []*ent.Account{{ID: 1}, {ID: 2}, {ID: 3}}
got := excludeAccounts(candidates, []int{2})

if len(got) != 2 || got[0].ID != 1 || got[1].ID != 3 {
t.Fatalf("excludeAccounts result = %+v, want IDs [1 3]", got)
}
if len(candidates) != 3 || candidates[0].ID != 1 || candidates[1].ID != 2 || candidates[2].ID != 3 {
t.Fatalf("candidates mutated to %+v, want original IDs [1 2 3]", candidates)
}
}
Loading
Loading