diff --git a/remote_list.go b/remote_list.go index 60a1afdaf..0259555a2 100644 --- a/remote_list.go +++ b/remote_list.go @@ -5,8 +5,10 @@ import ( "context" "net" "net/netip" + "regexp" "sort" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -85,21 +87,41 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, // Fastrack IP addresses to ensure they're immediately available for use. // DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine. + performBackgroundLookup, ips, err := preprocessHostPorts(hostPorts, r) + if err != nil { + return nil, err + } + r.ips.Store(&ips) + + if performBackgroundLookup { + setupBackgroundLookup(ctx, r, d, onUpdate) + } + + return r, nil +} + +func preprocessHostPorts(hostPorts []string, results *hostnamesResults) (bool, map[netip.AddrPort]struct{}, error) { performBackgroundLookup := false ips := map[netip.AddrPort]struct{}{} + for idx, hostPort := range hostPorts { + if isSRV(hostPort) { + results.hostnames[idx] = hostnamePort{name: hostPort, port: uint16(0)} + performBackgroundLookup = true + continue + } rIp, sPort, err := net.SplitHostPort(hostPort) if err != nil { - return nil, err + return false, nil, err } iPort, err := strconv.Atoi(sPort) if err != nil { - return nil, err + return false, nil, err } - r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)} + results.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)} addr, err := netip.ParseAddr(rIp) if err != nil { // This address is a hostname, not an IP address @@ -110,61 +132,135 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, // Save the IP address immediately ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{} } - r.ips.Store(&ips) - // Time for the DNS lookup goroutine - if performBackgroundLookup { - newCtx, cancel := context.WithCancel(ctx) - r.cancelFn = cancel - ticker := time.NewTicker(d) - go func() { - defer ticker.Stop() - for { - netipAddrs := map[netip.AddrPort]struct{}{} - for _, hostPort := range r.hostnames { - timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout) - addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) - timeoutCancel() - if err != nil { - l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") - continue - } - for _, a := range addrs { - netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} - } - } - origSet := r.ips.Load() - different := false - for a := range *origSet { - if _, ok := netipAddrs[a]; !ok { - different = true - break - } - } - if !different { - for a := range netipAddrs { - if _, ok := (*origSet)[a]; !ok { - different = true - break - } - } - } - if different { - l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") - r.ips.Store(&netipAddrs) - onUpdate() - } - select { - case <-newCtx.Done(): - return - case <-ticker.C: - continue - } - } - }() + return performBackgroundLookup, ips, nil +} + +func isSRV(s string) bool { + re := regexp.MustCompile(`^_[A-Za-z0-9-]+?\._[A-Za-z0-9-]+?\..+$`) + return re.MatchString(s) +} + +func setupBackgroundLookup(ctx context.Context, r *hostnamesResults, duration time.Duration, onUpdate func()) { + newCtx, cancel := context.WithCancel(ctx) + r.cancelFn = cancel + ticker := time.NewTicker(duration) + go performLookup(newCtx, r, ticker, onUpdate) +} + +func performLookup(ctx context.Context, results *hostnamesResults, ticker *time.Ticker, onUpdate func()) { + defer ticker.Stop() + for { + if lookupAndUpdate(ctx, results) { + onUpdate() + } + + select { + case <-ctx.Done(): + return + case <-ticker.C: + continue + } } +} - return r, nil +func lookupAndUpdate(ctx context.Context, results *hostnamesResults) bool { + var netipAddrs map[netip.AddrPort]struct{} + for _, hostPort := range results.hostnames { + netipAddrs = resolveHostPort(ctx, hostPort, results.network, results.lookupTimeout, results.l) + } + + if len(netipAddrs) == 0 { + return false + } + + origSet := results.ips.Load() + different := isDifferent(origSet, netipAddrs) + if different { + results.l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") + results.ips.Store(&netipAddrs) + } + + return different +} + +func isDifferent(origSet *map[netip.AddrPort]struct{}, newSet map[netip.AddrPort]struct{}) bool { + for a := range *origSet { + if _, ok := newSet[a]; !ok { + return true + } + } + + for a := range newSet { + if _, ok := (*origSet)[a]; !ok { + return true + } + } + + return false +} + +func resolveHostPort(ctx context.Context, hostPort hostnamePort, network string, timeout time.Duration, logger *logrus.Logger) map[netip.AddrPort]struct{} { + if isSRV(hostPort.name) { + return resolveSRV(ctx, hostPort, network, timeout, logger) + } else { + return resolveIP(ctx, hostPort, network, timeout, logger) + } +} + +func resolveSRV(ctx context.Context, hostPort hostnamePort, network string, timeout time.Duration, logger *logrus.Logger) map[netip.AddrPort]struct{} { + netipAddrs := make(map[netip.AddrPort]struct{}) + + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + _, srvs, err := net.DefaultResolver.LookupSRV(timeoutCtx, "", "", hostPort.name) + if err != nil { + logger.WithFields(logrus.Fields{"srv": hostPort.name, "network": network}).WithError(err).Error("SRV resolution failed for static_map host") + return netipAddrs + } + + for _, srv := range srvs { + var domain string + if strings.HasSuffix(srv.Target, ".") { + domain = srv.Target[:len(srv.Target)-1] + } else { + domain = srv.Target + } + + ipTimeoutCtx, ipTimeoutCancel := context.WithTimeout(ctx, timeout) + addrs, err := net.DefaultResolver.LookupNetIP(ipTimeoutCtx, network, domain) + ipTimeoutCancel() + if err != nil { + logger.WithFields(logrus.Fields{"hostname": srv.Target, "network": network}).WithError(err).Error("DNS resolution failed for static_map host") + continue + } + + for _, addr := range addrs { + netipAddrs[netip.AddrPortFrom(addr, srv.Port)] = struct{}{} + } + } + + return netipAddrs +} + +func resolveIP(ctx context.Context, hostPort hostnamePort, network string, timeout time.Duration, logger *logrus.Logger) map[netip.AddrPort]struct{} { + netipAddrs := make(map[netip.AddrPort]struct{}) + + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, network, hostPort.name) + if err != nil { + logger.WithFields(logrus.Fields{"hostname": hostPort.name, "network": network}).WithError(err).Error("DNS resolution failed") + return netipAddrs + } + + for _, addr := range addrs { + netipAddrs[netip.AddrPortFrom(addr, hostPort.port)] = struct{}{} + } + + return netipAddrs } func (hr *hostnamesResults) Cancel() {