223 lines
6.9 KiB
Diff
223 lines
6.9 KiB
Diff
From 64a83b8e28b3988df9eec4425130b57a09b15032 Mon Sep 17 00:00:00 2001
|
|
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
|
|
Date: Thu, 21 Sep 2023 22:06:49 +0800
|
|
Subject: [PATCH 1/9] pool: simplify PackBuffer
|
|
|
|
---
|
|
pkg/dnsutils/net_io.go | 7 +++---
|
|
pkg/pool/msg_buf.go | 41 ++++++++++++++++++------------------
|
|
pkg/pool/msg_buf_test.go | 6 +-----
|
|
pkg/server/http_handler.go | 6 +++---
|
|
pkg/server/tcp.go | 6 +++---
|
|
pkg/server/udp.go | 6 +++---
|
|
pkg/upstream/doh/upstream.go | 5 +++--
|
|
7 files changed, 37 insertions(+), 40 deletions(-)
|
|
|
|
diff --git a/pkg/dnsutils/net_io.go b/pkg/dnsutils/net_io.go
|
|
index f165446..26e6efb 100644
|
|
--- a/pkg/dnsutils/net_io.go
|
|
+++ b/pkg/dnsutils/net_io.go
|
|
@@ -101,13 +101,12 @@ func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) {
|
|
}
|
|
|
|
func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) {
|
|
- b, buf, err := pool.PackBuffer(m)
|
|
+ b, err := pool.PackBuffer(m)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
- defer pool.ReleaseBuf(buf)
|
|
-
|
|
- return c.Write(b)
|
|
+ defer pool.ReleaseBuf(b)
|
|
+ return c.Write(*b)
|
|
}
|
|
|
|
func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) {
|
|
diff --git a/pkg/pool/msg_buf.go b/pkg/pool/msg_buf.go
|
|
index 11faf7d..b5f861c 100644
|
|
--- a/pkg/pool/msg_buf.go
|
|
+++ b/pkg/pool/msg_buf.go
|
|
@@ -26,47 +26,48 @@ import (
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
-// There is no such way to give dns.Msg.PackBuffer() a buffer
|
|
-// with a proper size.
|
|
-// Just give it a big buf and hope the buf will be reused in most scenes.
|
|
-const packBufSize = 4096
|
|
+// dns.Msg.PackBuffer requires a buffer with length of m.Len() + 1.
|
|
+// Don't know why it needs one more byte.
|
|
+func getPackBuffer(m *dns.Msg) int {
|
|
+ return m.Len() + 1
|
|
+}
|
|
|
|
// PackBuffer packs the dns msg m to wire format.
|
|
// Callers should release the buf by calling ReleaseBuf after they have done
|
|
// with the wire []byte.
|
|
-func PackBuffer(m *dns.Msg) (wire []byte, buf *[]byte, err error) {
|
|
- buf = GetBuf(packBufSize)
|
|
- wire, err = m.PackBuffer(*buf)
|
|
+func PackBuffer(m *dns.Msg) (*[]byte, error) {
|
|
+ b := GetBuf(getPackBuffer(m))
|
|
+ wire, err := m.PackBuffer(*b)
|
|
if err != nil {
|
|
- ReleaseBuf(buf)
|
|
- return nil, nil, err
|
|
+ ReleaseBuf(b)
|
|
+ return nil, err
|
|
}
|
|
- return wire, buf, nil
|
|
+ if &((*b)[0]) != &wire[0] { // reallocated
|
|
+ ReleaseBuf(b)
|
|
+ return nil, dns.ErrBuf
|
|
+ }
|
|
+ return b, nil
|
|
}
|
|
|
|
// PackBuffer packs the dns msg m to wire format, with to bytes length header.
|
|
// Callers should release the buf by calling ReleaseBuf.
|
|
-func PackTCPBuffer(m *dns.Msg) (buf *[]byte, err error) {
|
|
- b := GetBuf(packBufSize)
|
|
+func PackTCPBuffer(m *dns.Msg) (*[]byte, error) {
|
|
+ b := GetBuf(2 + getPackBuffer(m))
|
|
wire, err := m.PackBuffer((*b)[2:])
|
|
if err != nil {
|
|
ReleaseBuf(b)
|
|
return nil, err
|
|
}
|
|
+ if &((*b)[2]) != &wire[0] { // reallocated
|
|
+ ReleaseBuf(b)
|
|
+ return nil, dns.ErrBuf
|
|
+ }
|
|
|
|
l := len(wire)
|
|
if l > dns.MaxMsgSize {
|
|
ReleaseBuf(b)
|
|
return nil, fmt.Errorf("dns payload size %d is too large", l)
|
|
}
|
|
-
|
|
- if &((*b)[2]) != &wire[0] { // reallocated
|
|
- ReleaseBuf(b)
|
|
- b = GetBuf(l + 2)
|
|
- binary.BigEndian.PutUint16((*b)[:2], uint16(l))
|
|
- copy((*b)[2:], wire)
|
|
- return b, nil
|
|
- }
|
|
binary.BigEndian.PutUint16((*b)[:2], uint16(l))
|
|
*b = (*b)[:2+l]
|
|
return b, nil
|
|
diff --git a/pkg/pool/msg_buf_test.go b/pkg/pool/msg_buf_test.go
|
|
index 97d9a76..bfd98d1 100644
|
|
--- a/pkg/pool/msg_buf_test.go
|
|
+++ b/pkg/pool/msg_buf_test.go
|
|
@@ -28,12 +28,8 @@ import (
|
|
func TestPackBuffer_No_Allocation(t *testing.T) {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("123.", dns.TypeAAAA)
|
|
- wire, buf, err := PackBuffer(m)
|
|
+ _, err := PackBuffer(m)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
-
|
|
- if cap(wire) != cap(*buf) {
|
|
- t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(*buf))
|
|
- }
|
|
}
|
|
diff --git a/pkg/server/http_handler.go b/pkg/server/http_handler.go
|
|
index 58f5811..3e671e3 100644
|
|
--- a/pkg/server/http_handler.go
|
|
+++ b/pkg/server/http_handler.go
|
|
@@ -103,17 +103,17 @@ func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
panic(err) // Force http server to close connection.
|
|
}
|
|
|
|
- b, buf, err := pool.PackBuffer(r)
|
|
+ b, err := pool.PackBuffer(r)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
h.warnErr(req, "failed to unpack handler's response", err)
|
|
return
|
|
}
|
|
- defer pool.ReleaseBuf(buf)
|
|
+ defer pool.ReleaseBuf(b)
|
|
|
|
w.Header().Set("Content-Type", "application/dns-message")
|
|
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", dnsutils.GetMinimalTTL(r)))
|
|
- if _, err := w.Write(b); err != nil {
|
|
+ if _, err := w.Write(*b); err != nil {
|
|
h.warnErr(req, "failed to write response", err)
|
|
return
|
|
}
|
|
diff --git a/pkg/server/tcp.go b/pkg/server/tcp.go
|
|
index 5f479b1..ddc4846 100644
|
|
--- a/pkg/server/tcp.go
|
|
+++ b/pkg/server/tcp.go
|
|
@@ -101,14 +101,14 @@ func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error {
|
|
c.Close() // abort the connection
|
|
return
|
|
}
|
|
- b, buf, err := pool.PackBuffer(r)
|
|
+ b, err := pool.PackTCPBuffer(r)
|
|
if err != nil {
|
|
logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r))
|
|
return
|
|
}
|
|
- defer pool.ReleaseBuf(buf)
|
|
+ defer pool.ReleaseBuf(b)
|
|
|
|
- if _, err := dnsutils.WriteRawMsgToTCP(c, b); err != nil {
|
|
+ if _, err := c.Write(*b); err != nil {
|
|
logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err))
|
|
return
|
|
}
|
|
diff --git a/pkg/server/udp.go b/pkg/server/udp.go
|
|
index 4dc1087..22e8d2b 100644
|
|
--- a/pkg/server/udp.go
|
|
+++ b/pkg/server/udp.go
|
|
@@ -95,18 +95,18 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error {
|
|
}
|
|
if r != nil {
|
|
r.Truncate(getUDPSize(q))
|
|
- b, buf, err := pool.PackBuffer(r)
|
|
+ b, err := pool.PackBuffer(r)
|
|
if err != nil {
|
|
logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r))
|
|
return
|
|
}
|
|
- defer pool.ReleaseBuf(buf)
|
|
+ defer pool.ReleaseBuf(b)
|
|
|
|
var oob []byte
|
|
if oobWriter != nil && dstIpFromCm != nil {
|
|
oob = oobWriter(dstIpFromCm)
|
|
}
|
|
- if _, _, err := c.WriteMsgUDPAddrPort(b, oob, remoteAddr); err != nil {
|
|
+ if _, _, err := c.WriteMsgUDPAddrPort(*b, oob, remoteAddr); err != nil {
|
|
logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
|
|
}
|
|
}
|
|
diff --git a/pkg/upstream/doh/upstream.go b/pkg/upstream/doh/upstream.go
|
|
index abc124b..9cc72c4 100644
|
|
--- a/pkg/upstream/doh/upstream.go
|
|
+++ b/pkg/upstream/doh/upstream.go
|
|
@@ -54,11 +54,12 @@ var (
|
|
)
|
|
|
|
func (u *Upstream) ExchangeContext(ctx context.Context, q *dns.Msg) (*dns.Msg, error) {
|
|
- wire, buf, err := pool.PackBuffer(q)
|
|
+ bp, err := pool.PackBuffer(q)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to pack query msg, %w", err)
|
|
}
|
|
- defer pool.ReleaseBuf(buf)
|
|
+ defer pool.ReleaseBuf(bp)
|
|
+ wire := *bp
|
|
|
|
// In order to maximize HTTP cache friendliness, DoH clients using media
|
|
// formats that include the ID field from the DNS message header, such
|
|
--
|
|
2.34.8
|
|
|