Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for DNS SRV #1141

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 151 additions & 55 deletions remote_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"context"
"net"
"net/netip"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -85,21 +87,41 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,

// Fastrack IP addresses to ensure they're immediately available for use.
// DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine.
performBackgroundLookup, ips, err := preprocessHostPorts(hostPorts, r)
if err != nil {
return nil, err
}
r.ips.Store(&ips)

if performBackgroundLookup {
setupBackgroundLookup(ctx, r, d, onUpdate)
}

return r, nil
}

func preprocessHostPorts(hostPorts []string, results *hostnamesResults) (bool, map[netip.AddrPort]struct{}, error) {
performBackgroundLookup := false
ips := map[netip.AddrPort]struct{}{}

for idx, hostPort := range hostPorts {
if isSRV(hostPort) {
results.hostnames[idx] = hostnamePort{name: hostPort, port: uint16(0)}
performBackgroundLookup = true
continue
}

rIp, sPort, err := net.SplitHostPort(hostPort)
if err != nil {
return nil, err
return false, nil, err
}

iPort, err := strconv.Atoi(sPort)
if err != nil {
return nil, err
return false, nil, err
}

r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)}
results.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)}
addr, err := netip.ParseAddr(rIp)
if err != nil {
// This address is a hostname, not an IP address
Expand All @@ -110,61 +132,135 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
// Save the IP address immediately
ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{}
}
r.ips.Store(&ips)

// Time for the DNS lookup goroutine
if performBackgroundLookup {
newCtx, cancel := context.WithCancel(ctx)
r.cancelFn = cancel
ticker := time.NewTicker(d)
go func() {
defer ticker.Stop()
for {
netipAddrs := map[netip.AddrPort]struct{}{}
for _, hostPort := range r.hostnames {
timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout)
addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name)
timeoutCancel()
if err != nil {
l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host")
continue
}
for _, a := range addrs {
netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
}
}
origSet := r.ips.Load()
different := false
for a := range *origSet {
if _, ok := netipAddrs[a]; !ok {
different = true
break
}
}
if !different {
for a := range netipAddrs {
if _, ok := (*origSet)[a]; !ok {
different = true
break
}
}
}
if different {
l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
r.ips.Store(&netipAddrs)
onUpdate()
}
select {
case <-newCtx.Done():
return
case <-ticker.C:
continue
}
}
}()
return performBackgroundLookup, ips, nil
}

func isSRV(s string) bool {
re := regexp.MustCompile(`^_[A-Za-z0-9-]+?\._[A-Za-z0-9-]+?\..+$`)
return re.MatchString(s)
}

func setupBackgroundLookup(ctx context.Context, r *hostnamesResults, duration time.Duration, onUpdate func()) {
newCtx, cancel := context.WithCancel(ctx)
r.cancelFn = cancel
ticker := time.NewTicker(duration)
go performLookup(newCtx, r, ticker, onUpdate)
}

func performLookup(ctx context.Context, results *hostnamesResults, ticker *time.Ticker, onUpdate func()) {
defer ticker.Stop()
for {
if lookupAndUpdate(ctx, results) {
onUpdate()
}

select {
case <-ctx.Done():
return
case <-ticker.C:
continue
}
}
}

return r, nil
func lookupAndUpdate(ctx context.Context, results *hostnamesResults) bool {
var netipAddrs map[netip.AddrPort]struct{}
for _, hostPort := range results.hostnames {
netipAddrs = resolveHostPort(ctx, hostPort, results.network, results.lookupTimeout, results.l)
}

if len(netipAddrs) == 0 {
return false
}

origSet := results.ips.Load()
different := isDifferent(origSet, netipAddrs)
if different {
results.l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list")
results.ips.Store(&netipAddrs)
}

return different
}

func isDifferent(origSet *map[netip.AddrPort]struct{}, newSet map[netip.AddrPort]struct{}) bool {
for a := range *origSet {
if _, ok := newSet[a]; !ok {
return true
}
}

for a := range newSet {
if _, ok := (*origSet)[a]; !ok {
return true
}
}

return false
}

func resolveHostPort(ctx context.Context, hostPort hostnamePort, network string, timeout time.Duration, logger *logrus.Logger) map[netip.AddrPort]struct{} {
if isSRV(hostPort.name) {
return resolveSRV(ctx, hostPort, network, timeout, logger)
} else {
return resolveIP(ctx, hostPort, network, timeout, logger)
}
}

func resolveSRV(ctx context.Context, hostPort hostnamePort, network string, timeout time.Duration, logger *logrus.Logger) map[netip.AddrPort]struct{} {
netipAddrs := make(map[netip.AddrPort]struct{})

timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

_, srvs, err := net.DefaultResolver.LookupSRV(timeoutCtx, "", "", hostPort.name)
if err != nil {
logger.WithFields(logrus.Fields{"srv": hostPort.name, "network": network}).WithError(err).Error("SRV resolution failed for static_map host")
return netipAddrs
}

for _, srv := range srvs {
var domain string
if strings.HasSuffix(srv.Target, ".") {
domain = srv.Target[:len(srv.Target)-1]
} else {
domain = srv.Target
}

ipTimeoutCtx, ipTimeoutCancel := context.WithTimeout(ctx, timeout)
addrs, err := net.DefaultResolver.LookupNetIP(ipTimeoutCtx, network, domain)
ipTimeoutCancel()
if err != nil {
logger.WithFields(logrus.Fields{"hostname": srv.Target, "network": network}).WithError(err).Error("DNS resolution failed for static_map host")
continue
}

for _, addr := range addrs {
netipAddrs[netip.AddrPortFrom(addr, srv.Port)] = struct{}{}
}
}

return netipAddrs
}

func resolveIP(ctx context.Context, hostPort hostnamePort, network string, timeout time.Duration, logger *logrus.Logger) map[netip.AddrPort]struct{} {
netipAddrs := make(map[netip.AddrPort]struct{})

timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, network, hostPort.name)
if err != nil {
logger.WithFields(logrus.Fields{"hostname": hostPort.name, "network": network}).WithError(err).Error("DNS resolution failed")
return netipAddrs
}

for _, addr := range addrs {
netipAddrs[netip.AddrPortFrom(addr, hostPort.port)] = struct{}{}
}

return netipAddrs
}

func (hr *hostnamesResults) Cancel() {
Expand Down