From 588e83131a850a927982ca97d123ed6077174ba9 Mon Sep 17 00:00:00 2001 From: WeebDataHoarder <57538841+WeebDataHoarder@users.noreply.github.com> Date: Wed, 15 Jun 2022 11:23:55 +0200 Subject: [PATCH] Added zone alias support, glue records --- dns.go | 276 ++++++++++++++++++++++++++++++++++------------------ dns_test.go | 57 +++++++++-- rr.go | 121 +++++++++++++++++++++++ 3 files changed, 353 insertions(+), 101 deletions(-) create mode 100644 rr.go diff --git a/dns.go b/dns.go index d9e1790..ae78bed 100644 --- a/dns.go +++ b/dns.go @@ -12,32 +12,16 @@ import ( "time" ) -type RRSet []dns.RR - -func (l RRSet) Get(rrt uint16, name string) (rrset RRSet) { - for _, rr := range l { - if rr != nil && rr.Header().Rrtype == rrt && rr.Header().Name == name { - rrset = append(rrset, rr) - } - } - - return -} - -func (l RRSet) Delete(rrt uint16, name string) { - for i, rr := range l { - if rr != nil && rr.Header().Rrtype == rrt && rr.Header().Name == name { - l[i] = nil - } - } -} - type Zone struct { publicKey ed25519.PublicKey privateKey ed25519.PrivateKey dnsKey *dns.DNSKEY + baseZone string + + zoneAliases []string + rrset RRSet enforcedEntries RRSet @@ -61,6 +45,7 @@ func NewZoneFromPublicKey(baseZone string, publicKey ed25519.PublicKey) *Zone { zone := &Zone{ publicKey: publicKey, + baseZone: baseZone, } zone.dnsKey = new(dns.DNSKEY) @@ -83,7 +68,7 @@ func (z *Zone) DeleteRR(rrtype string, name string) error { return fmt.Errorf("could not find record type for %s", rrtype) } - z.rrset.Delete(rrt, name) + z.rrset.Delete(rrt, z.GetFullName(name)) return nil } @@ -93,7 +78,11 @@ func (z *Zone) GetRR(rrtype string, name string) (rrset RRSet) { return nil } - return z.rrset.Get(rrt, name) + return z.rrset.Get(rrt, z.GetFullName(name)) +} + +func (z *Zone) IsInZone(n string) bool { + return strings.HasSuffix(n, z.Name()) } func (z *Zone) AddRR(rr dns.RR) error { @@ -101,7 +90,7 @@ func (z *Zone) AddRR(rr dns.RR) error { return errors.New("no valid record found") } - if !strings.HasSuffix(rr.Header().Name, z.Name()) { + if !z.IsInZone(rr.Header().Name) { return fmt.Errorf("%s is not part of %s", rr.Header().Name, z.Name()) } @@ -110,12 +99,24 @@ func (z *Zone) AddRR(rr dns.RR) error { return nil } +func (z *Zone) GetFullName(name string) string { + if name == "@" || name == "" { + //$ORIGIN + name = z.Name() + } else if !z.IsInZone(name) && name[len(name)-1] != '.' { + //Lone names + name = name + "." + z.Name() + } + return name +} + func (z *Zone) AddRecord(name string, ttl uint32, rr dns.RR) error { if rr == nil { return errors.New("no valid record found") } + hdr := rr.Header() - hdr.Name = name + hdr.Name = z.GetFullName(name) hdr.Class = dns.ClassINET hdr.Ttl = ttl @@ -153,6 +154,26 @@ func (z *Zone) AddRecordCNAME(name string, target string, ttl uint32) error { return z.AddRecord(name, ttl, record) } +func (z *Zone) AddRecordNS(name string, ns string, ttl uint32) error { + record := new(dns.NS) + record.Hdr.Rrtype = dns.TypeNS + record.Ns = ns + return z.AddRecord(name, ttl, record) +} + +func (z *Zone) AddRecordSOA(name string, ns, mbox string, serial, refresh, retry, expire, minttl, ttl uint32) error { + record := new(dns.SOA) + record.Hdr.Rrtype = dns.TypeSOA + record.Ns = ns + record.Mbox = mbox + record.Serial = serial + record.Refresh = refresh + record.Retry = retry + record.Expire = expire + record.Minttl = minttl + return z.AddRecord(name, ttl, record) +} + func (z *Zone) AddRecordTXT(name string, txt []string, ttl uint32) error { record := new(dns.TXT) record.Hdr.Rrtype = dns.TypeTXT @@ -221,7 +242,7 @@ func (z *Zone) SetRR(rrtype string, name string, rrset RRSet) error { return fmt.Errorf("could not find record type for %s", rrtype) } - z.rrset.Delete(rrt, name) + z.rrset.Delete(rrt, z.GetFullName(name)) for i, rr := range rrset { if rr != nil && rr.Header().Rrtype != rrt { @@ -240,15 +261,53 @@ func (z *Zone) SetRR(rrtype string, name string, rrset RRSet) error { return nil } -func (z *Zone) GetRRSet() RRSet { - - set := make(RRSet, 0, len(z.rrset)) - for _, rr := range z.rrset { - if rr != nil { - set = append(set, rr) +func (z *Zone) AddZoneAlias(alias string) { + for _, a := range z.zoneAliases { + if alias == a { + return } } - return set + z.zoneAliases = append(z.zoneAliases, alias) +} + +func (z *Zone) GetRRSet() RRSet { + return z.rrset.NotNil() +} + +func (z *Zone) SetEnforcedEntries(set RRSet) { + z.enforcedEntries = set +} + +func (z *Zone) GetEnforcedEntries() RRSet { + return z.enforcedEntries.NotNil() +} + +func (z *Zone) GetGlueRecords() (result RRSet) { + + //get key records to pin them. Only main Ed25519 KSK is applied + result = append(result, z.dnsKey.ToDS(dns.SHA256)) + + var nameservers []string + + for _, rr := range z.rrset.Get(dns.TypeNS, z.Name()) { + if ns, ok := rr.(*dns.NS); ok { + if z.IsInZone(ns.Ns) { + nameservers = append(nameservers, ns.Ns) + } + + result = append(result, ns) + } + } + + //get A / AAAA records from nameservers under the zone to glue them + func() { + for _, ns := range nameservers { + result = append(result, z.rrset.Get(dns.TypeA, ns)...) + result = append(result, z.rrset.Get(dns.TypeAAAA, ns)...) + } + }() + + return } func (z *Zone) AddMissingRecords() { @@ -293,73 +352,93 @@ func (z *Zone) AddMissingRecords() { cds.Hdr.Rrtype = dns.TypeCDS z.AddRR(&cds) }() + + func() { + for _, nameset := range z.enforcedEntries.SplitByName() { + for _, typeset := range nameset.SplitByType() { + z.rrset.Delete(typeset[0].Header().Rrtype, typeset[0].Header().Name) + z.rrset = append(z.rrset, typeset...) + } + } + }() + + //clear nil + z.rrset = z.rrset.NotNil() } -func (z *Zone) Sign() (rrsigs []*dns.RRSIG, err error) { +var TopLevelCNAMEError = errors.New("top-level CNAME not allowed") - var types []uint16 - var names []string - - inTypes := func(i uint16) bool { - if i == dns.TypeRRSIG || - i == dns.TypeDS || - i == dns.TypeTA || - i == dns.TypeSOA || - i == dns.TypeRP { - return true - } - for _, j := range types { - if i == j { - return true - } - } - return false - } - - inNames := func(i string) bool { - for _, j := range names { - if i == j { - return true - } - } - return false - } +func (z *Zone) Sign() (zones []*Zone, err error) { + //SOA z.AddMissingRecords() + rrset := z.rrset.RemoveTypes(dns.TypeRRSIG, dns.TypeDS, dns.TypeTA) - for _, rr := range z.rrset { - if rr == nil { - continue - } - h := rr.Header() - if !inTypes(h.Rrtype) { - types = append(types, h.Rrtype) - } - if !inNames(h.Name) { - names = append(names, h.Name) - } + //check if a top-level CNAME exists, comply with RFC 1912 + if len(z.rrset.Get(dns.TypeCNAME, z.Name())) > 0 { + return nil, TopLevelCNAMEError } - //TODO: do check, remove entries - /* - //TODO: make this an allowlist - if rrt == dns.TypeRRSIG || - rrt == dns.TypeDNSKEY || - rrt == dns.TypeCDNSKEY || - rrt == dns.TypeDS || - rrt == dns.TypeCDS || - rrt == dns.TypeNS || - rrt == dns.TypeTA || - rrt == dns.TypeRP || - rrt == dns.TypeSOA { //TODO: allow changing some but with verification of values - return fmt.Errorf("type %s not allowed", rrtype) + newZone, err := signOneZone(z, rrset) + if err != nil { + return nil, err + } + zones = append(zones, newZone) + + for _, alias := range z.zoneAliases { + + aliasZone := *z + aliasZone.zoneAliases = nil + aliasZone.baseZone = strings.Join(strings.Split(alias, ".")[1:], ".") + aliasZone.dnsKey = dns.Copy(aliasZone.dnsKey).(*dns.DNSKEY) + aliasZone.dnsKey.Hdr.Name = alias + aliasZone.rrset = aliasZone.rrset.NotNil() + + fixName := func(n string) string { + return strings.TrimSuffix(n, z.Name()) + aliasZone.Name() } - */ + + for i, rr := range aliasZone.rrset { + rcopy := dns.Copy(rr) + rcopy.Header().Name = fixName(rcopy.Header().Name) + + if ns, ok := rcopy.(*dns.NS); ok { + ns.Ns = fixName(ns.Ns) + } + if ns, ok := rcopy.(*dns.SOA); ok { + ns.Mbox = fixName(ns.Mbox) + ns.Ns = fixName(ns.Ns) + } + if ns, ok := rcopy.(*dns.CNAME); ok { + ns.Target = fixName(ns.Target) + } + + aliasZone.rrset[i] = rcopy + } + + //Add owner record + if err = aliasZone.AddRecordTXT("_owner", []string{z.Name()}, 3600*24); err != nil { + return nil, err + } + + newZone, err = signOneZone(&aliasZone, aliasZone.rrset) + if err != nil { + return nil, err + } + zones = append(zones, newZone) + + } + + return zones, nil +} + +func signOneZone(z *Zone, rrset RRSet) (*Zone, error) { + signSet := rrset //Check if another ZSK key exists for zones, otherwise behave as CSK var isKSKOnly bool - for _, rr := range z.rrset.Get(dns.TypeDNSKEY, z.Name()) { + for _, rr := range rrset.Get(dns.TypeDNSKEY, z.Name()) { if k, ok := rr.(*dns.DNSKEY); ok { //found ZSK key that is not ours if k.Flags == 256 && k.PublicKey != z.dnsKey.PublicKey && k.Algorithm == z.dnsKey.Algorithm { @@ -371,24 +450,29 @@ func (z *Zone) Sign() (rrsigs []*dns.RRSIG, err error) { //only sign DNSKEY, CDNSKEY, DS, CDS as KSK if isKSKOnly { - types = []uint16{dns.TypeDS, dns.TypeCDS, dns.TypeDNSKEY, dns.TypeCDNSKEY} + signSet = signSet.OnlyTypes(dns.TypeDS, dns.TypeCDS, dns.TypeDNSKEY, dns.TypeCDNSKEY) } - for _, rname := range names { - for _, rtype := range types { - rrset := z.rrset.Get(rtype, rname) - if len(rrset) > 0 { - rrsig, err := z.SignRRSet(rrset) + zone := *z + zone.rrset = make(RRSet, len(rrset)) + copy(zone.rrset, rrset) + + for _, nameset := range signSet.SplitByName() { + for _, typeset := range nameset.SplitByType() { + if len(typeset) > 0 { + rrsig, err := z.SignRRSet(typeset) if err != nil { return nil, err } else { - rrsigs = append(rrsigs, rrsig) + if err = zone.AddRR(rrsig); err != nil { + return nil, err + } } } } } - return rrsigs, nil + return &zone, nil } //TODO: implement NSEC or similar @@ -426,6 +510,10 @@ func (z *Zone) Name() string { return z.dnsKey.Hdr.Name } +func (z *Zone) BaseZoneName() string { + return z.baseZone +} + func (z *Zone) DNSKEY() *dns.DNSKEY { return z.dnsKey } diff --git a/dns_test.go b/dns_test.go index 4a86b65..763fea3 100644 --- a/dns_test.go +++ b/dns_test.go @@ -12,21 +12,64 @@ func TestRecordSign(t *testing.T) { zone := NewZoneFromPrivateKey(testZone, DecodeTorPrivateKey(testPrivateKey)) t.Logf("zone name %s", zone.Name()) - if err := zone.AddRecordA("test."+zone.Name(), net.IPv4(1, 2, 3, 4), 3600); err != nil { + if err := zone.AddRecordA("test", net.IPv4(1, 2, 3, 4), 3600); err != nil { + t.Error(err) + } + if err := zone.AddRecordNS("@", "ns1."+zone.Name(), 3600*24); err != nil { + t.Error(err) + } + if err := zone.AddRecordNS("@", "ns2."+zone.Name(), 3600*24); err != nil { + t.Error(err) + } + if err := zone.AddRecordA("ns1", net.IPv4(1, 1, 1, 1), 3600*24); err != nil { + t.Error(err) + } + if err := zone.AddRecordA("ns2", net.IPv4(2, 2, 2, 2), 3600*24); err != nil { t.Error(err) } - rrsigs, err := zone.Sign() + enforcedZone := NewZoneFromPublicKey(zone.BaseZoneName(), zone.PubKey()) + if err := enforcedZone.AddRecordSOA("@", "ns1."+zone.Name(), "mail."+zone.Name(), 31337, 10800, 3600, 604800, 3600, 3600); err != nil { + t.Error(err) + } + + zone.SetEnforcedEntries(enforcedZone.GetRRSet()) + + zone.AddZoneAlias("testalias." + zone.BaseZoneName()) + + zones, err := zone.Sign() if err != nil { t.Error(err) } - for _, rr := range zone.GetRRSet() { - t.Logf(" %s\n", rr.String()) - } + for _, newZone := range zones { + t.Logf("====== ZONE %s ======", newZone.Name()) + for _, rr := range newZone.GetRRSet() { + t.Logf(" %s\n", rr.String()) + } - for _, rr := range rrsigs { - t.Logf(" %s\n", rr.String()) + t.Logf("auto-generated GLUE records:") + for _, rr := range newZone.GetGlueRecords() { + t.Logf("GLUE: %s\n", rr.String()) + } + } +} + +func TestRecordTopLevelCNAME(t *testing.T) { + + zone := NewZoneFromPrivateKey(testZone, DecodeTorPrivateKey(testPrivateKey)) + t.Logf("zone name %s", zone.Name()) + + if err := zone.AddRecordCNAME("@", "example.com.", 3600); err != nil { + t.Error(err) + } + + _, err := zone.Sign() + + if err == nil { + t.Error("expected CNAME error") + } else if err != TopLevelCNAMEError { + t.Error(err) } } diff --git a/rr.go b/rr.go new file mode 100644 index 0000000..3d57026 --- /dev/null +++ b/rr.go @@ -0,0 +1,121 @@ +package dns_api + +import "github.com/miekg/dns" + +type RRSet []dns.RR + +func (s RRSet) Get(rrt uint16, name string) (rrset RRSet) { + for _, rr := range s { + if rr != nil && (rr.Header().Rrtype == rrt || rrt == dns.TypeANY) && (rr.Header().Name == name || name == "") { + rrset = append(rrset, rr) + } + } + + return +} + +func (s RRSet) Delete(rrt uint16, name string) { + for i, rr := range s { + if rr != nil && rr.Header().Rrtype == rrt && rr.Header().Name == name { + s[i] = nil + } + } +} + +func (s RRSet) SplitByType() (result []RRSet) { + + var types []uint16 + + inTypes := func(i uint16) bool { + for _, j := range types { + if i == j { + return true + } + } + return false + } + + for _, rr := range s { + if rr != nil && !inTypes(rr.Header().Rrtype) { + result = append(result, s.Get(rr.Header().Rrtype, "")) + types = append(types, rr.Header().Rrtype) + } + } + + return +} + +func (s RRSet) SplitByName() (result []RRSet) { + var names []string + + inNames := func(i string) bool { + for _, j := range names { + if i == j { + return true + } + } + return false + } + + for _, rr := range s { + if rr != nil && !inNames(rr.Header().Name) { + result = append(result, s.Get(dns.TypeANY, rr.Header().Name)) + names = append(names, rr.Header().Name) + } + } + + return +} + +func (s RRSet) NotNil() (result RRSet) { + result = make(RRSet, 0, len(s)) + for _, rr := range s { + if rr != nil { + result = append(result, rr) + } + } + + return +} + +func (s RRSet) RemoveTypes(disallow ...uint16) (result RRSet) { + + inList := func(i uint16) bool { + + for _, j := range disallow { + if i == j { + return true + } + } + return false + } + + for _, rr := range s { + if rr != nil && !inList(rr.Header().Rrtype) { + result = append(result, rr) + } + } + + return +} + +func (s RRSet) OnlyTypes(allow ...uint16) (result RRSet) { + + inList := func(i uint16) bool { + + for _, j := range allow { + if i == j { + return true + } + } + return false + } + + for _, rr := range s { + if rr != nil && inList(rr.Header().Rrtype) { + result = append(result, rr) + } + } + + return +}