Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ services:
--providers.docker=true
--providers.docker.network=default
--experimental.plugins.captcha-protect.modulename=github.com/libops/captcha-protect
--experimental.plugins.captcha-protect.version=v1.4.0
--experimental.plugins.captcha-protect.version=v1.4.1
volumes:
- /var/run/docker.sock:/var/run/docker.sock:z
- /CHANGEME/TO/A/HOST/PATH/FOR/STATE/FILE:/tmp/state.json:rw
Expand Down
79 changes: 57 additions & 22 deletions ci/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"log/slog"
"math/rand"
"net"
Expand All @@ -18,8 +17,10 @@ import (
cp "github.com/libops/captcha-protect"
)

var rateLimit = 5
var exemptIps []*net.IPNet
var (
rateLimit = 5
exemptIps []*net.IPNet
)

const numIPs = 100
const parallelism = 10
Expand All @@ -34,11 +35,7 @@ func main() {
"fc00::/8",
}
for _, ip := range _ips {
parsedIp, err := cp.ParseCIDR(ip)
if err != nil {
slog.Error("error parsing cidr", "ip", ip, "err", err)
os.Exit(1)
}
parsedIp := parseCIDR(ip)
exemptIps = append(exemptIps, parsedIp)
}

Expand Down Expand Up @@ -86,10 +83,23 @@ func generateUniquePublicIPs(n int) []string {
ipSet := make(map[string]struct{})
var ips []string
config := cp.CreateConfig()
bc := &cp.CaptchaProtect{}
bc.SetExemptIps(exemptIps)
err := bc.SetIpv4Mask(16)
if err != nil {
slog.Error("unable to set ipv4 mask")
os.Exit(1)
}

err = bc.SetIpv6Mask(64)
if err != nil {
slog.Error("unable to set ipv6 mask")
os.Exit(1)
}

for len(ips) < n {
ip := randomPublicIP(config)
ip, ipRange := cp.ParseIp(ip, 16, 64)
ip, ipRange := bc.ParseIp(ip)
if _, exists := ipSet[ipRange]; !exists {
ipSet[ipRange] = struct{}{}
ips = append(ips, ip)
Expand Down Expand Up @@ -142,7 +152,9 @@ func runParallelChecks(ips []string, rateLimit int) {
fmt.Printf("Checking %s\n", ip)
output := httpRequest(ip)
if output != "" {
log.Fatalf("Unexpected output for %s: %s", ip, output)
slog.Error("Unexpected output", "ip", ip, "output", output)
os.Exit(1)

}
}(ip)
}
Expand All @@ -157,7 +169,8 @@ func ensureRedirect(ips []string) {
output := httpRequest(ip)

if output != expectedRedirectURL {
log.Fatalf("Unexpected output for %s: %s", ip, output)
slog.Error("Unexpected output", "ip", ip, "output", output)
os.Exit(1)
}

fmt.Printf("Got a redirect! %s\n", output)
Expand All @@ -177,12 +190,15 @@ func httpRequest(ip string) string {

req, err := http.NewRequest("GET", "http://localhost", nil)
if err != nil {
log.Fatalf("Failed to create request: %v", err)
slog.Error("Failed to create request", "err", err)
os.Exit(1)
}
req.Header.Set("X-Forwarded-For", ip)
resp, err := client.Do(req)
if err != nil {
log.Fatalf("Request failed: %v", err)
slog.Error("Request failed", "err", err)
os.Exit(1)

}
defer resp.Body.Close()

Expand All @@ -192,7 +208,9 @@ func httpRequest(ip string) string {
if err == http.ErrNoLocation {
return ""
}
log.Fatalf("Failed to get redirect URL: %v", err)
slog.Error("Failed to get redirect URL", "err", err)
os.Exit(1)

}

return strings.TrimSpace(location.String())
Expand All @@ -211,38 +229,55 @@ func runCommand(name string, args ...string) {
cmd.Env = append(cmd.Env, fmt.Sprintf("TRAEFIK_TAG=%s", tt))
}
if err := cmd.Run(); err != nil {
log.Fatalf("Command failed: %v", err)
slog.Error("Command failed", "err", err)
os.Exit(1)
}
}

func checkStateReload() {
resp, err := http.Get("http://localhost/captcha-protect/stats")
if err != nil {
log.Fatalf("Failed to make GET request: %v", err)
slog.Error("Failed to make GET request", "err", err)
os.Exit(1)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed to read response body: %v", err)
slog.Error("Failed to read response body", "err", err)
os.Exit(1)

}
var jsonResponse map[string]interface{}
err = json.Unmarshal(body, &jsonResponse)
if err != nil {
log.Fatalf("Failed to unmarshal JSON: %v", err)
slog.Error("Failed to unmarshal JSON", "err", err)
os.Exit(1)

}
bots, exists := jsonResponse["bots"]
if !exists {
log.Fatalf("Key 'bots' not found in JSON response")
slog.Error("Key 'bots' not found in JSON response")
os.Exit(1)
}
botsMap, ok := bots.(map[string]interface{})
if !ok {
log.Fatalf("'bots' is not an array")
slog.Error("'bots' is not an array")
os.Exit(1)
}

if len(botsMap) != numIPs {
log.Fatalf("Expected %d bots, but got %d", numIPs, len(botsMap))
slog.Error("Unexpected number of bots", "expected", numIPs, "received", len(botsMap))
os.Exit(1)
}

log.Println("State reloaded successfully!")
slog.Info("State reloaded successfully!")
}

func parseCIDR(cidr string) *net.IPNet {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
slog.Error("Failed to parse CIDR", "cidr", cidr, "err", err)
}
return block
}
70 changes: 43 additions & 27 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ type CaptchaProtect struct {
captchaConfig CaptchaConfig
exemptIps []*net.IPNet
tmpl *template.Template
ipv4Mask net.IPMask
ipv6Mask net.IPMask
}

type CaptchaConfig struct {
Expand Down Expand Up @@ -108,7 +110,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h

level, err := ParseLogLevel(config.LogLevel)
if err != nil {
log.Error("Unknown log level", "err", err)
log.Warn("Unknown log level", "err", err)
}
logLevel.Set(level)

Expand Down Expand Up @@ -165,6 +167,16 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
tmpl: tmpl,
}

err = bc.SetIpv4Mask(config.IPv4SubnetMask)
if err != nil {
return nil, err
}

err = bc.SetIpv6Mask(config.IPv6SubnetMask)
if err != nil {
return nil, err
}

// set the captcha config based on the provider
// thanks to https://github.com/maxlerebourg/crowdsec-bouncer-traefik-plugin/blob/4708d76854c7ae95fa7313c46fbe21959be2fff1/pkg/captcha/captcha.go#L39-L55
// for the struct/idea
Expand Down Expand Up @@ -437,50 +449,50 @@ func (bc *CaptchaProtect) getClientIP(req *http.Request) (string, string) {
ip = host
}

return ParseIp(ip, bc.config.IPv4SubnetMask, bc.config.IPv6SubnetMask)
return bc.ParseIp(ip)
}

func ParseIp(ip string, ipv4Mask, ipv6Mask int) (string, string) {
func (bc *CaptchaProtect) ParseIp(ip string) (string, string) {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return ip, ip
}

// For IPv4 addresses
if parsedIP.To4() != nil {
ipParts := strings.Split(ip, ".")
var required int
switch ipv4Mask {
case 8:
required = 1
case 16:
required = 2
case 24:
required = 3
default:
// fallback to a default, for example /16
required = 2
}
if len(ipParts) >= required {
subnet := strings.Join(ipParts[:required], ".")
return ip, subnet
}
subnet := parsedIP.Mask(bc.ipv4Mask)
return ip, subnet.String()
}

// For IPv6 addresses
if parsedIP.To16() != nil {
ipParts := strings.Split(ip, ":")
// Calculate the number of hextets required.
required := ipv6Mask / 16
if len(ipParts) >= required {
subnet := strings.Join(ipParts[:required], ":")
return ip, subnet
}
subnet := parsedIP.Mask(bc.ipv6Mask)
return ip, subnet.String()
}

log.Warn("Unknown ip version", "ip", ip)

return ip, ip
}

func (bc *CaptchaProtect) SetIpv4Mask(m int) error {
if m < 8 || m > 32 {
return fmt.Errorf("invalid ipv4 mask: %d. Must be between 8 and 32", m)
}
bc.ipv4Mask = net.CIDRMask(m, 32)

return nil
}

func (bc *CaptchaProtect) SetIpv6Mask(m int) error {
if m < 8 || m > 128 {
return fmt.Errorf("invalid ipv6 mask: %d. Must be between 8 and 128", m)
}
bc.ipv6Mask = net.CIDRMask(m, 128)

return nil
}

func (bc *CaptchaProtect) isGoodBot(req *http.Request, clientIP string) bool {
if bc.config.ProtectParameters == "true" {
if len(req.URL.Query()) > 0 {
Expand Down Expand Up @@ -535,6 +547,10 @@ func IsIpGoodBot(clientIP string, goodBots []string) bool {
return false
}

func (bc *CaptchaProtect) SetExemptIps(exemptIps []*net.IPNet) {
bc.exemptIps = exemptIps
}

func ParseCIDR(cidr string) (*net.IPNet, error) {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
Expand Down
Loading