dns-api/dns.go
DataHoarder e258da67dd
All checks were successful
continuous-integration/drone/push Build is passing
Added DNS zones, DNSSEC signing
2022-06-14 20:20:13 +02:00

436 lines
9.5 KiB
Go

package dns_api
import (
"encoding/base64"
"errors"
"fmt"
"git.gammaspectra.live/givna.me/dns-api/ed25519"
"github.com/miekg/dns"
"math"
"net"
"strings"
"time"
)
type RRSet []dns.RR
func (l RRSet) Get(rrt uint16, name string) (rrset RRSet) {
for _, rr := range l {
if rr != nil && rr.Header().Rrtype == rrt && rr.Header().Name == name {
rrset = append(rrset, rr)
}
}
return
}
func (l RRSet) Delete(rrt uint16, name string) {
for i, rr := range l {
if rr != nil && rr.Header().Rrtype == rrt && rr.Header().Name == name {
l[i] = nil
}
}
}
type Zone struct {
publicKey ed25519.PublicKey
privateKey ed25519.PrivateKey
dnsKey *dns.DNSKEY
rrset RRSet
enforcedEntries RRSet
}
func NewZoneFromPrivateKey(baseZone string, privateKey ed25519.PrivateKey) *Zone {
zone := NewZoneFromPublicKey(baseZone, privateKey.Public().(ed25519.PublicKey))
if zone == nil {
return nil
}
zone.privateKey = privateKey
return zone
}
func NewZoneFromPublicKey(baseZone string, publicKey ed25519.PublicKey) *Zone {
if len(baseZone) == 0 || baseZone[len(baseZone)-1] != '.' {
return nil
}
zone := &Zone{
publicKey: publicKey,
}
zone.dnsKey = new(dns.DNSKEY)
zone.dnsKey.Hdr.Rrtype = dns.TypeDNSKEY
zone.dnsKey.Hdr.Name = PublicKeyToOnionV3(publicKey) + "." + baseZone
zone.dnsKey.Hdr.Class = dns.ClassINET
//public key cannot expire
zone.dnsKey.Hdr.Ttl = 2147483647
zone.dnsKey.Flags = 257 //KSK
zone.dnsKey.Protocol = 3
zone.dnsKey.Algorithm = dns.ED25519
zone.dnsKey.PublicKey = base64.StdEncoding.EncodeToString(publicKey)
return zone
}
func (z *Zone) DeleteRR(rrtype string, name string) error {
rrt, ok := dns.StringToType[rrtype]
if !ok {
return fmt.Errorf("could not find record type for %s", rrtype)
}
z.rrset.Delete(rrt, name)
return nil
}
func (z *Zone) GetRR(rrtype string, name string) (rrset RRSet) {
rrt, ok := dns.StringToType[rrtype]
if !ok {
return nil
}
return z.rrset.Get(rrt, name)
}
func (z *Zone) AddRR(rr dns.RR) error {
if rr == nil {
return errors.New("no valid record found")
}
if !strings.HasSuffix(rr.Header().Name, z.Name()) {
return fmt.Errorf("%s is not part of %s", rr.Header().Name, z.Name())
}
z.rrset = append(z.rrset, rr)
return nil
}
func (z *Zone) AddRecord(name string, ttl uint32, rr dns.RR) error {
if rr == nil {
return errors.New("no valid record found")
}
hdr := rr.Header()
hdr.Name = name
hdr.Class = dns.ClassINET
hdr.Ttl = ttl
return z.AddRR(rr)
}
func (z *Zone) AddRecordA(name string, ip net.IP, ttl uint32) error {
ip = ip.To4()
if ip == nil {
return errors.New("invalid ipv4")
}
record := new(dns.A)
record.Hdr.Rrtype = dns.TypeA
record.A = ip
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordAAAA(name string, ip net.IP, ttl uint32) error {
ip = ip.To16()
if ip == nil {
return errors.New("invalid ipv6")
}
record := new(dns.AAAA)
record.Hdr.Rrtype = dns.TypeAAAA
record.AAAA = ip
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordCNAME(name string, target string, ttl uint32) error {
record := new(dns.CNAME)
record.Hdr.Rrtype = dns.TypeCNAME
record.Target = target
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordTXT(name string, txt []string, ttl uint32) error {
record := new(dns.TXT)
record.Hdr.Rrtype = dns.TypeTXT
record.Txt = txt
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordCAA(name string, flag uint8, tag, value string, ttl uint32) error {
record := new(dns.CAA)
record.Hdr.Rrtype = dns.TypeCAA
record.Flag = flag
record.Tag = tag
record.Value = value
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordTLSA(name string, usage, selector, matchingType uint8, certificate string, ttl uint32) error {
record := new(dns.TLSA)
record.Hdr.Rrtype = dns.TypeTLSA
record.Usage = usage
record.Selector = selector
record.MatchingType = matchingType
record.Certificate = certificate
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordSVCB(name string, priority uint16, target string, value []dns.SVCBKeyValue, ttl uint32) error {
record := new(dns.SVCB)
record.Hdr.Rrtype = dns.TypeSVCB
record.Priority = priority
record.Target = target
record.Value = value
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordHTTPS(name string, priority uint16, target string, value []dns.SVCBKeyValue, ttl uint32) error {
record := new(dns.HTTPS)
record.Hdr.Rrtype = dns.TypeHTTPS
record.Priority = priority
record.Target = target
record.Value = value
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordSSHFP(name string, algorithm, ktype uint8, fingerprint string, ttl uint32) error {
record := new(dns.SSHFP)
record.Hdr.Rrtype = dns.TypeSSHFP
record.Algorithm = algorithm
record.Type = ktype
record.FingerPrint = fingerprint
return z.AddRecord(name, ttl, record)
}
func (z *Zone) AddRecordMX(name string, preference uint16, mx string, ttl uint32) error {
record := new(dns.MX)
record.Hdr.Rrtype = dns.TypeMX
record.Preference = preference
record.Mx = mx
return z.AddRecord(name, ttl, record)
}
func (z *Zone) SetRR(rrtype string, name string, rrset RRSet) error {
rrt, ok := dns.StringToType[rrtype]
if !ok {
return fmt.Errorf("could not find record type for %s", rrtype)
}
z.rrset.Delete(rrt, name)
for i, rr := range rrset {
if rr != nil && rr.Header().Rrtype != rrt {
return fmt.Errorf("expected %s, got record type %s at index %d", rrtype, dns.TypeToString[rr.Header().Rrtype], i)
}
}
for _, rr := range rrset {
if rr != nil {
if err := z.AddRR(rr); err != nil {
return err
}
}
}
return nil
}
func (z *Zone) GetRRSet() RRSet {
set := make(RRSet, 0, len(z.rrset))
for _, rr := range z.rrset {
if rr != nil {
set = append(set, rr)
}
}
return set
}
func (z *Zone) AddMissingRecords() {
func() {
for _, rr := range z.rrset.Get(dns.TypeDNSKEY, z.Name()) {
if k, ok := rr.(*dns.DNSKEY); ok {
if k.Flags == z.dnsKey.Flags && k.PublicKey == z.dnsKey.PublicKey && k.Algorithm == z.dnsKey.Algorithm {
return
}
}
}
z.AddRR(z.dnsKey)
}()
func() {
for _, rr := range z.rrset.Get(dns.TypeCDNSKEY, z.Name()) {
if k, ok := rr.(*dns.CDNSKEY); ok {
if k.Flags == z.dnsKey.Flags && k.PublicKey == z.dnsKey.PublicKey && k.Algorithm == z.dnsKey.Algorithm {
return
}
}
}
cdnskey := dns.CDNSKEY{DNSKEY: *z.dnsKey}
cdnskey.Hdr.Rrtype = dns.TypeCDNSKEY
z.AddRR(&cdnskey)
}()
func() {
ds := z.dnsKey.ToDS(dns.SHA256)
if ds == nil {
return
}
for _, rr := range z.rrset.Get(dns.TypeCDS, z.Name()) {
if k, ok := rr.(*dns.CDS); ok {
if k.KeyTag == ds.KeyTag && k.Digest == ds.Digest {
return
}
}
}
cds := dns.CDS{DS: *ds}
cds.Hdr.Rrtype = dns.TypeCDS
z.AddRR(&cds)
}()
}
func (z *Zone) Sign() (rrsigs []*dns.RRSIG, err error) {
var types []uint16
var names []string
inTypes := func(i uint16) bool {
if i == dns.TypeRRSIG ||
i == dns.TypeDS ||
i == dns.TypeTA ||
i == dns.TypeSOA ||
i == dns.TypeRP {
return true
}
for _, j := range types {
if i == j {
return true
}
}
return false
}
inNames := func(i string) bool {
for _, j := range names {
if i == j {
return true
}
}
return false
}
z.AddMissingRecords()
for _, rr := range z.rrset {
if rr == nil {
continue
}
h := rr.Header()
if !inTypes(h.Rrtype) {
types = append(types, h.Rrtype)
}
if !inNames(h.Name) {
names = append(names, h.Name)
}
}
//TODO: do check, remove entries
/*
//TODO: make this an allowlist
if rrt == dns.TypeRRSIG ||
rrt == dns.TypeDNSKEY ||
rrt == dns.TypeCDNSKEY ||
rrt == dns.TypeDS ||
rrt == dns.TypeCDS ||
rrt == dns.TypeNS ||
rrt == dns.TypeTA ||
rrt == dns.TypeRP ||
rrt == dns.TypeSOA { //TODO: allow changing some but with verification of values
return fmt.Errorf("type %s not allowed", rrtype)
}
*/
//Check if another ZSK key exists for zones, otherwise behave as CSK
var isKSKOnly bool
for _, rr := range z.rrset.Get(dns.TypeDNSKEY, z.Name()) {
if k, ok := rr.(*dns.DNSKEY); ok {
//found ZSK key that is not ours
if k.Flags == 256 && k.PublicKey != z.dnsKey.PublicKey && k.Algorithm == z.dnsKey.Algorithm {
isKSKOnly = true
break
}
}
}
//only sign DNSKEY, CDNSKEY, DS, CDS as KSK
if isKSKOnly {
types = []uint16{dns.TypeDS, dns.TypeCDS, dns.TypeDNSKEY, dns.TypeCDNSKEY}
}
for _, rname := range names {
for _, rtype := range types {
rrset := z.rrset.Get(rtype, rname)
if len(rrset) > 0 {
rrsig, err := z.SignRRSet(rrset)
if err != nil {
return nil, err
} else {
rrsigs = append(rrsigs, rrsig)
}
}
}
}
return rrsigs, nil
}
//TODO: implement NSEC or similar
func (z *Zone) SignRRSet(rrset RRSet) (*dns.RRSIG, error) {
if rrset == nil || len(rrset) == 0 {
return nil, errors.New("no records to sign")
}
if z.privateKey == nil {
return nil, errors.New("nil private key")
}
sig := new(dns.RRSIG)
//67 years or max uint32, RFC 4034 points to RFC1982 for max of 68y period
expirationTime := time.Now().UTC().AddDate(67, 0, 0).Unix()
if expirationTime > math.MaxUint32 {
expirationTime = math.MaxUint32
}
sig.Expiration = uint32(expirationTime)
sig.Inception = uint32(time.Now().UTC().Add(-time.Hour).Unix()) //add one hour in the past
sig.KeyTag = z.dnsKey.KeyTag()
sig.SignerName = z.dnsKey.Hdr.Name
sig.Algorithm = z.dnsKey.Algorithm
if err := sig.Sign(z.privateKey, rrset); err != nil {
return nil, err
}
return sig, nil
}
func (z *Zone) Name() string {
return z.dnsKey.Hdr.Name
}
func (z *Zone) DNSKEY() *dns.DNSKEY {
return z.dnsKey
}
func (z *Zone) PubKey() ed25519.PublicKey {
return z.publicKey
}