From 5b96077d2a44d085a8ae44b1b368808184dea64b Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Thu, 21 Sep 2023 11:52:20 +0800 Subject: [PATCH 6/6] upstream: switch to self implement bootstrap refine code structure --- pkg/upstream/bootstrap/bootstrap.go | 324 ++++++++++++---------- pkg/upstream/bootstrap/bootstrap_test.go | 38 --- pkg/upstream/doh/upstream.go | 24 +- pkg/upstream/doq/upstream.go | 66 +---- pkg/upstream/event_stat.go | 2 +- pkg/upstream/transport/dns_conn.go | 3 +- pkg/upstream/transport/dns_conn_test.go | 7 +- pkg/upstream/transport/pipeline.go | 3 +- pkg/upstream/transport/pipeline_test.go | 3 +- pkg/upstream/transport/reuse.go | 3 +- pkg/upstream/transport/reuse_test.go | 3 +- pkg/upstream/transport/transport.go | 3 +- pkg/upstream/transport/utils.go | 46 ++-- pkg/upstream/upstream.go | 337 ++++++++++++++++++----- pkg/upstream/upstream_test.go | 8 +- pkg/upstream/utils.go | 72 ++++- pkg/upstream/utils_unix.go | 3 +- plugin/executable/forward/forward.go | 9 +- 18 files changed, 586 insertions(+), 368 deletions(-) delete mode 100644 pkg/upstream/bootstrap/bootstrap_test.go diff --git a/pkg/upstream/bootstrap/bootstrap.go b/pkg/upstream/bootstrap/bootstrap.go index b831757..2cd8ef9 100644 --- a/pkg/upstream/bootstrap/bootstrap.go +++ b/pkg/upstream/bootstrap/bootstrap.go @@ -21,184 +21,226 @@ package bootstrap import ( "context" - "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" - "github.com/miekg/dns" + "errors" + "fmt" "net" - "strings" + "net/netip" "sync" + "sync/atomic" "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/miekg/dns" + "go.uber.org/zap" ) -// NewBootstrap returns a customized *net.Resolver which can be used as a Bootstrap for a -// certain domain. Its Dial func is modified to dial s through udp. -// It also has a small built-in cache. -// s SHOULD be a literal IP address and the port SHOULD also be literal. -// Port can be omitted. In this case, the default port is :53. -// e.g. NewBootstrap("8.8.8.8"), NewBootstrap("127.0.0.1:5353"). -// If s is empty, NewBootstrap returns nil. (A nil *net.Resolver is valid in net.Dialer.) -// Note that not all platform support a customized *net.Resolver. It also depends on the -// version of go runtime. -// See the package docs from the net package for more info. -func NewBootstrap(s string) *net.Resolver { - if len(s) == 0 { - return nil - } - // Add port. - _, _, err := net.SplitHostPort(s) - if err != nil { // no port, add it. - s = net.JoinHostPort(strings.Trim(s, "[]"), "53") - } +const ( + minimumUpdateInterval = time.Minute * 5 + retryInterval = time.Second * 2 + queryTimeout = time.Second * 5 +) - bs := newBootstrap(s) +var ( + errNoAddrInResp = errors.New("resp does not have ip address") +) - return &net.Resolver{ - PreferGo: true, - StrictErrors: false, - Dial: bs.dial, +func New( + host string, + port uint16, + bootstrapServer netip.AddrPort, + bootstrapVer int, // 0,4,6 + logger *zap.Logger, // not nil +) (*Bootstrap, error) { + dp := new(Bootstrap) + dp.fqdn = dns.Fqdn(host) + dp.port = port + if !bootstrapServer.IsValid() { + return nil, errors.New("invalid bootstrap server address") + } + dp.bootstrap = net.UDPAddrFromAddrPort(bootstrapServer) + qt, ok := bootstrapVer2Qt(bootstrapVer) + if !ok { + return nil, fmt.Errorf("invalid bootstrap version %d", bootstrapVer) } -} + dp.qt = qt + dp.logger = logger -type bootstrap struct { - upstream string - cache *cache + dp.readyNotify = make(chan struct{}) + return dp, nil } -func newBootstrap(upstream string) *bootstrap { - return &bootstrap{ - upstream: upstream, - cache: newCache(0), - } -} +type Bootstrap struct { + fqdn string + port uint16 + bootstrap *net.UDPAddr + qt uint16 // dns.TypeA or dns.TypeAAAA + logger *zap.Logger // not nil -func (b *bootstrap) dial(_ context.Context, _, _ string) (net.Conn, error) { - c1, c2 := net.Pipe() - go func() { - _ = b.handlePipe(c2) - _ = c2.Close() - }() - return c1, nil -} + updating atomic.Bool + nextUpdate time.Time -func (b *bootstrap) handlePipe(c net.Conn) error { - q, _, err := dnsutils.ReadMsgFromTCP(c) - if err != nil { - return err - } - - var resp *dns.Msg - if len(q.Question) == 1 { - k := q.Question[0] - m := b.cache.lookup(k) - if m != nil { - resp = m.Copy() - resp.Id = q.Id - } - } + readyNotify chan struct{} + m sync.Mutex + ready bool + addrStr string +} - if resp == nil { - d := net.Dialer{} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - defer cancel() - upstreamConn, err := d.DialContext(ctx, "udp", b.upstream) - if err != nil { - return err - } - defer upstreamConn.Close() - _ = upstreamConn.SetDeadline(time.Now().Add(time.Second * 3)) - if _, err := dnsutils.WriteMsgToUDP(upstreamConn, q); err != nil { - return err - } - m, _, err := dnsutils.ReadMsgFromUDP(upstreamConn, 1500) - if err != nil { - return err - } - resp = m - } +func (sp *Bootstrap) GetAddrPortStr(ctx context.Context) (string, error) { + sp.tryUpdate() - if resp.Rcode == dns.RcodeSuccess && hasIP(resp) { - if len(resp.Question) == 1 { - k := resp.Question[0] - ttl := time.Duration(dnsutils.GetMinimalTTL(resp)) * time.Second - b.cache.store(k, resp, ttl) - } + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-sp.readyNotify: } - _, err = dnsutils.WriteMsgToTCP(c, resp) - return err + sp.m.Lock() + addr := sp.addrStr + sp.m.Unlock() + return addr, nil } -func hasIP(m *dns.Msg) bool { - for _, rr := range m.Answer { - switch rr.Header().Rrtype { - case dns.TypeA, dns.TypeAAAA: - return true +func (sp *Bootstrap) tryUpdate() { + if sp.updating.CompareAndSwap(false, true) { + if time.Now().After(sp.nextUpdate) { + go func() { + defer sp.updating.Store(false) + ctx, cancel := context.WithTimeout(context.Background(), queryTimeout) + defer cancel() + start := time.Now() + addr, ttl, err := sp.updateAddr(ctx) + if err != nil { + sp.logger.Warn("failed to update bootstrap addr", zap.String("fqdn", sp.fqdn), zap.Error(err)) + sp.nextUpdate = time.Now().Add(retryInterval) + } else { + updateInterval := time.Second * time.Duration(ttl) + if updateInterval < minimumUpdateInterval { + updateInterval = minimumUpdateInterval + } + sp.logger.Info( + "bootstrap addr updated", + zap.String("fqdn", sp.fqdn), + zap.Stringer("addr", addr), + zap.Duration("ttl", updateInterval), + zap.Duration("elapse", time.Since(start)), + ) + sp.nextUpdate = time.Now().Add(updateInterval) + } + }() + } else { + sp.updating.Store(false) } } - return false } -type cache struct { - l int +func (sp *Bootstrap) updateAddr(ctx context.Context) (netip.Addr, uint32, error) { + addr, ttl, err := sp.resolve(ctx, sp.qt) + if err != nil { + return netip.Addr{}, 0, err + } - m sync.Mutex - c map[key]*elem + addrPort := netip.AddrPortFrom(addr, sp.port).String() + sp.m.Lock() + sp.addrStr = addrPort + if !sp.ready { + sp.ready = true + close(sp.readyNotify) + } + sp.m.Unlock() + return addr, ttl, nil } -type elem struct { - m *dns.Msg - expirationTime time.Time -} +func (sp *Bootstrap) resolve(ctx context.Context, qt uint16) (netip.Addr, uint32, error) { + const edns0UdpSize = 1200 -type key = dns.Question + q := new(dns.Msg) + q.SetQuestion(sp.fqdn, qt) + q.SetEdns0(edns0UdpSize, false) -func newCache(size int) *cache { - const defaultSize = 8 - if size <= 0 { - size = defaultSize - } - return &cache{ - l: size, - c: make(map[key]*elem), + c, err := net.DialUDP("udp", nil, sp.bootstrap) + if err != nil { + return netip.Addr{}, 0, err } -} + defer c.Close() -// lookup returns a cached msg. Note: the msg must not be modified. -func (c *cache) lookup(k key) *dns.Msg { - now := time.Now() - c.m.Lock() - defer c.m.Unlock() - e, ok := c.c[k] - if !ok { - return nil - } - if e.expirationTime.Before(now) { - delete(c.c, k) - return nil + writeErrC := make(chan error, 1) + type res struct { + resp *dns.Msg + err error } - return e.m -} + readResC := make(chan res, 1) -// store stores a msg to cache. The caller MUST NOT modify m anymore. -func (c *cache) store(k key, m *dns.Msg, ttl time.Duration) { - if ttl <= 0 { - return - } - expirationTime := time.Now().Add(ttl) + cancelWrite := make(chan struct{}) + defer close(cancelWrite) + go func() { + if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil { + writeErrC <- err + return + } - c.m.Lock() - defer c.m.Unlock() + retryTicker := time.NewTicker(time.Second) + defer retryTicker.Stop() + for { + select { + case <-cancelWrite: + return + case <-retryTicker.C: + if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil { + writeErrC <- err + return + } + } + } + }() + + go func() { + m, _, err := dnsutils.ReadMsgFromUDP(c, edns0UdpSize) + readResC <- res{resp: m, err: err} + }() - if len(c.c)+1 > c.l { - for k := range c.c { - if len(c.c)+1 <= c.l { - break + select { + case <-ctx.Done(): + return netip.Addr{}, 0, ctx.Err() + case err := <-writeErrC: + return netip.Addr{}, 0, fmt.Errorf("failed to write query, %w", err) + case r := <-readResC: + resp := r.resp + err := r.err + if err != nil { + return netip.Addr{}, 0, fmt.Errorf("failed to read resp, %w", err) + } + + for _, v := range resp.Answer { + var ip net.IP + var ttl uint32 + switch rr := v.(type) { + case *dns.A: + ip = rr.A + ttl = rr.Hdr.Ttl + case *dns.AAAA: + ip = rr.AAAA + ttl = rr.Hdr.Ttl + default: + continue + } + addr, ok := netip.AddrFromSlice(ip) + if ok { + return addr, ttl, nil } - delete(c.c, k) } + + // No ip addr in resp. + return netip.Addr{}, 0, errNoAddrInResp } +} - c.c[k] = &elem{ - m: m, - expirationTime: expirationTime, +func bootstrapVer2Qt(ver int) (uint16, bool) { + switch ver { + case 0, 4: + return dns.TypeA, true + case 6: + return dns.TypeAAAA, true + default: + return 0, false } } diff --git a/pkg/upstream/bootstrap/bootstrap_test.go b/pkg/upstream/bootstrap/bootstrap_test.go deleted file mode 100644 index 93f14ff..0000000 --- a/pkg/upstream/bootstrap/bootstrap_test.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (C) 2020-2022, IrineSistiana - * - * This file is part of mosdns. - * - * mosdns is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * mosdns is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package bootstrap - -import ( - "context" - "os" - "testing" -) - -func Test_Bootstrap(t *testing.T) { - if os.Getenv("TEST_BOOTSTRAP") == "" { - t.SkipNow() - } - - r := NewBootstrap("8.8.8.8") - _, err := r.LookupIP(context.Background(), "ip", "google.com") - if err != nil { - t.Fatal(err) - } -} diff --git a/pkg/upstream/doh/upstream.go b/pkg/upstream/doh/upstream.go index addf5b7..abc124b 100644 --- a/pkg/upstream/doh/upstream.go +++ b/pkg/upstream/doh/upstream.go @@ -23,13 +23,14 @@ import ( "context" "encoding/base64" "fmt" - "github.com/IrineSistiana/mosdns/v5/pkg/pool" - "github.com/IrineSistiana/mosdns/v5/pkg/utils" - "github.com/miekg/dns" "io" "net/http" "strings" "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/miekg/dns" ) const ( @@ -42,21 +43,10 @@ type Upstream struct { EndPoint string // Client is a http.Client that sends http requests. Client *http.Client - - // AddOnCloser will be closed when Upstream is closed. - AddOnCloser io.Closer } -func (u *Upstream) CloseIdleConnections() { - u.Client.CloseIdleConnections() -} - -func (u *Upstream) Close() error { - u.Client.CloseIdleConnections() - if u.AddOnCloser != nil { - u.AddOnCloser.Close() - } - return nil +func NewUpstream(endPoint string, client *http.Client) *Upstream { + return &Upstream{EndPoint: endPoint, Client: client} } var ( @@ -127,7 +117,7 @@ func (u *Upstream) ExchangeContext(ctx context.Context, q *dns.Msg) (*dns.Msg, e func (u *Upstream) exchange(ctx context.Context, url string) (*dns.Msg, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return nil, fmt.Errorf("interal err: NewRequestWithContext: %w", err) + return nil, fmt.Errorf("internal err: NewRequestWithContext: %w", err) } req.Header["Accept"] = []string{"application/dns-message"} diff --git a/pkg/upstream/doq/upstream.go b/pkg/upstream/doq/upstream.go index 7b2d4e7..23d7f1c 100644 --- a/pkg/upstream/doq/upstream.go +++ b/pkg/upstream/doq/upstream.go @@ -22,8 +22,6 @@ package doq import ( "context" "crypto/rand" - "crypto/tls" - "errors" "net" "sync" "sync/atomic" @@ -39,48 +37,27 @@ const ( defaultDoQTimeout = time.Second * 5 dialTimeout = time.Second * 3 connectionLostThreshold = time.Second * 5 - handshakeTimeout = time.Second * 3 ) var ( - doqAlpn = []string{"doq"} + DoqAlpn = []string{"doq"} ) type Upstream struct { - t *quic.Transport - addr string - tlsConfig *tls.Config - quicConfig *quic.Config + dialer func(ctx context.Context) (quic.Connection, error) cm sync.Mutex lc *lazyConn } -// tlsConfig cannot be nil, it should have Servername or InsecureSkipVerify. -func NewUpstream(addr string, lc *net.UDPConn, tlsConfig *tls.Config, quicConfig *quic.Config) (*Upstream, error) { - srk, err := initSrk() - if err != nil { - return nil, err - } - if tlsConfig == nil { - return nil, errors.New("nil tls config") - } - - tlsConfig = tlsConfig.Clone() - tlsConfig.NextProtos = doqAlpn - +func NewUpstream(dialer func(ctx context.Context) (quic.Connection, error)) *Upstream { return &Upstream{ - t: &quic.Transport{ - Conn: lc, - StatelessResetKey: (*quic.StatelessResetKey)(srk), - }, - addr: addr, - tlsConfig: tlsConfig, - quicConfig: quicConfig, - }, nil + dialer: dialer, + } } -func initSrk() (*[32]byte, error) { +// A helper func to init quic stateless reset key. +func InitSrk() (*[32]byte, error) { var b [32]byte _, err := rand.Read(b[:]) if err != nil { @@ -89,10 +66,6 @@ func initSrk() (*[32]byte, error) { return &b, nil } -func (u *Upstream) Close() error { - return u.t.Close() -} - func (u *Upstream) newStream(ctx context.Context) (quic.Stream, *lazyConn, error) { var lc *lazyConn u.cm.Lock() @@ -154,31 +127,10 @@ func (u *Upstream) asyncDialConn() *lazyConn { go func() { defer cancel() + c, err := u.dialer(ctx) - ua, err := net.ResolveUDPAddr("udp", u.addr) // TODO: Support bootstrap. - if err != nil { - lc.err = err - return - } - - var c quic.Connection - ec, err := u.t.DialEarly(ctx, ua, u.tlsConfig, u.quicConfig) - if ec != nil { - // This is a workaround to - // 1. recover from strange 0rtt rejected err. - // 2. avoid NextConnection might block forever. - // TODO: Remove this workaround. - select { - case <-ctx.Done(): - err = context.Cause(ctx) - ec.CloseWithError(0, "") - case <-ec.HandshakeComplete(): - c = ec.NextConnection() - } - } - - var closeC bool lc.m.Lock() + var closeC bool if lc.closed { closeC = true // lc was closed, nothing to do } else { diff --git a/pkg/upstream/event_stat.go b/pkg/upstream/event_stat.go index cfb40c5..6b7b803 100644 --- a/pkg/upstream/event_stat.go +++ b/pkg/upstream/event_stat.go @@ -37,7 +37,7 @@ type EventObserver interface { type nopEO struct{} -func (n nopEO) OnEvent(_ Event) { return } +func (n nopEO) OnEvent(_ Event) {} type connWrapper struct { net.Conn diff --git a/pkg/upstream/transport/dns_conn.go b/pkg/upstream/transport/dns_conn.go index 1c81a92..6df458e 100644 --- a/pkg/upstream/transport/dns_conn.go +++ b/pkg/upstream/transport/dns_conn.go @@ -22,10 +22,11 @@ package transport import ( "context" "fmt" - "github.com/miekg/dns" "io" "sync" "sync/atomic" + + "github.com/miekg/dns" ) // dnsConn is a low-level connection for dns. diff --git a/pkg/upstream/transport/dns_conn_test.go b/pkg/upstream/transport/dns_conn_test.go index c677797..6e4eb27 100644 --- a/pkg/upstream/transport/dns_conn_test.go +++ b/pkg/upstream/transport/dns_conn_test.go @@ -22,9 +22,6 @@ package transport import ( "context" "errors" - "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" - "github.com/IrineSistiana/mosdns/v5/pkg/pool" - "github.com/miekg/dns" "io" "math/rand" "net" @@ -32,6 +29,10 @@ import ( "sync" "testing" "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/miekg/dns" ) var ( diff --git a/pkg/upstream/transport/pipeline.go b/pkg/upstream/transport/pipeline.go index e4bc39e..10c3d23 100644 --- a/pkg/upstream/transport/pipeline.go +++ b/pkg/upstream/transport/pipeline.go @@ -21,10 +21,11 @@ package transport import ( "context" - "github.com/miekg/dns" "math/rand" "sync" "time" + + "github.com/miekg/dns" ) // PipelineTransport will pipeline queries as RFC 7766 6.2.1.1 suggested. diff --git a/pkg/upstream/transport/pipeline_test.go b/pkg/upstream/transport/pipeline_test.go index 5fcc0ae..653d779 100644 --- a/pkg/upstream/transport/pipeline_test.go +++ b/pkg/upstream/transport/pipeline_test.go @@ -21,10 +21,11 @@ package transport import ( "context" - "github.com/miekg/dns" "sync" "testing" "time" + + "github.com/miekg/dns" ) // Leak and race tests. diff --git a/pkg/upstream/transport/reuse.go b/pkg/upstream/transport/reuse.go index c005dc0..5489833 100644 --- a/pkg/upstream/transport/reuse.go +++ b/pkg/upstream/transport/reuse.go @@ -21,8 +21,9 @@ package transport import ( "context" - "github.com/miekg/dns" "sync" + + "github.com/miekg/dns" ) type ReuseConnTransport struct { diff --git a/pkg/upstream/transport/reuse_test.go b/pkg/upstream/transport/reuse_test.go index 3b45961..ac68142 100644 --- a/pkg/upstream/transport/reuse_test.go +++ b/pkg/upstream/transport/reuse_test.go @@ -21,10 +21,11 @@ package transport import ( "context" - "github.com/miekg/dns" "sync" "testing" "time" + + "github.com/miekg/dns" ) // Leak and race tests. diff --git a/pkg/upstream/transport/transport.go b/pkg/upstream/transport/transport.go index 4542d65..d393600 100644 --- a/pkg/upstream/transport/transport.go +++ b/pkg/upstream/transport/transport.go @@ -22,9 +22,10 @@ package transport import ( "context" "errors" - "github.com/miekg/dns" "io" "time" + + "github.com/miekg/dns" ) var ( diff --git a/pkg/upstream/transport/utils.go b/pkg/upstream/transport/utils.go index 870a88b..233f2a8 100644 --- a/pkg/upstream/transport/utils.go +++ b/pkg/upstream/transport/utils.go @@ -20,10 +20,11 @@ package transport import ( - "github.com/miekg/dns" "math/rand" - "sync/atomic" + "sync" "time" + + "github.com/miekg/dns" ) func shadowCopy(m *dns.Msg) *dns.Msg { @@ -94,10 +95,10 @@ func slicePopLatest[T any](s *[]T) (T, bool) { } type idleTimer struct { - d time.Duration - updating atomic.Bool - t *time.Timer - stopped bool + d time.Duration + m sync.Mutex + t *time.Timer + stopped bool } func newIdleTimer(d time.Duration, f func()) *idleTimer { @@ -108,22 +109,29 @@ func newIdleTimer(d time.Duration, f func()) *idleTimer { } func (t *idleTimer) reset(d time.Duration) { - if t.updating.CompareAndSwap(false, true) { - defer t.updating.Store(false) - if t.stopped { - return - } - if d <= 0 { - d = t.d - } - if !t.t.Reset(t.d) { - t.stopped = true - // re-activated. stop it - t.t.Stop() - } + t.m.Lock() + defer t.m.Unlock() + if t.stopped { + return + } + + if d <= 0 { + d = t.d + } + + if !t.t.Reset(d) { + t.stopped = true + // re-activated. stop it + t.t.Stop() } } func (t *idleTimer) stop() { + t.m.Lock() + defer t.m.Unlock() + if t.stopped { + return + } + t.stopped = true t.t.Stop() } diff --git a/pkg/upstream/upstream.go b/pkg/upstream/upstream.go index 1bc80a2..ab77aba 100644 --- a/pkg/upstream/upstream.go +++ b/pkg/upstream/upstream.go @@ -22,10 +22,12 @@ package upstream import ( "context" "crypto/tls" + "errors" "fmt" "io" "net" "net/http" + "net/netip" "net/url" "strconv" "strings" @@ -42,10 +44,11 @@ import ( "github.com/quic-go/quic-go/http3" "go.uber.org/zap" "golang.org/x/net/http2" + "golang.org/x/net/proxy" ) const ( - tlsHandshakeTimeout = time.Second * 5 + tlsHandshakeTimeout = time.Second * 3 ) // Upstream represents a DNS upstream. @@ -59,12 +62,17 @@ type Upstream interface { type Opt struct { // DialAddr specifies the address the upstream will - // actually dial to. + // actually dial to in the network layer by overwriting + // the address inferred from upstream url. + // It won't affect high level layers. (e.g. SNI, HTTP HOST header won't be changed). + // Can be an IP or a domain. Port is optional. + // Tips: If the upstream url host is a domain, specific an IP address + // here can skip resolving ip of this domain. DialAddr string // Socks5 specifies the socks5 proxy server that the upstream // will connect though. - // Not implemented for udp upstreams and doh upstreams with http/3. + // Not implemented for udp based protocols (aka. dns over udp, http3, quic). Socks5 string // SoMark sets the socket SO_MARK option in unix system. @@ -75,40 +83,42 @@ type Opt struct { // IdleTimeout specifies the idle timeout for long-connections. // Available for TCP, DoT, DoH. - // If negative, TCP, DoT will not reuse connections. - // Default: TCP, DoT: 10s , DoH: 30s. + // Default: TCP, DoT: 10s , DoH, DoQ: 30s. IdleTimeout time.Duration // EnablePipeline enables query pipelining support as RFC 7766 6.2.1.1 suggested. // Available for TCP, DoT upstream with IdleTimeout >= 0. + // Note: There is no fallback. EnablePipeline bool // EnableHTTP3 enables HTTP/3 protocol for DoH upstream. + // Note: There is no fallback. EnableHTTP3 bool // MaxConns limits the total number of connections, including connections // in the dialing states. - // Implemented for TCP/DoT pipeline enabled upstreams and DoH upstreams. + // Implemented for TCP/DoT pipeline enabled upstream and DoH upstream. // Default is 2. MaxConns int - // Bootstrap specifies a plain dns server for the go runtime to solve the - // domain of the upstream server. It SHOULD be an IP address. Custom port - // is supported. - // Note: Use a domain address may cause dead resolve loop and additional - // latency to dial upstream server. - // HTTP3 is not supported. + // Bootstrap specifies a plain dns server to solve the + // upstream server domain address. + // It must be an IP address. Port is optional. Bootstrap string + // Bootstrap version. One of 0 (default equals 4), 4, 6. + // TODO: Support dual-stack. + BootstrapVer int + // TLSConfig specifies the tls.Config that the TLS client will use. - // Available for DoT, DoH upstreams. + // Available for DoT, DoH upstream. TLSConfig *tls.Config // Logger specifies the logger that the upstream will use. Logger *zap.Logger // EventObserver can observe connection events. - // Note: Not Implemented for HTTP/3 upstreams. + // Not implemented for udp based protocols (dns over udp, http3, quic). EventObserver EventObserver } @@ -129,17 +139,127 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { return nil, fmt.Errorf("invalid server address, %w", err) } + // If host is a ipv6 without port, it will be in []. This will cause err when + // split and join address and port. Try to remove brackets now. + addrUrlHost := tryTrimIpv6Brackets(addrURL.Host) + dialer := &net.Dialer{ - Resolver: bootstrap.NewBootstrap(opt.Bootstrap), Control: getSocketControlFunc(socketOpts{ so_mark: opt.SoMark, bind_to_device: opt.BindToDevice, }), } + var bootstrapAp netip.AddrPort + if s := opt.Bootstrap; len(s) > 0 { + bootstrapAp, err = parseBootstrapAp(s) + if err != nil { + return nil, fmt.Errorf("invalid bootstrap, %w", err) + } + } + + newUdpAddrResolveFunc := func(defaultPort uint16) (func(ctx context.Context) (*net.UDPAddr, error), error) { + host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) + if err != nil { + return nil, err + } + + if addr, err := netip.ParseAddr(host); err == nil { // host is an ip. + ua := net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port)) + return func(ctx context.Context) (*net.UDPAddr, error) { + return ua, nil + }, nil + } else { // Not an ip, assuming it's a domain name. + if bootstrapAp.IsValid() { + // Bootstrap enabled. + bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger) + if err != nil { + return nil, err + } + + return func(ctx context.Context) (*net.UDPAddr, error) { + s, err := bs.GetAddrPortStr(ctx) + if err != nil { + return nil, fmt.Errorf("bootstrap failed, %w", err) + } + return net.ResolveUDPAddr("udp", s) + }, nil + } else { + // Bootstrap disabled. + dialAddr := joinPort(host, port) + return func(ctx context.Context) (*net.UDPAddr, error) { + return net.ResolveUDPAddr("udp", dialAddr) + }, nil + } + } + } + + newTcpDialer := func(dialAddrMustBeIp bool, defaultPort uint16) (func(ctx context.Context) (net.Conn, error), error) { + host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) + if err != nil { + return nil, err + } + + // Socks5 enabled. + if s5Addr := opt.Socks5; len(s5Addr) > 0 { + socks5Dialer, err := proxy.SOCKS5("tcp", s5Addr, nil, dialer) + if err != nil { + return nil, fmt.Errorf("failed to init socks5 dialer: %w", err) + } + + contextDialer := socks5Dialer.(proxy.ContextDialer) + dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) + return func(ctx context.Context) (net.Conn, error) { + return contextDialer.DialContext(ctx, "tcp", dialAddr) + }, nil + } + + if _, err := netip.ParseAddr(host); err == nil { + // Host is an ip addr. No need to resolve it. + dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) + return func(ctx context.Context) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", dialAddr) + }, nil + } else { + if dialAddrMustBeIp { + return nil, errors.New("addr must be an ip address") + } + // Host is not an ip addr, assuming it is a domain. + if bootstrapAp.IsValid() { + // Bootstrap enabled. + bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger) + if err != nil { + return nil, err + } + + return func(ctx context.Context) (net.Conn, error) { + dialAddr, err := bs.GetAddrPortStr(ctx) + if err != nil { + return nil, fmt.Errorf("bootstrap failed, %w", err) + } + return dialer.DialContext(ctx, "tcp", dialAddr) + }, nil + } else { + // Bootstrap disabled. + dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) + return func(ctx context.Context) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", dialAddr) + }, nil + } + } + } + switch addrURL.Scheme { case "", "udp": - dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 53) + const defaultPort = 53 + host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) + if err != nil { + return nil, err + } + if _, err := netip.ParseAddr(host); err != nil { + return nil, fmt.Errorf("addr must be an ip address, %w", err) + } + dialAddr := joinPort(host, port) uto := transport.IOOpts{ DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) { c, err := dialer.DialContext(ctx, "udp", dialAddr) @@ -166,10 +286,14 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { t: transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: tto}), }, nil case "tcp": - dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 53) + const defaultPort = 53 + tcpDialer, err := newTcpDialer(true, defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init tcp dialer, %w", err) + } to := transport.IOOpts{ DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) { - c, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer) + c, err := tcpDialer(ctx) c = wrapConn(c, opt.EventObserver) return c, err }, @@ -182,18 +306,22 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { } return transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: to}), nil case "tls": + const defaultPort = 853 tlsConfig := opt.TLSConfig.Clone() if tlsConfig == nil { tlsConfig = new(tls.Config) } if len(tlsConfig.ServerName) == 0 { - tlsConfig.ServerName = tryRemovePort(addrURL.Host) + tlsConfig.ServerName = tryRemovePort(addrUrlHost) } - dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 853) + tcpDialer, err := newTcpDialer(false, defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init tcp dialer, %w", err) + } to := transport.IOOpts{ DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) { - conn, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer) + conn, err := tcpDialer(ctx) if err != nil { return nil, err } @@ -203,7 +331,6 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { tlsConn.Close() return nil, err } - return tlsConn, nil }, WriteFunc: dnsutils.WriteMsgToTCP, @@ -215,6 +342,7 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { } return transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: to}), nil case "https": + const defaultPort = 443 idleConnTimeout := time.Second * 30 if opt.IdleTimeout > 0 { idleConnTimeout = opt.IdleTimeout @@ -224,41 +352,42 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { maxConn = opt.MaxConns } - dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 443) var t http.RoundTripper - var addonCloser io.Closer // udpConn + var addonCloser io.Closer if opt.EnableHTTP3 { + udpBootstrap, err := newUdpAddrResolveFunc(defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err) + } + lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})} conn, err := lc.ListenPacket(context.Background(), "udp", "") if err != nil { return nil, fmt.Errorf("failed to init udp socket for quic") } - qt := &quic.Transport{ + quicTransport := &quic.Transport{ Conn: conn, } - addonCloser = qt + addonCloser = quicTransport t = &http3.RoundTripper{ TLSClientConfig: opt.TLSConfig, - QuicConfig: &quic.Config{ - TokenStore: quic.NewLRUTokenStore(4, 8), - InitialStreamReceiveWindow: 4 * 1024, - MaxStreamReceiveWindow: 4 * 1024, - InitialConnectionReceiveWindow: 8 * 1024, - MaxConnectionReceiveWindow: 64 * 1024, - }, + QuicConfig: newDefaultQuicConfig(), Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - ua, err := net.ResolveUDPAddr("udp", dialAddr) // TODO: Support bootstrap. + ua, err := udpBootstrap(ctx) if err != nil { return nil, err } - - return qt.DialEarly(ctx, ua, tlsCfg, cfg) + return quicTransport.DialEarly(ctx, ua, tlsCfg, cfg) }, } } else { + tcpDialer, err := newTcpDialer(false, defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init tcp dialer, %w", err) + } t1 := &http.Transport{ DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { // overwrite server addr - c, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer) + c, err := tcpDialer(ctx) c = wrapConn(c, opt.EventObserver) return c, err }, @@ -281,26 +410,34 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { t = t1 } - return &doh.Upstream{ - EndPoint: addr, - Client: &http.Client{Transport: t}, - AddOnCloser: addonCloser, + return &dohWithClose{ + u: doh.NewUpstream(addr, &http.Client{Transport: t}), + closer: addonCloser, }, nil case "quic", "doq": + const defaultPort = 853 tlsConfig := opt.TLSConfig.Clone() if tlsConfig == nil { tlsConfig = new(tls.Config) } if len(tlsConfig.ServerName) == 0 { - tlsConfig.ServerName = tryRemovePort(addrURL.Host) + tlsConfig.ServerName = tryRemovePort(addrUrlHost) } + tlsConfig.NextProtos = doq.DoqAlpn - quicConfig := &quic.Config{ - TokenStore: quic.NewLRUTokenStore(4, 8), - InitialStreamReceiveWindow: 4 * 1024, - MaxStreamReceiveWindow: 4 * 1024, - InitialConnectionReceiveWindow: 8 * 1024, - MaxConnectionReceiveWindow: 64 * 1024, + quicConfig := newDefaultQuicConfig() + if opt.IdleTimeout > 0 { + quicConfig.MaxIdleTimeout = opt.IdleTimeout + } + + udpBootstrap, err := newUdpAddrResolveFunc(defaultPort) + if err != nil { + return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err) + } + + srk, err := doq.InitSrk() + if err != nil { + opt.Logger.Error("failed to init quic stateless reset key", zap.Error(err)) } lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})} @@ -309,35 +446,44 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { return nil, fmt.Errorf("failed to init udp socket for quic") } - dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 853) - u, err := doq.NewUpstream(dialAddr, uc.(*net.UDPConn), tlsConfig, quicConfig) - if err != nil { - return nil, fmt.Errorf("failed to setup doq upstream, %w", err) + t := &quic.Transport{ + Conn: uc, + StatelessResetKey: (*quic.StatelessResetKey)(srk), } - return u, nil - default: - return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme) - } -} -func getDialAddrWithPort(host, dialAddr string, defaultPort int) string { - addr := host - if len(dialAddr) > 0 { - addr = dialAddr - } - _, _, err := net.SplitHostPort(addr) - if err != nil { // no port, add it. - return net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(defaultPort)) - } - return addr -} + dialer := func(ctx context.Context) (quic.Connection, error) { + ua, err := udpBootstrap(ctx) + if err != nil { + return nil, fmt.Errorf("bootstrap failed, %w", err) + } + var c quic.Connection + ec, err := t.DialEarly(ctx, ua, tlsConfig, quicConfig) + if err != nil { + return nil, err + } -func tryRemovePort(s string) string { - host, _, err := net.SplitHostPort(s) - if err != nil { - return s + // This is a workaround to + // 1. recover from strange 0rtt rejected err. + // 2. avoid NextConnection might block forever. + // TODO: Remove this workaround. + select { + case <-ctx.Done(): + err := context.Cause(ctx) + ec.CloseWithError(0, "") + return nil, err + case <-ec.HandshakeComplete(): + c = ec.NextConnection() + } + return c, nil + } + + return &doqWithClose{ + u: doq.NewUpstream(dialer), + t: t, + }, nil + default: + return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme) } - return host } type udpWithFallback struct { @@ -361,3 +507,48 @@ func (u *udpWithFallback) Close() error { u.t.Close() return nil } + +type doqWithClose struct { + u *doq.Upstream + t *quic.Transport +} + +func (u *doqWithClose) ExchangeContext(ctx context.Context, m *dns.Msg) (*dns.Msg, error) { + return u.u.ExchangeContext(ctx, m) +} + +func (u *doqWithClose) Close() error { + return u.t.Close() +} + +type dohWithClose struct { + u *doh.Upstream + closer io.Closer // maybe nil +} + +func (u *dohWithClose) ExchangeContext(ctx context.Context, m *dns.Msg) (*dns.Msg, error) { + return u.u.ExchangeContext(ctx, m) +} + +func (u *dohWithClose) Close() error { + if u.closer != nil { + return u.closer.Close() + } + return nil +} + +func newDefaultQuicConfig() *quic.Config { + return &quic.Config{ + TokenStore: quic.NewLRUTokenStore(4, 8), + + // Dns does not need large amount of io, so the rx/tx windows are small. + InitialStreamReceiveWindow: 4 * 1024, + MaxStreamReceiveWindow: 4 * 1024, + InitialConnectionReceiveWindow: 8 * 1024, + MaxConnectionReceiveWindow: 64 * 1024, + + MaxIdleTimeout: time.Second * 30, + KeepAlivePeriod: time.Second * 25, + HandshakeIdleTimeout: tlsHandshakeTimeout, + } +} diff --git a/pkg/upstream/upstream_test.go b/pkg/upstream/upstream_test.go index 675b1ff..55e3b6d 100644 --- a/pkg/upstream/upstream_test.go +++ b/pkg/upstream/upstream_test.go @@ -24,12 +24,13 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/IrineSistiana/mosdns/v5/pkg/utils" - "github.com/miekg/dns" "net" "sync" "testing" "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/miekg/dns" ) func newUDPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) { @@ -68,6 +69,9 @@ func newTCPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownF func newDoTTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) { serverName := "test" cert, err := utils.GenerateCertificate(serverName) + if err != nil { + t.Fatal(err) + } tlsConfig := new(tls.Config) tlsConfig.Certificates = []tls.Certificate{cert} tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig) diff --git a/pkg/upstream/utils.go b/pkg/upstream/utils.go index 7d297f3..8a42125 100644 --- a/pkg/upstream/utils.go +++ b/pkg/upstream/utils.go @@ -20,10 +20,10 @@ package upstream import ( - "context" "fmt" - "golang.org/x/net/proxy" "net" + "net/netip" + "strconv" ) type socketOpts struct { @@ -31,14 +31,70 @@ type socketOpts struct { bind_to_device string } -func dialTCP(ctx context.Context, addr, socks5 string, dialer *net.Dialer) (net.Conn, error) { - if len(socks5) > 0 { - socks5Dialer, err := proxy.SOCKS5("tcp", socks5, nil, dialer) +func parseDialAddr(urlHost, dialAddr string, defaultPort uint16) (string, uint16, error) { + addr := urlHost + if len(dialAddr) > 0 { + addr = dialAddr + } + host, port, err := trySplitHostPort(addr) + if err != nil { + return "", 0, err + } + if port == 0 { + port = defaultPort + } + return host, port, nil +} + +func joinPort(host string, port uint16) string { + return net.JoinHostPort(host, strconv.Itoa(int(port))) +} + +func tryRemovePort(s string) string { + host, _, err := net.SplitHostPort(s) + if err != nil { + return s + } + return host +} + +// trySplitHostPort splits host and port. +// If s has no port, it returns s,0,nil +func trySplitHostPort(s string) (string, uint16, error) { + var port uint16 + host, portS, err := net.SplitHostPort(s) + if err == nil { + n, err := strconv.ParseUint(portS, 10, 16) if err != nil { - return nil, fmt.Errorf("failed to init socks5 dialer: %w", err) + return "", 0, fmt.Errorf("invalid port, %w", err) } - return socks5Dialer.(proxy.ContextDialer).DialContext(ctx, "tcp", addr) + port = uint16(n) + return host, port, nil } + return s, 0, nil +} - return dialer.DialContext(ctx, "tcp", addr) +func parseBootstrapAp(s string) (netip.AddrPort, error) { + host, port, err := trySplitHostPort(s) + if err != nil { + return netip.AddrPort{}, err + } + if port == 0 { + port = 53 + } + addr, err := netip.ParseAddr(host) + if err != nil { + return netip.AddrPort{}, err + } + return netip.AddrPortFrom(addr, port), nil +} + +func tryTrimIpv6Brackets(s string) string { + if len(s) < 2 { + return s + } + if s[0] == '[' && s[len(s)-1] == ']' { + return s[1 : len(s)-2] + } + return s } diff --git a/pkg/upstream/utils_unix.go b/pkg/upstream/utils_unix.go index b8ced47..685a82b 100644 --- a/pkg/upstream/utils_unix.go +++ b/pkg/upstream/utils_unix.go @@ -22,9 +22,10 @@ package upstream import ( - "golang.org/x/sys/unix" "os" "syscall" + + "golang.org/x/sys/unix" ) func getSocketControlFunc(opts socketOpts) func(string, string, syscall.RawConn) error { diff --git a/plugin/executable/forward/forward.go b/plugin/executable/forward/forward.go index be39ad7..bba0da2 100644 --- a/plugin/executable/forward/forward.go +++ b/plugin/executable/forward/forward.go @@ -24,6 +24,9 @@ import ( "crypto/tls" "errors" "fmt" + "strings" + "time" + "github.com/IrineSistiana/mosdns/v5/coremain" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/IrineSistiana/mosdns/v5/pkg/upstream" @@ -32,8 +35,6 @@ import ( "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" - "strings" - "time" ) const PluginType = "forward" @@ -57,6 +58,7 @@ type Args struct { SoMark int `yaml:"so_mark"` BindToDevice string `yaml:"bind_to_device"` Bootstrap string `yaml:"bootstrap"` + BootstrapVer int `yaml:"bootstrap_version"` } type UpstreamConfig struct { @@ -73,6 +75,7 @@ type UpstreamConfig struct { SoMark int `yaml:"so_mark"` BindToDevice string `yaml:"bind_to_device"` Bootstrap string `yaml:"bootstrap"` + BootstrapVer int `yaml:"bootstrap_version"` } func Init(bp *coremain.BP, args any) (any, error) { @@ -125,6 +128,7 @@ func NewForward(args *Args, opt Opts) (*Forward, error) { utils.SetDefaultUnsignNum(&c.SoMark, args.SoMark) utils.SetDefaultString(&c.BindToDevice, args.BindToDevice) utils.SetDefaultString(&c.Bootstrap, args.Bootstrap) + utils.SetDefaultUnsignNum(&c.BootstrapVer, args.BootstrapVer) } for i, c := range args.Upstreams { @@ -144,6 +148,7 @@ func NewForward(args *Args, opt Opts) (*Forward, error) { EnablePipeline: c.EnablePipeline, EnableHTTP3: c.EnableHTTP3, Bootstrap: c.Bootstrap, + BootstrapVer: c.BootstrapVer, TLSConfig: &tls.Config{ InsecureSkipVerify: c.InsecureSkipVerify, ClientSessionCache: tls.NewLRUClientSessionCache(4), -- 2.34.8