From f81a617d6bc3ad05bd9c9edd343083f4b4c09cd4 Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Sat, 23 Sep 2023 08:28:10 +0800 Subject: [PATCH 8/9] transport: fixed eol pipelineConn was't closed when calling PipelineTransport.Close() --- pkg/upstream/transport/pipeline.go | 54 ++++++++++++++----------- pkg/upstream/transport/pipeline_test.go | 2 +- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/pkg/upstream/transport/pipeline.go b/pkg/upstream/transport/pipeline.go index 10c3d23..70f9c8d 100644 --- a/pkg/upstream/transport/pipeline.go +++ b/pkg/upstream/transport/pipeline.go @@ -33,10 +33,11 @@ import ( type PipelineTransport struct { PipelineOpts - m sync.Mutex // protect following fields - closed bool - r *rand.Rand - conns []*pipelineConn + m sync.Mutex // protect following fields + closed bool + r *rand.Rand + activeConns []*pipelineConn + conns map[*pipelineConn]struct{} } type PipelineOpts struct { @@ -66,6 +67,7 @@ func NewPipelineTransport(opt PipelineOpts) *PipelineTransport { return &PipelineTransport{ PipelineOpts: opt, r: rand.New(rand.NewSource(time.Now().Unix())), + conns: make(map[*pipelineConn]struct{}), } } @@ -73,13 +75,13 @@ func (t *PipelineTransport) ExchangeContext(ctx context.Context, m *dns.Msg) (*d const maxAttempt = 3 attempt := 0 for { - conn, allocatedQid, isNewConn, wg, err := t.getPipelineConn() + pc, allocatedQid, isNewConn, err := t.getPipelineConn() if err != nil { return nil, err } - r, err := conn.exchangePipeline(ctx, m, allocatedQid) - wg.Done() + r, err := pc.dc.exchangePipeline(ctx, m, allocatedQid) + pc.wg.Done() if err != nil { // Reused connection may not stable. @@ -103,7 +105,7 @@ func (t *PipelineTransport) Close() error { return nil } t.closed = true - for _, conn := range t.conns { + for conn := range t.conns { conn.dc.closeWithErr(errClosedTransport) } return nil @@ -113,10 +115,9 @@ func (t *PipelineTransport) Close() error { // Caller must call wg.Done() after dnsConn.exchangePipeline. // The returned dnsConn is ready to serve queries. func (t *PipelineTransport) getPipelineConn() ( - dc *dnsConn, + pc *pipelineConn, allocatedQid uint16, isNewConn bool, - wg *sync.WaitGroup, err error, ) { t.m.Lock() @@ -128,21 +129,19 @@ func (t *PipelineTransport) getPipelineConn() ( pci, pc := t.pickPipelineConnLocked() - // Dail a new connection if (conn pool is empty), or - // (the picked conn is busy, and we are allowed to dail more connections). + // Dial a new connection if (conn pool is empty), or + // (the picked conn is busy, and we are allowed to dial more connections). maxConn := t.MaxConn if maxConn <= 0 { maxConn = defaultPipelineMaxConns } - if pc == nil || (pc.dc.queueLen() > pipelineBusyQueueLen && len(t.conns) < maxConn) { - dc = newDnsConn(t.IOOpts) + if pc == nil || (pc.dc.queueLen() > pipelineBusyQueueLen && len(t.activeConns) < maxConn) { + dc := newDnsConn(t.IOOpts) pc = newPipelineConn(dc) isNewConn = true - pci = sliceAdd(&t.conns, pc) - } else { - dc = pc.dc + pci = sliceAdd(&t.activeConns, pc) + t.conns[pc] = struct{}{} } - wg = &pc.wg pc.wg.Add(1) pc.servedLocked++ @@ -152,13 +151,20 @@ func (t *PipelineTransport) getPipelineConn() ( // This connection has served too many queries. // Note: the connection should be closed only after all its queries finished. // We can't close it here. Some queries may still on that connection. - sliceDel(&t.conns, pci) + sliceDel(&t.activeConns, pci) // remove from active conns + } + t.m.Unlock() + + if eol { + // Cleanup when all queries is finished. go func() { - wg.Wait() - dc.closeWithErr(errEOL) + pc.wg.Wait() + pc.dc.closeWithErr(errEOL) + t.m.Lock() + delete(t.conns, pc) + t.m.Unlock() }() } - t.m.Unlock() return } @@ -167,9 +173,9 @@ func (t *PipelineTransport) getPipelineConn() ( // Require holding PipelineTransport.m. func (t *PipelineTransport) pickPipelineConnLocked() (int, *pipelineConn) { for { - pci, pc := sliceRandGet(t.conns, t.r) + pci, pc := sliceRandGet(t.activeConns, t.r) if pc != nil && pc.dc.isClosed() { // closed conn, delete it and retry - sliceDel(&t.conns, pci) + sliceDel(&t.activeConns, pci) continue } return pci, pc // conn pool is empty or we got a pc diff --git a/pkg/upstream/transport/pipeline_test.go b/pkg/upstream/transport/pipeline_test.go index 653d779..c595288 100644 --- a/pkg/upstream/transport/pipeline_test.go +++ b/pkg/upstream/transport/pipeline_test.go @@ -86,7 +86,7 @@ func testPipelineTransport(t *testing.T, ioOpts IOOpts) { wg.Wait() pt.m.Lock() - pl := len(pt.conns) + pl := len(pt.activeConns) pt.m.Unlock() if pl > po.MaxConn { t.Fatalf("max %d active conn, but got %d active conn(s)", po.MaxConn, pl) -- 2.34.8