Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 177 additions & 59 deletions KubeArmor/core/containerdHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ package core
import (
"context"
"fmt"

"os"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/containerd/typeurl/v2"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -61,6 +65,22 @@ var defaultCaps = []string{
// Containerd Handler
var Containerd *ContainerdHandler

// small cache for PID/NS lookups to avoid repeated /proc lookups in quick succession
var pidNsCache sync.Map // map[string]pidNsCacheEntry
var pidNsCacheDuration = 5 * time.Second

type pidNsCacheEntry struct {
pid uint32
pidNS int
mntNS int
ts time.Time
}

// metrics
var jobsEnqueued uint64
var jobsProcessed uint64
var workerBusy int64

// init Function
func init() {
// Spec -> google.protobuf.Any
Expand All @@ -87,6 +107,11 @@ type ContainerdHandler struct {
dockerEventsCh <-chan *events.Envelope
}

type containerdEventJob struct {
envelope *events.Envelope
nsCtx context.Context
}

// NewContainerdHandler Function
func NewContainerdHandler() *ContainerdHandler {
ch := &ContainerdHandler{}
Expand Down Expand Up @@ -117,6 +142,56 @@ func NewContainerdHandler() *ContainerdHandler {
return ch
}

// getPrimaryPidAndNS performs a containerd TaskService.ListPids and reads pid/ns information
// It also employs a short-lived cache to prevent repeated /proc lookups when events arrive in bursts.
func (ch *ContainerdHandler) getPrimaryPidAndNSCached(ctx context.Context, containerID string) (uint32, int, int, error) {
// Check cache
if v, ok := pidNsCache.Load(containerID); ok {
entry := v.(pidNsCacheEntry)
if time.Since(entry.ts) < pidNsCacheDuration {
return entry.pid, entry.pidNS, entry.mntNS, nil
}
}

pid, pidNS, mntNS, err := ch.getPrimaryPidAndNS(ctx, containerID)
if err == nil {
pidNsCache.Store(containerID, pidNsCacheEntry{pid: pid, pidNS: pidNS, mntNS: mntNS, ts: time.Now()})
}
return pid, pidNS, mntNS, err
}

// original getPrimaryPidAndNS kept as-is (reads from task service and /proc)
func (ch *ContainerdHandler) getPrimaryPidAndNS(ctx context.Context, containerID string) (uint32, int, int, error) {
taskReq := task.ListPidsRequest{ContainerID: containerID}
taskRes, err := ch.client.TaskService().ListPids(ctx, &taskReq)
if err != nil {
return 0, 0, 0, err
}
if len(taskRes.Processes) == 0 {
return 0, 0, 0, fmt.Errorf("no processes found in container %s", containerID)
}

pid := taskRes.Processes[0].Pid
pidStr := strconv.Itoa(int(pid))

pidNS := 0
mntNS := 0

if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pidStr, "/ns/pid")); err == nil {
if _, err := fmt.Sscanf(data, "pid:[%d]\n", &pidNS); err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse pid namespace from %q: %w", data, err)
}
}

if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pidStr, "/ns/mnt")); err == nil {
if _, err := fmt.Sscanf(data, "mnt:[%d]\n", &mntNS); err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse mount namespace from %q: %w", data, err)
}
}

return pid, pidNS, mntNS, nil
}

// Close Function
func (ch *ContainerdHandler) Close() {
if err := ch.client.Close(); err != nil {
Expand Down Expand Up @@ -186,59 +261,31 @@ func (ch *ContainerdHandler) GetContainerInfo(ctx context.Context, containerID,

// == //
if eventpid == 0 {
taskReq := task.ListPidsRequest{ContainerID: container.ContainerID}
if taskRes, err := ch.client.TaskService().ListPids(ctx, &taskReq); err == nil {
if len(taskRes.Processes) == 0 {
return container, err
}

container.Pid = taskRes.Processes[0].Pid

} else {
// Use cached helper to get PID + namespaces from containerd + /proc
pid, pidNS, mntNS, err := ch.getPrimaryPidAndNSCached(ctx, container.ContainerID)
if err != nil {
return container, err
}

container.Pid = pid
container.PidNS = uint32(pidNS)
container.MntNS = uint32(mntNS)
} else {
// We already know the event PID; just resolve namespaces from /proc
container.Pid = eventpid
}

pid := strconv.Itoa(int(container.Pid))

if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pid, "/ns/pid")); err == nil {
if _, err := fmt.Sscanf(data, "pid:[%d]\n", &container.PidNS); err != nil {
kg.Warnf("Unable to get PidNS (%s, %s, %s)", containerID, pid, err.Error())
}
}

if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pid, "/ns/mnt")); err == nil {
if _, err := fmt.Sscanf(data, "mnt:[%d]\n", &container.MntNS); err != nil {
kg.Warnf("Unable to get MntNS (%s, %s, %s)", containerID, pid, err.Error())
}
}
pidStr := strconv.Itoa(int(container.Pid))

