From 46d5c95107f80cac7c6b18788a7d5d3f72d0bf7b Mon Sep 17 00:00:00 2001 From: brettlangdon Date: Sat, 23 Jan 2016 09:47:32 -0500 Subject: [PATCH] use registry for looking up records --- registry.go | 134 ++++++++++++++++++++++++++++++++++++++++++++++++++++ server.go | 10 ++-- zone.go | 46 +----------------- 3 files changed, 141 insertions(+), 49 deletions(-) create mode 100644 registry.go diff --git a/registry.go b/registry.go new file mode 100644 index 0000000..5f1b046 --- /dev/null +++ b/registry.go @@ -0,0 +1,134 @@ +package realm + +import "github.com/miekg/dns" + +type RecordsEntry map[uint16][]dns.RR + +func (entry RecordsEntry) GetRecords(rrType uint16) []dns.RR { + var records []dns.RR + records = make([]dns.RR, 0) + + if rrType == dns.TypeANY { + for _, rrs := range entry { + records = append(records, rrs...) + } + } else if rrs, ok := entry[rrType]; ok { + records = append(records, rrs...) + } + + return records +} + +type DomainEntry map[uint16]RecordsEntry + +func (entry DomainEntry) AddEntry(record dns.RR) { + var header *dns.RR_Header + header = record.Header() + + if _, ok := entry[header.Class]; !ok { + entry[header.Class] = make(RecordsEntry) + } + if _, ok := entry[header.Class][header.Rrtype]; !ok { + entry[header.Class][header.Rrtype] = make([]dns.RR, 0) + } + + entry[header.Class][header.Rrtype] = append(entry[header.Class][header.Rrtype], record) +} + +func (entry DomainEntry) GetEntries(rrClass uint16) []RecordsEntry { + var entries []RecordsEntry + entries = make([]RecordsEntry, 0) + + if rrClass == dns.ClassANY { + for _, entry := range entry { + entries = append(entries, entry) + } + } else if entry, ok := entry[rrClass]; ok { + entries = append(entries, entry) + } + + return entries +} + +type Registry struct { + records map[string]DomainEntry +} + +func NewRegistry() *Registry { + return &Registry{ + records: make(map[string]DomainEntry), + } +} + +func (r *Registry) addRecord(record dns.RR) { + var header *dns.RR_Header + header = record.Header() + + var name string + name = dns.Fqdn(header.Name) + + if _, ok := r.records[name]; !ok { + r.records[name] = make(DomainEntry) + } + r.records[name].AddEntry(record) + + // If this record is an SOA record then also store under the Mbox name + if header.Rrtype == dns.TypeSOA { + var soa *dns.SOA + soa = record.(*dns.SOA) + + if _, ok := r.records[soa.Mbox]; !ok { + r.records[soa.Mbox] = make(DomainEntry) + } + r.records[soa.Mbox].AddEntry(record) + } +} + +func (r *Registry) AddZone(z *Zone) { + for _, record := range z.Records() { + r.addRecord(record) + } +} + +// Lookup will find all records which we should respond with for the given name, request type, and request class. +func (r *Registry) Lookup(name string, reqType uint16, reqClass uint16) []dns.RR { + name = dns.Fqdn(name) + var records []dns.RR + records = make([]dns.RR, 0) + + var domainEntry DomainEntry + var ok bool + domainEntry, ok = r.records[name] + if !ok { + return records + } + + var recordEntries []RecordsEntry + recordEntries = domainEntry.GetEntries(reqClass) + + for _, recordEntry := range recordEntries { + var rrs []dns.RR + rrs = recordEntry.GetRecords(reqType) + records = append(records, rrs...) + + if len(rrs) == 0 && reqType == dns.TypeA { + rrs = recordEntry.GetRecords(dns.TypeCNAME) + for _, rr := range rrs { + records = append(records, rr) + var header *dns.RR_Header + header = rr.Header() + if header.Rrtype == dns.TypeCNAME && reqType != dns.TypeCNAME { + // Attempt to resolve this CNAME record + var cname *dns.CNAME + cname = rr.(*dns.CNAME) + var cnameRecords []dns.RR + cnameRecords = r.Lookup(dns.Fqdn(cname.Target), reqType, reqClass) + records = append(records, cnameRecords...) + } + + } + } + + } + return records +} diff --git a/server.go b/server.go index 9454ff7..2b7a647 100644 --- a/server.go +++ b/server.go @@ -4,14 +4,14 @@ import "github.com/miekg/dns" // A Server listens for DNS requests over UDP and responds with answers from the provided Zone. type Server struct { - server *dns.Server - zone *Zone + server *dns.Server + registry *Registry } // NewServer returns a new initialized *Server that will bind to listen and will look up answers from zone. -func NewServer(listen string, zone *Zone) *Server { +func NewServer(listen string, registry *Registry) *Server { var s *Server - s = &Server{zone: zone} + s = &Server{registry: registry} s.server = &dns.Server{ Addr: listen, Net: "udp", @@ -37,7 +37,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, request *dns.Msg) { // Lookup answers to any of the questions for _, question := range request.Question { var records []dns.RR - records = s.zone.Lookup(question.Name, question.Qtype, question.Qclass) + records = s.registry.Lookup(question.Name, question.Qtype, question.Qclass) response.Answer = append(response.Answer, records...) } diff --git a/zone.go b/zone.go index be3fa66..13d0a8e 100644 --- a/zone.go +++ b/zone.go @@ -42,48 +42,6 @@ func ParseZone(filename string) (*Zone, error) { return zone, nil } -// Lookup will find all records which we should respond with for the given name, request type, and request class. -func (zone *Zone) Lookup(name string, reqType uint16, reqClass uint16) []dns.RR { - name = dns.Fqdn(name) - var records []dns.RR - records = make([]dns.RR, 0) - for _, record := range zone.records { - var header *dns.RR_Header - header = record.Header() - - // Skip this record if the class does not match up - if header.Class != reqClass && reqClass != dns.ClassANY { - continue - } - - // If this record is an SOA then check name against Mbox - if header.Rrtype == dns.TypeSOA { - var soa *dns.SOA - soa = record.(*dns.SOA) - if soa.Mbox == name { - records = append(records, soa) - } - } - - // Skip this record if the name does not match - if header.Name != name { - continue - } - - // Collect this record if the types match or this record is a CNAME - if reqType == dns.TypeANY || reqType == header.Rrtype { - records = append(records, record) - } else if header.Rrtype == dns.TypeCNAME { - // Append this CNAME record as a response - records = append(records, record) - - // Attempt to resolve this CNAME record - var cname *dns.CNAME - cname = record.(*dns.CNAME) - var cnameRecords []dns.RR - cnameRecords = zone.Lookup(dns.Fqdn(cname.Target), reqType, reqClass) - records = append(records, cnameRecords...) - } - } - return records +func (zone *Zone) Records() []dns.RR { + return zone.records }