Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions banditcallback/banditcallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package banditcallback

import (
"context"
"net"
"net/http"
"net/url"
"sync"
Expand Down Expand Up @@ -76,10 +77,14 @@ func (e *Emitter) Enabled() bool {

// EmitIfFirstSeen records the device-id and, if it hasn't been seen
// within the TTL window, fires an async best-effort GET to the
// configured callback URL. Returns immediately — the HTTP request runs
// in its own goroutine and any failure is logged at debug only (the
// callback is a hint to the bandit, not a correctness signal).
func (e *Emitter) EmitIfFirstSeen(ctx context.Context, deviceID string) {
// configured callback URL. clientIP, when non-empty, is forwarded as a
// True-Client-IP header so the API can attribute the reward to the
// device's ASN rather than to the proxy VPS's egress IP — the source
// address of our outbound request. Returns immediately — the HTTP
// request runs in its own goroutine and any failure is logged at debug
// only (the callback is a hint to the bandit, not a correctness
// signal).
func (e *Emitter) EmitIfFirstSeen(ctx context.Context, deviceID, clientIP string) {
if !e.Enabled() || deviceID == "" {
return
}
Expand All @@ -90,7 +95,7 @@ func (e *Emitter) EmitIfFirstSeen(ctx context.Context, deviceID string) {
return
}

go e.fire(ctx, deviceID)
go e.fire(ctx, deviceID, clientIP)
}

// checkAndRecord returns true if the deviceID is first-seen within the
Expand Down Expand Up @@ -120,7 +125,7 @@ func (e *Emitter) checkAndRecord(deviceID string, now time.Time) bool {
return true
}

func (e *Emitter) fire(ctx context.Context, deviceID string) {
func (e *Emitter) fire(ctx context.Context, deviceID, clientIP string) {
// Detach from the request context so a closing client connection
// doesn't cancel the outbound HTTP request. The callback is for
// the bandit's benefit, not the client's; we want it to complete
Expand All @@ -143,6 +148,9 @@ func (e *Emitter) fire(ctx context.Context, deviceID string) {
log.Debugf("banditcallback: build request: %v", err)
return
}
if clientIP != "" {
req.Header.Set("True-Client-IP", clientIP)
}

resp, err := e.client.Do(req)
if err != nil {
Expand Down Expand Up @@ -183,10 +191,17 @@ func NewFilter(headerName string, emitter *Emitter) *Filter {
}

// Apply implements filters.Filter. Forwards unconditionally (the
// emitter is a side-effect; failures are non-fatal).
// emitter is a side-effect; failures are non-fatal). The client IP is
// taken from req.RemoteAddr (same source opsfilter uses for measured
// reporting) so the API receives the device's real IP via the
// True-Client-IP header instead of our VPS egress.
func (f *Filter) Apply(cs *filters.ConnectionState, req *http.Request, next filters.Next) (*http.Response, *filters.ConnectionState, error) {
if f.emitter != nil && f.emitter.Enabled() {
f.emitter.EmitIfFirstSeen(req.Context(), req.Header.Get(f.deviceIDHeader))
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
clientIP = req.RemoteAddr
}
f.emitter.EmitIfFirstSeen(req.Context(), req.Header.Get(f.deviceIDHeader), clientIP)
}
return next(cs, req)
}
39 changes: 32 additions & 7 deletions banditcallback/banditcallback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestEmitter_DisabledWhenUnconfigured(t *testing.T) {
if e.Enabled() {
t.Fatal("expected disabled")
}
e.EmitIfFirstSeen(context.Background(), "did-1")
e.EmitIfFirstSeen(context.Background(), "did-1", "")
emitted, _ := e.Stats()
if emitted != 0 {
t.Fatalf("expected 0 emits, got %d", emitted)
Expand All @@ -48,12 +48,15 @@ func TestEmitter_FirstSeenFires(t *testing.T) {
if got := r.URL.Query().Get("did"); got == "" {
t.Error("missing did")
}
if got := r.Header.Get("True-Client-IP"); got != "203.0.113.7" {
t.Errorf("True-Client-IP mismatch: %q", got)
}
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()

e := New("arm-test", srv.URL, time.Minute)
e.EmitIfFirstSeen(context.Background(), "device-a")
e.EmitIfFirstSeen(context.Background(), "device-a", "203.0.113.7")

// Emission is async; wait for the goroutine. 1s ceiling is generous.
deadline := time.Now().Add(time.Second)
Expand Down Expand Up @@ -83,7 +86,7 @@ func TestEmitter_DedupSuppressesRepeat(t *testing.T) {

e := New("arm-test", srv.URL, time.Minute)
for i := 0; i < 10; i++ {
e.EmitIfFirstSeen(context.Background(), "device-a")
e.EmitIfFirstSeen(context.Background(), "device-a", "203.0.113.7")
}

// One emit, nine suppressed. Wait for async emit to complete.
Expand Down Expand Up @@ -123,7 +126,7 @@ func TestEmitter_ConcurrentFirstSeenIsSingleFire(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
e.EmitIfFirstSeen(context.Background(), "race-device")
e.EmitIfFirstSeen(context.Background(), "race-device", "")
}()
}
wg.Wait()
Expand All @@ -146,14 +149,36 @@ func TestEmitter_ReEmitsAfterTTL(t *testing.T) {
defer srv.Close()

e := New("arm-test", srv.URL, 50*time.Millisecond)
e.EmitIfFirstSeen(context.Background(), "ttl-device")
e.EmitIfFirstSeen(context.Background(), "ttl-device", "")
time.Sleep(20 * time.Millisecond) // within TTL — suppressed
e.EmitIfFirstSeen(context.Background(), "ttl-device")
e.EmitIfFirstSeen(context.Background(), "ttl-device", "")
time.Sleep(80 * time.Millisecond) // past TTL — fires again
e.EmitIfFirstSeen(context.Background(), "ttl-device")
e.EmitIfFirstSeen(context.Background(), "ttl-device", "")

time.Sleep(100 * time.Millisecond)
if got := atomic.LoadInt32(&hits); got != 2 {
t.Fatalf("expected 2 hits across TTL window, got %d", got)
}
}

func TestEmitter_OmitsTrueClientIPWhenEmpty(t *testing.T) {
// Empty clientIP must NOT set the header — leaving it absent
// lets the API fall through to its existing RemoteAddr-based
// MaxMind lookup. Setting an empty header would shadow that
// fallback with a junk value.
var sawHeader int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := r.Header["True-Client-Ip"]; ok {
atomic.StoreInt32(&sawHeader, 1)
}
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()

e := New("arm-test", srv.URL, time.Minute)
e.EmitIfFirstSeen(context.Background(), "device-x", "")
time.Sleep(100 * time.Millisecond)
if atomic.LoadInt32(&sawHeader) != 0 {
t.Fatal("True-Client-IP must be absent when clientIP is empty")
}
}
Loading