taskReq := task.ListPidsRequest{ContainerID: container.ContainerID}
if taskRes, err := ch.client.TaskService().ListPids(ctx, &taskReq); err == nil {
if len(taskRes.Processes) == 0 {
return container, err
}

pid := strconv.Itoa(int(taskRes.Processes[0].Pid))

container.Pid = taskRes.Processes[0].Pid

if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pid, "/ns/pid")); err == nil {
if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pidStr, "/ns/pid")); err == nil {
if _, err := fmt.Sscanf(data, "pid:[%d]\n", &container.PidNS); err != nil {
kg.Warnf("Unable to get PidNS (%s, %s, %s)", containerID, pid, err.Error())
kg.Warnf("Unable to get PidNS (%s, %s, %s)", containerID, pidStr, err.Error())
}
}

if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pid, "/ns/mnt")); err == nil {
if data, err := os.Readlink(filepath.Join(cfg.GlobalCfg.ProcFsMount, pidStr, "/ns/mnt")); err == nil {
if _, err := fmt.Sscanf(data, "mnt:[%d]\n", &container.MntNS); err != nil {
kg.Warnf("Unable to get MntNS (%s, %s, %s)", containerID, pid, err.Error())
kg.Warnf("Unable to get MntNS (%s, %s, %s)", containerID, pidStr, err.Error())
}
}
} else {
return container, err
}

// == //
Expand Down Expand Up @@ -299,9 +346,10 @@ func (ch *ContainerdHandler) GetContainerdContainers() map[string]context.Contex
return containers
}

// UpdateContainerdContainer Function
// UpdateContainerdContainer Function (unchanged signature) but keep being called from worker goroutines
func (dm *KubeArmorDaemon) UpdateContainerdContainer(ctx context.Context, containerID string, containerPid uint32, action string) error {
// check if Containerd exists

if Containerd == nil {
return fmt.Errorf("containerd client not initialized")
}
Expand Down Expand Up @@ -523,6 +571,7 @@ func (dm *KubeArmorDaemon) UpdateContainerdContainer(ctx context.Context, contai
dm.MatchandRemoveContainerFromEndpoint(containerID)
dm.EndPointsLock.Unlock()
}

delete(dm.Containers, containerID)
dm.ContainersLock.Unlock()

Expand Down Expand Up @@ -583,6 +632,8 @@ func (dm *KubeArmorDaemon) UpdateContainerdContainer(ctx context.Context, contai
}

// MonitorContainerdEvents Function
// Implements a bounded worker-pool that enqueues events and processes them concurrently.
// Full-parallel mode (no per-container ordering) is used as requested.
func (dm *KubeArmorDaemon) MonitorContainerdEvents() {
dm.WgDaemon.Add(1)
defer dm.WgDaemon.Done()
Expand All @@ -594,79 +645,146 @@ func (dm *KubeArmorDaemon) MonitorContainerdEvents() {
return
}

dm.Logger.Print("Started to monitor Containerd events")
dm.Logger.Print("Started to monitor Containerd events (worker-pool mode)")

// Tunables — adjust as needed
numWorkers := 8
jobQueueSize := 200

jobs := make(chan containerdEventJob, jobQueueSize)

// start metric reporter
go func() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-StopChan:
return
case <-ticker.C:
qLen := uint64(len(jobs))
kg.Printf("containerd events: queued=%d processed=%d busy=%d", qLen, atomic.LoadUint64(&jobsProcessed), atomic.LoadInt64(&workerBusy))
}
}
}()

// Start worker pool
for i := 0; i < numWorkers; i++ {
dm.WgDaemon.Add(1)
workerID := i
go func(id int) {
defer dm.WgDaemon.Done()
for job := range jobs {
// protect workers from panic
func() {
atomic.AddInt64(&workerBusy, 1)
defer atomic.AddInt64(&workerBusy, -1)
defer func() {
if r := recover(); r != nil {
kg.Errf("panic in containerd event worker %d: %v", id, r)
}
}()

containers := Containerd.GetContainerdContainers()
// process the event
dm.processContainerdJob(job)
atomic.AddUint64(&jobsProcessed, 1)
}()
}
}(workerID)
}

