@@ -6,6 +6,7 @@ package azdext
66import (
77 "fmt"
88 "net"
9+ "net/http"
910 "net/url"
1011 "os"
1112 "path/filepath"
@@ -25,6 +26,10 @@ type MCPSecurityPolicy struct {
2526 blockedHosts map [string ]bool
2627 // lookupHost is used for DNS resolution; override in tests.
2728 lookupHost func (string ) ([]string , error )
29+ // onBlocked is an optional callback invoked when a URL or path is blocked.
30+ // Parameters: action ("url_blocked", "path_blocked"),
31+ // detail (human-readable explanation). Safe for concurrent use.
32+ onBlocked func (action , detail string )
2833}
2934
3035// NewMCPSecurityPolicy creates an empty security policy.
@@ -111,6 +116,20 @@ func (p *MCPSecurityPolicy) ValidatePathsWithinBase(basePaths ...string) *MCPSec
111116 return p
112117}
113118
119+ // OnBlocked registers a callback that is invoked whenever a URL or path is
120+ // blocked by the security policy. This enables security audit
121+ // logging without coupling the policy to a specific logging framework.
122+ //
123+ // The callback receives an action tag ("url_blocked", "path_blocked")
124+ // and a human-readable detail string. It must be safe
125+ // for concurrent invocation.
126+ func (p * MCPSecurityPolicy ) OnBlocked (fn func (action , detail string )) * MCPSecurityPolicy {
127+ p .mu .Lock ()
128+ defer p .mu .Unlock ()
129+ p .onBlocked = fn
130+ return p
131+ }
132+
114133// isLocalhostHost returns true if the host is localhost or a loopback address.
115134func isLocalhostHost (host string ) bool {
116135 h := strings .ToLower (host )
@@ -125,8 +144,20 @@ func isLocalhostHost(host string) bool {
125144// Returns an error describing the violation, or nil if allowed.
126145func (p * MCPSecurityPolicy ) CheckURL (rawURL string ) error {
127146 p .mu .RLock ()
128- defer p .mu .RUnlock ()
147+ fn := p .onBlocked
148+ err := p .checkURLCore (rawURL )
149+ p .mu .RUnlock ()
129150
151+ if fn != nil && err != nil {
152+ fn ("url_blocked" , err .Error ())
153+ }
154+
155+ return err
156+ }
157+
158+ // checkURLCore performs URL validation without acquiring the lock or invoking
159+ // the onBlocked callback. Callers must hold p.mu (at least RLock).
160+ func (p * MCPSecurityPolicy ) checkURLCore (rawURL string ) error {
130161 u , err := url .Parse (rawURL )
131162 if err != nil {
132163 return fmt .Errorf ("invalid URL: %w" , err )
@@ -140,7 +171,7 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error {
140171 // always allowed
141172 case "http" :
142173 if p .requireHTTPS && ! isLocalhostHost (host ) {
143- return fmt .Errorf ("HTTPS required: %s" , rawURL )
174+ return fmt .Errorf ("HTTPS required: %s" , redactSecurityURL ( rawURL ) )
144175 }
145176 default :
146177 return fmt .Errorf ("scheme not allowed: %q (only http and https are permitted)" , u .Scheme )
@@ -179,6 +210,16 @@ func (p *MCPSecurityPolicy) CheckURL(rawURL string) error {
179210 return nil
180211}
181212
213+ func redactSecurityURL (rawURL string ) string {
214+ u , err := url .Parse (rawURL )
215+ if err != nil {
216+ return "<invalid-url>"
217+ }
218+ u .RawQuery = ""
219+ u .Fragment = ""
220+ return u .String ()
221+ }
222+
182223func (p * MCPSecurityPolicy ) checkIP (ip net.IP , originalHost string ) error {
183224 for _ , cidr := range p .blockedCIDRs {
184225 if cidr .Contains (ip ) {
@@ -193,55 +234,18 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error {
193234 return fmt .Errorf ("blocked IP %s (private/loopback/link-local) for host %s" , ip , originalHost )
194235 }
195236
196- // Handle encoding variants that Go's net.IP methods don't classify, by extracting
197- // the embedded IPv4 address and re-checking it against all blocked ranges.
198- if len (ip ) == net .IPv6len && ip .To4 () == nil {
199- // IPv4-compatible (::x.x.x.x, RFC 4291 §2.5.5.1): first 12 bytes are zero.
200- isV4Compatible := true
201- for i := 0 ; i < 12 ; i ++ {
202- if ip [i ] != 0 {
203- isV4Compatible = false
204- break
237+ // Handle encoding variants that Go's net.IP methods don't classify,
238+ // by extracting the embedded IPv4 and re-checking it.
239+ if v4 := extractEmbeddedIPv4 (ip ); v4 != nil {
240+ for _ , cidr := range p .blockedCIDRs {
241+ if cidr .Contains (v4 ) {
242+ return fmt .Errorf ("blocked IP %s (embedded %s, CIDR %s) for host %s" ,
243+ ip , v4 , cidr , originalHost )
205244 }
206245 }
207- if isV4Compatible && (ip [12 ] != 0 || ip [13 ] != 0 || ip [14 ] != 0 || ip [15 ] != 0 ) {
208- v4 := net .IPv4 (ip [12 ], ip [13 ], ip [14 ], ip [15 ])
209- for _ , cidr := range p .blockedCIDRs {
210- if cidr .Contains (v4 ) {
211- return fmt .Errorf ("blocked IP %s (IPv4-compatible %s, CIDR %s) for host %s" ,
212- ip , v4 , cidr , originalHost )
213- }
214- }
215- if v4 .IsLoopback () || v4 .IsPrivate () || v4 .IsLinkLocalUnicast () || v4 .IsUnspecified () {
216- return fmt .Errorf ("blocked IP %s (IPv4-compatible %s, private/loopback) for host %s" ,
217- ip , v4 , originalHost )
218- }
219- }
220-
221- // IPv4-translated (::ffff:0:x.x.x.x, RFC 2765 §4.2.1): bytes 0-7 zero,
222- // bytes 8-9 = 0xFF 0xFF, bytes 10-11 = 0x00 0x00, bytes 12-15 = IPv4.
223- // Distinct from IPv4-mapped (bytes 10-11 = 0xFF), so To4() returns nil.
224- isV4Translated := ip [8 ] == 0xFF && ip [9 ] == 0xFF && ip [10 ] == 0x00 && ip [11 ] == 0x00
225- if isV4Translated {
226- for i := 0 ; i < 8 ; i ++ {
227- if ip [i ] != 0 {
228- isV4Translated = false
229- break
230- }
231- }
232- }
233- if isV4Translated && (ip [12 ] != 0 || ip [13 ] != 0 || ip [14 ] != 0 || ip [15 ] != 0 ) {
234- v4 := net .IPv4 (ip [12 ], ip [13 ], ip [14 ], ip [15 ])
235- for _ , cidr := range p .blockedCIDRs {
236- if cidr .Contains (v4 ) {
237- return fmt .Errorf ("blocked IP %s (IPv4-translated %s, CIDR %s) for host %s" ,
238- ip , v4 , cidr , originalHost )
239- }
240- }
241- if v4 .IsLoopback () || v4 .IsPrivate () || v4 .IsLinkLocalUnicast () || v4 .IsUnspecified () {
242- return fmt .Errorf ("blocked IP %s (IPv4-translated %s, private/loopback) for host %s" ,
243- ip , v4 , originalHost )
244- }
246+ if v4 .IsLoopback () || v4 .IsPrivate () || v4 .IsLinkLocalUnicast () || v4 .IsUnspecified () {
247+ return fmt .Errorf ("blocked IP %s (embedded %s, private/loopback) for host %s" ,
248+ ip , v4 , originalHost )
245249 }
246250 }
247251 }
@@ -251,10 +255,37 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error {
251255
252256// CheckPath validates a file path against the security policy.
253257// Resolves symlinks and checks for directory traversal.
258+ //
259+ // Security note (TOCTOU): There is an inherent time-of-check to time-of-use
260+ // gap between the symlink resolution performed here and the caller's
261+ // subsequent file operation. An adversary with write access to the filesystem
262+ // could create or modify a symlink between the check and the use. This is a
263+ // fundamental limitation of path-based validation on POSIX systems.
264+ //
265+ // Mitigations callers should consider:
266+ // - Use O_NOFOLLOW when opening files after validation (prevents symlink
267+ // following at the final component).
268+ // - Use file-descriptor-based approaches (openat2 with RESOLVE_BENEATH on
269+ // Linux 5.6+) where possible.
270+ // - Avoid writing to directories that untrusted users can modify.
271+ // - Consider validating the opened fd's path post-open via /proc/self/fd/N
272+ // or fstat.
254273func (p * MCPSecurityPolicy ) CheckPath (path string ) error {
255274 p .mu .RLock ()
256- defer p .mu .RUnlock ()
275+ fn := p .onBlocked
276+ err := p .checkPathCore (path )
277+ p .mu .RUnlock ()
278+
279+ if fn != nil && err != nil {
280+ fn ("path_blocked" , err .Error ())
281+ }
282+
283+ return err
284+ }
257285
286+ // checkPathCore performs path validation without acquiring the lock or invoking
287+ // the onBlocked callback. Callers must hold p.mu (at least RLock).
288+ func (p * MCPSecurityPolicy ) checkPathCore (path string ) error {
258289 if len (p .allowedBasePaths ) == 0 {
259290 return nil
260291 }
@@ -348,3 +379,163 @@ func resolveExistingPrefix(p string) string {
348379 }
349380 }
350381}
382+
383+ // ---------------------------------------------------------------------------
384+ // Redirect SSRF protection
385+ // ---------------------------------------------------------------------------
386+
387+ // redirectBlockedHosts lists cloud metadata service endpoints that must never
388+ // be the target of an HTTP redirect.
389+ var redirectBlockedHosts = map [string ]bool {
390+ "169.254.169.254" : true ,
391+ "fd00:ec2::254" : true ,
392+ "metadata.google.internal" : true ,
393+ "100.100.100.200" : true ,
394+ }
395+
396+ // SSRFSafeRedirect is an [http.Client] CheckRedirect function that blocks
397+ // redirects to private/loopback IP literals, hostnames that resolve to private
398+ // networks, and cloud metadata endpoints. It prevents
399+ // redirect-based SSRF attacks where an attacker-controlled URL redirects to
400+ // an internal service.
401+ //
402+ // Usage:
403+ //
404+ // client := &http.Client{CheckRedirect: azdext.SSRFSafeRedirect}
405+ func SSRFSafeRedirect (req * http.Request , via []* http.Request ) error {
406+ return ssrfSafeRedirect (req , via , net .LookupHost )
407+ }
408+
409+ func ssrfSafeRedirect (req * http.Request , via []* http.Request , lookupHost func (string ) ([]string , error )) error {
410+ const maxRedirects = 10
411+ if len (via ) >= maxRedirects {
412+ return fmt .Errorf ("stopped after %d redirects" , maxRedirects )
413+ }
414+
415+ // Block HTTPS → HTTP scheme downgrades to prevent leaking
416+ // Authorization headers (including Bearer tokens) in cleartext.
417+ // Go's net/http preserves headers on same-host redirects regardless
418+ // of scheme change.
419+ if len (via ) > 0 && via [len (via )- 1 ].URL .Scheme == "https" && req .URL .Scheme != "https" {
420+ return fmt .Errorf (
421+ "redirect from HTTPS to %s blocked (credential protection)" , req .URL .Scheme )
422+ }
423+
424+ host := req .URL .Hostname ()
425+
426+ // Block redirects to known metadata endpoints.
427+ if redirectBlockedHosts [strings .ToLower (host )] {
428+ return fmt .Errorf ("redirect to metadata endpoint %s blocked (SSRF protection)" , host )
429+ }
430+
431+ // Block redirects to localhost hostnames (e.g. "localhost",
432+ // "127.0.0.1") regardless of how they are spelled, preventing
433+ // hostname-based SSRF bypasses of the IP-literal checks below.
434+ if isLocalhostHost (host ) {
435+ return fmt .Errorf ("redirect to localhost %s blocked (SSRF protection)" , host )
436+ }
437+
438+ // Block redirects to private/loopback IP addresses, including
439+ // IPv4-compatible and IPv4-translated IPv6 encoding variants
440+ // that bypass Go's IsPrivate()/IsLoopback() classification.
441+ if ip := net .ParseIP (host ); ip != nil {
442+ if ip .IsLoopback () || ip .IsPrivate () || ip .IsLinkLocalUnicast () || ip .IsUnspecified () {
443+ return fmt .Errorf ("redirect to private/loopback IP %s blocked (SSRF protection)" , ip )
444+ }
445+
446+ // Check IPv6 encoding variants (IPv4-compatible, IPv4-translated)
447+ // that embed private IPv4 addresses but aren't caught by Go's
448+ // net.IP classifier methods.
449+ if err := checkIPEncodingVariants (ip , host ); err != nil {
450+ return err
451+ }
452+ }
453+
454+ // Resolve hostnames and block redirects to private/loopback resolved IPs.
455+ ips , err := lookupHost (host )
456+ if err != nil {
457+ return fmt .Errorf ("redirect host %s DNS resolution failed (SSRF protection): %w" , host , err )
458+ }
459+ for _ , rawIP := range ips {
460+ ip := net .ParseIP (rawIP )
461+ if ip == nil {
462+ continue
463+ }
464+ if ip .IsLoopback () || ip .IsPrivate () || ip .IsLinkLocalUnicast () || ip .IsUnspecified () {
465+ return fmt .Errorf ("redirect host %s resolved to private/loopback IP %s blocked (SSRF protection)" , host , ip )
466+ }
467+ if err := checkIPEncodingVariants (ip , host ); err != nil {
468+ return err
469+ }
470+ }
471+
472+ return nil
473+ }
474+
475+ // checkIPEncodingVariants detects IPv4-compatible (::x.x.x.x) and
476+ // IPv4-translated (::ffff:0:x.x.x.x) IPv6 addresses that embed
477+ // private IPv4 addresses but bypass Go's IsPrivate()/IsLoopback().
478+ func checkIPEncodingVariants (ip net.IP , originalHost string ) error {
479+ v4 := extractEmbeddedIPv4 (ip )
480+ if v4 == nil {
481+ return nil
482+ }
483+
484+ if v4 .IsLoopback () || v4 .IsPrivate () || v4 .IsLinkLocalUnicast () || v4 .IsUnspecified () {
485+ return fmt .Errorf (
486+ "redirect to embedded IPv4 address %s (embedded %s) blocked (SSRF protection)" ,
487+ ip , v4 )
488+ }
489+
490+ return nil
491+ }
492+
493+ // extractEmbeddedIPv4 returns the embedded IPv4 address from IPv4-compatible
494+ // (::x.x.x.x, RFC 4291 §2.5.5.1) or IPv4-translated (::ffff:0:x.x.x.x,
495+ // RFC 2765 §4.2.1) IPv6 encodings. Returns nil if the address is not one of
496+ // these encoding variants.
497+ //
498+ // This handles addresses that Go's net.IP.To4() does not classify as IPv4
499+ // (To4 returns nil for these), which means Go's IsPrivate()/IsLoopback()
500+ // methods also return false for them.
501+ func extractEmbeddedIPv4 (ip net.IP ) net.IP {
502+ if len (ip ) != net .IPv6len || ip .To4 () != nil {
503+ return nil // Not a pure IPv6 address or already handled as IPv4-mapped
504+ }
505+
506+ // Check if last 4 bytes are non-zero (otherwise it's just :: which is
507+ // already handled by IsUnspecified).
508+ if ip [12 ] == 0 && ip [13 ] == 0 && ip [14 ] == 0 && ip [15 ] == 0 {
509+ return nil
510+ }
511+
512+ // IPv4-compatible (::x.x.x.x): first 12 bytes are zero.
513+ isV4Compatible := true
514+ for i := 0 ; i < 12 ; i ++ {
515+ if ip [i ] != 0 {
516+ isV4Compatible = false
517+ break
518+ }
519+ }
520+ if isV4Compatible {
521+ return net .IPv4 (ip [12 ], ip [13 ], ip [14 ], ip [15 ])
522+ }
523+
524+ // IPv4-translated (::ffff:0:x.x.x.x, RFC 2765): bytes 0-7 zero,
525+ // bytes 8-9 = 0xFF 0xFF, bytes 10-11 = 0x00 0x00, bytes 12-15 = IPv4.
526+ // Distinct from IPv4-mapped (bytes 10-11 = 0xFF), so To4() returns nil.
527+ if ip [8 ] == 0xFF && ip [9 ] == 0xFF && ip [10 ] == 0x00 && ip [11 ] == 0x00 {
528+ allZero := true
529+ for i := 0 ; i < 8 ; i ++ {
530+ if ip [i ] != 0 {
531+ allZero = false
532+ break
533+ }
534+ }
535+ if allZero {
536+ return net .IPv4 (ip [12 ], ip [13 ], ip [14 ], ip [15 ])
537+ }
538+ }
539+
540+ return nil
541+ }
0 commit comments