diff --git a/mosdns/patches/111-pkg-server-decoupling-from-query_context.patch b/mosdns/patches/111-pkg-server-decoupling-from-query_context.patch new file mode 100644 index 0000000..91e4115 --- /dev/null +++ b/mosdns/patches/111-pkg-server-decoupling-from-query_context.patch @@ -0,0 +1,639 @@ +From 24c1cd73acc4fb1c9e5fb8a54eff570889ec81a3 Mon Sep 17 00:00:00 2001 +From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> +Date: Wed, 20 Sep 2023 09:25:41 +0800 +Subject: [PATCH 1/6] pkg/server: decoupling from query_context + +--- + .../handler.go => http_handler.go} | 68 +++++++++-------- + pkg/server/iface.go | 23 ++++++ + pkg/server/tcp.go | 74 ++++++++----------- + pkg/server/udp.go | 49 ++++-------- + pkg/server/utils.go | 7 ++ + .../entry_handler.go | 31 ++++---- + plugin/server/http_server/http_server.go | 16 ++-- + plugin/server/server_utils/handler.go | 10 ++- + plugin/server/tcp_server/tcp_server.go | 11 ++- + plugin/server/udp_server/udp_server.go | 4 +- + 10 files changed, 144 insertions(+), 149 deletions(-) + rename pkg/server/{http_handler/handler.go => http_handler.go} (75%) + create mode 100644 pkg/server/iface.go + create mode 100644 pkg/server/utils.go + rename pkg/{server/dns_handler => server_handler}/entry_handler.go (77%) + +diff --git a/pkg/server/http_handler/handler.go b/pkg/server/http_handler.go +similarity index 75% +rename from pkg/server/http_handler/handler.go +rename to pkg/server/http_handler.go +index 25f52e1..5fa76b4 100644 +--- a/pkg/server/http_handler/handler.go ++++ b/pkg/server/http_handler.go +@@ -17,69 +17,67 @@ + * along with this program. If not, see . + */ + +-package http_handler ++package server + + import ( + "encoding/base64" + "errors" + "fmt" +- "github.com/IrineSistiana/mosdns/v5/mlog" +- "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" +- "github.com/IrineSistiana/mosdns/v5/pkg/pool" +- "github.com/IrineSistiana/mosdns/v5/pkg/query_context" +- "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" +- "github.com/miekg/dns" +- "go.uber.org/zap" + "io" + "net/http" + "net/netip" + "strings" +-) + +-type HandlerOpts struct { +- // DNSHandler is required. +- DNSHandler dns_handler.Handler ++ "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" ++ "github.com/IrineSistiana/mosdns/v5/pkg/pool" ++ "github.com/miekg/dns" ++ "go.uber.org/zap" ++) + +- // SrcIPHeader specifies the header that contain client source address. ++type HttpHandlerOpts struct { ++ // GetSrcIPFromHeader specifies the header that contain client source address. + // e.g. "X-Forwarded-For". +- SrcIPHeader string ++ GetSrcIPFromHeader string + + // Logger specifies the logger which Handler writes its log to. + // Default is a nop logger. + Logger *zap.Logger + } + +-func (opts *HandlerOpts) init() { +- if opts.Logger == nil { +- opts.Logger = mlog.Nop() +- } +- return ++type HttpHandler struct { ++ dnsHandler Handler ++ logger *zap.Logger ++ srcIPHeader string + } + +-type Handler struct { +- opts HandlerOpts +-} ++var _ http.Handler = (*HttpHandler)(nil) + +-func NewHandler(opts HandlerOpts) *Handler { +- opts.init() +- return &Handler{opts: opts} ++func NewHttpHandler(h Handler, opts HttpHandlerOpts) *HttpHandler { ++ hh := new(HttpHandler) ++ hh.dnsHandler = h ++ hh.srcIPHeader = opts.GetSrcIPFromHeader ++ hh.logger = opts.Logger ++ if hh.logger == nil { ++ hh.logger = nopLogger ++ } ++ return hh + } + +-func (h *Handler) warnErr(req *http.Request, msg string, err error) { +- h.opts.Logger.Warn(msg, zap.String("from", req.RemoteAddr), zap.String("method", req.Method), zap.String("url", req.RequestURI), zap.Error(err)) ++func (h *HttpHandler) warnErr(req *http.Request, msg string, err error) { ++ h.logger.Warn(msg, zap.String("from", req.RemoteAddr), zap.String("method", req.Method), zap.String("url", req.RequestURI), zap.Error(err)) + } + +-func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { ++func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + addrPort, err := netip.ParseAddrPort(req.RemoteAddr) + if err != nil { +- h.opts.Logger.Error("failed to parse request remote addr", zap.String("addr", req.RemoteAddr), zap.Error(err)) ++ h.logger.Error("failed to parse request remote addr", zap.String("addr", req.RemoteAddr), zap.Error(err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + clientAddr := addrPort.Addr() + + // read remote addr from header +- if header := h.opts.SrcIPHeader; len(header) != 0 { ++ if header := h.srcIPHeader; len(header) != 0 { + if xff := req.Header.Get(header); len(xff) != 0 { + addr, err := readClientAddrFromXFF(xff) + if err != nil { +@@ -100,12 +98,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + return + } + +- qCtx := query_context.NewContext(q) +- query_context.SetClientAddr(qCtx, &clientAddr) +- if err := h.opts.DNSHandler.ServeDNS(req.Context(), qCtx); err != nil { +- panic(err.Error()) // Force http server to close connection. ++ r, err := h.dnsHandler.Handle(req.Context(), q, QueryMeta{ClientAddr: clientAddr}) ++ if err != nil { ++ h.warnErr(req, "handler err", err) ++ panic(err) // Force http server to close connection. + } +- r := qCtx.R() ++ + b, buf, err := pool.PackBuffer(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) +diff --git a/pkg/server/iface.go b/pkg/server/iface.go +new file mode 100644 +index 0000000..2f15be1 +--- /dev/null ++++ b/pkg/server/iface.go +@@ -0,0 +1,23 @@ ++package server ++ ++import ( ++ "context" ++ "net/netip" ++ ++ "github.com/miekg/dns" ++) ++ ++// Handler handles incoming request q and MUST ALWAYS return a response. ++// Handler MUST handle dns errors by itself and return a proper error responses. ++// e.g. Return a SERVFAIL if something goes wrong. ++// If Handle() returns an error, caller considers that the error is associated ++// with the downstream connection and will close the downstream connection ++// immediately. ++type Handler interface { ++ Handle(ctx context.Context, q *dns.Msg, meta QueryMeta) (resp *dns.Msg, err error) ++} ++ ++type QueryMeta struct { ++ ClientAddr netip.Addr // Maybe invalid ++ FromUDP bool ++} +diff --git a/pkg/server/tcp.go b/pkg/server/tcp.go +index 5dc80de..5f479b1 100644 +--- a/pkg/server/tcp.go ++++ b/pkg/server/tcp.go +@@ -22,15 +22,13 @@ package server + import ( + "context" + "fmt" +- "github.com/IrineSistiana/mosdns/v5/mlog" ++ "net" ++ "net/netip" ++ "time" ++ + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" +- "github.com/IrineSistiana/mosdns/v5/pkg/query_context" +- "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" +- "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "go.uber.org/zap" +- "net" +- "time" + ) + + const ( +@@ -38,33 +36,30 @@ const ( + tcpFirstReadTimeout = time.Millisecond * 500 + ) + +-type TCPServer struct { +- opts TCPServerOpts +-} +- +-func NewTCPServer(opts TCPServerOpts) *TCPServer { +- opts.init() +- return &TCPServer{opts: opts} +-} +- + type TCPServerOpts struct { +- DNSHandler dns_handler.Handler // Required. +- Logger *zap.Logger +- IdleTimeout time.Duration +-} ++ // Nil logger == nop ++ Logger *zap.Logger + +-func (opts *TCPServerOpts) init() { +- if opts.Logger == nil { +- opts.Logger = mlog.Nop() +- } +- utils.SetDefaultNum(&opts.IdleTimeout, defaultTCPIdleTimeout) +- return ++ // Default is defaultTCPIdleTimeout. ++ IdleTimeout time.Duration + } + + // ServeTCP starts a server at l. It returns if l had an Accept() error. + // It always returns a non-nil error. +-func (s *TCPServer) ServeTCP(l net.Listener) error { +- // handle listener ++func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error { ++ logger := opts.Logger ++ if logger == nil { ++ logger = nopLogger ++ } ++ idleTimeout := opts.IdleTimeout ++ if idleTimeout <= 0 { ++ idleTimeout = defaultTCPIdleTimeout ++ } ++ firstReadTimeout := tcpFirstReadTimeout ++ if idleTimeout < firstReadTimeout { ++ firstReadTimeout = idleTimeout ++ } ++ + listenerCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + for { +@@ -79,14 +74,12 @@ func (s *TCPServer) ServeTCP(l net.Listener) error { + defer c.Close() + defer cancelConn() + +- firstReadTimeout := tcpFirstReadTimeout +- idleTimeout := s.opts.IdleTimeout +- if idleTimeout < firstReadTimeout { +- firstReadTimeout = idleTimeout ++ var clientAddr netip.Addr ++ ta, ok := c.RemoteAddr().(*net.TCPAddr) ++ if ok { ++ clientAddr = ta.AddrPort().Addr() + } + +- clientAddr := utils.GetAddrFromAddr(c.RemoteAddr()) +- + firstRead := true + for { + if firstRead { +@@ -102,24 +95,21 @@ func (s *TCPServer) ServeTCP(l net.Listener) error { + + // handle query + go func() { +- qCtx := query_context.NewContext(req) +- query_context.SetClientAddr(qCtx, &clientAddr) +- if err := s.opts.DNSHandler.ServeDNS(tcpConnCtx, qCtx); err != nil { +- s.opts.Logger.Warn("handler err", zap.Error(err)) +- c.Close() ++ r, err := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr}) ++ if err != nil { ++ logger.Warn("handler err", zap.Error(err)) ++ c.Close() // abort the connection + return + } +- r := qCtx.R() +- + b, buf, err := pool.PackBuffer(r) + if err != nil { +- s.opts.Logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) ++ logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) + return + } + defer pool.ReleaseBuf(buf) + + if _, err := dnsutils.WriteRawMsgToTCP(c, b); err != nil { +- s.opts.Logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err)) ++ logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err)) + return + } + }() +diff --git a/pkg/server/udp.go b/pkg/server/udp.go +index 8bb1b85..247455b 100644 +--- a/pkg/server/udp.go ++++ b/pkg/server/udp.go +@@ -24,38 +24,24 @@ import ( + "fmt" + "net" + +- "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" +- "github.com/IrineSistiana/mosdns/v5/pkg/query_context" +- "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" + "github.com/miekg/dns" + "go.uber.org/zap" + ) + +-type UDPServer struct { +- opts UDPServerOpts +-} +- +-func NewUDPServer(opts UDPServerOpts) *UDPServer { +- opts.init() +- return &UDPServer{opts: opts} +-} +- + type UDPServerOpts struct { +- DNSHandler dns_handler.Handler // Required. +- Logger *zap.Logger +-} +- +-func (opts *UDPServerOpts) init() { +- if opts.Logger == nil { +- opts.Logger = mlog.Nop() +- } +- return ++ Logger *zap.Logger + } + + // ServeUDP starts a server at c. It returns if c had a read error. + // It always returns a non-nil error. +-func (s *UDPServer) ServeUDP(c *net.UDPConn) error { ++// h is required. logger is optional. ++func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error { ++ logger := opts.Logger ++ if logger == nil { ++ logger = nopLogger ++ } ++ + listenerCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + +@@ -78,11 +64,10 @@ func (s *UDPServer) ServeUDP(c *net.UDPConn) error { + if err != nil { + return fmt.Errorf("unexpected read err: %w", err) + } +- clientAddr := remoteAddr.Addr() + + q := new(dns.Msg) + if err := q.Unpack((*rb)[:n]); err != nil { +- s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr)) ++ logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr)) + continue + } + +@@ -91,34 +76,32 @@ func (s *UDPServer) ServeUDP(c *net.UDPConn) error { + var err error + dstIpFromCm, err = oobReader(ob[:oobn]) + if err != nil { +- s.opts.Logger.Error("failed to get dst address from oob", zap.Error(err)) ++ logger.Error("failed to get dst address from oob", zap.Error(err)) + } + } + + // handle query + go func() { +- qCtx := query_context.NewContext(q) +- query_context.SetClientAddr(qCtx, &clientAddr) +- if err := s.opts.DNSHandler.ServeDNS(listenerCtx, qCtx); err != nil { +- s.opts.Logger.Warn("handler err", zap.Error(err)) ++ r, err := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}) ++ if err != nil { ++ logger.Warn("handler err", zap.Error(err)) + return + } +- r := qCtx.R() + if r != nil { + r.Truncate(getUDPSize(q)) + b, buf, err := pool.PackBuffer(r) + if err != nil { +- s.opts.Logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) ++ logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) + return + } + defer pool.ReleaseBuf(buf) +- var oob []byte + ++ var oob []byte + if oobWriter != nil && dstIpFromCm != nil { + oob = oobWriter(dstIpFromCm) + } + if _, _, err := c.WriteMsgUDPAddrPort(b, oob, remoteAddr); err != nil { +- s.opts.Logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) ++ logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) + } + } + }() +diff --git a/pkg/server/utils.go b/pkg/server/utils.go +new file mode 100644 +index 0000000..5e1b5c1 +--- /dev/null ++++ b/pkg/server/utils.go +@@ -0,0 +1,7 @@ ++package server ++ ++import "go.uber.org/zap" ++ ++var ( ++ nopLogger = zap.NewNop() ++) +diff --git a/pkg/server/dns_handler/entry_handler.go b/pkg/server_handler/entry_handler.go +similarity index 77% +rename from pkg/server/dns_handler/entry_handler.go +rename to pkg/server_handler/entry_handler.go +index cec4123..121d943 100644 +--- a/pkg/server/dns_handler/entry_handler.go ++++ b/pkg/server_handler/entry_handler.go +@@ -17,17 +17,19 @@ + * along with this program. If not, see . + */ + +-package dns_handler ++package server_handler + + import ( + "context" ++ "time" ++ + "github.com/IrineSistiana/mosdns/v5/mlog" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" ++ "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "go.uber.org/zap" +- "time" + ) + + const ( +@@ -38,18 +40,6 @@ var ( + nopLogger = mlog.Nop() + ) + +-// Handler handles dns query. +-type Handler interface { +- // ServeDNS handles incoming request qCtx and MUST ALWAYS set a response. +- // Implements must not keep and use qCtx after the ServeDNS returned. +- // ServeDNS should handle dns errors by itself and return a proper error responses +- // for clients. +- // If ServeDNS returns an error, caller considers that the error is associated +- // with the downstream connection and will close the downstream connection +- // immediately. +- ServeDNS(ctx context.Context, qCtx *query_context.Context) error +-} +- + type EntryHandlerOpts struct { + // Logger is used for logging. Default is a noop logger. + Logger *zap.Logger +@@ -73,20 +63,26 @@ type EntryHandler struct { + opts EntryHandlerOpts + } + ++var _ server.Handler = (*EntryHandler)(nil) ++ + func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler { + opts.init() + return &EntryHandler{opts: opts} + } + +-// ServeDNS implements Handler. ++// ServeDNS implements server.Handler. + // If entry returns an error, a SERVFAIL response will be set. + // If entry returns without a response, a REFUSED response will be set. +-func (h *EntryHandler) ServeDNS(ctx context.Context, qCtx *query_context.Context) error { ++func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta) (*dns.Msg, error) { + ddl := time.Now().Add(h.opts.QueryTimeout) + ctx, cancel := context.WithDeadline(ctx, ddl) + defer cancel() + + // exec entry ++ qCtx := query_context.NewContext(q) ++ if qInfo.ClientAddr.IsValid() { ++ query_context.SetClientAddr(qCtx, &qInfo.ClientAddr) ++ } + err := h.opts.Entry.Exec(ctx, qCtx) + respMsg := qCtx.R() + if err != nil { +@@ -106,6 +102,5 @@ func (h *EntryHandler) ServeDNS(ctx context.Context, qCtx *query_context.Context + respMsg.Rcode = dns.RcodeServerFailure + } + respMsg.RecursionAvailable = true +- qCtx.SetResponse(respMsg) +- return nil ++ return respMsg, nil + } +diff --git a/plugin/server/http_server/http_server.go b/plugin/server/http_server/http_server.go +index 8e66b37..daca6db 100644 +--- a/plugin/server/http_server/http_server.go ++++ b/plugin/server/http_server/http_server.go +@@ -21,13 +21,14 @@ package tcp_server + + import ( + "fmt" ++ "net/http" ++ "time" ++ + "github.com/IrineSistiana/mosdns/v5/coremain" +- "github.com/IrineSistiana/mosdns/v5/pkg/server/http_handler" ++ "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" + "golang.org/x/net/http2" +- "net/http" +- "time" + ) + + const PluginType = "http_server" +@@ -73,12 +74,11 @@ func StartServer(bp *coremain.BP, args *Args) (*HttpServer, error) { + if err != nil { + return nil, fmt.Errorf("failed to init dns handler, %w", err) + } +- hhOpts := http_handler.HandlerOpts{ +- DNSHandler: dh, +- SrcIPHeader: args.SrcIPHeader, +- Logger: bp.L(), ++ hhOpts := server.HttpHandlerOpts{ ++ GetSrcIPFromHeader: args.SrcIPHeader, ++ Logger: bp.L(), + } +- hh := http_handler.NewHandler(hhOpts) ++ hh := server.NewHttpHandler(dh, hhOpts) + mux.Handle(entry.Path, hh) + } + +diff --git a/plugin/server/server_utils/handler.go b/plugin/server/server_utils/handler.go +index 2a20e1a..bbc6eab 100644 +--- a/plugin/server/server_utils/handler.go ++++ b/plugin/server/server_utils/handler.go +@@ -21,21 +21,23 @@ package server_utils + + import ( + "fmt" ++ + "github.com/IrineSistiana/mosdns/v5/coremain" +- "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" ++ "github.com/IrineSistiana/mosdns/v5/pkg/server" ++ "github.com/IrineSistiana/mosdns/v5/pkg/server_handler" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + ) + +-func NewHandler(bp *coremain.BP, entry string) (dns_handler.Handler, error) { ++func NewHandler(bp *coremain.BP, entry string) (server.Handler, error) { + p := bp.M().GetPlugin(entry) + exec := sequence.ToExecutable(p) + if exec == nil { + return nil, fmt.Errorf("cannot find executable entry by tag %s", entry) + } + +- handlerOpts := dns_handler.EntryHandlerOpts{ ++ handlerOpts := server_handler.EntryHandlerOpts{ + Logger: bp.L(), + Entry: exec, + } +- return dns_handler.NewEntryHandler(handlerOpts), nil ++ return server_handler.NewEntryHandler(handlerOpts), nil + } +diff --git a/plugin/server/tcp_server/tcp_server.go b/plugin/server/tcp_server/tcp_server.go +index 5aca0f5..f69c667 100644 +--- a/plugin/server/tcp_server/tcp_server.go ++++ b/plugin/server/tcp_server/tcp_server.go +@@ -22,12 +22,13 @@ package tcp_server + import ( + "crypto/tls" + "fmt" ++ "net" ++ "time" ++ + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/utils" + "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" +- "net" +- "time" + ) + + const PluginType = "tcp_server" +@@ -69,9 +70,6 @@ func StartServer(bp *coremain.BP, args *Args) (*TcpServer, error) { + return nil, fmt.Errorf("failed to init dns handler, %w", err) + } + +- serverOpts := server.TCPServerOpts{Logger: bp.L(), DNSHandler: dh, IdleTimeout: time.Duration(args.IdleTimeout) * time.Second} +- s := server.NewTCPServer(serverOpts) +- + // Init tls + var tc *tls.Config + if len(args.Key)+len(args.Cert) > 0 { +@@ -91,7 +89,8 @@ func StartServer(bp *coremain.BP, args *Args) (*TcpServer, error) { + + go func() { + defer l.Close() +- err := s.ServeTCP(l) ++ serverOpts := server.TCPServerOpts{Logger: bp.L(), IdleTimeout: time.Duration(args.IdleTimeout) * time.Second} ++ err := server.ServeTCP(l, dh, serverOpts) + bp.M().GetSafeClose().SendCloseSignal(err) + }() + return &TcpServer{ +diff --git a/plugin/server/udp_server/udp_server.go b/plugin/server/udp_server/udp_server.go +index 293f720..988f312 100644 +--- a/plugin/server/udp_server/udp_server.go ++++ b/plugin/server/udp_server/udp_server.go +@@ -64,15 +64,13 @@ func StartServer(bp *coremain.BP, args *Args) (*UdpServer, error) { + return nil, fmt.Errorf("failed to init dns handler, %w", err) + } + +- serverOpts := server.UDPServerOpts{Logger: bp.L(), DNSHandler: dh} +- s := server.NewUDPServer(serverOpts) + c, err := net.ListenPacket("udp", args.Listen) + if err != nil { + return nil, fmt.Errorf("failed to create socket, %w", err) + } + go func() { + defer c.Close() +- err := s.ServeUDP(c.(*net.UDPConn)) ++ err := server.ServeUDP(c.(*net.UDPConn), dh, server.UDPServerOpts{Logger: bp.L()}) + bp.M().GetSafeClose().SendCloseSignal(err) + }() + return &UdpServer{ +-- +2.34.8 + diff --git a/mosdns/patches/112-server-don-t-exit-udp-server-on-tempory-read-err.patch b/mosdns/patches/112-server-don-t-exit-udp-server-on-tempory-read-err.patch new file mode 100644 index 0000000..ba8bf0d --- /dev/null +++ b/mosdns/patches/112-server-don-t-exit-udp-server-on-tempory-read-err.patch @@ -0,0 +1,31 @@ +From 61c1586082d21ad793447c3c4510230b492ffbc0 Mon Sep 17 00:00:00 2001 +From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> +Date: Wed, 20 Sep 2023 09:31:39 +0800 +Subject: [PATCH 2/6] server: don't exit udp server on tempory read err + +--- + pkg/server/udp.go | 8 +++++++- + 1 file changed, 7 insertions(+), 1 deletion(-) + +diff --git a/pkg/server/udp.go b/pkg/server/udp.go +index 247455b..4dc1087 100644 +--- a/pkg/server/udp.go ++++ b/pkg/server/udp.go +@@ -62,7 +62,13 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error { + for { + n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob) + if err != nil { +- return fmt.Errorf("unexpected read err: %w", err) ++ if n == 0 { ++ // err with zero read. Most likely becasue c was closed. ++ return fmt.Errorf("unexpected read err: %w", err) ++ } ++ // err with some read. Tempory err. ++ logger.Warn("read err", zap.Error(err)) ++ continue + } + + q := new(dns.Msg) +-- +2.34.8 + diff --git a/mosdns/patches/113-pool-fixed-bytes-pool-size-was-1k.patch b/mosdns/patches/113-pool-fixed-bytes-pool-size-was-1k.patch new file mode 100644 index 0000000..e06d600 --- /dev/null +++ b/mosdns/patches/113-pool-fixed-bytes-pool-size-was-1k.patch @@ -0,0 +1,25 @@ +From c19d24ab47674c2a82591c9e16fb450df7882465 Mon Sep 17 00:00:00 2001 +From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> +Date: Thu, 21 Sep 2023 08:57:07 +0800 +Subject: [PATCH 3/6] pool: fixed bytes pool size was 1k + +--- + pkg/pool/allocator.go | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/pkg/pool/allocator.go b/pkg/pool/allocator.go +index eb011ea..84b1110 100644 +--- a/pkg/pool/allocator.go ++++ b/pkg/pool/allocator.go +@@ -24,7 +24,7 @@ import ( + ) + + var ( +- _pool = bytesPool.NewPool(10) // 1Mbyte pool, should be enough. ++ _pool = bytesPool.NewPool(20) // 1Mbyte pool, should be enough. + GetBuf = _pool.Get + ReleaseBuf = _pool.Release + ) +-- +2.34.8 + diff --git a/mosdns/patches/114-pool-fixed-PackTCPBuffer-always-re-allocate.patch b/mosdns/patches/114-pool-fixed-PackTCPBuffer-always-re-allocate.patch new file mode 100644 index 0000000..c619b7e --- /dev/null +++ b/mosdns/patches/114-pool-fixed-PackTCPBuffer-always-re-allocate.patch @@ -0,0 +1,32 @@ +From bedebc75e1f88d02e737203b09041c39094d5777 Mon Sep 17 00:00:00 2001 +From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> +Date: Thu, 21 Sep 2023 11:04:03 +0800 +Subject: [PATCH 4/6] pool: fixed PackTCPBuffer always re-allocate + +--- + pkg/pool/msg_buf.go | 3 ++- + 1 file changed, 2 insertions(+), 1 deletion(-) + +diff --git a/pkg/pool/msg_buf.go b/pkg/pool/msg_buf.go +index f132cc2..11faf7d 100644 +--- a/pkg/pool/msg_buf.go ++++ b/pkg/pool/msg_buf.go +@@ -60,13 +60,14 @@ func PackTCPBuffer(m *dns.Msg) (buf *[]byte, err error) { + return nil, fmt.Errorf("dns payload size %d is too large", l) + } + +- if &((*b)[0]) != &wire[0] { // reallocated ++ if &((*b)[2]) != &wire[0] { // reallocated + ReleaseBuf(b) + b = GetBuf(l + 2) + binary.BigEndian.PutUint16((*b)[:2], uint16(l)) + copy((*b)[2:], wire) + return b, nil + } ++ binary.BigEndian.PutUint16((*b)[:2], uint16(l)) + *b = (*b)[:2+l] + return b, nil + } +-- +2.34.8 + diff --git a/mosdns/patches/115-dnsutils-let-WriteMsgToTCP-use-PackTCPBuffer.patch b/mosdns/patches/115-dnsutils-let-WriteMsgToTCP-use-PackTCPBuffer.patch new file mode 100644 index 0000000..3bd7c4e --- /dev/null +++ b/mosdns/patches/115-dnsutils-let-WriteMsgToTCP-use-PackTCPBuffer.patch @@ -0,0 +1,31 @@ +From f0005ccc3a27dcbcc2266c550ffb7acf688523f0 Mon Sep 17 00:00:00 2001 +From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> +Date: Thu, 21 Sep 2023 11:04:35 +0800 +Subject: [PATCH 5/6] dnsutils: let WriteMsgToTCP use PackTCPBuffer + +--- + pkg/dnsutils/net_io.go | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/pkg/dnsutils/net_io.go b/pkg/dnsutils/net_io.go +index 8fc769b..f165446 100644 +--- a/pkg/dnsutils/net_io.go ++++ b/pkg/dnsutils/net_io.go +@@ -78,12 +78,12 @@ func ReadMsgFromTCP(c io.Reader) (*dns.Msg, int, error) { + // WriteMsgToTCP packs and writes m to c in RFC 1035 format. + // n represents how many bytes are written to c. + func WriteMsgToTCP(c io.Writer, m *dns.Msg) (n int, err error) { +- mRaw, buf, err := pool.PackBuffer(m) ++ buf, err := pool.PackTCPBuffer(m) + if err != nil { + return 0, err + } + defer pool.ReleaseBuf(buf) +- return WriteRawMsgToTCP(c, mRaw) ++ return c.Write(*buf) + } + + // WriteRawMsgToTCP See WriteMsgToTCP +-- +2.34.8 + diff --git a/mosdns/patches/116-upstream-switch-to-self-implement-bootstrap.patch b/mosdns/patches/116-upstream-switch-to-self-implement-bootstrap.patch new file mode 100644 index 0000000..c51e3f5 --- /dev/null +++ b/mosdns/patches/116-upstream-switch-to-self-implement-bootstrap.patch @@ -0,0 +1,1540 @@ +From 5b96077d2a44d085a8ae44b1b368808184dea64b Mon Sep 17 00:00:00 2001 +From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> +Date: Thu, 21 Sep 2023 11:52:20 +0800 +Subject: [PATCH 6/6] upstream: switch to self implement bootstrap + +refine code structure +--- + pkg/upstream/bootstrap/bootstrap.go | 324 ++++++++++++---------- + pkg/upstream/bootstrap/bootstrap_test.go | 38 --- + pkg/upstream/doh/upstream.go | 24 +- + pkg/upstream/doq/upstream.go | 66 +---- + pkg/upstream/event_stat.go | 2 +- + pkg/upstream/transport/dns_conn.go | 3 +- + pkg/upstream/transport/dns_conn_test.go | 7 +- + pkg/upstream/transport/pipeline.go | 3 +- + pkg/upstream/transport/pipeline_test.go | 3 +- + pkg/upstream/transport/reuse.go | 3 +- + pkg/upstream/transport/reuse_test.go | 3 +- + pkg/upstream/transport/transport.go | 3 +- + pkg/upstream/transport/utils.go | 46 ++-- + pkg/upstream/upstream.go | 337 ++++++++++++++++++----- + pkg/upstream/upstream_test.go | 8 +- + pkg/upstream/utils.go | 72 ++++- + pkg/upstream/utils_unix.go | 3 +- + plugin/executable/forward/forward.go | 9 +- + 18 files changed, 586 insertions(+), 368 deletions(-) + delete mode 100644 pkg/upstream/bootstrap/bootstrap_test.go + +diff --git a/pkg/upstream/bootstrap/bootstrap.go b/pkg/upstream/bootstrap/bootstrap.go +index b831757..2cd8ef9 100644 +--- a/pkg/upstream/bootstrap/bootstrap.go ++++ b/pkg/upstream/bootstrap/bootstrap.go +@@ -21,184 +21,226 @@ package bootstrap + + import ( + "context" +- "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" +- "github.com/miekg/dns" ++ "errors" ++ "fmt" + "net" +- "strings" ++ "net/netip" + "sync" ++ "sync/atomic" + "time" ++ ++ "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" ++ "github.com/miekg/dns" ++ "go.uber.org/zap" + ) + +-// NewBootstrap returns a customized *net.Resolver which can be used as a Bootstrap for a +-// certain domain. Its Dial func is modified to dial s through udp. +-// It also has a small built-in cache. +-// s SHOULD be a literal IP address and the port SHOULD also be literal. +-// Port can be omitted. In this case, the default port is :53. +-// e.g. NewBootstrap("8.8.8.8"), NewBootstrap("127.0.0.1:5353"). +-// If s is empty, NewBootstrap returns nil. (A nil *net.Resolver is valid in net.Dialer.) +-// Note that not all platform support a customized *net.Resolver. It also depends on the +-// version of go runtime. +-// See the package docs from the net package for more info. +-func NewBootstrap(s string) *net.Resolver { +- if len(s) == 0 { +- return nil +- } +- // Add port. +- _, _, err := net.SplitHostPort(s) +- if err != nil { // no port, add it. +- s = net.JoinHostPort(strings.Trim(s, "[]"), "53") +- } ++const ( ++ minimumUpdateInterval = time.Minute * 5 ++ retryInterval = time.Second * 2 ++ queryTimeout = time.Second * 5 ++) + +- bs := newBootstrap(s) ++var ( ++ errNoAddrInResp = errors.New("resp does not have ip address") ++) + +- return &net.Resolver{ +- PreferGo: true, +- StrictErrors: false, +- Dial: bs.dial, ++func New( ++ host string, ++ port uint16, ++ bootstrapServer netip.AddrPort, ++ bootstrapVer int, // 0,4,6 ++ logger *zap.Logger, // not nil ++) (*Bootstrap, error) { ++ dp := new(Bootstrap) ++ dp.fqdn = dns.Fqdn(host) ++ dp.port = port ++ if !bootstrapServer.IsValid() { ++ return nil, errors.New("invalid bootstrap server address") ++ } ++ dp.bootstrap = net.UDPAddrFromAddrPort(bootstrapServer) ++ qt, ok := bootstrapVer2Qt(bootstrapVer) ++ if !ok { ++ return nil, fmt.Errorf("invalid bootstrap version %d", bootstrapVer) + } +-} ++ dp.qt = qt ++ dp.logger = logger + +-type bootstrap struct { +- upstream string +- cache *cache ++ dp.readyNotify = make(chan struct{}) ++ return dp, nil + } + +-func newBootstrap(upstream string) *bootstrap { +- return &bootstrap{ +- upstream: upstream, +- cache: newCache(0), +- } +-} ++type Bootstrap struct { ++ fqdn string ++ port uint16 ++ bootstrap *net.UDPAddr ++ qt uint16 // dns.TypeA or dns.TypeAAAA ++ logger *zap.Logger // not nil + +-func (b *bootstrap) dial(_ context.Context, _, _ string) (net.Conn, error) { +- c1, c2 := net.Pipe() +- go func() { +- _ = b.handlePipe(c2) +- _ = c2.Close() +- }() +- return c1, nil +-} ++ updating atomic.Bool ++ nextUpdate time.Time + +-func (b *bootstrap) handlePipe(c net.Conn) error { +- q, _, err := dnsutils.ReadMsgFromTCP(c) +- if err != nil { +- return err +- } +- +- var resp *dns.Msg +- if len(q.Question) == 1 { +- k := q.Question[0] +- m := b.cache.lookup(k) +- if m != nil { +- resp = m.Copy() +- resp.Id = q.Id +- } +- } ++ readyNotify chan struct{} ++ m sync.Mutex ++ ready bool ++ addrStr string ++} + +- if resp == nil { +- d := net.Dialer{} +- ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) +- defer cancel() +- upstreamConn, err := d.DialContext(ctx, "udp", b.upstream) +- if err != nil { +- return err +- } +- defer upstreamConn.Close() +- _ = upstreamConn.SetDeadline(time.Now().Add(time.Second * 3)) +- if _, err := dnsutils.WriteMsgToUDP(upstreamConn, q); err != nil { +- return err +- } +- m, _, err := dnsutils.ReadMsgFromUDP(upstreamConn, 1500) +- if err != nil { +- return err +- } +- resp = m +- } ++func (sp *Bootstrap) GetAddrPortStr(ctx context.Context) (string, error) { ++ sp.tryUpdate() + +- if resp.Rcode == dns.RcodeSuccess && hasIP(resp) { +- if len(resp.Question) == 1 { +- k := resp.Question[0] +- ttl := time.Duration(dnsutils.GetMinimalTTL(resp)) * time.Second +- b.cache.store(k, resp, ttl) +- } ++ select { ++ case <-ctx.Done(): ++ return "", ctx.Err() ++ case <-sp.readyNotify: + } + +- _, err = dnsutils.WriteMsgToTCP(c, resp) +- return err ++ sp.m.Lock() ++ addr := sp.addrStr ++ sp.m.Unlock() ++ return addr, nil + } + +-func hasIP(m *dns.Msg) bool { +- for _, rr := range m.Answer { +- switch rr.Header().Rrtype { +- case dns.TypeA, dns.TypeAAAA: +- return true ++func (sp *Bootstrap) tryUpdate() { ++ if sp.updating.CompareAndSwap(false, true) { ++ if time.Now().After(sp.nextUpdate) { ++ go func() { ++ defer sp.updating.Store(false) ++ ctx, cancel := context.WithTimeout(context.Background(), queryTimeout) ++ defer cancel() ++ start := time.Now() ++ addr, ttl, err := sp.updateAddr(ctx) ++ if err != nil { ++ sp.logger.Warn("failed to update bootstrap addr", zap.String("fqdn", sp.fqdn), zap.Error(err)) ++ sp.nextUpdate = time.Now().Add(retryInterval) ++ } else { ++ updateInterval := time.Second * time.Duration(ttl) ++ if updateInterval < minimumUpdateInterval { ++ updateInterval = minimumUpdateInterval ++ } ++ sp.logger.Info( ++ "bootstrap addr updated", ++ zap.String("fqdn", sp.fqdn), ++ zap.Stringer("addr", addr), ++ zap.Duration("ttl", updateInterval), ++ zap.Duration("elapse", time.Since(start)), ++ ) ++ sp.nextUpdate = time.Now().Add(updateInterval) ++ } ++ }() ++ } else { ++ sp.updating.Store(false) + } + } +- return false + } + +-type cache struct { +- l int ++func (sp *Bootstrap) updateAddr(ctx context.Context) (netip.Addr, uint32, error) { ++ addr, ttl, err := sp.resolve(ctx, sp.qt) ++ if err != nil { ++ return netip.Addr{}, 0, err ++ } + +- m sync.Mutex +- c map[key]*elem ++ addrPort := netip.AddrPortFrom(addr, sp.port).String() ++ sp.m.Lock() ++ sp.addrStr = addrPort ++ if !sp.ready { ++ sp.ready = true ++ close(sp.readyNotify) ++ } ++ sp.m.Unlock() ++ return addr, ttl, nil + } + +-type elem struct { +- m *dns.Msg +- expirationTime time.Time +-} ++func (sp *Bootstrap) resolve(ctx context.Context, qt uint16) (netip.Addr, uint32, error) { ++ const edns0UdpSize = 1200 + +-type key = dns.Question ++ q := new(dns.Msg) ++ q.SetQuestion(sp.fqdn, qt) ++ q.SetEdns0(edns0UdpSize, false) + +-func newCache(size int) *cache { +- const defaultSize = 8 +- if size <= 0 { +- size = defaultSize +- } +- return &cache{ +- l: size, +- c: make(map[key]*elem), ++ c, err := net.DialUDP("udp", nil, sp.bootstrap) ++ if err != nil { ++ return netip.Addr{}, 0, err + } +-} ++ defer c.Close() + +-// lookup returns a cached msg. Note: the msg must not be modified. +-func (c *cache) lookup(k key) *dns.Msg { +- now := time.Now() +- c.m.Lock() +- defer c.m.Unlock() +- e, ok := c.c[k] +- if !ok { +- return nil +- } +- if e.expirationTime.Before(now) { +- delete(c.c, k) +- return nil ++ writeErrC := make(chan error, 1) ++ type res struct { ++ resp *dns.Msg ++ err error + } +- return e.m +-} ++ readResC := make(chan res, 1) + +-// store stores a msg to cache. The caller MUST NOT modify m anymore. +-func (c *cache) store(k key, m *dns.Msg, ttl time.Duration) { +- if ttl <= 0 { +- return +- } +- expirationTime := time.Now().Add(ttl) ++ cancelWrite := make(chan struct{}) ++ defer close(cancelWrite) ++ go func() { ++ if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil { ++ writeErrC <- err ++ return ++ } + +- c.m.Lock() +- defer c.m.Unlock() ++ retryTicker := time.NewTicker(time.Second) ++ defer retryTicker.Stop() ++ for { ++ select { ++ case <-cancelWrite: ++ return ++ case <-retryTicker.C: ++ if _, err := dnsutils.WriteMsgToUDP(c, q); err != nil { ++ writeErrC <- err ++ return ++ } ++ } ++ } ++ }() ++ ++ go func() { ++ m, _, err := dnsutils.ReadMsgFromUDP(c, edns0UdpSize) ++ readResC <- res{resp: m, err: err} ++ }() + +- if len(c.c)+1 > c.l { +- for k := range c.c { +- if len(c.c)+1 <= c.l { +- break ++ select { ++ case <-ctx.Done(): ++ return netip.Addr{}, 0, ctx.Err() ++ case err := <-writeErrC: ++ return netip.Addr{}, 0, fmt.Errorf("failed to write query, %w", err) ++ case r := <-readResC: ++ resp := r.resp ++ err := r.err ++ if err != nil { ++ return netip.Addr{}, 0, fmt.Errorf("failed to read resp, %w", err) ++ } ++ ++ for _, v := range resp.Answer { ++ var ip net.IP ++ var ttl uint32 ++ switch rr := v.(type) { ++ case *dns.A: ++ ip = rr.A ++ ttl = rr.Hdr.Ttl ++ case *dns.AAAA: ++ ip = rr.AAAA ++ ttl = rr.Hdr.Ttl ++ default: ++ continue ++ } ++ addr, ok := netip.AddrFromSlice(ip) ++ if ok { ++ return addr, ttl, nil + } +- delete(c.c, k) + } ++ ++ // No ip addr in resp. ++ return netip.Addr{}, 0, errNoAddrInResp + } ++} + +- c.c[k] = &elem{ +- m: m, +- expirationTime: expirationTime, ++func bootstrapVer2Qt(ver int) (uint16, bool) { ++ switch ver { ++ case 0, 4: ++ return dns.TypeA, true ++ case 6: ++ return dns.TypeAAAA, true ++ default: ++ return 0, false + } + } +diff --git a/pkg/upstream/bootstrap/bootstrap_test.go b/pkg/upstream/bootstrap/bootstrap_test.go +deleted file mode 100644 +index 93f14ff..0000000 +--- a/pkg/upstream/bootstrap/bootstrap_test.go ++++ /dev/null +@@ -1,38 +0,0 @@ +-/* +- * Copyright (C) 2020-2022, IrineSistiana +- * +- * This file is part of mosdns. +- * +- * mosdns is free software: you can redistribute it and/or modify +- * it under the terms of the GNU General Public License as published by +- * the Free Software Foundation, either version 3 of the License, or +- * (at your option) any later version. +- * +- * mosdns is distributed in the hope that it will be useful, +- * but WITHOUT ANY WARRANTY; without even the implied warranty of +- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +- * GNU General Public License for more details. +- * +- * You should have received a copy of the GNU General Public License +- * along with this program. If not, see . +- */ +- +-package bootstrap +- +-import ( +- "context" +- "os" +- "testing" +-) +- +-func Test_Bootstrap(t *testing.T) { +- if os.Getenv("TEST_BOOTSTRAP") == "" { +- t.SkipNow() +- } +- +- r := NewBootstrap("8.8.8.8") +- _, err := r.LookupIP(context.Background(), "ip", "google.com") +- if err != nil { +- t.Fatal(err) +- } +-} +diff --git a/pkg/upstream/doh/upstream.go b/pkg/upstream/doh/upstream.go +index addf5b7..abc124b 100644 +--- a/pkg/upstream/doh/upstream.go ++++ b/pkg/upstream/doh/upstream.go +@@ -23,13 +23,14 @@ import ( + "context" + "encoding/base64" + "fmt" +- "github.com/IrineSistiana/mosdns/v5/pkg/pool" +- "github.com/IrineSistiana/mosdns/v5/pkg/utils" +- "github.com/miekg/dns" + "io" + "net/http" + "strings" + "time" ++ ++ "github.com/IrineSistiana/mosdns/v5/pkg/pool" ++ "github.com/IrineSistiana/mosdns/v5/pkg/utils" ++ "github.com/miekg/dns" + ) + + const ( +@@ -42,21 +43,10 @@ type Upstream struct { + EndPoint string + // Client is a http.Client that sends http requests. + Client *http.Client +- +- // AddOnCloser will be closed when Upstream is closed. +- AddOnCloser io.Closer + } + +-func (u *Upstream) CloseIdleConnections() { +- u.Client.CloseIdleConnections() +-} +- +-func (u *Upstream) Close() error { +- u.Client.CloseIdleConnections() +- if u.AddOnCloser != nil { +- u.AddOnCloser.Close() +- } +- return nil ++func NewUpstream(endPoint string, client *http.Client) *Upstream { ++ return &Upstream{EndPoint: endPoint, Client: client} + } + + var ( +@@ -127,7 +117,7 @@ func (u *Upstream) ExchangeContext(ctx context.Context, q *dns.Msg) (*dns.Msg, e + func (u *Upstream) exchange(ctx context.Context, url string) (*dns.Msg, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { +- return nil, fmt.Errorf("interal err: NewRequestWithContext: %w", err) ++ return nil, fmt.Errorf("internal err: NewRequestWithContext: %w", err) + } + + req.Header["Accept"] = []string{"application/dns-message"} +diff --git a/pkg/upstream/doq/upstream.go b/pkg/upstream/doq/upstream.go +index 7b2d4e7..23d7f1c 100644 +--- a/pkg/upstream/doq/upstream.go ++++ b/pkg/upstream/doq/upstream.go +@@ -22,8 +22,6 @@ package doq + import ( + "context" + "crypto/rand" +- "crypto/tls" +- "errors" + "net" + "sync" + "sync/atomic" +@@ -39,48 +37,27 @@ const ( + defaultDoQTimeout = time.Second * 5 + dialTimeout = time.Second * 3 + connectionLostThreshold = time.Second * 5 +- handshakeTimeout = time.Second * 3 + ) + + var ( +- doqAlpn = []string{"doq"} ++ DoqAlpn = []string{"doq"} + ) + + type Upstream struct { +- t *quic.Transport +- addr string +- tlsConfig *tls.Config +- quicConfig *quic.Config ++ dialer func(ctx context.Context) (quic.Connection, error) + + cm sync.Mutex + lc *lazyConn + } + +-// tlsConfig cannot be nil, it should have Servername or InsecureSkipVerify. +-func NewUpstream(addr string, lc *net.UDPConn, tlsConfig *tls.Config, quicConfig *quic.Config) (*Upstream, error) { +- srk, err := initSrk() +- if err != nil { +- return nil, err +- } +- if tlsConfig == nil { +- return nil, errors.New("nil tls config") +- } +- +- tlsConfig = tlsConfig.Clone() +- tlsConfig.NextProtos = doqAlpn +- ++func NewUpstream(dialer func(ctx context.Context) (quic.Connection, error)) *Upstream { + return &Upstream{ +- t: &quic.Transport{ +- Conn: lc, +- StatelessResetKey: (*quic.StatelessResetKey)(srk), +- }, +- addr: addr, +- tlsConfig: tlsConfig, +- quicConfig: quicConfig, +- }, nil ++ dialer: dialer, ++ } + } + +-func initSrk() (*[32]byte, error) { ++// A helper func to init quic stateless reset key. ++func InitSrk() (*[32]byte, error) { + var b [32]byte + _, err := rand.Read(b[:]) + if err != nil { +@@ -89,10 +66,6 @@ func initSrk() (*[32]byte, error) { + return &b, nil + } + +-func (u *Upstream) Close() error { +- return u.t.Close() +-} +- + func (u *Upstream) newStream(ctx context.Context) (quic.Stream, *lazyConn, error) { + var lc *lazyConn + u.cm.Lock() +@@ -154,31 +127,10 @@ func (u *Upstream) asyncDialConn() *lazyConn { + + go func() { + defer cancel() ++ c, err := u.dialer(ctx) + +- ua, err := net.ResolveUDPAddr("udp", u.addr) // TODO: Support bootstrap. +- if err != nil { +- lc.err = err +- return +- } +- +- var c quic.Connection +- ec, err := u.t.DialEarly(ctx, ua, u.tlsConfig, u.quicConfig) +- if ec != nil { +- // This is a workaround to +- // 1. recover from strange 0rtt rejected err. +- // 2. avoid NextConnection might block forever. +- // TODO: Remove this workaround. +- select { +- case <-ctx.Done(): +- err = context.Cause(ctx) +- ec.CloseWithError(0, "") +- case <-ec.HandshakeComplete(): +- c = ec.NextConnection() +- } +- } +- +- var closeC bool + lc.m.Lock() ++ var closeC bool + if lc.closed { + closeC = true // lc was closed, nothing to do + } else { +diff --git a/pkg/upstream/event_stat.go b/pkg/upstream/event_stat.go +index cfb40c5..6b7b803 100644 +--- a/pkg/upstream/event_stat.go ++++ b/pkg/upstream/event_stat.go +@@ -37,7 +37,7 @@ type EventObserver interface { + + type nopEO struct{} + +-func (n nopEO) OnEvent(_ Event) { return } ++func (n nopEO) OnEvent(_ Event) {} + + type connWrapper struct { + net.Conn +diff --git a/pkg/upstream/transport/dns_conn.go b/pkg/upstream/transport/dns_conn.go +index 1c81a92..6df458e 100644 +--- a/pkg/upstream/transport/dns_conn.go ++++ b/pkg/upstream/transport/dns_conn.go +@@ -22,10 +22,11 @@ package transport + import ( + "context" + "fmt" +- "github.com/miekg/dns" + "io" + "sync" + "sync/atomic" ++ ++ "github.com/miekg/dns" + ) + + // dnsConn is a low-level connection for dns. +diff --git a/pkg/upstream/transport/dns_conn_test.go b/pkg/upstream/transport/dns_conn_test.go +index c677797..6e4eb27 100644 +--- a/pkg/upstream/transport/dns_conn_test.go ++++ b/pkg/upstream/transport/dns_conn_test.go +@@ -22,9 +22,6 @@ package transport + import ( + "context" + "errors" +- "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" +- "github.com/IrineSistiana/mosdns/v5/pkg/pool" +- "github.com/miekg/dns" + "io" + "math/rand" + "net" +@@ -32,6 +29,10 @@ import ( + "sync" + "testing" + "time" ++ ++ "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" ++ "github.com/IrineSistiana/mosdns/v5/pkg/pool" ++ "github.com/miekg/dns" + ) + + var ( +diff --git a/pkg/upstream/transport/pipeline.go b/pkg/upstream/transport/pipeline.go +index e4bc39e..10c3d23 100644 +--- a/pkg/upstream/transport/pipeline.go ++++ b/pkg/upstream/transport/pipeline.go +@@ -21,10 +21,11 @@ package transport + + import ( + "context" +- "github.com/miekg/dns" + "math/rand" + "sync" + "time" ++ ++ "github.com/miekg/dns" + ) + + // PipelineTransport will pipeline queries as RFC 7766 6.2.1.1 suggested. +diff --git a/pkg/upstream/transport/pipeline_test.go b/pkg/upstream/transport/pipeline_test.go +index 5fcc0ae..653d779 100644 +--- a/pkg/upstream/transport/pipeline_test.go ++++ b/pkg/upstream/transport/pipeline_test.go +@@ -21,10 +21,11 @@ package transport + + import ( + "context" +- "github.com/miekg/dns" + "sync" + "testing" + "time" ++ ++ "github.com/miekg/dns" + ) + + // Leak and race tests. +diff --git a/pkg/upstream/transport/reuse.go b/pkg/upstream/transport/reuse.go +index c005dc0..5489833 100644 +--- a/pkg/upstream/transport/reuse.go ++++ b/pkg/upstream/transport/reuse.go +@@ -21,8 +21,9 @@ package transport + + import ( + "context" +- "github.com/miekg/dns" + "sync" ++ ++ "github.com/miekg/dns" + ) + + type ReuseConnTransport struct { +diff --git a/pkg/upstream/transport/reuse_test.go b/pkg/upstream/transport/reuse_test.go +index 3b45961..ac68142 100644 +--- a/pkg/upstream/transport/reuse_test.go ++++ b/pkg/upstream/transport/reuse_test.go +@@ -21,10 +21,11 @@ package transport + + import ( + "context" +- "github.com/miekg/dns" + "sync" + "testing" + "time" ++ ++ "github.com/miekg/dns" + ) + + // Leak and race tests. +diff --git a/pkg/upstream/transport/transport.go b/pkg/upstream/transport/transport.go +index 4542d65..d393600 100644 +--- a/pkg/upstream/transport/transport.go ++++ b/pkg/upstream/transport/transport.go +@@ -22,9 +22,10 @@ package transport + import ( + "context" + "errors" +- "github.com/miekg/dns" + "io" + "time" ++ ++ "github.com/miekg/dns" + ) + + var ( +diff --git a/pkg/upstream/transport/utils.go b/pkg/upstream/transport/utils.go +index 870a88b..233f2a8 100644 +--- a/pkg/upstream/transport/utils.go ++++ b/pkg/upstream/transport/utils.go +@@ -20,10 +20,11 @@ + package transport + + import ( +- "github.com/miekg/dns" + "math/rand" +- "sync/atomic" ++ "sync" + "time" ++ ++ "github.com/miekg/dns" + ) + + func shadowCopy(m *dns.Msg) *dns.Msg { +@@ -94,10 +95,10 @@ func slicePopLatest[T any](s *[]T) (T, bool) { + } + + type idleTimer struct { +- d time.Duration +- updating atomic.Bool +- t *time.Timer +- stopped bool ++ d time.Duration ++ m sync.Mutex ++ t *time.Timer ++ stopped bool + } + + func newIdleTimer(d time.Duration, f func()) *idleTimer { +@@ -108,22 +109,29 @@ func newIdleTimer(d time.Duration, f func()) *idleTimer { + } + + func (t *idleTimer) reset(d time.Duration) { +- if t.updating.CompareAndSwap(false, true) { +- defer t.updating.Store(false) +- if t.stopped { +- return +- } +- if d <= 0 { +- d = t.d +- } +- if !t.t.Reset(t.d) { +- t.stopped = true +- // re-activated. stop it +- t.t.Stop() +- } ++ t.m.Lock() ++ defer t.m.Unlock() ++ if t.stopped { ++ return ++ } ++ ++ if d <= 0 { ++ d = t.d ++ } ++ ++ if !t.t.Reset(d) { ++ t.stopped = true ++ // re-activated. stop it ++ t.t.Stop() + } + } + + func (t *idleTimer) stop() { ++ t.m.Lock() ++ defer t.m.Unlock() ++ if t.stopped { ++ return ++ } ++ t.stopped = true + t.t.Stop() + } +diff --git a/pkg/upstream/upstream.go b/pkg/upstream/upstream.go +index 1bc80a2..ab77aba 100644 +--- a/pkg/upstream/upstream.go ++++ b/pkg/upstream/upstream.go +@@ -22,10 +22,12 @@ package upstream + import ( + "context" + "crypto/tls" ++ "errors" + "fmt" + "io" + "net" + "net/http" ++ "net/netip" + "net/url" + "strconv" + "strings" +@@ -42,10 +44,11 @@ import ( + "github.com/quic-go/quic-go/http3" + "go.uber.org/zap" + "golang.org/x/net/http2" ++ "golang.org/x/net/proxy" + ) + + const ( +- tlsHandshakeTimeout = time.Second * 5 ++ tlsHandshakeTimeout = time.Second * 3 + ) + + // Upstream represents a DNS upstream. +@@ -59,12 +62,17 @@ type Upstream interface { + + type Opt struct { + // DialAddr specifies the address the upstream will +- // actually dial to. ++ // actually dial to in the network layer by overwriting ++ // the address inferred from upstream url. ++ // It won't affect high level layers. (e.g. SNI, HTTP HOST header won't be changed). ++ // Can be an IP or a domain. Port is optional. ++ // Tips: If the upstream url host is a domain, specific an IP address ++ // here can skip resolving ip of this domain. + DialAddr string + + // Socks5 specifies the socks5 proxy server that the upstream + // will connect though. +- // Not implemented for udp upstreams and doh upstreams with http/3. ++ // Not implemented for udp based protocols (aka. dns over udp, http3, quic). + Socks5 string + + // SoMark sets the socket SO_MARK option in unix system. +@@ -75,40 +83,42 @@ type Opt struct { + + // IdleTimeout specifies the idle timeout for long-connections. + // Available for TCP, DoT, DoH. +- // If negative, TCP, DoT will not reuse connections. +- // Default: TCP, DoT: 10s , DoH: 30s. ++ // Default: TCP, DoT: 10s , DoH, DoQ: 30s. + IdleTimeout time.Duration + + // EnablePipeline enables query pipelining support as RFC 7766 6.2.1.1 suggested. + // Available for TCP, DoT upstream with IdleTimeout >= 0. ++ // Note: There is no fallback. + EnablePipeline bool + + // EnableHTTP3 enables HTTP/3 protocol for DoH upstream. ++ // Note: There is no fallback. + EnableHTTP3 bool + + // MaxConns limits the total number of connections, including connections + // in the dialing states. +- // Implemented for TCP/DoT pipeline enabled upstreams and DoH upstreams. ++ // Implemented for TCP/DoT pipeline enabled upstream and DoH upstream. + // Default is 2. + MaxConns int + +- // Bootstrap specifies a plain dns server for the go runtime to solve the +- // domain of the upstream server. It SHOULD be an IP address. Custom port +- // is supported. +- // Note: Use a domain address may cause dead resolve loop and additional +- // latency to dial upstream server. +- // HTTP3 is not supported. ++ // Bootstrap specifies a plain dns server to solve the ++ // upstream server domain address. ++ // It must be an IP address. Port is optional. + Bootstrap string + ++ // Bootstrap version. One of 0 (default equals 4), 4, 6. ++ // TODO: Support dual-stack. ++ BootstrapVer int ++ + // TLSConfig specifies the tls.Config that the TLS client will use. +- // Available for DoT, DoH upstreams. ++ // Available for DoT, DoH upstream. + TLSConfig *tls.Config + + // Logger specifies the logger that the upstream will use. + Logger *zap.Logger + + // EventObserver can observe connection events. +- // Note: Not Implemented for HTTP/3 upstreams. ++ // Not implemented for udp based protocols (dns over udp, http3, quic). + EventObserver EventObserver + } + +@@ -129,17 +139,127 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { + return nil, fmt.Errorf("invalid server address, %w", err) + } + ++ // If host is a ipv6 without port, it will be in []. This will cause err when ++ // split and join address and port. Try to remove brackets now. ++ addrUrlHost := tryTrimIpv6Brackets(addrURL.Host) ++ + dialer := &net.Dialer{ +- Resolver: bootstrap.NewBootstrap(opt.Bootstrap), + Control: getSocketControlFunc(socketOpts{ + so_mark: opt.SoMark, + bind_to_device: opt.BindToDevice, + }), + } + ++ var bootstrapAp netip.AddrPort ++ if s := opt.Bootstrap; len(s) > 0 { ++ bootstrapAp, err = parseBootstrapAp(s) ++ if err != nil { ++ return nil, fmt.Errorf("invalid bootstrap, %w", err) ++ } ++ } ++ ++ newUdpAddrResolveFunc := func(defaultPort uint16) (func(ctx context.Context) (*net.UDPAddr, error), error) { ++ host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) ++ if err != nil { ++ return nil, err ++ } ++ ++ if addr, err := netip.ParseAddr(host); err == nil { // host is an ip. ++ ua := net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port)) ++ return func(ctx context.Context) (*net.UDPAddr, error) { ++ return ua, nil ++ }, nil ++ } else { // Not an ip, assuming it's a domain name. ++ if bootstrapAp.IsValid() { ++ // Bootstrap enabled. ++ bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger) ++ if err != nil { ++ return nil, err ++ } ++ ++ return func(ctx context.Context) (*net.UDPAddr, error) { ++ s, err := bs.GetAddrPortStr(ctx) ++ if err != nil { ++ return nil, fmt.Errorf("bootstrap failed, %w", err) ++ } ++ return net.ResolveUDPAddr("udp", s) ++ }, nil ++ } else { ++ // Bootstrap disabled. ++ dialAddr := joinPort(host, port) ++ return func(ctx context.Context) (*net.UDPAddr, error) { ++ return net.ResolveUDPAddr("udp", dialAddr) ++ }, nil ++ } ++ } ++ } ++ ++ newTcpDialer := func(dialAddrMustBeIp bool, defaultPort uint16) (func(ctx context.Context) (net.Conn, error), error) { ++ host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) ++ if err != nil { ++ return nil, err ++ } ++ ++ // Socks5 enabled. ++ if s5Addr := opt.Socks5; len(s5Addr) > 0 { ++ socks5Dialer, err := proxy.SOCKS5("tcp", s5Addr, nil, dialer) ++ if err != nil { ++ return nil, fmt.Errorf("failed to init socks5 dialer: %w", err) ++ } ++ ++ contextDialer := socks5Dialer.(proxy.ContextDialer) ++ dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) ++ return func(ctx context.Context) (net.Conn, error) { ++ return contextDialer.DialContext(ctx, "tcp", dialAddr) ++ }, nil ++ } ++ ++ if _, err := netip.ParseAddr(host); err == nil { ++ // Host is an ip addr. No need to resolve it. ++ dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) ++ return func(ctx context.Context) (net.Conn, error) { ++ return dialer.DialContext(ctx, "tcp", dialAddr) ++ }, nil ++ } else { ++ if dialAddrMustBeIp { ++ return nil, errors.New("addr must be an ip address") ++ } ++ // Host is not an ip addr, assuming it is a domain. ++ if bootstrapAp.IsValid() { ++ // Bootstrap enabled. ++ bs, err := bootstrap.New(host, port, bootstrapAp, opt.BootstrapVer, opt.Logger) ++ if err != nil { ++ return nil, err ++ } ++ ++ return func(ctx context.Context) (net.Conn, error) { ++ dialAddr, err := bs.GetAddrPortStr(ctx) ++ if err != nil { ++ return nil, fmt.Errorf("bootstrap failed, %w", err) ++ } ++ return dialer.DialContext(ctx, "tcp", dialAddr) ++ }, nil ++ } else { ++ // Bootstrap disabled. ++ dialAddr := net.JoinHostPort(host, strconv.Itoa(int(port))) ++ return func(ctx context.Context) (net.Conn, error) { ++ return dialer.DialContext(ctx, "tcp", dialAddr) ++ }, nil ++ } ++ } ++ } ++ + switch addrURL.Scheme { + case "", "udp": +- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 53) ++ const defaultPort = 53 ++ host, port, err := parseDialAddr(addrUrlHost, opt.DialAddr, defaultPort) ++ if err != nil { ++ return nil, err ++ } ++ if _, err := netip.ParseAddr(host); err != nil { ++ return nil, fmt.Errorf("addr must be an ip address, %w", err) ++ } ++ dialAddr := joinPort(host, port) + uto := transport.IOOpts{ + DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) { + c, err := dialer.DialContext(ctx, "udp", dialAddr) +@@ -166,10 +286,14 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { + t: transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: tto}), + }, nil + case "tcp": +- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 53) ++ const defaultPort = 53 ++ tcpDialer, err := newTcpDialer(true, defaultPort) ++ if err != nil { ++ return nil, fmt.Errorf("failed to init tcp dialer, %w", err) ++ } + to := transport.IOOpts{ + DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) { +- c, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer) ++ c, err := tcpDialer(ctx) + c = wrapConn(c, opt.EventObserver) + return c, err + }, +@@ -182,18 +306,22 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { + } + return transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: to}), nil + case "tls": ++ const defaultPort = 853 + tlsConfig := opt.TLSConfig.Clone() + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } + if len(tlsConfig.ServerName) == 0 { +- tlsConfig.ServerName = tryRemovePort(addrURL.Host) ++ tlsConfig.ServerName = tryRemovePort(addrUrlHost) + } + +- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 853) ++ tcpDialer, err := newTcpDialer(false, defaultPort) ++ if err != nil { ++ return nil, fmt.Errorf("failed to init tcp dialer, %w", err) ++ } + to := transport.IOOpts{ + DialFunc: func(ctx context.Context) (io.ReadWriteCloser, error) { +- conn, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer) ++ conn, err := tcpDialer(ctx) + if err != nil { + return nil, err + } +@@ -203,7 +331,6 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { + tlsConn.Close() + return nil, err + } +- + return tlsConn, nil + }, + WriteFunc: dnsutils.WriteMsgToTCP, +@@ -215,6 +342,7 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { + } + return transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: to}), nil + case "https": ++ const defaultPort = 443 + idleConnTimeout := time.Second * 30 + if opt.IdleTimeout > 0 { + idleConnTimeout = opt.IdleTimeout +@@ -224,41 +352,42 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { + maxConn = opt.MaxConns + } + +- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 443) + var t http.RoundTripper +- var addonCloser io.Closer // udpConn ++ var addonCloser io.Closer + if opt.EnableHTTP3 { ++ udpBootstrap, err := newUdpAddrResolveFunc(defaultPort) ++ if err != nil { ++ return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err) ++ } ++ + lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})} + conn, err := lc.ListenPacket(context.Background(), "udp", "") + if err != nil { + return nil, fmt.Errorf("failed to init udp socket for quic") + } +- qt := &quic.Transport{ ++ quicTransport := &quic.Transport{ + Conn: conn, + } +- addonCloser = qt ++ addonCloser = quicTransport + t = &http3.RoundTripper{ + TLSClientConfig: opt.TLSConfig, +- QuicConfig: &quic.Config{ +- TokenStore: quic.NewLRUTokenStore(4, 8), +- InitialStreamReceiveWindow: 4 * 1024, +- MaxStreamReceiveWindow: 4 * 1024, +- InitialConnectionReceiveWindow: 8 * 1024, +- MaxConnectionReceiveWindow: 64 * 1024, +- }, ++ QuicConfig: newDefaultQuicConfig(), + Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { +- ua, err := net.ResolveUDPAddr("udp", dialAddr) // TODO: Support bootstrap. ++ ua, err := udpBootstrap(ctx) + if err != nil { + return nil, err + } +- +- return qt.DialEarly(ctx, ua, tlsCfg, cfg) ++ return quicTransport.DialEarly(ctx, ua, tlsCfg, cfg) + }, + } + } else { ++ tcpDialer, err := newTcpDialer(false, defaultPort) ++ if err != nil { ++ return nil, fmt.Errorf("failed to init tcp dialer, %w", err) ++ } + t1 := &http.Transport{ + DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { // overwrite server addr +- c, err := dialTCP(ctx, dialAddr, opt.Socks5, dialer) ++ c, err := tcpDialer(ctx) + c = wrapConn(c, opt.EventObserver) + return c, err + }, +@@ -281,26 +410,34 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { + t = t1 + } + +- return &doh.Upstream{ +- EndPoint: addr, +- Client: &http.Client{Transport: t}, +- AddOnCloser: addonCloser, ++ return &dohWithClose{ ++ u: doh.NewUpstream(addr, &http.Client{Transport: t}), ++ closer: addonCloser, + }, nil + case "quic", "doq": ++ const defaultPort = 853 + tlsConfig := opt.TLSConfig.Clone() + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } + if len(tlsConfig.ServerName) == 0 { +- tlsConfig.ServerName = tryRemovePort(addrURL.Host) ++ tlsConfig.ServerName = tryRemovePort(addrUrlHost) + } ++ tlsConfig.NextProtos = doq.DoqAlpn + +- quicConfig := &quic.Config{ +- TokenStore: quic.NewLRUTokenStore(4, 8), +- InitialStreamReceiveWindow: 4 * 1024, +- MaxStreamReceiveWindow: 4 * 1024, +- InitialConnectionReceiveWindow: 8 * 1024, +- MaxConnectionReceiveWindow: 64 * 1024, ++ quicConfig := newDefaultQuicConfig() ++ if opt.IdleTimeout > 0 { ++ quicConfig.MaxIdleTimeout = opt.IdleTimeout ++ } ++ ++ udpBootstrap, err := newUdpAddrResolveFunc(defaultPort) ++ if err != nil { ++ return nil, fmt.Errorf("failed to init udp addr bootstrap, %w", err) ++ } ++ ++ srk, err := doq.InitSrk() ++ if err != nil { ++ opt.Logger.Error("failed to init quic stateless reset key", zap.Error(err)) + } + + lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})} +@@ -309,35 +446,44 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { + return nil, fmt.Errorf("failed to init udp socket for quic") + } + +- dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 853) +- u, err := doq.NewUpstream(dialAddr, uc.(*net.UDPConn), tlsConfig, quicConfig) +- if err != nil { +- return nil, fmt.Errorf("failed to setup doq upstream, %w", err) ++ t := &quic.Transport{ ++ Conn: uc, ++ StatelessResetKey: (*quic.StatelessResetKey)(srk), + } +- return u, nil +- default: +- return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme) +- } +-} + +-func getDialAddrWithPort(host, dialAddr string, defaultPort int) string { +- addr := host +- if len(dialAddr) > 0 { +- addr = dialAddr +- } +- _, _, err := net.SplitHostPort(addr) +- if err != nil { // no port, add it. +- return net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(defaultPort)) +- } +- return addr +-} ++ dialer := func(ctx context.Context) (quic.Connection, error) { ++ ua, err := udpBootstrap(ctx) ++ if err != nil { ++ return nil, fmt.Errorf("bootstrap failed, %w", err) ++ } ++ var c quic.Connection ++ ec, err := t.DialEarly(ctx, ua, tlsConfig, quicConfig) ++ if err != nil { ++ return nil, err ++ } + +-func tryRemovePort(s string) string { +- host, _, err := net.SplitHostPort(s) +- if err != nil { +- return s ++ // This is a workaround to ++ // 1. recover from strange 0rtt rejected err. ++ // 2. avoid NextConnection might block forever. ++ // TODO: Remove this workaround. ++ select { ++ case <-ctx.Done(): ++ err := context.Cause(ctx) ++ ec.CloseWithError(0, "") ++ return nil, err ++ case <-ec.HandshakeComplete(): ++ c = ec.NextConnection() ++ } ++ return c, nil ++ } ++ ++ return &doqWithClose{ ++ u: doq.NewUpstream(dialer), ++ t: t, ++ }, nil ++ default: ++ return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme) + } +- return host + } + + type udpWithFallback struct { +@@ -361,3 +507,48 @@ func (u *udpWithFallback) Close() error { + u.t.Close() + return nil + } ++ ++type doqWithClose struct { ++ u *doq.Upstream ++ t *quic.Transport ++} ++ ++func (u *doqWithClose) ExchangeContext(ctx context.Context, m *dns.Msg) (*dns.Msg, error) { ++ return u.u.ExchangeContext(ctx, m) ++} ++ ++func (u *doqWithClose) Close() error { ++ return u.t.Close() ++} ++ ++type dohWithClose struct { ++ u *doh.Upstream ++ closer io.Closer // maybe nil ++} ++ ++func (u *dohWithClose) ExchangeContext(ctx context.Context, m *dns.Msg) (*dns.Msg, error) { ++ return u.u.ExchangeContext(ctx, m) ++} ++ ++func (u *dohWithClose) Close() error { ++ if u.closer != nil { ++ return u.closer.Close() ++ } ++ return nil ++} ++ ++func newDefaultQuicConfig() *quic.Config { ++ return &quic.Config{ ++ TokenStore: quic.NewLRUTokenStore(4, 8), ++ ++ // Dns does not need large amount of io, so the rx/tx windows are small. ++ InitialStreamReceiveWindow: 4 * 1024, ++ MaxStreamReceiveWindow: 4 * 1024, ++ InitialConnectionReceiveWindow: 8 * 1024, ++ MaxConnectionReceiveWindow: 64 * 1024, ++ ++ MaxIdleTimeout: time.Second * 30, ++ KeepAlivePeriod: time.Second * 25, ++ HandshakeIdleTimeout: tlsHandshakeTimeout, ++ } ++} +diff --git a/pkg/upstream/upstream_test.go b/pkg/upstream/upstream_test.go +index 675b1ff..55e3b6d 100644 +--- a/pkg/upstream/upstream_test.go ++++ b/pkg/upstream/upstream_test.go +@@ -24,12 +24,13 @@ import ( + "crypto/tls" + "errors" + "fmt" +- "github.com/IrineSistiana/mosdns/v5/pkg/utils" +- "github.com/miekg/dns" + "net" + "sync" + "testing" + "time" ++ ++ "github.com/IrineSistiana/mosdns/v5/pkg/utils" ++ "github.com/miekg/dns" + ) + + func newUDPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) { +@@ -68,6 +69,9 @@ func newTCPTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownF + func newDoTTestServer(t testing.TB, handler dns.Handler) (addr string, shutdownFunc func()) { + serverName := "test" + cert, err := utils.GenerateCertificate(serverName) ++ if err != nil { ++ t.Fatal(err) ++ } + tlsConfig := new(tls.Config) + tlsConfig.Certificates = []tls.Certificate{cert} + tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig) +diff --git a/pkg/upstream/utils.go b/pkg/upstream/utils.go +index 7d297f3..8a42125 100644 +--- a/pkg/upstream/utils.go ++++ b/pkg/upstream/utils.go +@@ -20,10 +20,10 @@ + package upstream + + import ( +- "context" + "fmt" +- "golang.org/x/net/proxy" + "net" ++ "net/netip" ++ "strconv" + ) + + type socketOpts struct { +@@ -31,14 +31,70 @@ type socketOpts struct { + bind_to_device string + } + +-func dialTCP(ctx context.Context, addr, socks5 string, dialer *net.Dialer) (net.Conn, error) { +- if len(socks5) > 0 { +- socks5Dialer, err := proxy.SOCKS5("tcp", socks5, nil, dialer) ++func parseDialAddr(urlHost, dialAddr string, defaultPort uint16) (string, uint16, error) { ++ addr := urlHost ++ if len(dialAddr) > 0 { ++ addr = dialAddr ++ } ++ host, port, err := trySplitHostPort(addr) ++ if err != nil { ++ return "", 0, err ++ } ++ if port == 0 { ++ port = defaultPort ++ } ++ return host, port, nil ++} ++ ++func joinPort(host string, port uint16) string { ++ return net.JoinHostPort(host, strconv.Itoa(int(port))) ++} ++ ++func tryRemovePort(s string) string { ++ host, _, err := net.SplitHostPort(s) ++ if err != nil { ++ return s ++ } ++ return host ++} ++ ++// trySplitHostPort splits host and port. ++// If s has no port, it returns s,0,nil ++func trySplitHostPort(s string) (string, uint16, error) { ++ var port uint16 ++ host, portS, err := net.SplitHostPort(s) ++ if err == nil { ++ n, err := strconv.ParseUint(portS, 10, 16) + if err != nil { +- return nil, fmt.Errorf("failed to init socks5 dialer: %w", err) ++ return "", 0, fmt.Errorf("invalid port, %w", err) + } +- return socks5Dialer.(proxy.ContextDialer).DialContext(ctx, "tcp", addr) ++ port = uint16(n) ++ return host, port, nil + } ++ return s, 0, nil ++} + +- return dialer.DialContext(ctx, "tcp", addr) ++func parseBootstrapAp(s string) (netip.AddrPort, error) { ++ host, port, err := trySplitHostPort(s) ++ if err != nil { ++ return netip.AddrPort{}, err ++ } ++ if port == 0 { ++ port = 53 ++ } ++ addr, err := netip.ParseAddr(host) ++ if err != nil { ++ return netip.AddrPort{}, err ++ } ++ return netip.AddrPortFrom(addr, port), nil ++} ++ ++func tryTrimIpv6Brackets(s string) string { ++ if len(s) < 2 { ++ return s ++ } ++ if s[0] == '[' && s[len(s)-1] == ']' { ++ return s[1 : len(s)-2] ++ } ++ return s + } +diff --git a/pkg/upstream/utils_unix.go b/pkg/upstream/utils_unix.go +index b8ced47..685a82b 100644 +--- a/pkg/upstream/utils_unix.go ++++ b/pkg/upstream/utils_unix.go +@@ -22,9 +22,10 @@ + package upstream + + import ( +- "golang.org/x/sys/unix" + "os" + "syscall" ++ ++ "golang.org/x/sys/unix" + ) + + func getSocketControlFunc(opts socketOpts) func(string, string, syscall.RawConn) error { +diff --git a/plugin/executable/forward/forward.go b/plugin/executable/forward/forward.go +index be39ad7..bba0da2 100644 +--- a/plugin/executable/forward/forward.go ++++ b/plugin/executable/forward/forward.go +@@ -24,6 +24,9 @@ import ( + "crypto/tls" + "errors" + "fmt" ++ "strings" ++ "time" ++ + "github.com/IrineSistiana/mosdns/v5/coremain" + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/upstream" +@@ -32,8 +35,6 @@ import ( + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" +- "strings" +- "time" + ) + + const PluginType = "forward" +@@ -57,6 +58,7 @@ type Args struct { + SoMark int `yaml:"so_mark"` + BindToDevice string `yaml:"bind_to_device"` + Bootstrap string `yaml:"bootstrap"` ++ BootstrapVer int `yaml:"bootstrap_version"` + } + + type UpstreamConfig struct { +@@ -73,6 +75,7 @@ type UpstreamConfig struct { + SoMark int `yaml:"so_mark"` + BindToDevice string `yaml:"bind_to_device"` + Bootstrap string `yaml:"bootstrap"` ++ BootstrapVer int `yaml:"bootstrap_version"` + } + + func Init(bp *coremain.BP, args any) (any, error) { +@@ -125,6 +128,7 @@ func NewForward(args *Args, opt Opts) (*Forward, error) { + utils.SetDefaultUnsignNum(&c.SoMark, args.SoMark) + utils.SetDefaultString(&c.BindToDevice, args.BindToDevice) + utils.SetDefaultString(&c.Bootstrap, args.Bootstrap) ++ utils.SetDefaultUnsignNum(&c.BootstrapVer, args.BootstrapVer) + } + + for i, c := range args.Upstreams { +@@ -144,6 +148,7 @@ func NewForward(args *Args, opt Opts) (*Forward, error) { + EnablePipeline: c.EnablePipeline, + EnableHTTP3: c.EnableHTTP3, + Bootstrap: c.Bootstrap, ++ BootstrapVer: c.BootstrapVer, + TLSConfig: &tls.Config{ + InsecureSkipVerify: c.InsecureSkipVerify, + ClientSessionCache: tls.NewLRUClientSessionCache(4), +-- +2.34.8 + diff --git a/mosdns/patches/200-add-debug-log-again.patch b/mosdns/patches/200-add-debug-log-again.patch index 624e94a..c6cfc67 100644 --- a/mosdns/patches/200-add-debug-log-again.patch +++ b/mosdns/patches/200-add-debug-log-again.patch @@ -4,31 +4,17 @@ Date: Sun, 25 Jun 2023 06:50:27 +0800 Subject: [PATCH 10/10] add debug log again --- - pkg/server/dns_handler/entry_handler.go | 4 +++- - pkg/server/http_handler/handler.go | 1 + - plugin/executable/cache/cache.go | 3 +++ - 3 files changed, 7 insertions(+), 1 deletion(-) + pkg/server/http_handler.go | 1 + + pkg/server_handler/entry_handler.go | 2 ++ + pkg/upstream/bootstrap/bootstrap.go | 2 +- + plugin/executable/cache/cache.go | 3 +++ + 4 files changed, 7 insertions(+), 1 deletion(-) -diff --git a/pkg/server/dns_handler/entry_handler.go b/pkg/server/dns_handler/entry_handler.go -index 4737811..cec4123 100644 ---- a/pkg/server/dns_handler/entry_handler.go -+++ b/pkg/server/dns_handler/entry_handler.go -@@ -90,7 +90,9 @@ func (h *EntryHandler) ServeDNS(ctx context.Context, qCtx *query_context.Context - err := h.opts.Entry.Exec(ctx, qCtx) - respMsg := qCtx.R() - if err != nil { -- h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err)) -+ h.opts.Logger.Warn("entry returned an err", qCtx.InfoField(), zap.Error(err)) -+ } else { -+ h.opts.Logger.Debug("entry returned", qCtx.InfoField()) - } - - if err == nil && respMsg == nil { -diff --git a/pkg/server/http_handler/handler.go b/pkg/server/http_handler/handler.go -index 3e800f9..25f52e1 100644 ---- a/pkg/server/http_handler/handler.go -+++ b/pkg/server/http_handler/handler.go -@@ -96,6 +96,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { +diff --git a/pkg/server/http_handler.go b/pkg/server/http_handler.go +index 58f5811..5fa76b4 100644 +--- a/pkg/server/http_handler.go ++++ b/pkg/server/http_handler.go +@@ -94,6 +94,7 @@ func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { if err != nil { h.warnErr(req, "invalid request", err) w.WriteHeader(http.StatusBadRequest) @@ -36,6 +22,32 @@ index 3e800f9..25f52e1 100644 return } +diff --git a/pkg/server_handler/entry_handler.go b/pkg/server_handler/entry_handler.go +index 520e3d2..38df952 100644 +--- a/pkg/server_handler/entry_handler.go ++++ b/pkg/server_handler/entry_handler.go +@@ -87,6 +87,8 @@ func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.Quer + respMsg := qCtx.R() + if err != nil { + h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err)) ++ } else { ++ h.opts.Logger.Debug("entry returned", qCtx.InfoField()) + } + + if err == nil && respMsg == nil { +diff --git a/pkg/upstream/bootstrap/bootstrap.go b/pkg/upstream/bootstrap/bootstrap.go +index 2cd8ef9..5192053 100644 +--- a/pkg/upstream/bootstrap/bootstrap.go ++++ b/pkg/upstream/bootstrap/bootstrap.go +@@ -117,7 +117,7 @@ func (sp *Bootstrap) tryUpdate() { + if updateInterval < minimumUpdateInterval { + updateInterval = minimumUpdateInterval + } +- sp.logger.Info( ++ sp.logger.Debug( + "bootstrap addr updated", + zap.String("fqdn", sp.fqdn), + zap.Stringer("addr", addr), diff --git a/plugin/executable/cache/cache.go b/plugin/executable/cache/cache.go index 58162ee..dd833dc 100644 --- a/plugin/executable/cache/cache.go