Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/port/builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ func TestBuiltIn(t *testing.T) {
return d
}
testsuite.Run(t, pf)
testsuite.RunTCPTransparent(t, pf)
}
68 changes: 63 additions & 5 deletions pkg/port/builtin/child/child.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ import (
"io"
"net"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"syscall"

"golang.org/x/sys/unix"

Expand All @@ -25,7 +28,8 @@ func NewDriver(logWriter io.Writer) port.ChildDriver {
}

type childDriver struct {
logWriter io.Writer
logWriter io.Writer
routingSetup sync.Once
}

func (d *childDriver) RunChildDriver(opaque map[string]string, quit <-chan struct{}, detachedNetNSPath string) error {
Expand Down Expand Up @@ -119,7 +123,6 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er
}
// dialProto does not need "4", "6" suffix
dialProto := strings.TrimSuffix(strings.TrimSuffix(req.Proto, "6"), "4")
var dialer net.Dialer
ip := req.IP
if ip == "" {
ip = "127.0.0.1"
Expand All @@ -135,9 +138,24 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er
}
ip = p.String()
}
targetConn, err := dialer.Dial(dialProto, net.JoinHostPort(ip, strconv.Itoa(req.Port)))
if err != nil {
return err
targetAddr := net.JoinHostPort(ip, strconv.Itoa(req.Port))

var targetConn net.Conn
var err error
if req.SourceIP != "" && req.SourcePort != 0 && dialProto == "tcp" {
d.routingSetup.Do(func() { d.setupTransparentRouting() })
targetConn, err = transparentDial(dialProto, targetAddr, req.SourceIP, req.SourcePort)
if err != nil {
fmt.Fprintf(d.logWriter, "transparent dial failed, falling back: %v\n", err)
targetConn, err = nil, nil
}
}
if targetConn == nil {
var dialer net.Dialer
targetConn, err = dialer.Dial(dialProto, targetAddr)
if err != nil {
return err
}
}
defer targetConn.Close() // no effect on duplicated FD
targetConnFiler, ok := targetConn.(filer)
Expand All @@ -164,6 +182,46 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er
return err
}

// setupTransparentRouting sets up policy routing so that SYN-ACK packets
// from services to transparent-bound source IPs are routed back via loopback.
// This is safe because the "from 127.0.0.0/8" rule only matches loopback-sourced
// packets, leaving TAP traffic unaffected.
func (d *childDriver) setupTransparentRouting() {
cmds := [][]string{
{"ip", "route", "add", "local", "default", "dev", "lo", "table", "100"},
{"ip", "rule", "add", "from", "127.0.0.0/8", "lookup", "100", "priority", "100"},
{"ip", "-6", "route", "add", "local", "default", "dev", "lo", "table", "100"},
{"ip", "-6", "rule", "add", "from", "::1/128", "lookup", "100", "priority", "100"},
}
for _, args := range cmds {
if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil {
fmt.Fprintf(d.logWriter, "transparent routing setup: %v: %s\n", err, out)
}
}
}

// transparentDial dials targetAddr using IP_TRANSPARENT, binding to the given
// source IP and port so the backend service sees the real client address.
func transparentDial(dialProto, targetAddr, sourceIP string, sourcePort int) (net.Conn, error) {
dialer := net.Dialer{
LocalAddr: &net.TCPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort},
Control: func(network, address string, c syscall.RawConn) error {
var sockErr error
if err := c.Control(func(fd uintptr) {
if strings.Contains(network, "6") {
sockErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_TRANSPARENT, 1)
} else {
sockErr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1)
}
}); err != nil {
return err
}
return sockErr
},
}
return dialer.Dial(dialProto, targetAddr)
}

// filer is implemented by *net.TCPConn and *net.UDPConn
type filer interface {
File() (f *os.File, err error)
Expand Down
18 changes: 13 additions & 5 deletions pkg/port/builtin/msg/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Request struct {
Port int
ParentIP string
HostGatewayIP string
SourceIP string `json:",omitempty"` // real client IP for IP_TRANSPARENT
SourcePort int `json:",omitempty"` // real client port for IP_TRANSPARENT
}

// Reply may contain FD as OOB
Expand Down Expand Up @@ -69,7 +71,9 @@ func hostGatewayIP() string {

// ConnectToChild connects to the child UNIX socket, and obtains TCP or UDP socket FD
// that corresponds to the port spec.
func ConnectToChild(c *net.UnixConn, spec port.Spec) (int, error) {
// sourceAddr is the real client address (e.g., from net.Conn.RemoteAddr()) for IP_TRANSPARENT support.
// Pass nil to skip source IP preservation.
func ConnectToChild(c *net.UnixConn, spec port.Spec, sourceAddr net.Addr) (int, error) {
req := Request{
Type: RequestTypeConnect,
Proto: spec.Proto,
Expand All @@ -78,6 +82,10 @@ func ConnectToChild(c *net.UnixConn, spec port.Spec) (int, error) {
ParentIP: spec.ParentIP,
HostGatewayIP: hostGatewayIP(),
}
if tcpAddr, ok := sourceAddr.(*net.TCPAddr); ok && tcpAddr != nil {
req.SourceIP = tcpAddr.IP.String()
req.SourcePort = tcpAddr.Port
}
if _, err := lowlevelmsgutil.MarshalToWriter(c, &req); err != nil {
return 0, err
}
Expand Down Expand Up @@ -114,21 +122,21 @@ func ConnectToChild(c *net.UnixConn, spec port.Spec) (int, error) {
}

// ConnectToChildWithSocketPath wraps ConnectToChild
func ConnectToChildWithSocketPath(socketPath string, spec port.Spec) (int, error) {
func ConnectToChildWithSocketPath(socketPath string, spec port.Spec, sourceAddr net.Addr) (int, error) {
var dialer net.Dialer
conn, err := dialer.Dial("unix", socketPath)
if err != nil {
return 0, err
}
defer conn.Close()
c := conn.(*net.UnixConn)
return ConnectToChild(c, spec)
return ConnectToChild(c, spec, sourceAddr)
}

// ConnectToChildWithRetry retries ConnectToChild every (i*5) milliseconds.
func ConnectToChildWithRetry(socketPath string, spec port.Spec, retries int) (int, error) {
func ConnectToChildWithRetry(socketPath string, spec port.Spec, retries int, sourceAddr net.Addr) (int, error) {
for i := 0; i < retries; i++ {
fd, err := ConnectToChildWithSocketPath(socketPath, spec)
fd, err := ConnectToChildWithSocketPath(socketPath, spec, sourceAddr)
if i == retries-1 && err != nil {
return 0, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/port/builtin/parent/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func Run(socketPath string, spec port.Spec, stopCh <-chan struct{}, stoppedCh ch
func copyConnToChild(c net.Conn, socketPath string, spec port.Spec, stopCh <-chan struct{}) error {
defer c.Close()
// get fd from the child as an SCM_RIGHTS cmsg
fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10)
fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, c.RemoteAddr())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/port/builtin/parent/udp/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func Run(socketPath string, spec port.Spec, stopCh <-chan struct{}, stoppedCh ch
Listener: c,
BackendDial: func() (*net.UDPConn, error) {
// get fd from the child as an SCM_RIGHTS cmsg
fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10)
fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, nil)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading