luci-app-mosdns/mosdns/patches/119-server-simplify-Handler-interface-add-more-meta.patch
2023-09-23 15:17:07 +08:00

255 lines
8.4 KiB
Diff

From c0af4b587311766650c8c103656dcb595bcfef34 Mon Sep 17 00:00:00 2001
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
Date: Fri, 22 Sep 2023 09:24:05 +0800
Subject: [PATCH 3/9] server: simplify Handler interface, add more meta
---
pkg/server/http_handler.go | 25 +++++++++--------
pkg/server/iface.go | 18 ++++++++-----
pkg/server/tcp.go | 21 ++++++++-------
pkg/server/udp.go | 42 ++++++++---------------------
pkg/server_handler/entry_handler.go | 29 +++++++++++++++++---
5 files changed, 71 insertions(+), 64 deletions(-)
diff --git a/pkg/server/http_handler.go b/pkg/server/http_handler.go
index 3e671e3..5a41314 100644
--- a/pkg/server/http_handler.go
+++ b/pkg/server/http_handler.go
@@ -28,7 +28,6 @@ import (
"net/netip"
"strings"
- "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
"go.uber.org/zap"
@@ -97,23 +96,23 @@ func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
- r, err := h.dnsHandler.Handle(req.Context(), q, QueryMeta{ClientAddr: clientAddr})
- if err != nil {
- h.warnErr(req, "handler err", err)
- panic(err) // Force http server to close connection.
+ queryMeta := QueryMeta{
+ ClientAddr: clientAddr,
}
-
- b, err := pool.PackBuffer(r)
- if err != nil {
+ if u := req.URL; u != nil {
+ queryMeta.UrlPath = u.Path
+ }
+ if tlsStat := req.TLS; tlsStat != nil {
+ queryMeta.ServerName = tlsStat.ServerName
+ }
+ resp := h.dnsHandler.Handle(req.Context(), q, queryMeta, pool.PackBuffer)
+ if resp == nil {
w.WriteHeader(http.StatusInternalServerError)
- h.warnErr(req, "failed to unpack handler's response", err)
return
}
- defer pool.ReleaseBuf(b)
-
+ defer pool.ReleaseBuf(resp)
w.Header().Set("Content-Type", "application/dns-message")
- w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", dnsutils.GetMinimalTTL(r)))
- if _, err := w.Write(*b); err != nil {
+ if _, err := w.Write(*resp); err != nil {
h.warnErr(req, "failed to write response", err)
return
}
diff --git a/pkg/server/iface.go b/pkg/server/iface.go
index 2f15be1..c45b502 100644
--- a/pkg/server/iface.go
+++ b/pkg/server/iface.go
@@ -10,14 +10,20 @@ import (
// Handler handles incoming request q and MUST ALWAYS return a response.
// Handler MUST handle dns errors by itself and return a proper error responses.
// e.g. Return a SERVFAIL if something goes wrong.
-// If Handle() returns an error, caller considers that the error is associated
-// with the downstream connection and will close the downstream connection
-// immediately.
+// If Handle() returns a nil resp, caller will
+// udp: do nothing.
+// tcp/dot: close the connection immediately.
+// doh: send a 500 response.
+// doq: close the stream immediately.
type Handler interface {
- Handle(ctx context.Context, q *dns.Msg, meta QueryMeta) (resp *dns.Msg, err error)
+ Handle(ctx context.Context, q *dns.Msg, meta QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) (respPayload *[]byte)
}
type QueryMeta struct {
- ClientAddr netip.Addr // Maybe invalid
- FromUDP bool
+ FromUDP bool
+
+ // Optional
+ ClientAddr netip.Addr
+ ServerName string
+ UrlPath string
}
diff --git a/pkg/server/tcp.go b/pkg/server/tcp.go
index ddc4846..6faba76 100644
--- a/pkg/server/tcp.go
+++ b/pkg/server/tcp.go
@@ -21,6 +21,7 @@ package server
import (
"context"
+ "crypto/tls"
"fmt"
"net"
"net/netip"
@@ -93,22 +94,22 @@ func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error {
return // read err, close the connection
}
+ // Try to get server name from tls conn.
+ var serverName string
+ if tlsConn, ok := c.(*tls.Conn); ok {
+ serverName = tlsConn.ConnectionState().ServerName
+ }
+
// handle query
go func() {
- r, err := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr})
- if err != nil {
- logger.Warn("handler err", zap.Error(err))
+ r := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr, ServerName: serverName}, pool.PackTCPBuffer)
+ if r == nil {
c.Close() // abort the connection
return
}
- b, err := pool.PackTCPBuffer(r)
- if err != nil {
- logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r))
- return
- }
- defer pool.ReleaseBuf(b)
+ defer pool.ReleaseBuf(r)
- if _, err := c.Write(*b); err != nil {
+ if _, err := c.Write(*r); err != nil {
logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err))
return
}
diff --git a/pkg/server/udp.go b/pkg/server/udp.go
index 22e8d2b..89e57e2 100644
--- a/pkg/server/udp.go
+++ b/pkg/server/udp.go
@@ -63,10 +63,10 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error {
n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob)
if err != nil {
if n == 0 {
- // err with zero read. Most likely becasue c was closed.
+ // Err with zero read. Most likely because c was closed.
return fmt.Errorf("unexpected read err: %w", err)
}
- // err with some read. Tempory err.
+ // Temporary err.
logger.Warn("read err", zap.Error(err))
continue
}
@@ -88,42 +88,22 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error {
// handle query
go func() {
- r, err := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true})
- if err != nil {
- logger.Warn("handler err", zap.Error(err))
+ payload := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}, pool.PackBuffer)
+ if payload == nil {
return
}
- if r != nil {
- r.Truncate(getUDPSize(q))
- b, err := pool.PackBuffer(r)
- if err != nil {
- logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r))
- return
- }
- defer pool.ReleaseBuf(b)
+ defer pool.ReleaseBuf(payload)
- var oob []byte
- if oobWriter != nil && dstIpFromCm != nil {
- oob = oobWriter(dstIpFromCm)
- }
- if _, _, err := c.WriteMsgUDPAddrPort(*b, oob, remoteAddr); err != nil {
- logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
- }
+ var oob []byte
+ if oobWriter != nil && dstIpFromCm != nil {
+ oob = oobWriter(dstIpFromCm)
+ }
+ if _, _, err := c.WriteMsgUDPAddrPort(*payload, oob, remoteAddr); err != nil {
+ logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
}
}()
}
}
-func getUDPSize(m *dns.Msg) int {
- var s uint16
- if opt := m.IsEdns0(); opt != nil {
- s = opt.UDPSize()
- }
- if s < dns.MinMsgSize {
- s = dns.MinMsgSize
- }
- return int(s)
-}
-
type getSrcAddrFromOOB func(oob []byte) (net.IP, error)
type writeSrcAddrToOOB func(a net.IP) []byte
diff --git a/pkg/server_handler/entry_handler.go b/pkg/server_handler/entry_handler.go
index 520e3d2..9e3a386 100644
--- a/pkg/server_handler/entry_handler.go
+++ b/pkg/server_handler/entry_handler.go
@@ -71,9 +71,9 @@ func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler {
}
// ServeDNS implements server.Handler.
-// If entry returns an error, a SERVFAIL response will be set.
-// If entry returns without a response, a REFUSED response will be set.
-func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta) (*dns.Msg, error) {
+// If entry returns an error, a SERVFAIL response will be returned.
+// If entry returns without a response, a REFUSED response will be returned.
+func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte {
ddl := time.Now().Add(h.opts.QueryTimeout)
ctx, cancel := context.WithDeadline(ctx, ddl)
defer cancel()
@@ -100,5 +100,26 @@ func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.Quer
respMsg.Rcode = dns.RcodeServerFailure
}
respMsg.RecursionAvailable = true
- return respMsg, nil
+
+ if qInfo.FromUDP {
+ respMsg.Truncate(getUDPSize(q))
+ }
+
+ payload, err := packMsgPayload(respMsg)
+ if err != nil {
+ h.opts.Logger.Error("internal err: failed to pack resp msg", zap.Error(err))
+ return nil
+ }
+ return payload
+}
+
+func getUDPSize(m *dns.Msg) int {
+ var s uint16
+ if opt := m.IsEdns0(); opt != nil {
+ s = opt.UDPSize()
+ }
+ if s < dns.MinMsgSize {
+ s = dns.MinMsgSize
+ }
+ return int(s)
}
--
2.34.8