Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/ateapi/internal/controlapi/functional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"testing"
"time"

"github.com/agent-substrate/substrate/cmd/ateapi/internal/store"
"github.com/agent-substrate/substrate/cmd/ateapi/internal/store/ateredis"
"github.com/agent-substrate/substrate/internal/ateinterceptors"
"github.com/agent-substrate/substrate/internal/proto/ateletpb"
Expand Down Expand Up @@ -1203,7 +1204,7 @@ func TestSuspendActor_DanglingWorker(t *testing.T) {
deleteWorkerPod(t, tc, ns, "worker-1")

// 3. Call SuspendActor -> Should succeed (our fix skips missing pod execution)
actors, _ := tc.persistence.ListActors(context.Background())
actors, _ := tc.persistence.ListActors(context.Background(), store.ListOptions{})
t.Logf("Actors in Redis before Suspend: %d", len(actors))
for _, a := range actors {
t.Logf(" Actor: %s/%s/%s", a.GetActorTemplateNamespace(), a.GetActorTemplateName(), a.GetActorId())
Expand Down
3 changes: 2 additions & 1 deletion cmd/ateapi/internal/controlapi/list_actors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ import (
"context"
"fmt"

"github.com/agent-substrate/substrate/cmd/ateapi/internal/store"
"github.com/agent-substrate/substrate/pkg/proto/ateapipb"
)

func (s *Service) ListActors(ctx context.Context, req *ateapipb.ListActorsRequest) (*ateapipb.ListActorsResponse, error) {
if err := validateListActorsRequest(req); err != nil {
return nil, err
}
actors, err := s.persistence.ListActors(ctx)
actors, err := s.persistence.ListActors(ctx, store.ListOptions{})
if err != nil {
return nil, fmt.Errorf("while listing actors in db: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/ateapi/internal/controlapi/list_workers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ import (
"context"
"fmt"

"github.com/agent-substrate/substrate/cmd/ateapi/internal/store"
"github.com/agent-substrate/substrate/pkg/proto/ateapipb"
)

func (s *Service) ListWorkers(ctx context.Context, req *ateapipb.ListWorkersRequest) (*ateapipb.ListWorkersResponse, error) {
if err := validateListWorkersRequest(req); err != nil {
return nil, err
}
workers, err := s.persistence.ListWorkers(ctx)
workers, err := s.persistence.ListWorkers(ctx, store.ListOptions{})
if err != nil {
return nil, fmt.Errorf("while listing workers in db: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/ateapi/internal/controlapi/syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestSyncer_Lifecycle(t *testing.T) {
poolName := "pool1"

// 1. Verify no workers in Redis initially
workers, err := persistence.ListWorkers(context.Background())
workers, err := persistence.ListWorkers(context.Background(), store.ListOptions{})
if err != nil {
t.Fatalf("failed to list workers: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/ateapi/internal/controlapi/workflow_resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (s *AssignWorkerStep) IsComplete(ctx context.Context, input *ResumeInput, s
return state.Actor.GetStatus() == ateapipb.Actor_STATUS_RUNNING, nil
}
func (s *AssignWorkerStep) Execute(ctx context.Context, input *ResumeInput, state *ResumeState) error {
workers, err := s.store.ListWorkers(ctx)
workers, err := s.store.ListWorkers(ctx, store.ListOptions{})
if err != nil {
return fmt.Errorf("while listing workers: %w", err)
}
Expand Down
94 changes: 92 additions & 2 deletions cmd/ateapi/internal/store/ateredis/ateredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func (s *Persistence) UpdateActor(ctx context.Context, actor *ateapipb.Actor, ex
return nil
}

func (s *Persistence) ListWorkers(ctx context.Context) ([]*ateapipb.Worker, error) {
func (s *Persistence) ListWorkers(ctx context.Context, opts store.ListOptions) ([]*ateapipb.Worker, error) {
var result []*ateapipb.Worker
var mu sync.Mutex

Expand All @@ -377,6 +377,10 @@ func (s *Persistence) ListWorkers(ctx context.Context) ([]*ateapipb.Worker, erro
return fmt.Errorf("in protojson.Unmarshal: %w", err)
}

if !matchesWorker(worker, opts) {
continue
}

mu.Lock()
result = append(result, worker)
mu.Unlock()
Expand All @@ -393,7 +397,7 @@ func (s *Persistence) ListWorkers(ctx context.Context) ([]*ateapipb.Worker, erro
return result, nil
}

func (s *Persistence) ListActors(ctx context.Context) ([]*ateapipb.Actor, error) {
func (s *Persistence) ListActors(ctx context.Context, opts store.ListOptions) ([]*ateapipb.Actor, error) {
var result []*ateapipb.Actor
var mu sync.Mutex

Expand All @@ -416,6 +420,10 @@ func (s *Persistence) ListActors(ctx context.Context) ([]*ateapipb.Actor, error)
return fmt.Errorf("in protojson.Unmarshal: %w", err)
}

if !matchesActor(actor, opts) {
continue
}

mu.Lock()
result = append(result, actor)
mu.Unlock()
Expand All @@ -429,6 +437,88 @@ func (s *Persistence) ListActors(ctx context.Context) ([]*ateapipb.Actor, error)
return result, nil
}

func matchesWorker(w *ateapipb.Worker, opts store.ListOptions) bool {
if len(opts.FieldSelector) == 0 {
return true
}
for k, v := range opts.FieldSelector {
switch k {
case "worker_namespace":
if w.GetWorkerNamespace() != v {
return false
}
case "worker_pool":
if w.GetWorkerPool() != v {
return false
}
case "worker_pod":
if w.GetWorkerPod() != v {
return false
}
case "actor_namespace":
if w.GetActorNamespace() != v {
return false
}
case "actor_template":
if w.GetActorTemplate() != v {
return false
}
case "actor_id":
if w.GetActorId() != v {
return false
}
case "ip":
if w.GetIp() != v {
return false
}
default:
return false
}
}
return true
}

func matchesActor(a *ateapipb.Actor, opts store.ListOptions) bool {
if len(opts.FieldSelector) == 0 {
return true
}
for k, v := range opts.FieldSelector {
switch k {
case "actor_id":
if a.GetActorId() != v {
return false
}
case "actor_template_namespace":
if a.GetActorTemplateNamespace() != v {
return false
}
case "actor_template_name":
if a.GetActorTemplateName() != v {
return false
}
case "status":
if a.GetStatus().String() != v {
return false
}
case "ateom_pod_namespace":
if a.GetAteomPodNamespace() != v {
return false
}
case "ateom_pod_name":
if a.GetAteomPodName() != v {
return false
}
case "ateom_pod_ip":
if a.GetAteomPodIp() != v {
return false
}
default:
return false
}
}
return true
}

func (s *Persistence) AcquireLock(ctx context.Context, key string, value string, ttl time.Duration) (bool, error) {
ok, err := s.rdb.SetNX(ctx, key, value, ttl).Result()
if err != nil {
Expand Down
149 changes: 145 additions & 4 deletions cmd/ateapi/internal/store/ateredis/ateredis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ func TestListWorkers(t *testing.T) {
t.Fatalf("failed to create worker2: %v", err)
}

workers, err := s.ListWorkers(ctx)
workers, err := s.ListWorkers(ctx, store.ListOptions{})
if err != nil {
t.Fatalf("ListWorkers failed: %v", err)
}
Expand Down Expand Up @@ -424,7 +424,7 @@ func TestListActors(t *testing.T) {
t.Fatalf("failed to create actor2: %v", err)
}

actors, err := s.ListActors(ctx)
actors, err := s.ListActors(ctx, store.ListOptions{})
if err != nil {
t.Fatalf("ListActors failed: %v", err)
}
Expand Down Expand Up @@ -515,7 +515,7 @@ func TestListWorkers_Empty(t *testing.T) {
mr, s, ctx := setupTest(t)
defer mr.Close()

workers, err := s.ListWorkers(ctx)
workers, err := s.ListWorkers(ctx, store.ListOptions{})
if err != nil {
t.Fatalf("ListWorkers failed: %v", err)
}
Expand All @@ -529,7 +529,7 @@ func TestListActors_Empty(t *testing.T) {
mr, s, ctx := setupTest(t)
defer mr.Close()

actors, err := s.ListActors(ctx)
actors, err := s.ListActors(ctx, store.ListOptions{})
if err != nil {
t.Fatalf("ListActors failed: %v", err)
}
Expand Down Expand Up @@ -721,3 +721,144 @@ func TestAcquireLock_NonReentry(t *testing.T) {
t.Errorf("expected second lock acquisition to fail (non-reentrant)")
}
}

func TestListWorkers_Filtering(t *testing.T) {
mr, s, ctx := setupTest(t)
defer mr.Close()

worker1 := &ateapipb.Worker{
WorkerNamespace: "ns1",
WorkerPool: "pool1",
WorkerPod: "pod1",
ActorId: "actor1",
}
worker2 := &ateapipb.Worker{
WorkerNamespace: "ns2",
WorkerPool: "pool2",
WorkerPod: "pod2",
ActorId: "",
}
if err := s.CreateWorker(ctx, worker1); err != nil {
t.Fatalf("failed to create worker1: %v", err)
}
if err := s.CreateWorker(ctx, worker2); err != nil {
t.Fatalf("failed to create worker2: %v", err)
}

tests := []struct {
name string
selector map[string]string
expectedPodIDs []string
}{
{
name: "match pool1",
selector: map[string]string{"worker_pool": "pool1"},
expectedPodIDs: []string{"pod1"},
},
{
name: "match empty actor_id (idle worker)",
selector: map[string]string{"actor_id": ""},
expectedPodIDs: []string{"pod2"},
},
{
name: "match worker namespace and pool",
selector: map[string]string{"worker_namespace": "ns2", "worker_pool": "pool2"},
expectedPodIDs: []string{"pod2"},
},
{
name: "no match",
selector: map[string]string{"worker_pool": "non-existent"},
expectedPodIDs: []string{},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
opts := store.ListOptions{FieldSelector: tc.selector}
workers, err := s.ListWorkers(ctx, opts)
if err != nil {
t.Fatalf("ListWorkers failed: %v", err)
}

if len(workers) != len(tc.expectedPodIDs) {
t.Fatalf("expected %d workers, got %d", len(tc.expectedPodIDs), len(workers))
}
for i, w := range workers {
if w.GetWorkerPod() != tc.expectedPodIDs[i] {
t.Errorf("expected worker %s, got %s", tc.expectedPodIDs[i], w.GetWorkerPod())
}
}
})
}
}

func TestListActors_Filtering(t *testing.T) {
mr, s, ctx := setupTest(t)
defer mr.Close()

actor1 := &ateapipb.Actor{
ActorId: "id1",
ActorTemplateNamespace: "ns1",
ActorTemplateName: "tmpl1",
Status: ateapipb.Actor_STATUS_RUNNING,
}
actor2 := &ateapipb.Actor{
ActorId: "id2",
ActorTemplateNamespace: "ns2",
ActorTemplateName: "tmpl2",
Status: ateapipb.Actor_STATUS_SUSPENDED,
}

if err := s.CreateActor(ctx, actor1); err != nil {
t.Fatalf("failed to create actor1: %v", err)
}
if err := s.CreateActor(ctx, actor2); err != nil {
t.Fatalf("failed to create actor2: %v", err)
}

tests := []struct {
name string
selector map[string]string
expectedActorIDs []string
}{
{
name: "match status running",
selector: map[string]string{"status": "STATUS_RUNNING"},
expectedActorIDs: []string{"id1"},
},
{
name: "match status suspended",
selector: map[string]string{"status": "STATUS_SUSPENDED"},
expectedActorIDs: []string{"id2"},
},
{
name: "match template name",
selector: map[string]string{"actor_template_name": "tmpl1"},
expectedActorIDs: []string{"id1"},
},
{
name: "no match",
selector: map[string]string{"status": "STATUS_UNSPECIFIED"},
expectedActorIDs: []string{},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
opts := store.ListOptions{FieldSelector: tc.selector}
actors, err := s.ListActors(ctx, opts)
if err != nil {
t.Fatalf("ListActors failed: %v", err)
}

if len(actors) != len(tc.expectedActorIDs) {
t.Fatalf("expected %d actors, got %d", len(tc.expectedActorIDs), len(actors))
}
for i, a := range actors {
if a.GetActorId() != tc.expectedActorIDs[i] {
t.Errorf("expected actor %s, got %s", tc.expectedActorIDs[i], a.GetActorId())
}
}
})
}
}
Loading
Loading