From 9bb813bf50ba0eed3a35acc8f41764613c11687d Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:29:31 +0800 Subject: [PATCH 04/10] use go-bytes-pool --- go.mod | 1 + go.sum | 2 + pkg/dnsutils/net_io.go | 45 +++++---- pkg/pool/allocator.go | 78 +--------------- pkg/pool/allocator_test.go | 119 ------------------------ pkg/pool/msg_buf.go | 4 +- pkg/pool/msg_buf_test.go | 4 +- pkg/server/udp.go | 6 +- pkg/upstream/transport/dns_conn_test.go | 4 +- plugin/executable/cache/cache.go | 8 +- 10 files changed, 43 insertions(+), 228 deletions(-) delete mode 100644 pkg/pool/allocator_test.go diff --git a/go.mod b/go.mod index b56078b..2c359e3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21 toolchain go1.21.1 require ( + github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686 github.com/go-chi/chi/v5 v5.0.10 github.com/google/nftables v0.1.0 github.com/kardianos/service v1.2.2 diff --git a/go.sum b/go.sum index 633f864..d4bbaa2 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686 h1:5R32cCep3VUDTKf3aurFKfgbvg+RScuBmZsw/DyyXco= +github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686/go.mod h1:pQ/FSsWSNYmNdgIKmulKlmVC/R2PEpq2vIEi3J9IijI= github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a h1:GQdh/h0q0ni3L//CXusyk+7QdhBL289vdNaes1WKkHI= github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/pkg/dnsutils/net_io.go b/pkg/dnsutils/net_io.go index 1b8a7b1..8fc769b 100644 --- a/pkg/dnsutils/net_io.go +++ b/pkg/dnsutils/net_io.go @@ -23,9 +23,10 @@ import ( "encoding/binary" "errors" "fmt" + "io" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" "github.com/miekg/dns" - "io" ) var ( @@ -35,45 +36,43 @@ var ( // ReadRawMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed // with a two byte length field). // n represents how many bytes are read from c. -// The returned the []byte should be released by pool.ReleaseBuf. -func ReadRawMsgFromTCP(c io.Reader) ([]byte, int, error) { - n := 0 +// The returned the *[]byte should be released by pool.ReleaseBuf. +func ReadRawMsgFromTCP(c io.Reader) (*[]byte, error) { h := pool.GetBuf(2) defer pool.ReleaseBuf(h) - nh, err := io.ReadFull(c, h) - n += nh + _, err := io.ReadFull(c, *h) + if err != nil { - return nil, n, err + return nil, err } // dns length - length := binary.BigEndian.Uint16(h) + length := binary.BigEndian.Uint16(*h) if length == 0 { - return nil, 0, errZeroLenMsg + return nil, errZeroLenMsg } - buf := pool.GetBuf(int(length)) - nm, err := io.ReadFull(c, buf) - n += nm + b := pool.GetBuf(int(length)) + _, err = io.ReadFull(c, *b) if err != nil { - pool.ReleaseBuf(buf) - return nil, n, err + pool.ReleaseBuf(b) + return nil, err } - return buf, n, nil + return b, nil } // ReadMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed // with a two byte length field). // n represents how many bytes are read from c. func ReadMsgFromTCP(c io.Reader) (*dns.Msg, int, error) { - b, n, err := ReadRawMsgFromTCP(c) + b, err := ReadRawMsgFromTCP(c) if err != nil { return nil, 0, err } defer pool.ReleaseBuf(b) - m, err := unpackMsgWithDetailedErr(b) - return m, n, err + m, err := unpackMsgWithDetailedErr(*b) + return m, len(*b) + 2, err } // WriteMsgToTCP packs and writes m to c in RFC 1035 format. @@ -96,9 +95,9 @@ func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) { buf := pool.GetBuf(len(b) + 2) defer pool.ReleaseBuf(buf) - binary.BigEndian.PutUint16(buf[:2], uint16(len(b))) - copy(buf[2:], b) - return c.Write(buf) + binary.BigEndian.PutUint16((*buf)[:2], uint16(len(b))) + copy((*buf)[2:], b) + return c.Write((*buf)) } func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) { @@ -118,12 +117,12 @@ func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) { b := pool.GetBuf(bufSize) defer pool.ReleaseBuf(b) - n, err := c.Read(b) + n, err := c.Read(*b) if err != nil { return nil, n, err } - m, err := unpackMsgWithDetailedErr(b[:n]) + m, err := unpackMsgWithDetailedErr((*b)[:n]) return m, n, err } diff --git a/pkg/pool/allocator.go b/pkg/pool/allocator.go index 0bfc0bc..d8ac4ef 100644 --- a/pkg/pool/allocator.go +++ b/pkg/pool/allocator.go @@ -20,79 +20,11 @@ package pool import ( - "fmt" - "math" - "math/bits" - "sync" + bytesPool "github.com/IrineSistiana/go-bytes-pool" ) // defaultBufPool is an Allocator that has a maximum capacity. -var defaultBufPool = NewAllocator() - -// GetBuf returns a []byte from pool with most appropriate cap. -// It panics if size < 0. -func GetBuf(size int) []byte { - return defaultBufPool.Get(size) -} - -// ReleaseBuf puts the buf to the pool. -func ReleaseBuf(b []byte) { - defaultBufPool.Release(b) -} - -type Allocator struct { - buffers []sync.Pool -} - -// NewAllocator initiates a []byte Allocator. -// The waste(memory fragmentation) of space allocation is guaranteed to be -// no more than 50%. -func NewAllocator() *Allocator { - alloc := &Allocator{ - buffers: make([]sync.Pool, bits.UintSize+1), - } - - for i := range alloc.buffers { - var bufSize uint - if i == bits.UintSize { - bufSize = math.MaxUint - } else { - bufSize = 1 << i - } - alloc.buffers[i].New = func() any { - b := make([]byte, bufSize) - return &b - } - } - return alloc -} - -// Get returns a []byte from pool with most appropriate cap -func (alloc *Allocator) Get(size int) []byte { - if size < 0 { - panic(fmt.Sprintf("invalid slice size %d", size)) - } - - i := shard(size) - v := alloc.buffers[i].Get() - buf := v.(*[]byte) - return (*buf)[0:size] -} - -// Release releases the buf to the allocatorL. -func (alloc *Allocator) Release(buf []byte) { - c := cap(buf) - i := shard(c) - if c == 0 || c != 1<. - */ - -package pool - -import ( - "fmt" - "strconv" - "testing" -) - -func TestAllocator_Get(t *testing.T) { - alloc := NewAllocator() - tests := []struct { - size int - wantCap int - wantPanic bool - }{ - {-1, 0, true}, // invalid - {0, 1, false}, - {1, 1, false}, - {2, 2, false}, - {12, 16, false}, - {256, 256, false}, - {257, 512, false}, - } - for _, tt := range tests { - t.Run(strconv.Itoa(tt.size), func(t *testing.T) { - if tt.wantPanic { - defer func() { - msg := recover() - if msg == nil { - t.Error("no panic") - } - }() - } - - for i := 0; i < 5; i++ { - b := alloc.Get(tt.size) - if len(b) != tt.size { - t.Fatalf("buffer size, want %d, got %d", tt.size, len(b)) - } - if cap(b) != tt.wantCap { - t.Fatalf("buffer cap, want %d, got %d", tt.wantCap, cap(b)) - } - alloc.Release(b) - } - }) - } -} - -func Test_shard(t *testing.T) { - tests := []struct { - size int - want int - }{ - {-1, 0}, - {0, 0}, - {1, 0}, - {2, 1}, - {3, 2}, - {4, 2}, - {5, 3}, - {8, 3}, - {1023, 10}, - {1024, 10}, - {1025, 11}, - } - for _, tt := range tests { - t.Run(strconv.Itoa(tt.size), func(t *testing.T) { - if got := shard(tt.size); got != tt.want { - t.Errorf("shard() = %v, want %v", got, tt.want) - } - }) - } -} - -func Benchmark_Allocator(b *testing.B) { - allocator := NewAllocator() - - for l := 0; l <= 16; l += 4 { - bufLen := 1 << l - b.Run(fmt.Sprintf("length %d", bufLen), func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - buf := allocator.Get(bufLen) - allocator.Release(buf) - } - }) - } -} - -func Benchmark_MakeByteSlice(b *testing.B) { - for l := 0; l <= 8; l++ { - bufLen := 1 << l - b.Run(fmt.Sprintf("length %d", bufLen), func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - _ = make([]byte, bufLen) - } - }) - } -} diff --git a/pkg/pool/msg_buf.go b/pkg/pool/msg_buf.go index 83eea80..980a08b 100644 --- a/pkg/pool/msg_buf.go +++ b/pkg/pool/msg_buf.go @@ -31,9 +31,9 @@ const packBufSize = 4096 // PackBuffer packs the dns msg m to wire format. // Callers should release the buf by calling ReleaseBuf after they have done // with the wire []byte. -func PackBuffer(m *dns.Msg) (wire, buf []byte, err error) { +func PackBuffer(m *dns.Msg) (wire []byte, buf *[]byte, err error) { buf = GetBuf(packBufSize) - wire, err = m.PackBuffer(buf) + wire, err = m.PackBuffer(*buf) if err != nil { ReleaseBuf(buf) return nil, nil, err diff --git a/pkg/pool/msg_buf_test.go b/pkg/pool/msg_buf_test.go index 0685864..02d7348 100644 --- a/pkg/pool/msg_buf_test.go +++ b/pkg/pool/msg_buf_test.go @@ -32,7 +32,7 @@ func TestPackBuffer_No_Allocation(t *testing.T) { t.Fatal(err) } - if cap(wire) != cap(buf) { - t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(buf)) + if cap(wire) != cap(*buf) { + t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(*buf)) } } diff --git a/pkg/server/udp.go b/pkg/server/udp.go index 8980a08..c1c9aa9 100644 --- a/pkg/server/udp.go +++ b/pkg/server/udp.go @@ -73,15 +73,15 @@ func (s *UDPServer) ServeUDP(c *net.UDPConn) error { } for { - n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(rb, ob) + n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob) if err != nil { return fmt.Errorf("unexpected read err: %w", err) } clientAddr := remoteAddr.Addr() q := new(dns.Msg) - if err := q.Unpack(rb[:n]); err != nil { - s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", rb[:n]), zap.Stringer("from", remoteAddr)) + if err := q.Unpack((*rb)[:n]); err != nil { + s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr)) continue } diff --git a/pkg/upstream/transport/dns_conn_test.go b/pkg/upstream/transport/dns_conn_test.go index 8c49131..c677797 100644 --- a/pkg/upstream/transport/dns_conn_test.go +++ b/pkg/upstream/transport/dns_conn_test.go @@ -40,13 +40,13 @@ var ( c1, c2 := net.Pipe() go func() { for { - m, _, readErr := dnsutils.ReadRawMsgFromTCP(c2) + m, readErr := dnsutils.ReadRawMsgFromTCP(c2) if m != nil { go func() { defer pool.ReleaseBuf(m) latency := time.Millisecond * time.Duration(rand.Intn(20)) time.Sleep(latency) - _, _ = dnsutils.WriteRawMsgToTCP(c2, m) + _, _ = dnsutils.WriteRawMsgToTCP(c2, *m) }() } if readErr != nil { diff --git a/plugin/executable/cache/cache.go b/plugin/executable/cache/cache.go index f67d740..4091b50 100644 --- a/plugin/executable/cache/cache.go +++ b/plugin/executable/cache/cache.go @@ -420,27 +420,27 @@ func (c *Cache) readDump(r io.Reader) (int, error) { readBlock := func() error { h := pool.GetBuf(8) defer pool.ReleaseBuf(h) - _, err := io.ReadFull(gr, h) + _, err := io.ReadFull(gr, *h) if err != nil { if errors.Is(err, io.EOF) { return errReadHeaderEOF } return fmt.Errorf("failed to read block header, %w", err) } - u := binary.BigEndian.Uint64(h) + u := binary.BigEndian.Uint64(*h) if u > dumpMaximumBlockLength { return fmt.Errorf("invalid header, block length is big, %d", u) } b := pool.GetBuf(int(u)) defer pool.ReleaseBuf(b) - _, err = io.ReadFull(gr, b) + _, err = io.ReadFull(gr, *b) if err != nil { return fmt.Errorf("failed to read block data, %w", err) } block := new(CacheDumpBlock) - if err := proto.Unmarshal(b, block); err != nil { + if err := proto.Unmarshal(*b, block); err != nil { return fmt.Errorf("failed to decode block data, %w", err) } -- 2.34.8