Skip to content

Commit 413b5d3

Browse files
committed
refactor(containerd): add worker-pool & pid/ns cache
Signed-off-by: Keshav Kapoor <keshav.10919051722@std.ggsipu.ac.in>
1 parent d93a7bf commit 413b5d3

File tree

1 file changed

+172
-59
lines changed

1 file changed

+172
-59
lines changed

KubeArmor/core/containerdHandler.go

Lines changed: 172 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@ package core
77
import (
88
"context"
99
"fmt"
10+
1011
"os"
1112
"path/filepath"
1213
"strconv"
1314
"strings"
15+
"sync"
16+
"sync/atomic"
17+
"time"
1418

1519
"github.com/containerd/typeurl/v2"
1620
"google.golang.org/protobuf/proto"
@@ -61,6 +65,22 @@ var defaultCaps = []string{
6165
// Containerd Handler
6266
var Containerd *ContainerdHandler
6367

68+
// small cache for PID/NS lookups to avoid repeated /proc lookups in quick succession
69+
var pidNsCache sync.Map // map[string]pidNsCacheEntry
70+
var pidNsCacheDuration = 5 * time.Second
71+
72+
type pidNsCacheEntry struct {
73+
pid uint32
74+
pidNS int
75+
mntNS int
76+
ts time.Time
77+
}
78+
79+
// metrics
80+
var jobsEnqueued uint64
81+
var jobsProcessed uint64
82+
var workerBusy int64
83+
6484
// init Function
6585
func init() {
6686
// Spec -> google.protobuf.Any
@@ -87,6 +107,11 @@ type ContainerdHandler struct {
87107
dockerEventsCh <-chan *events.Envelope
88108
}
89109

110+
type containerdEventJob struct {
111+
envelope *events.Envelope
112+
nsCtx context.Context
113+
}
114+
90115
// NewContainerdHandler Function
91116
func NewContainerdHandler() *ContainerdHandler {
92117
ch := &ContainerdHandler{}
@@ -117,6 +142,51 @@ func NewContainerdHandler() *ContainerdHandler {
117142
return ch
118143
}
119144

145+
// getPrimaryPidAndNS performs a containerd TaskService.ListPids and reads pid/ns information
146+
// It also employs a short-lived cache to prevent repeated /proc lookups when events arrive in bursts.
147+
func (ch *ContainerdHandler) getPrimaryPidAndNSCached(ctx context.Context, containerID string) (uint32, int, int, error) {
148+
// Check cache
149+
if v, ok := pidNsCache.Load(containerID); ok {
150+
entry := v.(pidNsCacheEntry)
151+
if time.Since(entry.ts) < pidNsCacheDuration {
152+
return entry.pid, entry.pidNS, entry.mntNS, nil
153+
}
154+
}
155+
156+
pid, pidNS, mntNS, err := ch.getPrimaryPidAndNS(ctx, containerID)
157+
if err == nil {
158+
pidNsCache.Store(containerID, pidNsCacheEntry{pid: pid, pidNS: pidNS, mntNS: mntNS, ts: time.Now()})
159+
}
160+
return pid, pidNS, mntNS, err
161+
}
162+
163+
// original getPrimaryPidAndNS kept as-is (reads from task service and /proc)
164+
func (ch *ContainerdHandler) getPrimaryPidAndNS(ctx context.Context, containerID string) (uint32, int, int, error) {
165+
taskReq := task.ListPidsRequest{ContainerID: containerID}
166+
taskRes, err := ch.client.TaskService().ListPids(ctx, &taskReq)
167+
if err != nil {
168+
return 0, 0, 0, err
169+
}
170+
if len(taskRes.Processes) == 0 {
171+
return 0, 0, 0, fmt.Errorf("no processes found in container %s", containerID)
172+
}
173+
174+
pid := taskRes.Processes[0].Pid
175+
pidStr := strconv.Itoa(int(pid))
176+
177+
pidNS := 0
178+
mntNS := 0
179+
180+
if data, e := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pidStr, "/ns/pid")); e == nil {
181+
fmt.Sscanf(data, "pid:[%d]\n", &pidNS)
182+
}
183+
if data, e := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pidStr, "/ns/mnt")); e == nil {
184+
fmt.Sscanf(data, "mnt:[%d]\n", &mntNS)
185+
}
186+
187+
return pid, pidNS, mntNS, nil
188+
}
189+
120190
// Close Function
121191
func (ch *ContainerdHandler) Close() {
122192
if err := ch.client.Close(); err != nil {
@@ -186,59 +256,31 @@ func (ch *ContainerdHandler) GetContainerInfo(ctx context.Context, containerID,
186256

187257
// == //
188258
if eventpid == 0 {
189-
taskReq := task.ListPidsRequest{ContainerID: container.ContainerID}
190-
if taskRes, err := ch.client.TaskService().ListPids(ctx, &taskReq); err == nil {
191-
if len(taskRes.Processes) == 0 {
192-
return container, err
193-
}
194-
195-
container.Pid = taskRes.Processes[0].Pid
196-
197-
} else {
259+
// Use cached helper to get PID + namespaces from containerd + /proc
260+
pid, pidNS, mntNS, err := ch.getPrimaryPidAndNSCached(ctx, container.ContainerID)
261+
if err != nil {
198262
return container, err
199263
}
200264

265+
container.Pid = pid
266+
container.PidNS = uint32(pidNS)
267+
container.MntNS = uint32(mntNS)
201268
} else {
269+
// We already know the event PID; just resolve namespaces from /proc
202270
container.Pid = eventpid
203-
}
204-
205-
pid := strconv.Itoa(int(container.Pid))
271+
pidStr := strconv.Itoa(int(container.Pid))
206272

207-
if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pid, "/ns/pid")); err == nil {
208-
if _, err := fmt.Sscanf(data, "pid:[%d]\n", &container.PidNS); err != nil {
209-
kg.Warnf("Unable to get PidNS (%s, %s, %s)", containerID, pid, err.Error())
210-
}
211-
}
212-
213-
if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pid, "/ns/mnt")); err == nil {
214-
if _, err := fmt.Sscanf(data, "mnt:[%d]\n", &container.MntNS); err != nil {
215-
kg.Warnf("Unable to get MntNS (%s, %s, %s)", containerID, pid, err.Error())
216-
}
217-
}
218-
219-
taskReq := task.ListPidsRequest{ContainerID: container.ContainerID}
220-
if taskRes, err := ch.client.TaskService().ListPids(ctx, &taskReq); err == nil {
221-
if len(taskRes.Processes) == 0 {
222-
return container, err
223-
}
224-
225-
pid := strconv.Itoa(int(taskRes.Processes[0].Pid))
226-
227-
container.Pid = taskRes.Processes[0].Pid
228-
229-
if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pid, "/ns/pid")); err == nil {
273+
if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pidStr, "/ns/pid")); err == nil {
230274
if _, err := fmt.Sscanf(data, "pid:[%d]\n", &container.PidNS); err != nil {
231-
kg.Warnf("Unable to get PidNS (%s, %s, %s)", containerID, pid, err.Error())
275+
kg.Warnf("Unable to get PidNS (%s, %s, %s)", containerID, pidStr, err.Error())
232276
}
233277
}
234278

