From c0af4b587311766650c8c103656dcb595bcfef34 Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Fri, 22 Sep 2023 09:24:05 +0800 Subject: [PATCH 3/9] server: simplify Handler interface, add more meta --- pkg/server/http_handler.go | 25 +++++++++-------- pkg/server/iface.go | 18 ++++++++----- pkg/server/tcp.go | 21 ++++++++------- pkg/server/udp.go | 42 ++++++++--------------------- pkg/server_handler/entry_handler.go | 29 +++++++++++++++++--- 5 files changed, 71 insertions(+), 64 deletions(-) diff --git a/pkg/server/http_handler.go b/pkg/server/http_handler.go index 3e671e3..5a41314 100644 --- a/pkg/server/http_handler.go +++ b/pkg/server/http_handler.go @@ -28,7 +28,6 @@ import ( "net/netip" "strings" - "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" "github.com/IrineSistiana/mosdns/v5/pkg/pool" "github.com/miekg/dns" "go.uber.org/zap" @@ -97,23 +96,23 @@ func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - r, err := h.dnsHandler.Handle(req.Context(), q, QueryMeta{ClientAddr: clientAddr}) - if err != nil { - h.warnErr(req, "handler err", err) - panic(err) // Force http server to close connection. + queryMeta := QueryMeta{ + ClientAddr: clientAddr, } - - b, err := pool.PackBuffer(r) - if err != nil { + if u := req.URL; u != nil { + queryMeta.UrlPath = u.Path + } + if tlsStat := req.TLS; tlsStat != nil { + queryMeta.ServerName = tlsStat.ServerName + } + resp := h.dnsHandler.Handle(req.Context(), q, queryMeta, pool.PackBuffer) + if resp == nil { w.WriteHeader(http.StatusInternalServerError) - h.warnErr(req, "failed to unpack handler's response", err) return } - defer pool.ReleaseBuf(b) - + defer pool.ReleaseBuf(resp) w.Header().Set("Content-Type", "application/dns-message") - w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", dnsutils.GetMinimalTTL(r))) - if _, err := w.Write(*b); err != nil { + if _, err := w.Write(*resp); err != nil { h.warnErr(req, "failed to write response", err) return } diff --git a/pkg/server/iface.go b/pkg/server/iface.go index 2f15be1..c45b502 100644 --- a/pkg/server/iface.go +++ b/pkg/server/iface.go @@ -10,14 +10,20 @@ import ( // Handler handles incoming request q and MUST ALWAYS return a response. // Handler MUST handle dns errors by itself and return a proper error responses. // e.g. Return a SERVFAIL if something goes wrong. -// If Handle() returns an error, caller considers that the error is associated -// with the downstream connection and will close the downstream connection -// immediately. +// If Handle() returns a nil resp, caller will +// udp: do nothing. +// tcp/dot: close the connection immediately. +// doh: send a 500 response. +// doq: close the stream immediately. type Handler interface { - Handle(ctx context.Context, q *dns.Msg, meta QueryMeta) (resp *dns.Msg, err error) + Handle(ctx context.Context, q *dns.Msg, meta QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) (respPayload *[]byte) } type QueryMeta struct { - ClientAddr netip.Addr // Maybe invalid - FromUDP bool + FromUDP bool + + // Optional + ClientAddr netip.Addr + ServerName string + UrlPath string } diff --git a/pkg/server/tcp.go b/pkg/server/tcp.go index ddc4846..6faba76 100644 --- a/pkg/server/tcp.go +++ b/pkg/server/tcp.go @@ -21,6 +21,7 @@ package server import ( "context" + "crypto/tls" "fmt" "net" "net/netip" @@ -93,22 +94,22 @@ func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error { return // read err, close the connection } + // Try to get server name from tls conn. + var serverName string + if tlsConn, ok := c.(*tls.Conn); ok { + serverName = tlsConn.ConnectionState().ServerName + } + // handle query go func() { - r, err := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr}) - if err != nil { - logger.Warn("handler err", zap.Error(err)) + r := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr, ServerName: serverName}, pool.PackTCPBuffer) + if r == nil { c.Close() // abort the connection return } - b, err := pool.PackTCPBuffer(r) - if err != nil { - logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) - return - } - defer pool.ReleaseBuf(b) + defer pool.ReleaseBuf(r) - if _, err := c.Write(*b); err != nil { + if _, err := c.Write(*r); err != nil { logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err)) return } diff --git a/pkg/server/udp.go b/pkg/server/udp.go index 22e8d2b..89e57e2 100644 --- a/pkg/server/udp.go +++ b/pkg/server/udp.go @@ -63,10 +63,10 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error { n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob) if err != nil { if n == 0 { - // err with zero read. Most likely becasue c was closed. + // Err with zero read. Most likely because c was closed. return fmt.Errorf("unexpected read err: %w", err) } - // err with some read. Tempory err. + // Temporary err. logger.Warn("read err", zap.Error(err)) continue } @@ -88,42 +88,22 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error { // handle query go func() { - r, err := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}) - if err != nil { - logger.Warn("handler err", zap.Error(err)) + payload := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}, pool.PackBuffer) + if payload == nil { return } - if r != nil { - r.Truncate(getUDPSize(q)) - b, err := pool.PackBuffer(r) - if err != nil { - logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) - return - } - defer pool.ReleaseBuf(b) + defer pool.ReleaseBuf(payload) - var oob []byte - if oobWriter != nil && dstIpFromCm != nil { - oob = oobWriter(dstIpFromCm) - } - if _, _, err := c.WriteMsgUDPAddrPort(*b, oob, remoteAddr); err != nil { - logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) - } + var oob []byte + if oobWriter != nil && dstIpFromCm != nil { + oob = oobWriter(dstIpFromCm) + } + if _, _, err := c.WriteMsgUDPAddrPort(*payload, oob, remoteAddr); err != nil { + logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) } }() } } -func getUDPSize(m *dns.Msg) int { - var s uint16 - if opt := m.IsEdns0(); opt != nil { - s = opt.UDPSize() - } - if s < dns.MinMsgSize { - s = dns.MinMsgSize - } - return int(s) -} - type getSrcAddrFromOOB func(oob []byte) (net.IP, error) type writeSrcAddrToOOB func(a net.IP) []byte diff --git a/pkg/server_handler/entry_handler.go b/pkg/server_handler/entry_handler.go index 520e3d2..9e3a386 100644 --- a/pkg/server_handler/entry_handler.go +++ b/pkg/server_handler/entry_handler.go @@ -71,9 +71,9 @@ func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler { } // ServeDNS implements server.Handler. -// If entry returns an error, a SERVFAIL response will be set. -// If entry returns without a response, a REFUSED response will be set. -func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta) (*dns.Msg, error) { +// If entry returns an error, a SERVFAIL response will be returned. +// If entry returns without a response, a REFUSED response will be returned. +func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte { ddl := time.Now().Add(h.opts.QueryTimeout) ctx, cancel := context.WithDeadline(ctx, ddl) defer cancel() @@ -100,5 +100,26 @@ func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.Quer respMsg.Rcode = dns.RcodeServerFailure } respMsg.RecursionAvailable = true - return respMsg, nil + + if qInfo.FromUDP { + respMsg.Truncate(getUDPSize(q)) + } + + payload, err := packMsgPayload(respMsg) + if err != nil { + h.opts.Logger.Error("internal err: failed to pack resp msg", zap.Error(err)) + return nil + } + return payload +} + +func getUDPSize(m *dns.Msg) int { + var s uint16 + if opt := m.IsEdns0(); opt != nil { + s = opt.UDPSize() + } + if s < dns.MinMsgSize { + s = dns.MinMsgSize + } + return int(s) } -- 2.34.8