1541 lines
42 KiB
Diff
1541 lines
42 KiB
Diff
From 5b96077d2a44d085a8ae44b1b368808184dea64b Mon Sep 17 00:00:00 2001
|
|
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
|
|
Date: Thu, 21 Sep 2023 11:52:20 +0800
|
|
Subject: [PATCH 6/6] upstream: switch to self implement bootstrap
|
|
|
|
refine code structure
|
|
---
|
|
pkg/upstream/bootstrap/bootstrap.go | 324 ++++++++++++----------
|
|
pkg/upstream/bootstrap/bootstrap_test.go | 38 ---
|
|
pkg/upstream/doh/upstream.go | 24 +-
|
|
pkg/upstream/doq/upstream.go | 66 +----
|
|
pkg/upstream/event_stat.go | 2 +-
|
|
pkg/upstream/transport/dns_conn.go | 3 +-
|
|
pkg/upstream/transport/dns_conn_test.go | 7 +-
|
|
pkg/upstream/transport/pipeline.go | 3 +-
|
|
pkg/upstream/transport/pipeline_test.go | 3 +-
|
|
pkg/upstream/transport/reuse.go | 3 +-
|
|
pkg/upstream/transport/reuse_test.go | 3 +-
|
|
pkg/upstream/transport/transport.go | 3 +-
|
|
pkg/upstream/transport/utils.go | 46 ++--
|
|
pkg/upstream/upstream.go | 337 ++++++++++++++++++-----
|
|
pkg/upstream/upstream_test.go | 8 +-
|
|
pkg/upstream/utils.go | 72 ++++-
|
|
pkg/upstream/utils_unix.go | 3 +-
|
|
plugin/executable/forward/forward.go | 9 +-
|
|
18 files changed, 586 insertions(+), 368 deletions(-)
|
|
delete mode 100644 pkg/upstream/bootstrap/bootstrap_test.go
|
|
|
|
diff --git a/pkg/upstream/bootstrap/bootstrap.go b/pkg/upstream/bootstrap/bootstrap.go
|
|
index b831757..2cd8ef9 100644
|
|
--- a/pkg/upstream/bootstrap/bootstrap.go
|
|
+++ b/pkg/upstream/bootstrap/bootstrap.go
|
|
@@ -21,184 +21,226 @@ package bootstrap
|
|
|
|
import (
|
|
"context"
|
|
- "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
|
- "github.com/miekg/dns"
|
|
+ "errors"
|
|
+ "fmt"
|
|
"net"
|
|
- "strings"
|
|
+ "net/netip"
|
|
"sync"
|
|
+ "sync/atomic"
|
|
"time"
|
|
+
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
|
+ "github.com/miekg/dns"
|
|
+ "go.uber.org/zap"
|
|
)
|
|
|
|
-// NewBootstrap returns a customized *net.Resolver which can be used as a Bootstrap for a
|
|
-// certain domain. Its Dial func is modified to dial s through udp.
|
|
-// It also has a small built-in cache.
|
|
-// s SHOULD be a literal IP address and the port SHOULD also be literal.
|
|
-// Port can be omitted. In this case, the default port is :53.
|
|
-// e.g. NewBootstrap("8.8.8.8"), NewBootstrap("127.0.0.1:5353").
|
|
-// If s is empty, NewBootstrap returns nil. (A nil *net.Resolver is valid in net.Dialer.)
|
|
-// Note that not all platform support a customized *net.Resolver. It also depends on the
|
|
-// version of go runtime.
|
|
-// See the package docs from the net package for more info.
|
|
-func NewBootstrap(s string) *net.Resolver {
|
|
- if len(s) == 0 {
|
|
- return nil
|
|
- }
|
|
- // Add port.
|
|
- _, _, err := net.SplitHostPort(s)
|
|
- if err != nil { // no port, add it.
|
|
- s = net.JoinHostPort(strings.Trim(s, "[]"), "53")
|
|
- }
|
|
+const (
|
|
+ minimumUpdateInterval = time.Minute * 5
|
|
+ retryInterval = time.Second * 2
|
|
+ queryTimeout = time.Second * 5
|
|
+)
|
|
|
|
- bs := newBootstrap(s)
|
|
+var (
|
|
+ errNoAddrInResp = errors.New("resp does not have ip address")
|
|
+)
|
|
|
|
- return &net.Resolver{
|
|
- PreferGo: true,
|
|
- StrictErrors: false,
|
|
- Dial: bs.dial,
|
|
+func New(
|
|
+ host string,
|
|
+ port uint16,
|
|
+ bootstrapServer netip.AddrPort,
|
|
+ bootstrapVer int, // 0,4,6
|
|
+ logger *zap.Logger, // not nil
|
|
+) (*Bootstrap, error) {
|
|
+ dp := new(Bootstrap)
|
|
+ dp.fqdn = dns.Fqdn(host)
|
|
+ dp.port = port
|
|
+ if !bootstrapServer.IsValid() {
|
|
+ return nil, errors.New("invalid bootstrap server address")
|
|
+ }
|
|
+ dp.bootstrap = net.UDPAddrFromAddrPort(bootstrapServer)
|
|
+ qt, ok := bootstrapVer2Qt(bootstrapVer)
|
|
+ if !ok {
|
|
+ return nil, fmt.Errorf("invalid bootstrap version %d", bootstrapVer)
|
|
}
|
|
-}
|
|
+ dp.qt = qt
|
|
+ dp.logger = logger
|
|
|
|
-type bootstrap struct {
|
|
- upstream string
|
|
- cache *cache
|
|
+ dp.readyNotify = make(chan struct{})
|
|
+ return dp, nil
|
|
}
|
|
|
|
-func newBootstrap(upstream string) *bootstrap {
|
|
- return &bootstrap{
|
|
- upstream: upstream,
|
|
- cache: newCache(0),
|
|
- }
|
|
-}
|
|
+type Bootstrap struct {
|
|
+ fqdn string
|
|
+ port uint16
|
|
+ bootstrap *net.UDPAddr
|
|
+ qt uint16 // dns.TypeA or dns.TypeAAAA
|
|
+ logger *zap.Logger // not nil
|
|
|
|
-func (b *bootstrap) dial(_ context.Context, _, _ string) (net.Conn, error) {
|
|
- c1, c2 := net.Pipe()
|
|
- go func() {
|
|
- _ = b.handlePipe(c2)
|
|
- _ = c2.Close()
|
|
- }()
|
|
- return c1, nil
|
|
-}
|
|
+ updating atomic.Bool
|
|
+ nextUpdate time.Time
|
|
|
|
-func (b *bootstrap) handlePipe(c net.Conn) error {
|
|
- q, _, err := dnsutils.ReadMsgFromTCP(c)
|
|
- if err != nil {
|
|
- return err
|
|
- }
|
|
-
|
|
- var resp *dns.Msg
|
|
- if len(q.Question) == 1 {
|
|
- k := q.Question[0]
|
|
- m := b.cache.lookup(k)
|
|
- if m != nil {
|
|
- resp = m.Copy()
|
|
- resp.Id = q.Id
|
|
- }
|
|
- }
|
|
+ readyNotify chan struct{}
|
|
+ m sync.Mutex
|
|
+ ready bool
|
|
+ addrStr string
|
|
+}
|
|
|
|
- if resp == nil {
|
|
- d := net.Dialer{}
|
|
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
|
|
- defer cancel()
|
|
- upstreamConn, err := d.DialContext(ctx, "udp", b.upstream)
|
|
- if err != nil {
|
|
- return err
|
|
- }
|
|
- defer upstreamConn.Close()
|
|
- _ = upstreamConn.SetDeadline(time.Now().Add(time.Second * 3))
|
|
- if _, err := dnsutils.WriteMsgToUDP(upstreamConn, q); err != nil {
|
|
- return err
|
|
- }
|
|
- m, _, err := dnsutils.ReadMsgFromUDP(upstreamConn, 1500)
|
|
- if err != nil {
|
|
- return err
|
|
- }
|
|
- resp = m
|
|
- }
|
|
+func (sp *Bootstrap) GetAddrPortStr(ctx context.Context) (string, error) {
|
|
+ sp.tryUpdate()
|
|
|
|
- if resp.Rcode == dns.RcodeSuccess && hasIP(resp) {
|
|
- if len(resp.Question) == 1 {
|
|
- k := resp.Question[0]
|
|
- ttl := time.Duration(dnsutils.GetMinimalTTL(resp)) * time.Second
|
|
- b.cache.store(k, resp, ttl)
|
|
- }
|
|
+ select {
|
|
+ case <-ctx.Done():
|
|
+ return "", ctx.Err()
|
|
+ case <-sp.readyNotify:
|
|
}
|
|
|
|
- _, err = dnsutils.WriteMsgToTCP(c, resp)
|
|
- return err
|
|
+ sp.m.Lock()
|
|
+ addr := sp.addrStr
|
|
+ sp.m.Unlock()
|
|
+ return addr, nil
|
|
}
|
|
|
|
-func hasIP(m *dns.Msg) bool {
|
|
- for _, rr := range m.Answer {
|
|
- switch rr.Header().Rrtype {
|
|
- case dns.TypeA, dns.TypeAAAA:
|
|
- return true
|
|
+func (sp *Bootstrap) tryUpdate() {
|
|
+ if sp.updating.CompareAndSwap(false, true) {
|
|
+ if time.Now().After(sp.nextUpdate) {
|
|
+ go func() {
|
|
+ defer sp.updating.Store(false)
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), queryTimeout)
|
|
+ defer cancel()
|
|
+ start := time.Now()
|
|
+ addr, ttl, err := sp.updateAddr(ctx)
|
|
+ if err != nil {
|
|
+ sp.logger.Warn("failed to update bootstrap addr", zap.String("fqdn", sp.fqdn), zap.Error(err))
|
|
+ sp.nextUpdate = time.Now().Add(retryInterval)
|
|
+ } else {
|
|
+ updateInterval := time.Second * time.Duration(ttl)
|
|
+ if updateInterval < minimumUpdateInterval {
|
|
+ updateInterval = minimumUpdateInterval
|
|
+ }
|
|
+ sp.logger.Info(
|
|
+ "bootstrap addr updated",
|
|
+ zap.String("fqdn", sp.fqdn),
|
|
+ zap.Stringer("addr", addr),
|
|
+ zap.Duration("ttl", updateInterval),
|
|
+ zap.Duration("elapse", time.Since(start)),
|
|
+ )
|
|
+ sp.nextUpdate = time.Now().Add(updateInterval)
|
|
+ }
|
|
+ }()
|
|
+ } else {
|
|
+ sp.updating.Store(false)
|
|
}
|
|
}
|
|
- return false
|
|
}
|
|
|
|
-type cache struct {
|
|
- l int
|
|
+func (sp *Bootstrap) updateAddr(ctx context.Context) (netip.Addr, uint32, error) {
|
|
+ addr, ttl, err := sp.resolve(ctx, sp.qt)
|
|
+ if err != nil {
|
|
+ return netip.Addr{}, 0, err
|
|
+ }
|
|
|
|
- m sync.Mutex
|
|
- c map[key]*elem
|
|
+ addrPort := netip.AddrPortFrom(addr, sp.port).String()
|
|
+ sp.m.Lock()
|
|
+ sp.addrStr = addrPort
|
|
+ if !sp.ready {
|
|
+ sp.ready = true
|
|
+ close(sp.readyNotify)
|
|
+ }
|
|
+ sp.m.Unlock()
|
|
+ return addr, ttl, nil
|
|
}
|
|
|
|
-type elem struct {
|
|
- m *dns.Msg
|
|
- expirationTime time.Time
|
|
-}
|
|
+func (sp *Bootstrap) resolve(ctx context.Context, qt uint16) (netip.Addr, uint32, error) {
|
|
+ const edns0UdpSize = 1200
|
|
|
|
-type key = dns.Question
|
|
+ q := new(dns.Msg)
|
|
+ q.SetQuestion(sp.fqdn, qt)
|
|
+ q.SetEdns0(edns0UdpSize, false)
|
|
|
|
-func newCache(size int) *cache {
|
|
- const defaultSize = 8
|
|
- if size <= 0 {
|
|
- size = defaultSize
|
|
- }
|
|
- return &cache{
|
|
- l: size,
|
|
- c: make(map[key]*elem),
|
|
+ c, err := net.DialUDP("udp", nil, sp.bootstrap)
|
|
+ if err != nil {
|
|
+ return netip.Addr{}, 0, err
|
|
}
|
|
-}
|
|
+ defer c.Close()
|
|
|
|
-// lookup returns a cached msg. Note: the msg must not be modified.
|
|
-func (c *cache) lookup(k key) *dns.Msg {
|
|
- now := time.Now()
|
|
- c.m.Lock()
|
|
- defer c.m.Unlock()
|
|
- e, ok := c.c[k]
|
|
- if !ok {
|
|
- return nil
|
|
- }
|
|
- if e.expirationTime.Before(now) {
|
|
- delete(c.c, k)
|
|
- return nil
|
|
+ writeErrC := make(chan error, 1)
|
|
+ type res struct {
|
|
+ resp *dns.Msg
|
|
+ err error
|
|
}
|
|
- return e.m
|
|
-}
|
|
+ readResC := make(chan res, 1)
|
|
|
|
-// store stores a msg to cache. The caller MUST NOT modify m anymore.
|
|
-func (c *cache) store(k key, m *dns.Msg, ttl time.Duration) {
|
|
- if ttl <= 0 {
|
|
- return
|
|
- }
|
|
- expirationTime := time.Now().Add(ttl)
|
|
+ cancelWrite := make(chan struct{})
|
|
+ defer close(cancelWrite)
|
|
+ go func() {
|
|
+ if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil {
|
|
+ writeErrC <- err
|
|
+ return
|
|
+ }
|
|
|
|
- c.m.Lock()
|
|
- defer c.m.Unlock()
|
|
+ retryTicker := time.NewTicker(time.Second)
|
|
+ defer retryTicker.Stop()
|
|
+ for {
|
|
+ select {
|
|
+ case <-cancelWrite:
|
|
+ return
|
|
+ case <-retryTicker.C:
|
|
+ if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil {
|
|
+ writeErrC <- err
|
|
+ return
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ }()
|
|
+
|
|
+ go func() {
|
|
+ m, _, err := dnsutils.ReadMsgFromUDP(c, edns0UdpSize)
|
|
+ readResC <- res{resp: m, err: err}
|
|
+ }()
|
|
|
|
- if len(c.c)+1 > c.l {
|
|
- for k := range c.c {
|
|
- if len(c.c)+1 <= c.l {
|
|
- break
|
|
+ select {
|
|
+ case <-ctx.Done():
|
|
+ return netip.Addr{}, 0, ctx.Err()
|
|
+ case err := <-writeErrC:
|
|
+ return netip.Addr{}, 0, fmt.Errorf("failed to write query, %w", err)
|
|
+ case r := <-readResC:
|
|
+ resp := r.resp
|
|
+ err := r.err
|
|
+ if err != nil {
|
|
+ return netip.Addr{}, 0, fmt.Errorf("failed to read resp, %w", err)
|
|
+ }
|
|
+
|
|
+ for _, v := range resp.Answer {
|
|
+ var ip net.IP
|
|
+ var ttl uint32
|
|
+ switch rr := v.(type) {
|
|
+ case *dns.A:
|
|
+ ip = rr.A
|
|
+ ttl = rr.Hdr.Ttl
|
|
+ case *dns.AAAA:
|
|
+ ip = rr.AAAA
|
|
+ ttl = rr.Hdr.Ttl
|
|
+ default:
|
|
+ continue
|
|
+ }
|
|
+ addr, ok := netip.AddrFromSlice(ip)
|
|
+ if ok {
|
|
+ return addr, ttl, nil
|
|
}
|
|
- delete(c.c, k)
|
|
}
|
|
+
|
|
+ // No ip addr in resp.
|
|
+ return netip.Addr{}, 0, errNoAddrInResp
|
|
}
|
|
+}
|
|
|
|
- c.c[k] = &elem{
|
|
- m: m,
|
|
- expirationTime: expirationTime,
|
|
+func bootstrapVer2Qt(ver int) (uint16, bool) {
|
|
+ switch ver {
|
|
+ case 0, 4:
|
|
+ return dns.TypeA, true
|
|
+ case 6:
|
|
+ return dns.TypeAAAA, true
|
|
+ default:
|
|
+ return 0, false
|
|
}
|
|
}
|
|
diff --git a/pkg/upstream/bootstrap/bootstrap_test.go b/pkg/upstream/bootstrap/bootstrap_test.go
|
|
deleted file mode 100644
|
|
index 93f14ff..0000000
|
|
--- a/pkg/upstream/bootstrap/bootstrap_test.go
|
|
+++ /dev/null
|
|
@@ -1,38 +0,0 @@
|
|
-/*
|
|
- * Copyright (C) 2020-2022, IrineSistiana
|
|
- *
|
|
- * This file is part of mosdns.
|
|
- *
|
|
- * mosdns is free software: you can redistribute it and/or modify
|
|
- * it under the terms of the GNU General Public License as published by
|
|
- * the Free Software Foundation, either version 3 of the License, or
|
|
- * (at your option) any later version.
|
|
- *
|
|
- * mosdns is distributed in the hope that it will be useful,
|
|
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
- * GNU General Public License for more details.
|
|
- *
|
|
- * You should have received a copy of the GNU General Public License
|
|
- * along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
- */
|
|
-
|
|
-package bootstrap
|
|
-
|
|
-import (
|
|
- "context"
|
|
- "os"
|
|
- "testing"
|
|
-)
|
|
-
|
|
-func Test_Bootstrap(t *testing.T) {
|
|
- if os.Getenv("TEST_BOOTSTRAP") == "" {
|
|
- t.SkipNow()
|
|
- }
|
|
-
|
|
- r := NewBootstrap("8.8.8.8")
|
|
- _, err := r.LookupIP(context.Background(), "ip", "google.com")
|
|
- if err != nil {
|
|
- t.Fatal(err)
|
|
- }
|
|
-}
|
|
diff --git a/pkg/upstream/doh/upstream.go b/pkg/upstream/doh/upstream.go
|
|
index addf5b7..abc124b 100644
|
|
--- a/pkg/upstream/doh/upstream.go
|
|
+++ b/pkg/upstream/doh/upstream.go
|
|
@@ -23,13 +23,14 @@ import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
- "github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
|
- "github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
- "github.com/miekg/dns"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
+
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
const (
|
|
@@ -42,21 +43,10 @@ type Upstream struct {
|
|
EndPoint string
|
|
// Client is a http.Client that sends http requests.
|
|
Client *http.Client
|
|
-
|
|
- // AddOnCloser will be closed when Upstream is closed.
|
|
- AddOnCloser io.Closer
|
|
}
|
|
|
|
-func (u *Upstream) CloseIdleConnections() {
|
|
- u.Client.CloseIdleConnections()
|
|
-}
|
|
-
|
|
-func (u *Upstream) Close() error {
|
|
- u.Client.CloseIdleConnections()
|
|
- if u.AddOnCloser != nil {
|
|
- u.AddOnCloser.Close()
|
|
- }
|
|
- return nil
|
|
+func NewUpstream(endPoint string, client *http.Client) *Upstream {
|
|
+ return &Upstream{EndPoint: endPoint, Client: client}
|
|
}
|
|
|
|
var (
|
|
@@ -127,7 +117,7 @@ func (u *Upstream) ExchangeContext(ctx context.Context, q *dns.Msg) (*dns.Msg, e
|
|
func (u *Upstream) exchange(ctx context.Context, url string) (*dns.Msg, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
- return nil, fmt.Errorf("interal err: NewRequestWithContext: %w", err)
|
|
+ return nil, fmt.Errorf("internal err: NewRequestWithContext: %w", err)
|
|
}
|
|
|
|
req.Header["Accept"] = []string{"application/dns-message"}
|
|
diff --git a/pkg/upstream/doq/upstream.go b/pkg/upstream/doq/upstream.go
|
|
index 7b2d4e7..23d7f1c 100644
|
|
--- a/pkg/upstream/doq/upstream.go
|
|
+++ b/pkg/upstream/doq/upstream.go
|
|
@@ -22,8 +22,6 @@ package doq
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
- "crypto/tls"
|
|
- "errors"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
@@ -39,48 +37,27 @@ const (
|
|
defaultDoQTimeout = time.Second * 5
|
|
dialTimeout = time.Second * 3
|
|
connectionLostThreshold = time.Second * 5
|
|
- handshakeTimeout = time.Second * 3
|
|
)
|
|
|
|
var (
|
|
- doqAlpn = []string{"doq"}
|
|
+ DoqAlpn = []string{"doq"}
|
|
)
|
|
|
|
type Upstream struct {
|
|
- t *quic.Transport
|
|
- addr string
|
|
- tlsConfig *tls.Config
|
|
- quicConfig *quic.Config
|
|
+ dialer func(ctx context.Context) (quic.Connection, error)
|
|
|
|
cm sync.Mutex
|
|
lc *lazyConn
|
|
}
|
|
|
|
-// tlsConfig cannot be nil, it should have Servername or InsecureSkipVerify.
|
|
-func NewUpstream(addr string, lc *net.UDPConn, tlsConfig *tls.Config, quicConfig *quic.Config) (*Upstream, error) {
|
|
- srk, err := initSrk()
|
|
- if err != nil {
|
|
- return nil, err
|
|
- }
|
|
- if tlsConfig == nil {
|
|
- return nil, errors.New("nil tls config")
|
|
- }
|
|
-
|
|
- tlsConfig = tlsConfig.Clone()
|
|
- tlsConfig.NextProtos = doqAlpn
|
|
-
|
|
+func NewUpstream(dialer func(ctx context.Context) (quic.Connection, error)) *Upstream {
|
|
return &Upstream{
|
|
- t: &quic.Transport{
|
|
- Conn: lc,
|
|
- StatelessResetKey: (*quic.StatelessResetKey)(srk),
|
|
- },
|
|
- addr: addr,
|
|
- tlsConfig: tlsConfig,
|
|
- quicConfig: quicConfig,
|
|
- }, nil
|
|
+ dialer: dialer,
|
|
+ }
|
|
}
|
|
|
|
-func initSrk() (*[32]byte, error) {
|
|
+// A helper func to init quic stateless reset key.
|
|
+func InitSrk() (*[32]byte, error) {
|
|
var b [32]byte
|
|
_, err := rand.Read(b[:])
|
|
if err != nil {
|
|
@@ -89,10 +66,6 @@ func initSrk() (*[32]byte, error) {
|
|
return &b, nil
|
|
}
|
|
|
|
-func (u *Upstream) Close() error {
|
|
- return u.t.Close()
|
|
-}
|
|
-
|
|
func (u *Upstream) newStream(ctx context.Context) (quic.Stream, *lazyConn, error) {
|
|
var lc *lazyConn
|
|
u.cm.Lock()
|
|
@@ -154,31 +127,10 @@ func (u *Upstream) asyncDialConn() *lazyConn {
|
|
|
|
go func() {
|
|
defer cancel()
|
|
+ c, err := u.dialer(ctx)
|
|
|
|
- ua, err := net.ResolveUDPAddr("udp", u.addr) // TODO: Support bootstrap.
|
|
- if err != nil {
|
|
- lc.err = err
|
|
- return
|
|
- }
|
|
-
|
|
- var c quic.Connection
|
|
- ec, err := u.t.DialEarly(ctx, ua, u.tlsConfig, u.quicConfig)
|
|
- if ec != nil {
|
|
- // This is a workaround to
|
|
- // 1. recover from strange 0rtt rejected err.
|
|
- // 2. avoid NextConnection might block forever.
|
|
- // TODO: Remove this workaround.
|
|
- select {
|
|
- case <-ctx.Done():
|
|
- err = context.Cause(ctx)
|
|
- ec.CloseWithError(0, "")
|
|
- case <-ec.HandshakeComplete():
|
|
- c = ec.NextConnection()
|
|
- }
|
|
- }
|
|
-
|
|
- var closeC bool
|
|
lc.m.Lock()
|
|
+ var closeC bool
|
|
if lc.closed {
|
|
closeC = true // lc was closed, nothing to do
|
|
} else {
|
|
diff --git a/pkg/upstream/event_stat.go b/pkg/upstream/event_stat.go
|
|
index cfb40c5..6b7b803 100644
|
|
--- a/pkg/upstream/event_stat.go
|
|
+++ b/pkg/upstream/event_stat.go
|
|
@@ -37,7 +37,7 @@ type EventObserver interface {
|
|
|
|
type nopEO struct{}
|
|
|
|
-func (n nopEO) OnEvent(_ Event) { return }
|
|
+func (n nopEO) OnEvent(_ Event) {}
|
|
|
|
type connWrapper struct {
|
|
net.Conn
|
|
diff --git a/pkg/upstream/transport/dns_conn.go b/pkg/upstream/transport/dns_conn.go
|
|
index 1c81a92..6df458e 100644
|
|
--- a/pkg/upstream/transport/dns_conn.go
|
|
+++ b/pkg/upstream/transport/dns_conn.go
|
|
@@ -22,10 +22,11 @@ package transport
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
- "github.com/miekg/dns"
|
|
"io"
|
|
"sync"
|
|
"sync/atomic"
|
|
+
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
// dnsConn is a low-level connection for dns.
|
|
diff --git a/pkg/upstream/transport/dns_conn_test.go b/pkg/upstream/transport/dns_conn_test.go
|
|
index c677797..6e4eb27 100644
|
|
--- a/pkg/upstream/transport/dns_conn_test.go
|
|
+++ b/pkg/upstream/transport/dns_conn_test.go
|
|
@@ -22,9 +22,6 @@ package transport
|
|
import (
|
|
"context"
|
|
"errors"
|
|
- "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
|
- "github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
|
- "github.com/miekg/dns"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
@@ -32,6 +29,10 @@ import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
+
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
var (
|
|
diff --git a/pkg/upstream/transport/pipeline.go b/pkg/upstream/transport/pipeline.go
|
|
index e4bc39e..10c3d23 100644
|
|
--- a/pkg/upstream/transport/pipeline.go
|
|
+++ b/pkg/upstream/transport/pipeline.go
|
|
@@ -21,10 +21,11 @@ package transport
|
|
|
|
import (
|
|
"context"
|
|
- "github.com/miekg/dns"
|
|
"math/rand"
|
|
"sync"
|
|
"time"
|
|
+
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
// PipelineTransport will pipeline queries as RFC 7766 6.2.1.1 suggested.
|
|
diff --git a/pkg/upstream/transport/pipeline_test.go b/pkg/upstream/transport/pipeline_test.go
|
|
index 5fcc0ae..653d779 100644
|
|
--- a/pkg/upstream/transport/pipeline_test.go
|
|
+++ b/pkg/upstream/transport/pipeline_test.go
|
|
@@ -21,10 +21,11 @@ package transport
|
|
|
|
import (
|
|
"context"
|
|
- "github.com/miekg/dns"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
+
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
// Leak and race tests.
|
|
diff --git a/pkg/upstream/transport/reuse.go b/pkg/upstream/transport/reuse.go
|
|
index c005dc0..5489833 100644
|
|
--- a/pkg/upstream/transport/reuse.go
|
|
+++ b/pkg/upstream/transport/reuse.go
|
|
@@ -21,8 +21,9 @@ package transport
|
|
|
|
import (
|
|
"context"
|
|
- "github.com/miekg/dns"
|
|
"sync"
|
|
+
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
type ReuseConnTransport struct {
|
|
diff --git a/pkg/upstream/transport/reuse_test.go b/pkg/upstream/transport/reuse_test.go
|
|
index 3b45961..ac68142 100644
|
|
--- a/pkg/upstream/transport/reuse_test.go
|
|
+++ b/pkg/upstream/transport/reuse_test.go
|
|
@@ -21,10 +21,11 @@ package transport
|
|
|
|
import (
|
|
"context"
|
|
- "github.com/miekg/dns"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
+
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
// Leak and race tests.
|
|
diff --git a/pkg/upstream/transport/transport.go b/pkg/upstream/transport/transport.go
|
|
index 4542d65..d393600 100644
|
|
--- a/pkg/upstream/transport/transport.go
|
|
+++ b/pkg/upstream/transport/transport.go
|
|
@@ -22,9 +22,10 @@ package transport
|
|
import (
|
|
"context"
|
|
"errors"
|
|
- "github.com/miekg/dns"
|
|
"io"
|
|
"time"
|
|
+
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
var (
|
|
diff --git a/pkg/upstream/transport/utils.go b/pkg/upstream/transport/utils.go
|
|
index 870a88b..233f2a8 100644
|
|
--- a/pkg/upstream/transport/utils.go
|
|
+++ b/pkg/upstream/transport/utils.go
|
|
@@ -20,10 +20,11 @@
|
|
package transport
|
|
|
|
import (
|
|
- "github.com/miekg/dns"
|
|
"math/rand"
|
|
- "sync/atomic"
|
|
+ "sync"
|
|
"time"
|
|
+
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
func shadowCopy(m *dns.Msg) *dns.Msg {
|
|
@@ -94,10 +95,10 @@ func slicePopLatest[T any](s *[]T) (T, bool) {
|
|
}
|
|
|
|
type idleTimer struct {
|
|
- d time.Duration
|
|
- updating atomic.Bool
|
|
- t *time.Timer
|
|
- stopped bool
|
|
+ d time.Duration
|
|
+ m sync.Mutex
|
|
+ t *time.Timer
|
|
+ stopped bool
|
|
}
|
|
|
|
func newIdleTimer(d time.Duration, f func()) *idleTimer {
|
|
@@ -108,22 +109,29 @@ func newIdleTimer(d time.Duration, f func()) *idleTimer {
|
|
}
|
|
|
|
func (t *idleTimer) reset(d time.Duration) {
|
|
- if t.updating.CompareAndSwap(false, true) {
|
|
- defer t.updating.Store(false)
|
|
- if t.stopped {
|
|
- return
|
|
- }
|
|
- if d <= 0 {
|
|
- d = t.d
|
|
- }
|
|
- if !t.t.Reset(t.d) {
|
|
- t.stopped = true
|
|
- // re-activated. stop it
|
|
- t.t.Stop()
|
|
- }
|
|
+ t.m.Lock()
|
|
+ defer t.m.Unlock()
|
|
+ if t.stopped {
|
|
+ return
|
|
+ }
|
|
+
|
|
+ if d <= 0 {
|
|
+ d = t.d
|
|
+ }
|
|
+
|
|
+ if !t.t.Reset(d) {
|
|
+ t.stopped = true
|
|
+ // re-activated. stop it
|
|
+ t.t.Stop()
|
|
}
|
|
}
|
|
|
|
func (t *idleTimer) stop() {
|
|
+ t.m.Lock()
|
|
+ defer t.m.Unlock()
|
|
+ if t.stopped {
|
|
+ return
|
|
+ }
|
|
+ t.stopped = true
|
|
t.t.Stop()
|
|
}
|
|
diff --git a/pkg/upstream/upstream.go b/pkg/upstream/upstream.go
|
|
index 1bc80a2..ab77aba 100644
|
|
--- a/pkg/upstream/upstream.go
|
|
+++ b/pkg/upstream/upstream.go
|
|
@@ -22,10 +22,12 @@ package upstream
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
+ "errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
+ "net/netip"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
@@ -42,10 +44,11 @@ import (
|
|
"github.com/quic-go/quic-go/http3"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/net/http2"
|
|
+ "golang.org/x/net/proxy"
|
|
)
|
|
|
|
const (
|
|
- tlsHandshakeTimeout = time.Second * 5
|
|
+ tlsHandshakeTimeout = time.Second * 3
|
|
)
|
|
|
|
// Upstream represents a DNS upstream.
|
|
@@ -59,12 +62,17 @@ type Upstream interface {
|
|
|
|
type Opt struct {
|
|
// DialAddr specifies the address the upstream will
|
|
- // actually dial to.
|
|
+ // actually dial to in the network layer by overwriting
|
|
+ // the address inferred from upstream url.
|
|
+ // It won't affect high level layers. (e.g. SNI, HTTP HOST header won't be changed).
|
|
+ // Can be an IP or a domain. Port is optional.
|
|
+ // Tips: If the upstream url host is a domain, specific an IP address
|
|
+ // here can skip resolving ip of this domain.
|
|
DialAddr string
|
|
|
|
// Socks5 specifies the socks5 proxy server that the upstream
|
|
// will connect though.
|
|
- // Not implemented for udp upstreams and doh upstreams with http/3.
|
|
+ // Not implemented for udp based protocols (aka. dns over udp, http3, quic).
|
|
Socks5 string
|
|
|
|
// SoMark sets the socket SO_MARK option in unix system.
|
|
@@ -75,40 +83,42 @@ type Opt struct {
|
|
|
|
// IdleTimeout specifies the idle timeout for long-connections.
|
|
// Available for TCP, DoT, DoH.
|
|
- // If negative, TCP, DoT will not reuse connections.
|
|
- // Default: TCP, DoT: 10s , DoH: 30s.
|
|
+ // Default: TCP, DoT: 10s , DoH, DoQ: 30s.
|
|
IdleTimeout time.Duration
|
|
|
|
// EnablePipeline enables query pipelining support as RFC 7766 6.2.1.1 suggested.
|
|
// Available for TCP, DoT upstream with IdleTimeout >= 0.
|
|
+ // Note: There is no fallback.
|
|
EnablePipeline bool
|
|
|
|
// EnableHTTP3 enables HTTP/3 protocol for DoH upstream.
|
|
+ // Note: There is no fallback.
|
|
EnableHTTP3 bool
|
|
|
|
// MaxConns limits the total number of connections, including connections
|
|
// in the dialing states.
|
|
- // Implemented for TCP/DoT pipeline enabled upstreams and DoH upstreams.
|
|
+ // Implemented for TCP/DoT pipeline enabled upstream and DoH upstream.
|
|
// Default is 2.
|
|
MaxConns int
|
|
|
|
- // Bootstrap specifies a plain dns server for the go runtime to solve the
|
|
- // domain of the upstream server. It SHOULD be an IP address. Custom port
|
|
- // is supported.
|
|
- // Note: Use a domain address may cause dead resolve loop and additional
|
|
- // latency to dial upstream server.
|
|
- // HTTP3 is not supported.
|
|
+ // Bootstrap specifies a plain dns server to solve the
|
|
+ // upstream server domain address.
|
|
+ // It must be an IP address. Port is optional.
|
|
Bootstrap string
|
|
|
|
+ // Bootstrap version. One of 0 (default equals 4), 4, 6.
|
|
+ // TODO: Support dual-stack.
|
|
+ BootstrapVer int
|
|
+
|
|
// TLSConfig specifies the tls.Config that the TLS client will use.
|
|
- // Available for DoT, DoH upstreams.
|
|
+ // Available for DoT, DoH upstream.
|
|
TLSConfig *tls.Config
|
|
|
|
// Logger specifies the logger that the upstream will use.
|
|
Logger *zap.Logger
|
|
|
|
// EventObserver can observe connection events.
|
|
- // Note: Not Implemented for HTTP/3 upstreams.
|
|
+ // Not implemented for udp based protocols (dns over udp, http3, quic).
|
|
EventObserver EventObserver
|
|
}
|
|
|
|
@@ -129,17 +139,127 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
|
|
return nil, fmt.Errorf("invalid server address, %w", err)
|
|
}
|
|
|
|
+ // If host is a ipv6 without port, it will be in []. This will cause err when
|
|
+ // split and join address and port. Try to remove brackets now.
|
|
+ addrUrlHost := tryTrimIpv6Brackets(addrURL.Host)
|
|
+
|
|
dialer := &net.Dialer{
|
|
- Resolver: bootstrap.NewBootstrap(opt.Bootstrap),
|
|
Control: getSocketControlFunc(socketOpts{
|
|
so_mark: opt.SoMark,
|
|
bind_to_device: opt.BindToDevice,
|
|
}),
|
|
}
|
|
|
|
+ var bootstrapAp netip.AddrPort
|
|
+ if s := opt.Bootstrap; len(s) > 0 {
|
|
+ bootstrapAp, err = parseBootstrapAp(s)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("invalid bootstrap, %w", err)
|
|
+ }
|
|
+ }
|
|
+
|
|
+ newUdpAddrResolveFunc := func(defaultPort uint16) (func(ctx context.Context) (*net.UDPAddr, error), error) {
|
|
+ host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
|
|
+ if err != nil {
|
|
+ return nil, err
|
|
+ }
|
|
+
|
|
+ if addr, err := netip.ParseAddr(host); err == nil { // host is an ip.
|
|
+ ua := net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
|
|
+ return func(ctx context.Context) (*net.UDPAddr, error) {
|
|
+ return ua, nil
|
|
+ }, nil
|
|
+ } else { // Not an ip, assuming it's a domain name.
|
|
+ if bootstrapAp.IsValid() {
|
|
+ // Bootstrap enabled.
|
|
+ bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger)
|
|
+ if err != nil {
|
|
+ return nil, err
|
|
+ }
|
|
+
|
|
+ return func(ctx context.Context) (*net.UDPAddr, error) {
|
|
+ s, err := bs.GetAddrPortStr(ctx)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("bootstrap failed, %w", err)
|
|
+ }
|
|
+ return net.ResolveUDPAddr("udp", s)
|
|
+ }, nil
|
|
+ } else {
|
|
+ // Bootstrap disabled.
|
|
+ dialAddr := joinPort(host, port)
|
|
+ return func(ctx context.Context) (*net.UDPAddr, error) {
|
|
+ return net.ResolveUDPAddr("udp", dialAddr)
|
|
+ }, nil
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+
|
|
+ newTcpDialer := func(dialAddrMustBeIp bool, defaultPort uint16) (func(ctx context.Context) (net.Conn, error), error) {
|
|
+ host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
|
|
+ if err != nil {
|
|
+ return nil, err
|
|
+ }
|
|
+
|
|
+ // Socks5 enabled.
|
|
+ if s5Addr := opt.Socks5; len(s5Addr) > 0 {
|
|
+ socks5Dialer, err := proxy.SOCKS5("tcp", s5Addr, nil, dialer)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("failed to init socks5 dialer: %w", err)
|
|
+ }
|
|
+
|
|
+ contextDialer := socks5Dialer.(proxy.ContextDialer)
|
|
+ dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
|
+ return func(ctx context.Context) (net.Conn, error) {
|
|
+ return contextDialer.DialContext(ctx, "tcp", dialAddr)
|
|
+ }, nil
|
|
+ }
|
|
+
|
|
+ if _, err := netip.ParseAddr(host); err == nil {
|
|
+ // Host is an ip addr. No need to resolve it.
|
|
+ dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
|
+ return func(ctx context.Context) (net.Conn, error) {
|
|
+ return dialer.DialContext(ctx, "tcp", dialAddr)
|
|
+ }, nil
|
|
+ } else {
|
|
+ if dialAddrMustBeIp {
|
|
+ return nil, errors.New("addr must be an ip address")
|
|
+ }
|
|
+ // Host is not an ip addr, assuming it is a domain.
|
|
+ if bootstrapAp.IsValid() {
|
|
+ // Bootstrap enabled.
|
|
+ bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger)
|
|
+ if err != nil {
|
|
+ return nil, err
|
|
+ }
|
|
+
|
|
+ return func(ctx context.Context) (net.Conn, error) {
|
|
+ dialAddr, err := bs.GetAddrPortStr(ctx)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("bootstrap failed, %w", err)
|
|
+ }
|
|
+ return dialer.DialContext(ctx, "tcp", dialAddr)
|
|
+ }, nil
|
|
+ } else {
|
|
+ // Bootstrap disabled.
|
|
+ dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
|
+ return func(ctx context.Context) (net.Conn, error) {
|
|
+ return dialer.DialContext(ctx, "tcp", dialAddr)
|
|
+ }, nil
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+
|
|
switch addrURL.Scheme {
|
|
case "", "udp":
|
|
- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 53)
|
|
+ const defaultPort = 53
|
|
+ host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort)
|
|
+ if err != nil {
|
|
+ return nil, err
|
|
+ }
|
|
+ if _, err := netip.ParseAddr(host); err != nil {
|
|
+ return nil, fmt.Errorf("addr must be an ip address, %w", err)
|
|
+ }
|
|
+ dialAddr := joinPort(host, port)
|
|
uto := transport.IOOpts{
|
|
DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) {
|
|
c, err := dialer.DialContext(ctx, "udp", dialAddr)
|
|
@@ -166,10 +286,14 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
|
|
t: transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: tto}),
|
|
}, nil
|
|
case "tcp":
|
|
- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 53)
|
|
+ const defaultPort = 53
|
|
+ tcpDialer, err := newTcpDialer(true, defaultPort)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
|
|
+ }
|
|
to := transport.IOOpts{
|
|
DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) {
|
|
- c, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer)
|
|
+ c, err := tcpDialer(ctx)
|
|
c = wrapConn(c, opt.EventObserver)
|
|
return c, err
|
|
},
|
|
@@ -182,18 +306,22 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
|
|
}
|
|
return transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: to}), nil
|
|
case "tls":
|
|
+ const defaultPort = 853
|
|
tlsConfig := opt.TLSConfig.Clone()
|
|
if tlsConfig == nil {
|
|
tlsConfig = new(tls.Config)
|
|
}
|
|
if len(tlsConfig.ServerName) == 0 {
|
|
- tlsConfig.ServerName = tryRemovePort(addrURL.Host)
|
|
+ tlsConfig.ServerName = tryRemovePort(addrUrlHost)
|
|
}
|
|
|
|
- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 853)
|
|
+ tcpDialer, err := newTcpDialer(false, defaultPort)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
|
|
+ }
|
|
to := transport.IOOpts{
|
|
DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) {
|
|
- conn, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer)
|
|
+ conn, err := tcpDialer(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
@@ -203,7 +331,6 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
|
|
tlsConn.Close()
|
|
return nil, err
|
|
}
|
|
-
|
|
return tlsConn, nil
|
|
},
|
|
WriteFunc: dnsutils.WriteMsgToTCP,
|
|
@@ -215,6 +342,7 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
|
|
}
|
|
return transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: to}), nil
|
|
case "https":
|
|
+ const defaultPort = 443
|
|
idleConnTimeout := time.Second * 30
|
|
if opt.IdleTimeout > 0 {
|
|
idleConnTimeout = opt.IdleTimeout
|
|
@@ -224,41 +352,42 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
|
|
maxConn = opt.MaxConns
|
|
}
|
|
|
|
- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 443)
|
|
var t http.RoundTripper
|
|
- var addonCloser io.Closer // udpConn
|
|
+ var addonCloser io.Closer
|
|
if opt.EnableHTTP3 {
|
|
+ udpBootstrap, err := newUdpAddrResolveFunc(defaultPort)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err)
|
|
+ }
|
|
+
|
|
lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})}
|
|
conn, err := lc.ListenPacket(context.Background(), "udp", "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to init udp socket for quic")
|
|
}
|
|
- qt := &quic.Transport{
|
|
+ quicTransport := &quic.Transport{
|
|
Conn: conn,
|
|
}
|
|
- addonCloser = qt
|
|
+ addonCloser = quicTransport
|
|
t = &http3.RoundTripper{
|
|
TLSClientConfig: opt.TLSConfig,
|
|
- QuicConfig: &quic.Config{
|
|
- TokenStore: quic.NewLRUTokenStore(4, 8),
|
|
- InitialStreamReceiveWindow: 4 * 1024,
|
|
- MaxStreamReceiveWindow: 4 * 1024,
|
|
- InitialConnectionReceiveWindow: 8 * 1024,
|
|
- MaxConnectionReceiveWindow: 64 * 1024,
|
|
- },
|
|
+ QuicConfig: newDefaultQuicConfig(),
|
|
Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
|
- ua, err := net.ResolveUDPAddr("udp", dialAddr) // TODO: Support bootstrap.
|
|
+ ua, err := udpBootstrap(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
-
|
|
- return qt.DialEarly(ctx, ua, tlsCfg, cfg)
|
|
+ return quicTransport.DialEarly(ctx, ua, tlsCfg, cfg)
|
|
},
|
|
}
|
|
} else {
|
|
+ tcpDialer, err := newTcpDialer(false, defaultPort)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("failed to init tcp dialer, %w", err)
|
|
+ }
|
|
t1 := &http.Transport{
|
|
DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { // overwrite server addr
|
|
- c, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer)
|
|
+ c, err := tcpDialer(ctx)
|
|
c = wrapConn(c, opt.EventObserver)
|
|
return c, err
|
|
},
|
|
@@ -281,26 +410,34 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
|
|
t = t1
|
|
}
|
|
|
|
- return &doh.Upstream{
|
|
- EndPoint: addr,
|
|
- Client: &http.Client{Transport: t},
|
|
- AddOnCloser: addonCloser,
|
|
+ return &dohWithClose{
|
|
+ u: doh.NewUpstream(addr, &http.Client{Transport: t}),
|
|
+ closer: addonCloser,
|
|
}, nil
|
|
case "quic", "doq":
|
|
+ const defaultPort = 853
|
|
tlsConfig := opt.TLSConfig.Clone()
|
|
if tlsConfig == nil {
|
|
tlsConfig = new(tls.Config)
|
|
}
|
|
if len(tlsConfig.ServerName) == 0 {
|
|
- tlsConfig.ServerName = tryRemovePort(addrURL.Host)
|
|
+ tlsConfig.ServerName = tryRemovePort(addrUrlHost)
|
|
}
|
|
+ tlsConfig.NextProtos = doq.DoqAlpn
|
|
|
|
- quicConfig := &quic.Config{
|
|
- TokenStore: quic.NewLRUTokenStore(4, 8),
|
|
- InitialStreamReceiveWindow: 4 * 1024,
|
|
- MaxStreamReceiveWindow: 4 * 1024,
|
|
- InitialConnectionReceiveWindow: 8 * 1024,
|
|
- MaxConnectionReceiveWindow: 64 * 1024,
|
|
+ quicConfig := newDefaultQuicConfig()
|
|
+ if opt.IdleTimeout > 0 {
|
|
+ quicConfig.MaxIdleTimeout = opt.IdleTimeout
|
|
+ }
|
|
+
|
|
+ udpBootstrap, err := newUdpAddrResolveFunc(defaultPort)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err)
|
|
+ }
|
|
+
|
|
+ srk, err := doq.InitSrk()
|
|
+ if err != nil {
|
|
+ opt.Logger.Error("failed to init quic stateless reset key", zap.Error(err))
|
|
}
|
|
|
|
lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})}
|
|
@@ -309,35 +446,44 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
|
|
return nil, fmt.Errorf("failed to init udp socket for quic")
|
|
}
|
|
|
|
- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 853)
|
|
- u, err := doq.NewUpstream(dialAddr, uc.(*net.UDPConn), tlsConfig, quicConfig)
|
|
- if err != nil {
|
|
- return nil, fmt.Errorf("failed to setup doq upstream, %w", err)
|
|
+ t := &quic.Transport{
|
|
+ Conn: uc,
|
|
+ StatelessResetKey: (*quic.StatelessResetKey)(srk),
|
|
}
|
|
- return u, nil
|
|
- default:
|
|
- return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme)
|
|
- }
|
|
-}
|
|
|
|
-func getDialAddrWithPort(host, dialAddr string, defaultPort int) string {
|
|
- addr := host
|
|
- if len(dialAddr) > 0 {
|
|
- addr = dialAddr
|
|
- }
|
|
- _, _, err := net.SplitHostPort(addr)
|
|
- if err != nil { // no port, add it.
|
|
- return net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(defaultPort))
|
|
- }
|
|
- return addr
|
|
-}
|
|
+ dialer := func(ctx context.Context) (quic.Connection, error) {
|
|
+ ua, err := udpBootstrap(ctx)
|
|
+ if err != nil {
|
|
+ return nil, fmt.Errorf("bootstrap failed, %w", err)
|
|
+ }
|
|
+ var c quic.Connection
|
|
+ ec, err := t.DialEarly(ctx, ua, tlsConfig, quicConfig)
|
|
+ if err != nil {
|
|
+ return nil, err
|
|
+ }
|
|
|
|
-func tryRemovePort(s string) string {
|
|
- host, _, err := net.SplitHostPort(s)
|
|
- if err != nil {
|
|
- return s
|
|
+ // This is a workaround to
|
|
+ // 1. recover from strange 0rtt rejected err.
|
|
+ // 2. avoid NextConnection might block forever.
|
|
+ // TODO: Remove this workaround.
|
|
+ select {
|
|
+ case <-ctx.Done():
|
|
+ err := context.Cause(ctx)
|
|
+ ec.CloseWithError(0, "")
|
|
+ return nil, err
|
|
+ case <-ec.HandshakeComplete():
|
|
+ c = ec.NextConnection()
|
|
+ }
|
|
+ return c, nil
|
|
+ }
|
|
+
|
|
+ return &doqWithClose{
|
|
+ u: doq.NewUpstream(dialer),
|
|
+ t: t,
|
|
+ }, nil
|
|
+ default:
|
|
+ return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme)
|
|
}
|
|
- return host
|
|
}
|
|
|
|
type udpWithFallback struct {
|
|
@@ -361,3 +507,48 @@ func (u *udpWithFallback) Close() error {
|
|
u.t.Close()
|
|
return nil
|
|
}
|
|
+
|
|
+type doqWithClose struct {
|
|
+ u *doq.Upstream
|
|
+ t *quic.Transport
|
|
+}
|
|
+
|
|
+func (u *doqWithClose) ExchangeContext(ctx context.Context, m *dns.Msg) (*dns.Msg, error) {
|
|
+ return u.u.ExchangeContext(ctx, m)
|
|
+}
|
|
+
|
|
+func (u *doqWithClose) Close() error {
|
|
+ return u.t.Close()
|
|
+}
|
|
+
|
|
+type dohWithClose struct {
|
|
+ u *doh.Upstream
|
|
+ closer io.Closer // maybe nil
|
|
+}
|
|
+
|
|
+func (u *dohWithClose) ExchangeContext(ctx context.Context, m *dns.Msg) (*dns.Msg, error) {
|
|
+ return u.u.ExchangeContext(ctx, m)
|
|
+}
|
|
+
|
|
+func (u *dohWithClose) Close() error {
|
|
+ if u.closer != nil {
|
|
+ return u.closer.Close()
|
|
+ }
|
|
+ return nil
|
|
+}
|
|
+
|
|
+func newDefaultQuicConfig() *quic.Config {
|
|
+ return &quic.Config{
|
|
+ TokenStore: quic.NewLRUTokenStore(4, 8),
|
|
+
|
|
+ // Dns does not need large amount of io, so the rx/tx windows are small.
|
|
+ InitialStreamReceiveWindow: 4 * 1024,
|
|
+ MaxStreamReceiveWindow: 4 * 1024,
|
|
+ InitialConnectionReceiveWindow: 8 * 1024,
|
|
+ MaxConnectionReceiveWindow: 64 * 1024,
|
|
+
|
|
+ MaxIdleTimeout: time.Second * 30,
|
|
+ KeepAlivePeriod: time.Second * 25,
|
|
+ HandshakeIdleTimeout: tlsHandshakeTimeout,
|
|
+ }
|
|
+}
|
|
diff --git a/pkg/upstream/upstream_test.go b/pkg/upstream/upstream_test.go
|
|
index 675b1ff..55e3b6d 100644
|
|
--- a/pkg/upstream/upstream_test.go
|
|
+++ b/pkg/upstream/upstream_test.go
|
|
@@ -24,12 +24,13 @@ import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
- "github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
- "github.com/miekg/dns"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
+
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
+ "github.com/miekg/dns"
|
|
)
|
|
|
|
func newUDPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
|
|
@@ -68,6 +69,9 @@ func newTCPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownF
|
|
func newDoTTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) {
|
|
serverName := "test"
|
|
cert, err := utils.GenerateCertificate(serverName)
|
|
+ if err != nil {
|
|
+ t.Fatal(err)
|
|
+ }
|
|
tlsConfig := new(tls.Config)
|
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
|
tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
|
|
diff --git a/pkg/upstream/utils.go b/pkg/upstream/utils.go
|
|
index 7d297f3..8a42125 100644
|
|
--- a/pkg/upstream/utils.go
|
|
+++ b/pkg/upstream/utils.go
|
|
@@ -20,10 +20,10 @@
|
|
package upstream
|
|
|
|
import (
|
|
- "context"
|
|
"fmt"
|
|
- "golang.org/x/net/proxy"
|
|
"net"
|
|
+ "net/netip"
|
|
+ "strconv"
|
|
)
|
|
|
|
type socketOpts struct {
|
|
@@ -31,14 +31,70 @@ type socketOpts struct {
|
|
bind_to_device string
|
|
}
|
|
|
|
-func dialTCP(ctx context.Context, addr, socks5 string, dialer *net.Dialer) (net.Conn, error) {
|
|
- if len(socks5) > 0 {
|
|
- socks5Dialer, err := proxy.SOCKS5("tcp", socks5, nil, dialer)
|
|
+func parseDialAddr(urlHost, dialAddr string, defaultPort uint16) (string, uint16, error) {
|
|
+ addr := urlHost
|
|
+ if len(dialAddr) > 0 {
|
|
+ addr = dialAddr
|
|
+ }
|
|
+ host, port, err := trySplitHostPort(addr)
|
|
+ if err != nil {
|
|
+ return "", 0, err
|
|
+ }
|
|
+ if port == 0 {
|
|
+ port = defaultPort
|
|
+ }
|
|
+ return host, port, nil
|
|
+}
|
|
+
|
|
+func joinPort(host string, port uint16) string {
|
|
+ return net.JoinHostPort(host, strconv.Itoa(int(port)))
|
|
+}
|
|
+
|
|
+func tryRemovePort(s string) string {
|
|
+ host, _, err := net.SplitHostPort(s)
|
|
+ if err != nil {
|
|
+ return s
|
|
+ }
|
|
+ return host
|
|
+}
|
|
+
|
|
+// trySplitHostPort splits host and port.
|
|
+// If s has no port, it returns s,0,nil
|
|
+func trySplitHostPort(s string) (string, uint16, error) {
|
|
+ var port uint16
|
|
+ host, portS, err := net.SplitHostPort(s)
|
|
+ if err == nil {
|
|
+ n, err := strconv.ParseUint(portS, 10, 16)
|
|
if err != nil {
|
|
- return nil, fmt.Errorf("failed to init socks5 dialer: %w", err)
|
|
+ return "", 0, fmt.Errorf("invalid port, %w", err)
|
|
}
|
|
- return socks5Dialer.(proxy.ContextDialer).DialContext(ctx, "tcp", addr)
|
|
+ port = uint16(n)
|
|
+ return host, port, nil
|
|
}
|
|
+ return s, 0, nil
|
|
+}
|
|
|
|
- return dialer.DialContext(ctx, "tcp", addr)
|
|
+func parseBootstrapAp(s string) (netip.AddrPort, error) {
|
|
+ host, port, err := trySplitHostPort(s)
|
|
+ if err != nil {
|
|
+ return netip.AddrPort{}, err
|
|
+ }
|
|
+ if port == 0 {
|
|
+ port = 53
|
|
+ }
|
|
+ addr, err := netip.ParseAddr(host)
|
|
+ if err != nil {
|
|
+ return netip.AddrPort{}, err
|
|
+ }
|
|
+ return netip.AddrPortFrom(addr, port), nil
|
|
+}
|
|
+
|
|
+func tryTrimIpv6Brackets(s string) string {
|
|
+ if len(s) < 2 {
|
|
+ return s
|
|
+ }
|
|
+ if s[0] == '[' && s[len(s)-1] == ']' {
|
|
+ return s[1 : len(s)-2]
|
|
+ }
|
|
+ return s
|
|
}
|
|
diff --git a/pkg/upstream/utils_unix.go b/pkg/upstream/utils_unix.go
|
|
index b8ced47..685a82b 100644
|
|
--- a/pkg/upstream/utils_unix.go
|
|
+++ b/pkg/upstream/utils_unix.go
|
|
@@ -22,9 +22,10 @@
|
|
package upstream
|
|
|
|
import (
|
|
- "golang.org/x/sys/unix"
|
|
"os"
|
|
"syscall"
|
|
+
|
|
+ "golang.org/x/sys/unix"
|
|
)
|
|
|
|
func getSocketControlFunc(opts socketOpts) func(string, string, syscall.RawConn) error {
|
|
diff --git a/plugin/executable/forward/forward.go b/plugin/executable/forward/forward.go
|
|
index be39ad7..bba0da2 100644
|
|
--- a/plugin/executable/forward/forward.go
|
|
+++ b/plugin/executable/forward/forward.go
|
|
@@ -24,6 +24,9 @@ import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
+ "strings"
|
|
+ "time"
|
|
+
|
|
"github.com/IrineSistiana/mosdns/v5/coremain"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/upstream"
|
|
@@ -32,8 +35,6 @@ import (
|
|
"github.com/miekg/dns"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"go.uber.org/zap"
|
|
- "strings"
|
|
- "time"
|
|
)
|
|
|
|
const PluginType = "forward"
|
|
@@ -57,6 +58,7 @@ type Args struct {
|
|
SoMark int `yaml:"so_mark"`
|
|
BindToDevice string `yaml:"bind_to_device"`
|
|
Bootstrap string `yaml:"bootstrap"`
|
|
+ BootstrapVer int `yaml:"bootstrap_version"`
|
|
}
|
|
|
|
type UpstreamConfig struct {
|
|
@@ -73,6 +75,7 @@ type UpstreamConfig struct {
|
|
SoMark int `yaml:"so_mark"`
|
|
BindToDevice string `yaml:"bind_to_device"`
|
|
Bootstrap string `yaml:"bootstrap"`
|
|
+ BootstrapVer int `yaml:"bootstrap_version"`
|
|
}
|
|
|
|
func Init(bp *coremain.BP, args any) (any, error) {
|
|
@@ -125,6 +128,7 @@ func NewForward(args *Args, opt Opts) (*Forward, error) {
|
|
utils.SetDefaultUnsignNum(&c.SoMark, args.SoMark)
|
|
utils.SetDefaultString(&c.BindToDevice, args.BindToDevice)
|
|
utils.SetDefaultString(&c.Bootstrap, args.Bootstrap)
|
|
+ utils.SetDefaultUnsignNum(&c.BootstrapVer, args.BootstrapVer)
|
|
}
|
|
|
|
for i, c := range args.Upstreams {
|
|
@@ -144,6 +148,7 @@ func NewForward(args *Args, opt Opts) (*Forward, error) {
|
|
EnablePipeline: c.EnablePipeline,
|
|
EnableHTTP3: c.EnableHTTP3,
|
|
Bootstrap: c.Bootstrap,
|
|
+ BootstrapVer: c.BootstrapVer,
|
|
TLSConfig: &tls.Config{
|
|
InsecureSkipVerify: c.InsecureSkipVerify,
|
|
ClientSessionCache: tls.NewLRUClientSessionCache(4),
|
|
--
|
|
2.34.8
|
|
|