package dns_api import ( "encoding/base64" "errors" "fmt" "git.gammaspectra.live/givna.me/dns-api/ed25519" "github.com/miekg/dns" "math" "net" "strings" "time" ) type Zone struct { publicKey ed25519.PublicKey privateKey ed25519.PrivateKey dnsKey *dns.DNSKEY baseZone string zoneAliases []string rrset RRSet enforcedEntries RRSet } func NewZoneFromPrivateKey(baseZone string, privateKey ed25519.PrivateKey) *Zone { zone := NewZoneFromPublicKey(baseZone, privateKey.Public().(ed25519.PublicKey)) if zone == nil { return nil } zone.privateKey = privateKey return zone } func NewZoneFromPublicKey(baseZone string, publicKey ed25519.PublicKey) *Zone { if len(baseZone) == 0 || baseZone[len(baseZone)-1] != '.' { return nil } zone := &Zone{ publicKey: publicKey, baseZone: baseZone, } zone.dnsKey = new(dns.DNSKEY) zone.dnsKey.Hdr.Rrtype = dns.TypeDNSKEY zone.dnsKey.Hdr.Name = PublicKeyToOnionV3(publicKey) + "." + baseZone zone.dnsKey.Hdr.Class = dns.ClassINET //public key cannot expire zone.dnsKey.Hdr.Ttl = 2147483647 zone.dnsKey.Flags = 257 //KSK zone.dnsKey.Protocol = 3 zone.dnsKey.Algorithm = dns.ED25519 zone.dnsKey.PublicKey = base64.StdEncoding.EncodeToString(publicKey) return zone } func (z *Zone) DeleteRR(rrtype string, name string) error { rrt, ok := dns.StringToType[rrtype] if !ok { return fmt.Errorf("could not find record type for %s", rrtype) } z.rrset.Delete(rrt, z.GetFullName(name)) return nil } func (z *Zone) GetRR(rrtype string, name string) (rrset RRSet) { rrt, ok := dns.StringToType[rrtype] if !ok { return nil } return z.rrset.Get(rrt, z.GetFullName(name)) } func (z *Zone) IsInZone(n string) bool { return strings.HasSuffix(n, z.Name()) } func (z *Zone) AddRR(rr dns.RR) error { if rr == nil { return errors.New("no valid record found") } if !z.IsInZone(rr.Header().Name) { return fmt.Errorf("%s is not part of %s", rr.Header().Name, z.Name()) } z.rrset = append(z.rrset, rr) return nil } func (z *Zone) GetFullName(name string) string { if name == "@" || name == "" { //$ORIGIN name = z.Name() } else if !z.IsInZone(name) && name[len(name)-1] != '.' { //Lone names name = name + "." + z.Name() } return name } func (z *Zone) AddRecord(name string, ttl uint32, rr dns.RR) error { if rr == nil { return errors.New("no valid record found") } hdr := rr.Header() hdr.Name = z.GetFullName(name) hdr.Class = dns.ClassINET hdr.Ttl = ttl return z.AddRR(rr) } func (z *Zone) AddRecordA(name string, ip net.IP, ttl uint32) error { ip = ip.To4() if ip == nil { return errors.New("invalid ipv4") } record := new(dns.A) record.Hdr.Rrtype = dns.TypeA record.A = ip return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordAAAA(name string, ip net.IP, ttl uint32) error { ip = ip.To16() if ip == nil { return errors.New("invalid ipv6") } record := new(dns.AAAA) record.Hdr.Rrtype = dns.TypeAAAA record.AAAA = ip return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordCNAME(name string, target string, ttl uint32) error { record := new(dns.CNAME) record.Hdr.Rrtype = dns.TypeCNAME record.Target = target return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordNS(name string, ns string, ttl uint32) error { record := new(dns.NS) record.Hdr.Rrtype = dns.TypeNS record.Ns = ns return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordSOA(name string, ns, mbox string, serial, refresh, retry, expire, minttl, ttl uint32) error { record := new(dns.SOA) record.Hdr.Rrtype = dns.TypeSOA record.Ns = ns record.Mbox = mbox record.Serial = serial record.Refresh = refresh record.Retry = retry record.Expire = expire record.Minttl = minttl return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordTXT(name string, txt []string, ttl uint32) error { record := new(dns.TXT) record.Hdr.Rrtype = dns.TypeTXT record.Txt = txt return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordCAA(name string, flag uint8, tag, value string, ttl uint32) error { record := new(dns.CAA) record.Hdr.Rrtype = dns.TypeCAA record.Flag = flag record.Tag = tag record.Value = value return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordTLSA(name string, usage, selector, matchingType uint8, certificate string, ttl uint32) error { record := new(dns.TLSA) record.Hdr.Rrtype = dns.TypeTLSA record.Usage = usage record.Selector = selector record.MatchingType = matchingType record.Certificate = certificate return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordSVCB(name string, priority uint16, target string, value []dns.SVCBKeyValue, ttl uint32) error { record := new(dns.SVCB) record.Hdr.Rrtype = dns.TypeSVCB record.Priority = priority record.Target = target record.Value = value return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordHTTPS(name string, priority uint16, target string, value []dns.SVCBKeyValue, ttl uint32) error { record := new(dns.HTTPS) record.Hdr.Rrtype = dns.TypeHTTPS record.Priority = priority record.Target = target record.Value = value return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordSSHFP(name string, algorithm, ktype uint8, fingerprint string, ttl uint32) error { record := new(dns.SSHFP) record.Hdr.Rrtype = dns.TypeSSHFP record.Algorithm = algorithm record.Type = ktype record.FingerPrint = fingerprint return z.AddRecord(name, ttl, record) } func (z *Zone) AddRecordMX(name string, preference uint16, mx string, ttl uint32) error { record := new(dns.MX) record.Hdr.Rrtype = dns.TypeMX record.Preference = preference record.Mx = mx return z.AddRecord(name, ttl, record) } func (z *Zone) SetRR(rrtype string, name string, rrset RRSet) error { rrt, ok := dns.StringToType[rrtype] if !ok { return fmt.Errorf("could not find record type for %s", rrtype) } z.rrset.Delete(rrt, z.GetFullName(name)) for i, rr := range rrset { if rr != nil && rr.Header().Rrtype != rrt { return fmt.Errorf("expected %s, got record type %s at index %d", rrtype, dns.TypeToString[rr.Header().Rrtype], i) } } for _, rr := range rrset { if rr != nil { if err := z.AddRR(rr); err != nil { return err } } } return nil } func (z *Zone) AddZoneAlias(alias string) { for _, a := range z.zoneAliases { if alias == a { return } } z.zoneAliases = append(z.zoneAliases, alias) } func (z *Zone) GetRRSet() RRSet { return z.rrset.NotNil() } func (z *Zone) SetEnforcedEntries(set RRSet) { z.enforcedEntries = set } func (z *Zone) GetEnforcedEntries() RRSet { return z.enforcedEntries.NotNil() } func (z *Zone) GetGlueRecords() (result RRSet) { //get key records to pin them. Only main Ed25519 KSK is applied result = append(result, z.dnsKey.ToDS(dns.SHA256)) var nameservers []string for _, rr := range z.rrset.Get(dns.TypeNS, z.Name()) { if ns, ok := rr.(*dns.NS); ok { if z.IsInZone(ns.Ns) { nameservers = append(nameservers, ns.Ns) } result = append(result, ns) } } //get A / AAAA records from nameservers under the zone to glue them func() { for _, ns := range nameservers { result = append(result, z.rrset.Get(dns.TypeA, ns)...) result = append(result, z.rrset.Get(dns.TypeAAAA, ns)...) } }() return } func (z *Zone) AddMissingRecords() { func() { for _, rr := range z.rrset.Get(dns.TypeDNSKEY, z.Name()) { if k, ok := rr.(*dns.DNSKEY); ok { if k.Flags == z.dnsKey.Flags && k.PublicKey == z.dnsKey.PublicKey && k.Algorithm == z.dnsKey.Algorithm { return } } } z.AddRR(z.dnsKey) }() func() { for _, rr := range z.rrset.Get(dns.TypeCDNSKEY, z.Name()) { if k, ok := rr.(*dns.CDNSKEY); ok { if k.Flags == z.dnsKey.Flags && k.PublicKey == z.dnsKey.PublicKey && k.Algorithm == z.dnsKey.Algorithm { return } } } cdnskey := dns.CDNSKEY{DNSKEY: *z.dnsKey} cdnskey.Hdr.Rrtype = dns.TypeCDNSKEY z.AddRR(&cdnskey) }() func() { ds := z.dnsKey.ToDS(dns.SHA256) if ds == nil { return } for _, rr := range z.rrset.Get(dns.TypeCDS, z.Name()) { if k, ok := rr.(*dns.CDS); ok { if k.KeyTag == ds.KeyTag && k.Digest == ds.Digest { return } } } cds := dns.CDS{DS: *ds} cds.Hdr.Rrtype = dns.TypeCDS z.AddRR(&cds) }() func() { for _, nameset := range z.enforcedEntries.SplitByName() { for _, typeset := range nameset.SplitByType() { z.rrset.Delete(typeset[0].Header().Rrtype, typeset[0].Header().Name) z.rrset = append(z.rrset, typeset...) } } }() //clear nil z.rrset = z.rrset.NotNil() } var TopLevelCNAMEError = errors.New("top-level CNAME not allowed") func (z *Zone) Sign() (zones []*Zone, err error) { //SOA z.AddMissingRecords() rrset := z.rrset.RemoveTypes(dns.TypeRRSIG, dns.TypeDS, dns.TypeTA) //check if a top-level CNAME exists, comply with RFC 1912 if len(z.rrset.Get(dns.TypeCNAME, z.Name())) > 0 { return nil, TopLevelCNAMEError } newZone, err := signOneZone(z, rrset) if err != nil { return nil, err } zones = append(zones, newZone) for _, alias := range z.zoneAliases { aliasZone := *z aliasZone.zoneAliases = nil aliasZone.baseZone = strings.Join(strings.Split(alias, ".")[1:], ".") aliasZone.dnsKey = dns.Copy(aliasZone.dnsKey).(*dns.DNSKEY) aliasZone.dnsKey.Hdr.Name = alias aliasZone.rrset = aliasZone.rrset.NotNil() fixName := func(n string) string { return strings.TrimSuffix(n, z.Name()) + aliasZone.Name() } for i, rr := range aliasZone.rrset { rcopy := dns.Copy(rr) rcopy.Header().Name = fixName(rcopy.Header().Name) if ns, ok := rcopy.(*dns.NS); ok { ns.Ns = fixName(ns.Ns) } if ns, ok := rcopy.(*dns.SOA); ok { ns.Mbox = fixName(ns.Mbox) ns.Ns = fixName(ns.Ns) } if ns, ok := rcopy.(*dns.CNAME); ok { ns.Target = fixName(ns.Target) } aliasZone.rrset[i] = rcopy } //Add owner record if err = aliasZone.AddRecordTXT("_owner", []string{z.Name()}, 3600*24); err != nil { return nil, err } newZone, err = signOneZone(&aliasZone, aliasZone.rrset) if err != nil { return nil, err } zones = append(zones, newZone) } return zones, nil } func signOneZone(z *Zone, rrset RRSet) (*Zone, error) { signSet := rrset //Check if another ZSK key exists for zones, otherwise behave as CSK var isKSKOnly bool for _, rr := range rrset.Get(dns.TypeDNSKEY, z.Name()) { if k, ok := rr.(*dns.DNSKEY); ok { //found ZSK key that is not ours if k.Flags == 256 && k.PublicKey != z.dnsKey.PublicKey && k.Algorithm == z.dnsKey.Algorithm { isKSKOnly = true break } } } //only sign DNSKEY, CDNSKEY, DS, CDS as KSK if isKSKOnly { signSet = signSet.OnlyTypes(dns.TypeDS, dns.TypeCDS, dns.TypeDNSKEY, dns.TypeCDNSKEY) } zone := *z zone.rrset = make(RRSet, len(rrset)) copy(zone.rrset, rrset) for _, nameset := range signSet.SplitByName() { for _, typeset := range nameset.SplitByType() { if len(typeset) > 0 { rrsig, err := z.SignRRSet(typeset) if err != nil { return nil, err } else { if err = zone.AddRR(rrsig); err != nil { return nil, err } } } } } return &zone, nil } //TODO: implement NSEC or similar func (z *Zone) SignRRSet(rrset RRSet) (*dns.RRSIG, error) { if rrset == nil || len(rrset) == 0 { return nil, errors.New("no records to sign") } if z.privateKey == nil { return nil, errors.New("nil private key") } sig := new(dns.RRSIG) //67 years or max uint32, RFC 4034 points to RFC1982 for max of 68y period expirationTime := time.Now().UTC().AddDate(67, 0, 0).Unix() if expirationTime > math.MaxUint32 { expirationTime = math.MaxUint32 } sig.Expiration = uint32(expirationTime) sig.Inception = uint32(time.Now().UTC().Add(-time.Hour).Unix()) //add one hour in the past sig.KeyTag = z.dnsKey.KeyTag() sig.SignerName = z.dnsKey.Hdr.Name sig.Algorithm = z.dnsKey.Algorithm if err := sig.Sign(z.privateKey, rrset); err != nil { return nil, err } return sig, nil } func (z *Zone) Name() string { return z.dnsKey.Hdr.Name } func (z *Zone) BaseZoneName() string { return z.baseZone } func (z *Zone) DNSKEY() *dns.DNSKEY { return z.dnsKey } func (z *Zone) PubKey() ed25519.PublicKey { return z.publicKey }