Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,8 @@ func (c *Sudoku) Build() (proto.Message, error) {
}

type Xdns struct {
Domain string `json:"domain"`
Domain string `json:"domain"`
Resolvers []string `json:"resolvers,omitempty"`
}

func (c *Xdns) Build() (proto.Message, error) {
Expand All @@ -1669,7 +1670,8 @@ func (c *Xdns) Build() (proto.Message, error) {
}

return &xdns.Config{
Domain: c.Domain,
Domain: c.Domain,
Resolvers: c.Resolvers,
}, nil
}

Expand Down
152 changes: 127 additions & 25 deletions transport/internet/finalmask/xdns/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/xtls/xray-core/common"
Expand All @@ -28,6 +29,23 @@ const (

var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)

type resolverConn struct {
conn net.PacketConn
addr *net.UDPAddr
}

func parseResolverAddr(s string) (*net.UDPAddr, error) {
host, port, err := net.SplitHostPort(s)
if err != nil {
host = s
port = "53"
}
if host == "" {
return nil, go_errors.New("empty resolver address")
}
return net.ResolveUDPAddr("udp", net.JoinHostPort(host, port))
}

type packet struct {
p []byte
addr net.Addr
Expand All @@ -39,12 +57,20 @@ type xdnsConnClient struct {
clientID []byte
domain Name

resolverConns []*resolverConn
resolverIdx atomic.Uint32
serverAddr atomic.Value // stores net.Addr; set by WriteTo, used by recvLoopFrom in resolver mode
recvWg sync.WaitGroup
sendWg sync.WaitGroup

pollChan chan struct{}
readQueue chan *packet
writeQueue chan *packet

closed bool
mutex sync.Mutex
closed atomic.Bool
closeOnce sync.Once
closeErr error
mutex sync.Mutex
}

func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
Expand All @@ -66,21 +92,71 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {

common.Must2(rand.Read(conn.clientID))

go conn.recvLoop()
go conn.sendLoop()
if len(c.Resolvers) > 0 {
lc := net.ListenConfig{}
if ctrl := resolverSocketControl(raw); ctrl != nil {
lc.Control = ctrl
}
for _, rs := range c.Resolvers {
addr, err := parseResolverAddr(rs)
if err != nil {
return nil, errors.New("invalid resolver address: ", rs, ": ", err)
}
uc, err := lc.ListenPacket(context.Background(), "udp", ":0")
if err != nil {
for _, rc := range conn.resolverConns {
rc.conn.Close()
}
return nil, errors.New("failed to create resolver socket: ", err)
}
conn.resolverConns = append(conn.resolverConns, &resolverConn{conn: uc, addr: addr})
}
for _, rc := range conn.resolverConns {
conn.recvWg.Add(1)
go func(pconn net.PacketConn) {
defer conn.recvWg.Done()
conn.recvLoopFrom(pconn)
}(rc.conn)
}
} else {
conn.recvWg.Add(1)
go func() {
defer conn.recvWg.Done()
conn.recvLoop()
}()
}
conn.sendWg.Add(1)
go func() {
defer conn.sendWg.Done()
conn.sendLoop()
}()

return conn, nil
}

func (c *xdnsConnClient) recvLoop() {
c.recvLoopFrom(c.PacketConn)

errors.LogDebug(context.Background(), "xdns closed")

close(c.pollChan)
close(c.readQueue)

c.closed.Store(true)
c.mutex.Lock()
defer c.mutex.Unlock()
close(c.writeQueue)
}

func (c *xdnsConnClient) recvLoopFrom(conn net.PacketConn) {
var buf [finalmask.UDPSize]byte

for {
if c.closed {
if c.closed.Load() {
break
}

n, addr, err := c.PacketConn.ReadFrom(buf[:])
n, addr, err := conn.ReadFrom(buf[:])
if err != nil || n == 0 {
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) {
break
Expand All @@ -95,6 +171,16 @@ func (c *xdnsConnClient) recvLoop() {
}

payload := dnsResponsePayload(&resp, c.domain)
if payload == nil {
continue
}

pktAddr := net.Addr(addr)
if len(c.resolverConns) > 0 {
if sa := c.serverAddr.Load(); sa != nil {
pktAddr = sa.(net.Addr)
}
}

r := bytes.NewReader(payload)
anyPacket := false
Expand All @@ -110,7 +196,7 @@ func (c *xdnsConnClient) recvLoop() {
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
addr: pktAddr,
}:
default:
errors.LogDebug(context.Background(), addr, " mask read err queue full")
Expand All @@ -124,17 +210,6 @@ func (c *xdnsConnClient) recvLoop() {
}
}
}

errors.LogDebug(context.Background(), "xdns closed")

close(c.pollChan)
close(c.readQueue)

c.mutex.Lock()
defer c.mutex.Unlock()

c.closed = true
close(c.writeQueue)
}

func (c *xdnsConnClient) sendLoop() {
Expand Down Expand Up @@ -179,20 +254,30 @@ func (c *xdnsConnClient) sendLoop() {
}
} else {
if !pollTimer.Stop() {
<-pollTimer.C
select {
case <-pollTimer.C:
default:
}
}
pollDelay = initPollDelay
}
pollTimer.Reset(pollDelay)

if c.closed {
if c.closed.Load() {
return
}

if p != nil {
_, err := c.PacketConn.WriteTo(p.p, p.addr)
var err error
if len(c.resolverConns) > 0 {
idx := c.resolverIdx.Add(1)
rc := c.resolverConns[idx%uint32(len(c.resolverConns))]
_, err = rc.conn.WriteTo(p.p, rc.addr)
} else {
_, err = c.PacketConn.WriteTo(p.p, p.addr)
}
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) {
c.closed = true
c.closed.Store(true)
break
}
}
Expand All @@ -213,10 +298,12 @@ func (c *xdnsConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
}

func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.serverAddr.Store(addr)

c.mutex.Lock()
defer c.mutex.Unlock()

if c.closed {
if c.closed.Load() {
return 0, io.ErrClosedPipe
}

Expand All @@ -239,8 +326,23 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
}

func (c *xdnsConnClient) Close() error {
c.closed = true
return c.PacketConn.Close()
c.closeOnce.Do(func() {
c.closed.Store(true)
for _, rc := range c.resolverConns {
rc.conn.Close()
}
c.closeErr = c.PacketConn.Close()
c.recvWg.Wait()
if len(c.resolverConns) > 0 {
close(c.pollChan)
close(c.readQueue)
c.mutex.Lock()
close(c.writeQueue)
c.mutex.Unlock()
}
c.sendWg.Wait()
})
return c.closeErr
}

func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
Expand Down
5 changes: 5 additions & 0 deletions transport/internet/finalmask/xdns/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ package xdns

import (
"net"

"github.com/xtls/xray-core/common/errors"
)

func (c *Config) UDP() {
}

func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
if len(c.Resolvers) > 0 && level > 0 {
return nil, errors.New("xdns resolver mode cannot be combined with lower finalmask layers because resolver traffic must be valid DNS on the wire")
}
return NewConnClient(c, raw)
}

Expand Down
15 changes: 12 additions & 3 deletions transport/internet/finalmask/xdns/config.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions transport/internet/finalmask/xdns/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ option java_multiple_files = true;

message Config {
string domain = 1;
repeated string resolvers = 2;
}

Loading