// Seed existing containers synchronously (safer initial sync)
containers := Containerd.GetContainerdContainers()
if len(containers) > 0 {
for containerID, context := range containers {
if err := dm.UpdateContainerdContainer(context, containerID, 0, "start"); err != nil {
for containerID, ns := range containers {
if err := dm.UpdateContainerdContainer(ns, containerID, 0, "start"); err != nil {
kg.Warnf("Failed to update containerd container %s: %s", containerID, err.Error())
continue
}
}
}

// Main subscription loop now only enqueues events (backpressure when jobs full)
for {
select {
case <-StopChan:
// close jobs to stop workers and wait for them (dm.WgDaemon handles waiting)
close(jobs)
return

case envelope := <-Containerd.k8sEventsCh:
dm.handleContainerdEvent(envelope, Containerd.containerd)
// will block when queue is full — desired backpressure
jobs <- containerdEventJob{envelope: envelope, nsCtx: Containerd.containerd}
atomic.AddUint64(&jobsEnqueued, 1)

case envelope := <-Containerd.dockerEventsCh:
dm.handleContainerdEvent(envelope, Containerd.docker)

jobs <- containerdEventJob{envelope: envelope, nsCtx: Containerd.docker}
atomic.AddUint64(&jobsEnqueued, 1)
}
}
}

func (dm *KubeArmorDaemon) handleContainerdEvent(envelope *events.Envelope, context context.Context) {
if envelope == nil {
// processContainerdJob unmarshals the event envelope and dispatches appropriate actions.
// This function runs inside worker goroutines and should avoid taking long locks.
func (dm *KubeArmorDaemon) processContainerdJob(job containerdEventJob) {
if job.envelope == nil {
return
}

// Handle the different event types
switch envelope.Topic {
env := job.envelope

switch env.Topic {
case "/containers/delete":
deleteContainer := &apievents.ContainerDelete{}

err := proto.Unmarshal(envelope.Event.GetValue(), deleteContainer)
err := proto.Unmarshal(env.Event.GetValue(), deleteContainer)
if err != nil {
kg.Errf("failed to unmarshal container's delete event: %v", err)
return
}
if err := dm.UpdateContainerdContainer(context, deleteContainer.GetID(), 0, "destroy"); err != nil {

// destroy the container
if err := dm.UpdateContainerdContainer(job.nsCtx, deleteContainer.GetID(), 0, "destroy"); err != nil {
kg.Warnf("Failed to destroy containerd container %s: %s", deleteContainer.GetID(), err.Error())
}

case "/tasks/start":
startTask := &apievents.TaskStart{}

err := proto.Unmarshal(envelope.Event.GetValue(), startTask)
err := proto.Unmarshal(env.Event.GetValue(), startTask)
if err != nil {
kg.Errf("failed to unmarshal container's start task: %v", err)
return
}
if err := dm.UpdateContainerdContainer(context, startTask.GetContainerID(), startTask.GetPid(), "start"); err != nil {

// start container handling
if err := dm.UpdateContainerdContainer(job.nsCtx, startTask.GetContainerID(), startTask.GetPid(), "start"); err != nil {
kg.Warnf("Failed to start containerd container %s: %s", startTask.GetContainerID(), err.Error())
}

case "/tasks/exit":
exitTask := &apievents.TaskStart{}

err := proto.Unmarshal(envelope.Event.GetValue(), exitTask)
err := proto.Unmarshal(env.Event.GetValue(), exitTask)
if err != nil {
kg.Errf("failed to unmarshal container's exit task: %v", err)
return
}

dm.ContainersLock.RLock()
pid := dm.Containers[exitTask.GetContainerID()].Pid
pid := uint32(0)
if c, ok := dm.Containers[exitTask.GetContainerID()]; ok {
pid = c.Pid
}
dm.ContainersLock.RUnlock()

if pid == exitTask.GetPid() {
if err := dm.UpdateContainerdContainer(context, exitTask.GetContainerID(), pid, "destroy"); err != nil {
if err := dm.UpdateContainerdContainer(job.nsCtx, exitTask.GetContainerID(), pid, "destroy"); err != nil {
kg.Warnf("Failed to destroy containerd container %s: %s", exitTask.GetContainerID(), err.Error())
}
}

default:
// ignore other events
}
}