255 lines
8.4 KiB
Diff
255 lines
8.4 KiB
Diff
From c0af4b587311766650c8c103656dcb595bcfef34 Mon Sep 17 00:00:00 2001
|
|
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
|
|
Date: Fri, 22 Sep 2023 09:24:05 +0800
|
|
Subject: [PATCH 3/9] server: simplify Handler interface, add more meta
|
|
|
|
---
|
|
pkg/server/http_handler.go | 25 +++++++++--------
|
|
pkg/server/iface.go | 18 ++++++++-----
|
|
pkg/server/tcp.go | 21 ++++++++-------
|
|
pkg/server/udp.go | 42 ++++++++---------------------
|
|
pkg/server_handler/entry_handler.go | 29 +++++++++++++++++---
|
|
5 files changed, 71 insertions(+), 64 deletions(-)
|
|
|
|
diff --git a/pkg/server/http_handler.go b/pkg/server/http_handler.go
|
|
index 3e671e3..5a41314 100644
|
|
--- a/pkg/server/http_handler.go
|
|
+++ b/pkg/server/http_handler.go
|
|
@@ -28,7 +28,6 @@ import (
|
|
"net/netip"
|
|
"strings"
|
|
|
|
- "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
|
"github.com/miekg/dns"
|
|
"go.uber.org/zap"
|
|
@@ -97,23 +96,23 @@ func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
return
|
|
}
|
|
|
|
- r, err := h.dnsHandler.Handle(req.Context(), q, QueryMeta{ClientAddr: clientAddr})
|
|
- if err != nil {
|
|
- h.warnErr(req, "handler err", err)
|
|
- panic(err) // Force http server to close connection.
|
|
+ queryMeta := QueryMeta{
|
|
+ ClientAddr: clientAddr,
|
|
}
|
|
-
|
|
- b, err := pool.PackBuffer(r)
|
|
- if err != nil {
|
|
+ if u := req.URL; u != nil {
|
|
+ queryMeta.UrlPath = u.Path
|
|
+ }
|
|
+ if tlsStat := req.TLS; tlsStat != nil {
|
|
+ queryMeta.ServerName = tlsStat.ServerName
|
|
+ }
|
|
+ resp := h.dnsHandler.Handle(req.Context(), q, queryMeta, pool.PackBuffer)
|
|
+ if resp == nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
- h.warnErr(req, "failed to unpack handler's response", err)
|
|
return
|
|
}
|
|
- defer pool.ReleaseBuf(b)
|
|
-
|
|
+ defer pool.ReleaseBuf(resp)
|
|
w.Header().Set("Content-Type", "application/dns-message")
|
|
- w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", dnsutils.GetMinimalTTL(r)))
|
|
- if _, err := w.Write(*b); err != nil {
|
|
+ if _, err := w.Write(*resp); err != nil {
|
|
h.warnErr(req, "failed to write response", err)
|
|
return
|
|
}
|
|
diff --git a/pkg/server/iface.go b/pkg/server/iface.go
|
|
index 2f15be1..c45b502 100644
|
|
--- a/pkg/server/iface.go
|
|
+++ b/pkg/server/iface.go
|
|
@@ -10,14 +10,20 @@ import (
|
|
// Handler handles incoming request q and MUST ALWAYS return a response.
|
|
// Handler MUST handle dns errors by itself and return a proper error responses.
|
|
// e.g. Return a SERVFAIL if something goes wrong.
|
|
-// If Handle() returns an error, caller considers that the error is associated
|
|
-// with the downstream connection and will close the downstream connection
|
|
-// immediately.
|
|
+// If Handle() returns a nil resp, caller will
|
|
+// udp: do nothing.
|
|
+// tcp/dot: close the connection immediately.
|
|
+// doh: send a 500 response.
|
|
+// doq: close the stream immediately.
|
|
type Handler interface {
|
|
- Handle(ctx context.Context, q *dns.Msg, meta QueryMeta) (resp *dns.Msg, err error)
|
|
+ Handle(ctx context.Context, q *dns.Msg, meta QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) (respPayload *[]byte)
|
|
}
|
|
|
|
type QueryMeta struct {
|
|
- ClientAddr netip.Addr // Maybe invalid
|
|
- FromUDP bool
|
|
+ FromUDP bool
|
|
+
|
|
+ // Optional
|
|
+ ClientAddr netip.Addr
|
|
+ ServerName string
|
|
+ UrlPath string
|
|
}
|
|
diff --git a/pkg/server/tcp.go b/pkg/server/tcp.go
|
|
index ddc4846..6faba76 100644
|
|
--- a/pkg/server/tcp.go
|
|
+++ b/pkg/server/tcp.go
|
|
@@ -21,6 +21,7 @@ package server
|
|
|
|
import (
|
|
"context"
|
|
+ "crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
@@ -93,22 +94,22 @@ func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error {
|
|
return // read err, close the connection
|
|
}
|
|
|
|
+ // Try to get server name from tls conn.
|
|
+ var serverName string
|
|
+ if tlsConn, ok := c.(*tls.Conn); ok {
|
|
+ serverName = tlsConn.ConnectionState().ServerName
|
|
+ }
|
|
+
|
|
// handle query
|
|
go func() {
|
|
- r, err := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr})
|
|
- if err != nil {
|
|
- logger.Warn("handler err", zap.Error(err))
|
|
+ r := h.Handle(tcpConnCtx, req, QueryMeta{ClientAddr: clientAddr, ServerName: serverName}, pool.PackTCPBuffer)
|
|
+ if r == nil {
|
|
c.Close() // abort the connection
|
|
return
|
|
}
|
|
- b, err := pool.PackTCPBuffer(r)
|
|
- if err != nil {
|
|
- logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r))
|
|
- return
|
|
- }
|
|
- defer pool.ReleaseBuf(b)
|
|
+ defer pool.ReleaseBuf(r)
|
|
|
|
- if _, err := c.Write(*b); err != nil {
|
|
+ if _, err := c.Write(*r); err != nil {
|
|
logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err))
|
|
return
|
|
}
|
|
diff --git a/pkg/server/udp.go b/pkg/server/udp.go
|
|
index 22e8d2b..89e57e2 100644
|
|
--- a/pkg/server/udp.go
|
|
+++ b/pkg/server/udp.go
|
|
@@ -63,10 +63,10 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error {
|
|
n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob)
|
|
if err != nil {
|
|
if n == 0 {
|
|
- // err with zero read. Most likely becasue c was closed.
|
|
+ // Err with zero read. Most likely because c was closed.
|
|
return fmt.Errorf("unexpected read err: %w", err)
|
|
}
|
|
- // err with some read. Tempory err.
|
|
+ // Temporary err.
|
|
logger.Warn("read err", zap.Error(err))
|
|
continue
|
|
}
|
|
@@ -88,42 +88,22 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error {
|
|
|
|
// handle query
|
|
go func() {
|
|
- r, err := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true})
|
|
- if err != nil {
|
|
- logger.Warn("handler err", zap.Error(err))
|
|
+ payload := h.Handle(listenerCtx, q, QueryMeta{ClientAddr: remoteAddr.Addr(), FromUDP: true}, pool.PackBuffer)
|
|
+ if payload == nil {
|
|
return
|
|
}
|
|
- if r != nil {
|
|
- r.Truncate(getUDPSize(q))
|
|
- b, err := pool.PackBuffer(r)
|
|
- if err != nil {
|
|
- logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r))
|
|
- return
|
|
- }
|
|
- defer pool.ReleaseBuf(b)
|
|
+ defer pool.ReleaseBuf(payload)
|
|
|
|
- var oob []byte
|
|
- if oobWriter != nil && dstIpFromCm != nil {
|
|
- oob = oobWriter(dstIpFromCm)
|
|
- }
|
|
- if _, _, err := c.WriteMsgUDPAddrPort(*b, oob, remoteAddr); err != nil {
|
|
- logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
|
|
- }
|
|
+ var oob []byte
|
|
+ if oobWriter != nil && dstIpFromCm != nil {
|
|
+ oob = oobWriter(dstIpFromCm)
|
|
+ }
|
|
+ if _, _, err := c.WriteMsgUDPAddrPort(*payload, oob, remoteAddr); err != nil {
|
|
+ logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
-func getUDPSize(m *dns.Msg) int {
|
|
- var s uint16
|
|
- if opt := m.IsEdns0(); opt != nil {
|
|
- s = opt.UDPSize()
|
|
- }
|
|
- if s < dns.MinMsgSize {
|
|
- s = dns.MinMsgSize
|
|
- }
|
|
- return int(s)
|
|
-}
|
|
-
|
|
type getSrcAddrFromOOB func(oob []byte) (net.IP, error)
|
|
type writeSrcAddrToOOB func(a net.IP) []byte
|
|
diff --git a/pkg/server_handler/entry_handler.go b/pkg/server_handler/entry_handler.go
|
|
index 520e3d2..9e3a386 100644
|
|
--- a/pkg/server_handler/entry_handler.go
|
|
+++ b/pkg/server_handler/entry_handler.go
|
|
@@ -71,9 +71,9 @@ func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler {
|
|
}
|
|
|
|
// ServeDNS implements server.Handler.
|
|
-// If entry returns an error, a SERVFAIL response will be set.
|
|
-// If entry returns without a response, a REFUSED response will be set.
|
|
-func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta) (*dns.Msg, error) {
|
|
+// If entry returns an error, a SERVFAIL response will be returned.
|
|
+// If entry returns without a response, a REFUSED response will be returned.
|
|
+func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte {
|
|
ddl := time.Now().Add(h.opts.QueryTimeout)
|
|
ctx, cancel := context.WithDeadline(ctx, ddl)
|
|
defer cancel()
|
|
@@ -100,5 +100,26 @@ func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.Quer
|
|
respMsg.Rcode = dns.RcodeServerFailure
|
|
}
|
|
respMsg.RecursionAvailable = true
|
|
- return respMsg, nil
|
|
+
|
|
+ if qInfo.FromUDP {
|
|
+ respMsg.Truncate(getUDPSize(q))
|
|
+ }
|
|
+
|
|
+ payload, err := packMsgPayload(respMsg)
|
|
+ if err != nil {
|
|
+ h.opts.Logger.Error("internal err: failed to pack resp msg", zap.Error(err))
|
|
+ return nil
|
|
+ }
|
|
+ return payload
|
|
+}
|
|
+
|
|
+func getUDPSize(m *dns.Msg) int {
|
|
+ var s uint16
|
|
+ if opt := m.IsEdns0(); opt != nil {
|
|
+ s = opt.UDPSize()
|
|
+ }
|
|
+ if s < dns.MinMsgSize {
|
|
+ s = dns.MinMsgSize
|
|
+ }
|
|
+ return int(s)
|
|
}
|
|
--
|
|
2.34.8
|
|
|