Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions internal/cli/doctor.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,52 +361,52 @@ func (d *Doctor) checkFrontingDomain(ntw mtglib.Network) bool {
}

func (d *Doctor) checkSecretHost(resolver *net.Resolver, ntw mtglib.Network) bool {
addresses, err := resolver.LookupIPAddr(context.Background(), d.conf.Secret.Host)
if err != nil {
res := runSNICheck(context.Background(), resolver, d.conf, ntw)

if res.ResolveErr != nil {
tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"description": fmt.Sprintf("cannot resolve DNS name of %s", d.conf.Secret.Host),
"error": err,
"error": res.ResolveErr,
})
return false
}

ourIP4 := d.conf.PublicIPv4.Get(nil)
if ourIP4 == nil {
ourIP4 = getIP(ntw, "tcp4")
}

ourIP6 := d.conf.PublicIPv6.Get(nil)
if ourIP6 == nil {
ourIP6 = getIP(ntw, "tcp6")
}

if ourIP4 == nil && ourIP6 == nil {
if !res.PublicIPKnown() {
tplError.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"description": "cannot detect public IP address",
"error": errors.New("cannot detect automatically and public-ipv4/public-ipv6 are not set in config"),
})
return false
}

strAddresses := []string{}
for _, value := range addresses {
if (ourIP4 != nil && value.IP.String() == ourIP4.String()) ||
(ourIP6 != nil && value.IP.String() == ourIP6.String()) {
tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"ip": value.IP,
"hostname": d.conf.Secret.Host,
})
return true
if res.IPv4Match || res.IPv6Match {
var matched net.IP

for _, ip := range res.Resolved {
if (res.OurIPv4 != nil && ip.String() == res.OurIPv4.String()) ||
(res.OurIPv6 != nil && ip.String() == res.OurIPv6.String()) {
matched = ip
break
}
}

strAddresses = append(strAddresses, `"`+value.IP.String()+`"`)
tplODNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"ip": matched,
"hostname": d.conf.Secret.Host,
})
return true
}

strAddresses := make([]string, 0, len(res.Resolved))
for _, ip := range res.Resolved {
strAddresses = append(strAddresses, `"`+ip.String()+`"`)
}

tplEDNSSNIMatch.Execute(os.Stdout, map[string]any{ //nolint: errcheck
"hostname": d.conf.Secret.Host,
"resolved": strings.Join(strAddresses, ", "),
"ip4": ourIP4,
"ip6": ourIP6,
"ip4": res.OurIPv4,
"ip6": res.OurIPv6,
})

return false
Expand Down
51 changes: 16 additions & 35 deletions internal/cli/run_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,72 +215,53 @@ func warnSNIMismatch(conf *config.Config, ntw mtglib.Network, log mtglib.Logger)
return
}

addresses, err := net.DefaultResolver.LookupIPAddr(context.Background(), host)
if err != nil {
res := runSNICheck(context.Background(), net.DefaultResolver, conf, ntw)

if res.ResolveErr != nil {
log.BindStr("hostname", host).
WarningError("SNI-DNS check: cannot resolve secret hostname", err)
WarningError("SNI-DNS check: cannot resolve secret hostname", res.ResolveErr)
return
}

ourIP4 := conf.PublicIPv4.Get(nil)
if ourIP4 == nil {
ourIP4 = getIP(ntw, "tcp4")
}

ourIP6 := conf.PublicIPv6.Get(nil)
if ourIP6 == nil {
ourIP6 = getIP(ntw, "tcp6")
}

if ourIP4 == nil && ourIP6 == nil {
if !res.PublicIPKnown() {
log.Warning("SNI-DNS check: cannot detect public IP address; set public-ipv4/public-ipv6 in config or run 'mtg doctor'")
return
}

v4Match := ourIP4 == nil
v6Match := ourIP6 == nil

for _, addr := range addresses {
if ourIP4 != nil && addr.IP.String() == ourIP4.String() {
v4Match = true
}

if ourIP6 != nil && addr.IP.String() == ourIP6.String() {
v6Match = true
}
}
v4Match := res.OurIPv4 == nil || res.IPv4Match
v6Match := res.OurIPv6 == nil || res.IPv6Match

if v4Match && v6Match {
return
}

resolved := make([]string, 0, len(addresses))
for _, addr := range addresses {
resolved = append(resolved, addr.IP.String())
resolved := make([]string, 0, len(res.Resolved))
for _, ip := range res.Resolved {
resolved = append(resolved, ip.String())
}

our := ""
if ourIP4 != nil {
our = ourIP4.String()
if res.OurIPv4 != nil {
our = res.OurIPv4.String()
}

if ourIP6 != nil {
if res.OurIPv6 != nil {
if our != "" {
our += "/"
}

our += ourIP6.String()
our += res.OurIPv6.String()
}

entry := log.BindStr("hostname", host).
BindStr("resolved", strings.Join(resolved, ", ")).
BindStr("public_ip", our)

if ourIP4 != nil {
if res.OurIPv4 != nil {
entry = entry.BindStr("ipv4_match", fmt.Sprintf("%t", v4Match))
}

if ourIP6 != nil {
if res.OurIPv6 != nil {
entry = entry.BindStr("ipv6_match", fmt.Sprintf("%t", v6Match))
}

Expand Down
78 changes: 78 additions & 0 deletions internal/cli/sni_check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package cli

import (
"context"
"net"

"github.com/9seconds/mtg/v2/internal/config"
"github.com/9seconds/mtg/v2/mtglib"
)

// sniCheckResult holds the data gathered while comparing the secret
// hostname's DNS records against this server's public IP addresses.
//
// IPv4Match / IPv6Match report whether a resolved record actually equals the
// corresponding public IP. They are false when that family's public IP could
// not be determined — there is nothing to compare against. Callers decide
// what counts as a clean result from these fields: `mtg doctor` and the
// startup warning apply different rules.
type sniCheckResult struct {
Resolved []net.IP
OurIPv4 net.IP
OurIPv6 net.IP
IPv4Match bool
IPv6Match bool
ResolveErr error
}

// PublicIPKnown reports whether at least one public IP family was detected.
func (r sniCheckResult) PublicIPKnown() bool {
return r.OurIPv4 != nil || r.OurIPv6 != nil
}

// runSNICheck resolves conf.Secret.Host and compares the records with this
// server's public IPv4 and IPv6. Public IPs come from config first and fall
// back to on-the-fly detection via ntw. It gathers data only — it does not
// decide success; see sniCheckResult.
func runSNICheck(
ctx context.Context,
resolver *net.Resolver,
conf *config.Config,
ntw mtglib.Network,
) sniCheckResult {
res := sniCheckResult{}

addrs, err := resolver.LookupIPAddr(ctx, conf.Secret.Host)
if err != nil {
res.ResolveErr = err

return res
}

res.Resolved = make([]net.IP, 0, len(addrs))
for _, a := range addrs {
res.Resolved = append(res.Resolved, a.IP)
}

res.OurIPv4 = conf.PublicIPv4.Get(nil)
if res.OurIPv4 == nil {
res.OurIPv4 = getIP(ntw, "tcp4")
}

res.OurIPv6 = conf.PublicIPv6.Get(nil)
if res.OurIPv6 == nil {
res.OurIPv6 = getIP(ntw, "tcp6")
}

for _, ip := range res.Resolved {
if res.OurIPv4 != nil && ip.String() == res.OurIPv4.String() {
res.IPv4Match = true
}

if res.OurIPv6 != nil && ip.String() == res.OurIPv6.String() {
res.IPv6Match = true
}
}

return res
}
Loading