From 24c1cd73acc4fb1c9e5fb8a54eff570889ec81a3 Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Wed, 20 Sep 2023 09:25:41 +0800 Subject: [PATCH 1/6] pkg/server: decoupling from query_context --- .../handler.go => http_handler.go} | 68 +++++++++-------- pkg/server/iface.go | 23 ++++++ pkg/server/tcp.go | 74 ++++++++----------- pkg/server/udp.go | 49 ++++-------- pkg/server/utils.go | 7 ++ .../entry_handler.go | 31 ++++---- plugin/server/http_server/http_server.go | 16 ++-- plugin/server/server_utils/handler.go | 10 ++- plugin/server/tcp_server/tcp_server.go | 11 ++- plugin/server/udp_server/udp_server.go | 4 +- 10 files changed, 144 insertions(+), 149 deletions(-) rename pkg/server/{http_handler/handler.go => http_handler.go} (75%) create mode 100644 pkg/server/iface.go create mode 100644 pkg/server/utils.go rename pkg/{server/dns_handler => server_handler}/entry_handler.go (77%) diff --git a/pkg/server/http_handler/handler.go b/pkg/server/http_handler.go similarity index 75% rename from pkg/server/http_handler/handler.go rename to pkg/server/http_handler.go index 25f52e1..5fa76b4 100644 --- a/pkg/server/http_handler/handler.go +++ b/pkg/server/http_handler.go @@ -17,69 +17,67 @@ * along with this program. If not, see . */ -package http_handler +package server import ( "encoding/base64" "errors" "fmt" - "github.com/IrineSistiana/mosdns/v5/mlog" - "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" - "github.com/IrineSistiana/mosdns/v5/pkg/pool" - "github.com/IrineSistiana/mosdns/v5/pkg/query_context" - "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" - "github.com/miekg/dns" - "go.uber.org/zap" "io" "net/http" "net/netip" "strings" -) -type HandlerOpts struct { - // DNSHandler is required. - DNSHandler dns_handler.Handler + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/miekg/dns" + "go.uber.org/zap" +) - // SrcIPHeader specifies the header that contain client source address. +type HttpHandlerOpts struct { + // GetSrcIPFromHeader specifies the header that contain client source address. // e.g. "X-Forwarded-For". - SrcIPHeader string + GetSrcIPFromHeader string // Logger specifies the logger which Handler writes its log to. // Default is a nop logger. Logger *zap.Logger } -func (opts *HandlerOpts) init() { - if opts.Logger == nil { - opts.Logger = mlog.Nop() - } - return +type HttpHandler struct { + dnsHandler Handler + logger *zap.Logger + srcIPHeader string } -type Handler struct { - opts HandlerOpts -} +var _ http.Handler = (*HttpHandler)(nil) -func NewHandler(opts HandlerOpts) *Handler { - opts.init() - return &Handler{opts: opts} +func NewHttpHandler(h Handler, opts HttpHandlerOpts) *HttpHandler { + hh := new(HttpHandler) + hh.dnsHandler = h + hh.srcIPHeader = opts.GetSrcIPFromHeader + hh.logger = opts.Logger + if hh.logger == nil { + hh.logger = nopLogger + } + return hh } -func (h *Handler) warnErr(req *http.Request, msg string, err error) { - h.opts.Logger.Warn(msg, zap.String("from", req.RemoteAddr), zap.String("method", req.Method), zap.String("url", req.RequestURI), zap.Error(err)) +func (h *HttpHandler) warnErr(req *http.Request, msg string, err error) { + h.logger.Warn(msg, zap.String("from", req.RemoteAddr), zap.String("method", req.Method), zap.String("url", req.RequestURI), zap.Error(err)) } -func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { addrPort, err := netip.ParseAddrPort(req.RemoteAddr) if err != nil { - h.opts.Logger.Error("failed to parse request remote addr", zap.String("addr", req.RemoteAddr), zap.Error(err)) + h.logger.Error("failed to parse request remote addr", zap.String("addr", req.RemoteAddr), zap.Error(err)) w.WriteHeader(http.StatusInternalServerError) return } clientAddr := addrPort.Addr() // read remote addr from header - if header := h.opts.SrcIPHeader; len(header) != 0 { + if header := h.srcIPHeader; len(header) != 0 { if xff := req.Header.Get(header); len(xff) != 0 { addr, err := readClientAddrFromXFF(xff) if err != nil { @@ -100,12 +98,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - qCtx := query_context.NewContext(q) - query_context.SetClientAddr(qCtx, &clientAddr) - if err := h.opts.DNSHandler.ServeDNS(req.Context(), qCtx); err != nil { - panic(err.Error()) // Force http server to close connection. + r, err := h.dnsHandler.Handle(req.Context(), q, QueryMeta{ClientAddr: clientAddr}) + if err != nil { + h.warnErr(req, "handler err", err) + panic(err) // Force http server to close connection. } - r := qCtx.R() + b, buf, err := pool.PackBuffer(r) if err != nil { w.WriteHeader(http.StatusInternalServerError) diff --git a/pkg/server/iface.go b/pkg/server/iface.go new file mode 100644 index 0000000..2f15be1 --- /dev/null +++ b/pkg/server/iface.go @@ -0,0 +1,23 @@ +package server + +import ( + "context" + "net/netip" + + "github.com/miekg/dns" +) + +// Handler handles incoming request q and MUST ALWAYS return a response. +// Handler MUST handle dns errors by itself and return a proper error responses. +// e.g. Return a SERVFAIL if something goes wrong. +// If Handle() returns an error, caller considers that the error is associated +// with the downstream connection and will close the downstream connection +// immediately. +type Handler interface { + Handle(ctx context.Context, q *dns.Msg, meta QueryMeta) (resp *dns.Msg, err error) +} + +type QueryMeta struct { + ClientAddr netip.Addr // Maybe invalid + FromUDP bool +} diff --git a/pkg/server/tcp.go b/pkg/server/tcp.go index 5dc80de..5f479b1 100644 --- a/pkg/server/tcp.go +++ b/pkg/server/tcp.go @@ -22,15 +22,13 @@ package server import ( "context" "fmt" - "github.com/IrineSistiana/mosdns/v5/mlog" + "net" + "net/netip" + "time" + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" "github.com/IrineSistiana/mosdns/v5/pkg/pool" - "github.com/IrineSistiana/mosdns/v5/pkg/query_context" - "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" - "github.com/IrineSistiana/mosdns/v5/pkg/utils" "go.uber.org/zap" - "net" - "time" ) const ( @@ -38,33 +36,30 @@ const ( tcpFirstReadTimeout = time.Millisecond * 500 ) -type TCPServer struct { - opts TCPServerOpts -} - -func NewTCPServer(opts TCPServerOpts) *TCPServer { - opts.init() - return &TCPServer{opts: opts} -} - type TCPServerOpts struct { - DNSHandler dns_handler.Handler // Required. - Logger *zap.Logger - IdleTimeout time.Duration -} + // Nil logger == nop + Logger *zap.Logger -func (opts *TCPServerOpts) init() { - if opts.Logger == nil { - opts.Logger = mlog.Nop() - } - utils.SetDefaultNum(&opts.IdleTimeout, defaultTCPIdleTimeout) - return + // Default is defaultTCPIdleTimeout. + IdleTimeout time.Duration } // ServeTCP starts a server at l. It returns if l had an Accept() error. // It always returns a non-nil error. -func (s *TCPServer) ServeTCP(l net.Listener) error { - // handle listener +func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error { + logger := opts.Logger + if logger == nil { + logger = nopLogger + } + idleTimeout := opts.IdleTimeout + if idleTimeout <= 0 { + idleTimeout = defaultTCPIdleTimeout + } + firstReadTimeout := tcpFirstReadTimeout + if idleTimeout < firstReadTimeout { + firstReadTimeout = idleTimeout + } + listenerCtx, cancel := context.WithCancel(context.Background()) defer cancel() for { @@ -79,14 +74,12 @@ func (s *TCPServer) ServeTCP(l net.Listener) error { defer c.Close() defer cancelConn() - firstReadTimeout := tcpFirstReadTimeout - idleTimeout := s.opts.IdleTimeout - if idleTimeout < firstReadTimeout { - firstReadTimeout = idleTimeout + var clientAddr netip.Addr + ta, ok := c.RemoteAddr().(*net.TCPAddr) + if ok { + clientAddr = ta.AddrPort().Addr() } - clientAddr := utils.GetAddrFromAddr(c.RemoteAddr()) - firstRead := true for { if firstRead { @@ -102,24 +95,21 @@ func (s *TCPServer) ServeTCP(l net.Listener) error { // handle query go func() { - qCtx := query_context.NewContext(req) - query_context.SetClientAddr(qCtx, &clientAddr) - if err := s.opts.DNSHandler.ServeDNS(tcpConnCtx, qCtx); err != nil { - s.opts.Logger.Warn("handler err", zap.Error(err)) - c.Close() + r, err := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr}) + if err != nil { + logger.Warn("handler err", zap.Error(err)) + c.Close() // abort the connection return } - r := qCtx.R() - b, buf, err := pool.PackBuffer(r) if err != nil { - s.opts.Logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) + logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) return } defer pool.ReleaseBuf(buf) if _, err := dnsutils.WriteRawMsgToTCP(c, b); err != nil { - s.opts.Logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err)) + logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err)) return } }() diff --git a/pkg/server/udp.go b/pkg/server/udp.go index 8bb1b85..247455b 100644 --- a/pkg/server/udp.go +++ b/pkg/server/udp.go @@ -24,38 +24,24 @@ import ( "fmt" "net" - "github.com/IrineSistiana/mosdns/v5/mlog" "github.com/IrineSistiana/mosdns/v5/pkg/pool" - "github.com/IrineSistiana/mosdns/v5/pkg/query_context" - "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" "github.com/miekg/dns" "go.uber.org/zap" ) -type UDPServer struct { - opts UDPServerOpts -} - -func NewUDPServer(opts UDPServerOpts) *UDPServer { - opts.init() - return &UDPServer{opts: opts} -} - type UDPServerOpts struct { - DNSHandler dns_handler.Handler // Required. - Logger *zap.Logger -} - -func (opts *UDPServerOpts) init() { - if opts.Logger == nil { - opts.Logger = mlog.Nop() - } - return + Logger *zap.Logger } // ServeUDP starts a server at c. It returns if c had a read error. // It always returns a non-nil error. -func (s *UDPServer) ServeUDP(c *net.UDPConn) error { +// h is required. logger is optional. +func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error { + logger := opts.Logger + if logger == nil { + logger = nopLogger + } + listenerCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -78,11 +64,10 @@ func (s *UDPServer) ServeUDP(c *net.UDPConn) error { if err != nil { return fmt.Errorf("unexpected read err: %w", err) } - clientAddr := remoteAddr.Addr() q := new(dns.Msg) if err := q.Unpack((*rb)[:n]); err != nil { - s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr)) + logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr)) continue } @@ -91,34 +76,32 @@ func (s *UDPServer) ServeUDP(c *net.UDPConn) error { var err error dstIpFromCm, err = oobReader(ob[:oobn]) if err != nil { - s.opts.Logger.Error("failed to get dst address from oob", zap.Error(err)) + logger.Error("failed to get dst address from oob", zap.Error(err)) } } // handle query go func() { - qCtx := query_context.NewContext(q) - query_context.SetClientAddr(qCtx, &clientAddr) - if err := s.opts.DNSHandler.ServeDNS(listenerCtx, qCtx); err != nil { - s.opts.Logger.Warn("handler err", zap.Error(err)) + r, err := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}) + if err != nil { + logger.Warn("handler err", zap.Error(err)) return } - r := qCtx.R() if r != nil { r.Truncate(getUDPSize(q)) b, buf, err := pool.PackBuffer(r) if err != nil { - s.opts.Logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) + logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) return } defer pool.ReleaseBuf(buf) - var oob []byte + var oob []byte if oobWriter != nil && dstIpFromCm != nil { oob = oobWriter(dstIpFromCm) } if _, _, err := c.WriteMsgUDPAddrPort(b, oob, remoteAddr); err != nil { - s.opts.Logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) + logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) } } }() diff --git a/pkg/server/utils.go b/pkg/server/utils.go new file mode 100644 index 0000000..5e1b5c1 --- /dev/null +++ b/pkg/server/utils.go @@ -0,0 +1,7 @@ +package server + +import "go.uber.org/zap" + +var ( + nopLogger = zap.NewNop() +) diff --git a/pkg/server/dns_handler/entry_handler.go b/pkg/server_handler/entry_handler.go similarity index 77% rename from pkg/server/dns_handler/entry_handler.go rename to pkg/server_handler/entry_handler.go index cec4123..121d943 100644 --- a/pkg/server/dns_handler/entry_handler.go +++ b/pkg/server_handler/entry_handler.go @@ -17,17 +17,19 @@ * along with this program. If not, see . */ -package dns_handler +package server_handler import ( "context" + "time" + "github.com/IrineSistiana/mosdns/v5/mlog" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/pkg/server" "github.com/IrineSistiana/mosdns/v5/pkg/utils" "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" "github.com/miekg/dns" "go.uber.org/zap" - "time" ) const ( @@ -38,18 +40,6 @@ var ( nopLogger = mlog.Nop() ) -// Handler handles dns query. -type Handler interface { - // ServeDNS handles incoming request qCtx and MUST ALWAYS set a response. - // Implements must not keep and use qCtx after the ServeDNS returned. - // ServeDNS should handle dns errors by itself and return a proper error responses - // for clients. - // If ServeDNS returns an error, caller considers that the error is associated - // with the downstream connection and will close the downstream connection - // immediately. - ServeDNS(ctx context.Context, qCtx *query_context.Context) error -} - type EntryHandlerOpts struct { // Logger is used for logging. Default is a noop logger. Logger *zap.Logger @@ -73,20 +63,26 @@ type EntryHandler struct { opts EntryHandlerOpts } +var _ server.Handler = (*EntryHandler)(nil) + func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler { opts.init() return &EntryHandler{opts: opts} } -// ServeDNS implements Handler. +// ServeDNS implements server.Handler. // If entry returns an error, a SERVFAIL response will be set. // If entry returns without a response, a REFUSED response will be set. -func (h *EntryHandler) ServeDNS(ctx context.Context, qCtx *query_context.Context) error { +func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta) (*dns.Msg, error) { ddl := time.Now().Add(h.opts.QueryTimeout) ctx, cancel := context.WithDeadline(ctx, ddl) defer cancel() // exec entry + qCtx := query_context.NewContext(q) + if qInfo.ClientAddr.IsValid() { + query_context.SetClientAddr(qCtx, &qInfo.ClientAddr) + } err := h.opts.Entry.Exec(ctx, qCtx) respMsg := qCtx.R() if err != nil { @@ -106,6 +102,5 @@ func (h *EntryHandler) ServeDNS(ctx context.Context, qCtx *query_context.Context respMsg.Rcode = dns.RcodeServerFailure } respMsg.RecursionAvailable = true - qCtx.SetResponse(respMsg) - return nil + return respMsg, nil } diff --git a/plugin/server/http_server/http_server.go b/plugin/server/http_server/http_server.go index 8e66b37..daca6db 100644 --- a/plugin/server/http_server/http_server.go +++ b/plugin/server/http_server/http_server.go @@ -21,13 +21,14 @@ package tcp_server import ( "fmt" + "net/http" + "time" + "github.com/IrineSistiana/mosdns/v5/coremain" - "github.com/IrineSistiana/mosdns/v5/pkg/server/http_handler" + "github.com/IrineSistiana/mosdns/v5/pkg/server" "github.com/IrineSistiana/mosdns/v5/pkg/utils" "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" "golang.org/x/net/http2" - "net/http" - "time" ) const PluginType = "http_server" @@ -73,12 +74,11 @@ func StartServer(bp *coremain.BP, args *Args) (*HttpServer, error) { if err != nil { return nil, fmt.Errorf("failed to init dns handler, %w", err) } - hhOpts := http_handler.HandlerOpts{ - DNSHandler: dh, - SrcIPHeader: args.SrcIPHeader, - Logger: bp.L(), + hhOpts := server.HttpHandlerOpts{ + GetSrcIPFromHeader: args.SrcIPHeader, + Logger: bp.L(), } - hh := http_handler.NewHandler(hhOpts) + hh := server.NewHttpHandler(dh, hhOpts) mux.Handle(entry.Path, hh) } diff --git a/plugin/server/server_utils/handler.go b/plugin/server/server_utils/handler.go index 2a20e1a..bbc6eab 100644 --- a/plugin/server/server_utils/handler.go +++ b/plugin/server/server_utils/handler.go @@ -21,21 +21,23 @@ package server_utils import ( "fmt" + "github.com/IrineSistiana/mosdns/v5/coremain" - "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" + "github.com/IrineSistiana/mosdns/v5/pkg/server" + "github.com/IrineSistiana/mosdns/v5/pkg/server_handler" "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" ) -func NewHandler(bp *coremain.BP, entry string) (dns_handler.Handler, error) { +func NewHandler(bp *coremain.BP, entry string) (server.Handler, error) { p := bp.M().GetPlugin(entry) exec := sequence.ToExecutable(p) if exec == nil { return nil, fmt.Errorf("cannot find executable entry by tag %s", entry) } - handlerOpts := dns_handler.EntryHandlerOpts{ + handlerOpts := server_handler.EntryHandlerOpts{ Logger: bp.L(), Entry: exec, } - return dns_handler.NewEntryHandler(handlerOpts), nil + return server_handler.NewEntryHandler(handlerOpts), nil } diff --git a/plugin/server/tcp_server/tcp_server.go b/plugin/server/tcp_server/tcp_server.go index 5aca0f5..f69c667 100644 --- a/plugin/server/tcp_server/tcp_server.go +++ b/plugin/server/tcp_server/tcp_server.go @@ -22,12 +22,13 @@ package tcp_server import ( "crypto/tls" "fmt" + "net" + "time" + "github.com/IrineSistiana/mosdns/v5/coremain" "github.com/IrineSistiana/mosdns/v5/pkg/server" "github.com/IrineSistiana/mosdns/v5/pkg/utils" "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" - "net" - "time" ) const PluginType = "tcp_server" @@ -69,9 +70,6 @@ func StartServer(bp *coremain.BP, args *Args) (*TcpServer, error) { return nil, fmt.Errorf("failed to init dns handler, %w", err) } - serverOpts := server.TCPServerOpts{Logger: bp.L(), DNSHandler: dh, IdleTimeout: time.Duration(args.IdleTimeout) * time.Second} - s := server.NewTCPServer(serverOpts) - // Init tls var tc *tls.Config if len(args.Key)+len(args.Cert) > 0 { @@ -91,7 +89,8 @@ func StartServer(bp *coremain.BP, args *Args) (*TcpServer, error) { go func() { defer l.Close() - err := s.ServeTCP(l) + serverOpts := server.TCPServerOpts{Logger: bp.L(), IdleTimeout: time.Duration(args.IdleTimeout) * time.Second} + err := server.ServeTCP(l, dh, serverOpts) bp.M().GetSafeClose().SendCloseSignal(err) }() return &TcpServer{ diff --git a/plugin/server/udp_server/udp_server.go b/plugin/server/udp_server/udp_server.go index 293f720..988f312 100644 --- a/plugin/server/udp_server/udp_server.go +++ b/plugin/server/udp_server/udp_server.go @@ -64,15 +64,13 @@ func StartServer(bp *coremain.BP, args *Args) (*UdpServer, error) { return nil, fmt.Errorf("failed to init dns handler, %w", err) } - serverOpts := server.UDPServerOpts{Logger: bp.L(), DNSHandler: dh} - s := server.NewUDPServer(serverOpts) c, err := net.ListenPacket("udp", args.Listen) if err != nil { return nil, fmt.Errorf("failed to create socket, %w", err) } go func() { defer c.Close() - err := s.ServeUDP(c.(*net.UDPConn)) + err := server.ServeUDP(c.(*net.UDPConn), dh, server.UDPServerOpts{Logger: bp.L()}) bp.M().GetSafeClose().SendCloseSignal(err) }() return &UdpServer{ -- 2.34.8