From 64a83b8e28b3988df9eec4425130b57a09b15032 Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Thu, 21 Sep 2023 22:06:49 +0800 Subject: [PATCH 1/9] pool: simplify PackBuffer --- pkg/dnsutils/net_io.go | 7 +++--- pkg/pool/msg_buf.go | 41 ++++++++++++++++++------------------ pkg/pool/msg_buf_test.go | 6 +----- pkg/server/http_handler.go | 6 +++--- pkg/server/tcp.go | 6 +++--- pkg/server/udp.go | 6 +++--- pkg/upstream/doh/upstream.go | 5 +++-- 7 files changed, 37 insertions(+), 40 deletions(-) diff --git a/pkg/dnsutils/net_io.go b/pkg/dnsutils/net_io.go index f165446..26e6efb 100644 --- a/pkg/dnsutils/net_io.go +++ b/pkg/dnsutils/net_io.go @@ -101,13 +101,12 @@ func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) { } func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) { - b, buf, err := pool.PackBuffer(m) + b, err := pool.PackBuffer(m) if err != nil { return 0, err } - defer pool.ReleaseBuf(buf) - - return c.Write(b) + defer pool.ReleaseBuf(b) + return c.Write(*b) } func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) { diff --git a/pkg/pool/msg_buf.go b/pkg/pool/msg_buf.go index 11faf7d..b5f861c 100644 --- a/pkg/pool/msg_buf.go +++ b/pkg/pool/msg_buf.go @@ -26,47 +26,48 @@ import ( "github.com/miekg/dns" ) -// There is no such way to give dns.Msg.PackBuffer() a buffer -// with a proper size. -// Just give it a big buf and hope the buf will be reused in most scenes. -const packBufSize = 4096 +// dns.Msg.PackBuffer requires a buffer with length of m.Len() + 1. +// Don't know why it needs one more byte. +func getPackBuffer(m *dns.Msg) int { + return m.Len() + 1 +} // PackBuffer packs the dns msg m to wire format. // Callers should release the buf by calling ReleaseBuf after they have done // with the wire []byte. -func PackBuffer(m *dns.Msg) (wire []byte, buf *[]byte, err error) { - buf = GetBuf(packBufSize) - wire, err = m.PackBuffer(*buf) +func PackBuffer(m *dns.Msg) (*[]byte, error) { + b := GetBuf(getPackBuffer(m)) + wire, err := m.PackBuffer(*b) if err != nil { - ReleaseBuf(buf) - return nil, nil, err + ReleaseBuf(b) + return nil, err } - return wire, buf, nil + if &((*b)[0]) != &wire[0] { // reallocated + ReleaseBuf(b) + return nil, dns.ErrBuf + } + return b, nil } // PackBuffer packs the dns msg m to wire format, with to bytes length header. // Callers should release the buf by calling ReleaseBuf. -func PackTCPBuffer(m *dns.Msg) (buf *[]byte, err error) { - b := GetBuf(packBufSize) +func PackTCPBuffer(m *dns.Msg) (*[]byte, error) { + b := GetBuf(2 + getPackBuffer(m)) wire, err := m.PackBuffer((*b)[2:]) if err != nil { ReleaseBuf(b) return nil, err } + if &((*b)[2]) != &wire[0] { // reallocated + ReleaseBuf(b) + return nil, dns.ErrBuf + } l := len(wire) if l > dns.MaxMsgSize { ReleaseBuf(b) return nil, fmt.Errorf("dns payload size %d is too large", l) } - - if &((*b)[2]) != &wire[0] { // reallocated - ReleaseBuf(b) - b = GetBuf(l + 2) - binary.BigEndian.PutUint16((*b)[:2], uint16(l)) - copy((*b)[2:], wire) - return b, nil - } binary.BigEndian.PutUint16((*b)[:2], uint16(l)) *b = (*b)[:2+l] return b, nil diff --git a/pkg/pool/msg_buf_test.go b/pkg/pool/msg_buf_test.go index 97d9a76..bfd98d1 100644 --- a/pkg/pool/msg_buf_test.go +++ b/pkg/pool/msg_buf_test.go @@ -28,12 +28,8 @@ import ( func TestPackBuffer_No_Allocation(t *testing.T) { m := new(dns.Msg) m.SetQuestion("123.", dns.TypeAAAA) - wire, buf, err := PackBuffer(m) + _, err := PackBuffer(m) if err != nil { t.Fatal(err) } - - if cap(wire) != cap(*buf) { - t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(*buf)) - } } diff --git a/pkg/server/http_handler.go b/pkg/server/http_handler.go index 58f5811..3e671e3 100644 --- a/pkg/server/http_handler.go +++ b/pkg/server/http_handler.go @@ -103,17 +103,17 @@ func (h *HttpHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { panic(err) // Force http server to close connection. } - b, buf, err := pool.PackBuffer(r) + b, err := pool.PackBuffer(r) if err != nil { w.WriteHeader(http.StatusInternalServerError) h.warnErr(req, "failed to unpack handler's response", err) return } - defer pool.ReleaseBuf(buf) + defer pool.ReleaseBuf(b) w.Header().Set("Content-Type", "application/dns-message") w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", dnsutils.GetMinimalTTL(r))) - if _, err := w.Write(b); err != nil { + if _, err := w.Write(*b); err != nil { h.warnErr(req, "failed to write response", err) return } diff --git a/pkg/server/tcp.go b/pkg/server/tcp.go index 5f479b1..ddc4846 100644 --- a/pkg/server/tcp.go +++ b/pkg/server/tcp.go @@ -101,14 +101,14 @@ func ServeTCP(l net.Listener, h Handler, opts TCPServerOpts) error { c.Close() // abort the connection return } - b, buf, err := pool.PackBuffer(r) + b, err := pool.PackTCPBuffer(r) if err != nil { logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) return } - defer pool.ReleaseBuf(buf) + defer pool.ReleaseBuf(b) - if _, err := dnsutils.WriteRawMsgToTCP(c, b); err != nil { + if _, err := c.Write(*b); err != nil { logger.Warn("failed to write response", zap.Stringer("client", c.RemoteAddr()), zap.Error(err)) return } diff --git a/pkg/server/udp.go b/pkg/server/udp.go index 4dc1087..22e8d2b 100644 --- a/pkg/server/udp.go +++ b/pkg/server/udp.go @@ -95,18 +95,18 @@ func ServeUDP(c *net.UDPConn, h Handler, opts UDPServerOpts) error { } if r != nil { r.Truncate(getUDPSize(q)) - b, buf, err := pool.PackBuffer(r) + b, err := pool.PackBuffer(r) if err != nil { logger.Error("failed to unpack handler's response", zap.Error(err), zap.Stringer("msg", r)) return } - defer pool.ReleaseBuf(buf) + defer pool.ReleaseBuf(b) var oob []byte if oobWriter != nil && dstIpFromCm != nil { oob = oobWriter(dstIpFromCm) } - if _, _, err := c.WriteMsgUDPAddrPort(b, oob, remoteAddr); err != nil { + if _, _, err := c.WriteMsgUDPAddrPort(*b, oob, remoteAddr); err != nil { logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) } } diff --git a/pkg/upstream/doh/upstream.go b/pkg/upstream/doh/upstream.go index abc124b..9cc72c4 100644 --- a/pkg/upstream/doh/upstream.go +++ b/pkg/upstream/doh/upstream.go @@ -54,11 +54,12 @@ var ( ) func (u *Upstream) ExchangeContext(ctx context.Context, q *dns.Msg) (*dns.Msg, error) { - wire, buf, err := pool.PackBuffer(q) + bp, err := pool.PackBuffer(q) if err != nil { return nil, fmt.Errorf("failed to pack query msg, %w", err) } - defer pool.ReleaseBuf(buf) + defer pool.ReleaseBuf(bp) + wire := *bp // In order to maximize HTTP cache friendliness, DoH clients using media // formats that include the ID field from the DNS message header, such -- 2.34.8