235-
if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pid, "/ns/mnt")); err == nil {
279+
if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pidStr, "/ns/mnt")); err == nil {
236280
if _, err := fmt.Sscanf(data, "mnt:[%d]\n", &container.MntNS); err != nil {
237-
kg.Warnf("Unable to get MntNS (%s, %s, %s)", containerID, pid, err.Error())
281+
kg.Warnf("Unable to get MntNS (%s, %s, %s)", containerID, pidStr, err.Error())
238282
}
239283
}
240-
} else {
241-
return container, err
242284
}
243285

244286
// == //
@@ -299,9 +341,10 @@ func (ch *ContainerdHandler) GetContainerdContainers() map[string]context.Contex
299341
return containers
300342
}
301343

302-
// UpdateContainerdContainer Function
344+
// UpdateContainerdContainer Function (unchanged signature) but keep being called from worker goroutines
303345
func (dm *KubeArmorDaemon) UpdateContainerdContainer(ctx context.Context, containerID string, containerPid uint32, action string) error {
304346
// check if Containerd exists
347+
305348
if Containerd == nil {
306349
return fmt.Errorf("containerd client not initialized")
307350
}
@@ -523,6 +566,7 @@ func (dm *KubeArmorDaemon) UpdateContainerdContainer(ctx context.Context, contai
523566
dm.MatchandRemoveContainerFromEndpoint(containerID)
524567
dm.EndPointsLock.Unlock()
525568
}
569+
526570
delete(dm.Containers, containerID)
527571
dm.ContainersLock.Unlock()
528572

@@ -583,6 +627,8 @@ func (dm *KubeArmorDaemon) UpdateContainerdContainer(ctx context.Context, contai
583627
}
584628

585629
// MonitorContainerdEvents Function
630+
// Implements a bounded worker-pool that enqueues events and processes them concurrently.
631+
// Full-parallel mode (no per-container ordering) is used as requested.
586632
func (dm *KubeArmorDaemon) MonitorContainerdEvents() {
587633
dm.WgDaemon.Add(1)
588634
defer dm.WgDaemon.Done()
@@ -594,79 +640,146 @@ func (dm *KubeArmorDaemon) MonitorContainerdEvents() {
594640
return
595641
}
596642

597-
dm.Logger.Print("Started to monitor Containerd events")
643+
dm.Logger.Print("Started to monitor Containerd events (worker-pool mode)")
644+
645+
// Tunables — adjust as needed
646+
numWorkers := 8
647+
jobQueueSize := 200
648+
649+
jobs := make(chan containerdEventJob, jobQueueSize)
650+
651+
// start metric reporter
652+
go func() {
653+
ticker := time.NewTicker(15 * time.Second)
654+
defer ticker.Stop()
655+
for {
656+
select {
657+
case <-StopChan:
658+
return
659+
case <-ticker.C:
660+
qLen := uint64(len(jobs))
661+
kg.Printf("containerd events: queued=%d processed=%d busy=%d", qLen, atomic.LoadUint64(&jobsProcessed), atomic.LoadInt64(&workerBusy))
662+
}
663+
}
664+
}()
665+
666+
// Start worker pool
667+
for i := 0; i < numWorkers; i++ {
668+
dm.WgDaemon.Add(1)
669+
workerID := i
670+
go func(id int) {
671+
defer dm.WgDaemon.Done()
672+
for job := range jobs {
673+
// protect workers from panic
674+
func() {
675+
atomic.AddInt64(&workerBusy, 1)
676+
defer atomic.AddInt64(&workerBusy, -1)
677+
defer func() {
678+
if r := recover(); r != nil {
679+
kg.Errf("panic in containerd event worker %d: %v", id, r)
680+
}
681+
}()
682+
683+
// process the event
684+
dm.processContainerdJob(job)
685+
atomic.AddUint64(&jobsProcessed, 1)
686+
}()
687+
}
688+
}(workerID)
689+
}
598690

691+
// Seed existing containers synchronously (safer initial sync)
599692
containers := Containerd.GetContainerdContainers()
600-
601693
if len(containers) > 0 {
602-
for containerID, context := range containers {
603-
if err := dm.UpdateContainerdContainer(context, containerID, 0, "start"); err != nil {
694+
for containerID, ns := range containers {
695+
if err := dm.UpdateContainerdContainer(ns, containerID, 0, "start"); err != nil {
604696
kg.Warnf("Failed to update containerd container %s: %s", containerID, err.Error())
605697
continue
606698
}
607699
}
608700
}
701+
702+
// Main subscription loop now only enqueues events (backpressure when jobs full)
609703
for {
610704
select {
611705
case <-StopChan:
706+
// close jobs to stop workers and wait for them (dm.WgDaemon handles waiting)
707+
close(jobs)
612708
return
613709

614710
case envelope := <-Containerd.k8sEventsCh:
615-
dm.handleContainerdEvent(envelope, Containerd.containerd)
711+
// will block when queue is full — desired backpressure
712+
jobs <- containerdEventJob{envelope: envelope, nsCtx: Containerd.containerd}
713+
atomic.AddUint64(&jobsEnqueued, 1)
616714

617715
case envelope := <-Containerd.dockerEventsCh:
618-
dm.handleContainerdEvent(envelope, Containerd.docker)
619-
716+
jobs <- containerdEventJob{envelope: envelope, nsCtx: Containerd.docker}
717+
atomic.AddUint64(&jobsEnqueued, 1)
620718
}
621719
}
622720
}
623721

624-
func (dm *KubeArmorDaemon) handleContainerdEvent(envelope *events.Envelope, context context.Context) {
625-
if envelope == nil {
722+
// processContainerdJob unmarshals the event envelope and dispatches appropriate actions.
723+
// This function runs inside worker goroutines and should avoid taking long locks.
724+
func (dm *KubeArmorDaemon) processContainerdJob(job containerdEventJob) {
725+
if job.envelope == nil {
626726
return
627727
}
628728

629-
// Handle the different event types
630-
switch envelope.Topic {
729+
env := job.envelope
730+
731+
switch env.Topic {
631732
case "/containers/delete":
632733
deleteContainer := &apievents.ContainerDelete{}
633734

634-
err := proto.Unmarshal(envelope.Event.GetValue(), deleteContainer)
735+
err := proto.Unmarshal(env.Event.GetValue(), deleteContainer)
635736
if err != nil {
636737
kg.Errf("failed to unmarshal container's delete event: %v", err)
738+
return
637739
}
638-
if err := dm.UpdateContainerdContainer(context, deleteContainer.GetID(), 0, "destroy"); err != nil {
740+
741+
// destroy the container
742+
if err := dm.UpdateContainerdContainer(job.nsCtx, deleteContainer.GetID(), 0, "destroy"); err != nil {
639743
kg.Warnf("Failed to destroy containerd container %s: %s", deleteContainer.GetID(), err.Error())
640744
}
641745

642746
case "/tasks/start":
643747
startTask := &apievents.TaskStart{}
644748

645-
err := proto.Unmarshal(envelope.Event.GetValue(), startTask)
749+
err := proto.Unmarshal(env.Event.GetValue(), startTask)
646750
if err != nil {
647751
kg.Errf("failed to unmarshal container's start task: %v", err)
752+
return
648753
}
649-
if err := dm.UpdateContainerdContainer(context, startTask.GetContainerID(), startTask.GetPid(), "start"); err != nil {
754+
755+
// start container handling
756+
if err := dm.UpdateContainerdContainer(job.nsCtx, startTask.GetContainerID(), startTask.GetPid(), "start"); err != nil {
650757
kg.Warnf("Failed to start containerd container %s: %s", startTask.GetContainerID(), err.Error())
651758
}
652759

653760
case "/tasks/exit":
654761
exitTask := &apievents.TaskStart{}
655762

656-
err := proto.Unmarshal(envelope.Event.GetValue(), exitTask)
763+
err := proto.Unmarshal(env.Event.GetValue(), exitTask)
657764
if err != nil {
658765
kg.Errf("failed to unmarshal container's exit task: %v", err)
766+
return
659767
}
660768

661769
dm.ContainersLock.RLock()
662-
pid := dm.Containers[exitTask.GetContainerID()].Pid
770+
pid := uint32(0)
771+
if c, ok := dm.Containers[exitTask.GetContainerID()]; ok {
772+
pid = c.Pid
773+
}
663774
dm.ContainersLock.RUnlock()
664775

665776
if pid == exitTask.GetPid() {
666-
if err := dm.UpdateContainerdContainer(context, exitTask.GetContainerID(), pid, "destroy"); err != nil {
777+
if err := dm.UpdateContainerdContainer(job.nsCtx, exitTask.GetContainerID(), pid, "destroy"); err != nil {
667778
kg.Warnf("Failed to destroy containerd container %s: %s", exitTask.GetContainerID(), err.Error())
668779
}
669780
}
670781

782+
default:
783+
// ignore other events
671784
}
672785
}

0 commit comments

Comments
 (0)