From 65bf1f77a56fe481cacf3a1cada155b66949578f Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Fri, 22 Sep 2023 16:10:24 +0800 Subject: [PATCH 5/9] query_context: add QueryMeta --- pkg/query_context/client_addr.go | 38 ------------------- pkg/query_context/context.go | 23 ++++++++--- pkg/server_handler/entry_handler.go | 5 +-- .../dual_selector/dual_selector_test.go | 9 +++-- plugin/executable/ipset/ipset_test.go | 11 +++--- plugin/executable/sequence/sequence_test.go | 5 ++- plugin/matcher/client_ip/client_ip_matcher.go | 4 +- 7 files changed, 34 insertions(+), 61 deletions(-) delete mode 100644 pkg/query_context/client_addr.go diff --git a/pkg/query_context/client_addr.go b/pkg/query_context/client_addr.go deleted file mode 100644 index 7793fe6..0000000 --- a/pkg/query_context/client_addr.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (C) 2020-2022, IrineSistiana - * - * This file is part of mosdns. - * - * mosdns is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * mosdns is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package query_context - -import ( - "net/netip" -) - -var clientAddrKey = RegKey() - -func SetClientAddr(qCtx *Context, addr *netip.Addr) { - qCtx.StoreValue(clientAddrKey, addr) -} - -func GetClientAddr(qCtx *Context) (*netip.Addr, bool) { - v, ok := qCtx.GetValue(clientAddrKey) - if !ok { - return nil, false - } - return v.(*netip.Addr), true -} diff --git a/pkg/query_context/context.go b/pkg/query_context/context.go index d3e67ae..9fa3fd7 100644 --- a/pkg/query_context/context.go +++ b/pkg/query_context/context.go @@ -20,11 +20,13 @@ package query_context import ( + "sync/atomic" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/server" "github.com/miekg/dns" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "sync/atomic" - "time" ) // Context is a query context that pass through plugins @@ -34,6 +36,7 @@ import ( type Context struct { startTime time.Time // when was this Context created q *dns.Msg + queryMeta QueryMeta // id for this Context. Not for the dns query. This id is mainly for logging. id uint32 @@ -48,14 +51,17 @@ type Context struct { var contextUid atomic.Uint32 +type QueryMeta = server.QueryMeta + // NewContext creates a new query Context. // q is the query dns msg. It cannot be nil, or NewContext will panic. -func NewContext(q *dns.Msg) *Context { +func NewContext(q *dns.Msg, qm QueryMeta) *Context { if q == nil { panic("handler: query msg is nil") } ctx := &Context{ q: q, + queryMeta: qm, id: contextUid.Add(1), startTime: time.Now(), } @@ -68,6 +74,11 @@ func (ctx *Context) Q() *dns.Msg { return ctx.q } +// QueryMeta returns the meta data of the query. +func (ctx *Context) QueryMeta() QueryMeta { + return ctx.queryMeta +} + // R returns the response. It might be nil. func (ctx *Context) R() *dns.Msg { return ctx.r @@ -164,8 +175,8 @@ func (ctx *Context) DeleteMark(m uint32) { func (ctx *Context) MarshalLogObject(encoder zapcore.ObjectEncoder) error { encoder.AddUint32("uqid", ctx.id) - if addr, ok := GetClientAddr(ctx); ok && addr.IsValid() { - zap.Stringer("client", addr).AddTo(encoder) + if clientAddr := ctx.queryMeta.ClientAddr; clientAddr.IsValid() { + zap.Stringer("client", clientAddr).AddTo(encoder) } q := ctx.Q() @@ -180,7 +191,7 @@ func (ctx *Context) MarshalLogObject(encoder zapcore.ObjectEncoder) error { if r := ctx.R(); r != nil { encoder.AddInt("rcode", r.Rcode) } - encoder.AddDuration("elapsed", time.Now().Sub(ctx.StartTime())) + encoder.AddDuration("elapsed", time.Since(ctx.StartTime())) return nil } diff --git a/pkg/server_handler/entry_handler.go b/pkg/server_handler/entry_handler.go index 9e3a386..c12d852 100644 --- a/pkg/server_handler/entry_handler.go +++ b/pkg/server_handler/entry_handler.go @@ -79,10 +79,7 @@ func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, qInfo server.Quer defer cancel() // exec entry - qCtx := query_context.NewContext(q) - if qInfo.ClientAddr.IsValid() { - query_context.SetClientAddr(qCtx, &qInfo.ClientAddr) - } + qCtx := query_context.NewContext(q, qInfo) err := h.opts.Entry.Exec(ctx, qCtx) respMsg := qCtx.R() if err != nil { diff --git a/plugin/executable/dual_selector/dual_selector_test.go b/plugin/executable/dual_selector/dual_selector_test.go index 6a5ae92..524e739 100644 --- a/plugin/executable/dual_selector/dual_selector_test.go +++ b/plugin/executable/dual_selector/dual_selector_test.go @@ -21,14 +21,15 @@ package dual_selector import ( "context" + "net" + "testing" + "time" + "github.com/IrineSistiana/mosdns/v5/coremain" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" "github.com/miekg/dns" "go.uber.org/zap" - "net" - "testing" - "time" ) type dummyNext struct { @@ -158,7 +159,7 @@ func TestSelector_Exec(t *testing.T) { q := new(dns.Msg) q.SetQuestion("example.", tt.qtype) - qCtx := query_context.NewContext(q) + qCtx := query_context.NewContext(q, query_context.QueryMeta{}) cw := sequence.NewChainWalker([]*sequence.ChainNode{{E: tt.next}}, nil) if err := s.Exec(context.Background(), qCtx, cw); (err != nil) != tt.wantErr { t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr) diff --git a/plugin/executable/ipset/ipset_test.go b/plugin/executable/ipset/ipset_test.go index cb92eb2..c5ad508 100644 --- a/plugin/executable/ipset/ipset_test.go +++ b/plugin/executable/ipset/ipset_test.go @@ -24,15 +24,16 @@ package ipset import ( "context" "fmt" - "github.com/IrineSistiana/mosdns/v5/pkg/query_context" - "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" - "github.com/miekg/dns" - "github.com/vishvananda/netlink" "math/rand" "net" "os" "strconv" "testing" + + "github.com/IrineSistiana/mosdns/v5/pkg/query_context" + "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" + "github.com/miekg/dns" + "github.com/vishvananda/netlink" ) func skipTest(t *testing.T) { @@ -85,7 +86,7 @@ func Test_ipset(t *testing.T) { r.Answer = append(r.Answer, &dns.A{A: net.ParseIP("127.0.0.2")}) r.Answer = append(r.Answer, &dns.AAAA{AAAA: net.ParseIP("::1")}) r.Answer = append(r.Answer, &dns.AAAA{AAAA: net.ParseIP("::2")}) - qCtx := query_context.NewContext(q) + qCtx := query_context.NewContext(q, query_context.QueryMeta{}) qCtx.SetResponse(r) if err := p.Exec(context.Background(), qCtx); err != nil { t.Fatal(err) diff --git a/plugin/executable/sequence/sequence_test.go b/plugin/executable/sequence/sequence_test.go index ea7704d..16b1360 100644 --- a/plugin/executable/sequence/sequence_test.go +++ b/plugin/executable/sequence/sequence_test.go @@ -22,10 +22,11 @@ package sequence import ( "context" "errors" + "testing" + "github.com/IrineSistiana/mosdns/v5/coremain" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/miekg/dns" - "testing" ) type dummy struct { @@ -186,7 +187,7 @@ func Test_sequence_Exec(t *testing.T) { if err != nil { t.Fatal(err) } - qCtx := query_context.NewContext(new(dns.Msg)) + qCtx := query_context.NewContext(new(dns.Msg), query_context.QueryMeta{}) if err := s.Exec(context.Background(), qCtx); (err != nil) != tt.wantErr { t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/plugin/matcher/client_ip/client_ip_matcher.go b/plugin/matcher/client_ip/client_ip_matcher.go index 357df9b..b308b5d 100644 --- a/plugin/matcher/client_ip/client_ip_matcher.go +++ b/plugin/matcher/client_ip/client_ip_matcher.go @@ -39,9 +39,9 @@ func QuickSetup(bq sequence.BQ, s string) (sequence.Matcher, error) { } func matchClientAddr(qCtx *query_context.Context, m netlist.Matcher) (bool, error) { - addr, _ := query_context.GetClientAddr(qCtx) + addr := qCtx.QueryMeta().ClientAddr if !addr.IsValid() { return false, nil } - return m.Match(*addr), nil + return m.Match(addr), nil } -- 2.34.8