package dnsserver
import (
"fmt"
"net"
"strings"
)
type zoneAddr struct {
Zone string
Port string
Transport string // dns, tls or grpc
Address string // used for bound zoneAddr - validation of overlapping
}
// String returns the string representation of z.
func (z zoneAddr) String() string {
s := z.Transport + "://" + z.Zone + ":" + z.Port
if z.Address != "" {
s += " on " + z.Address
}
return s
}
// SplitProtocolHostPort splits a full formed address like "dns://[::1]:53" into parts.
func SplitProtocolHostPort(address string) (protocol string, ip string, port string, err error) {
parts := strings.Split(address, "://")
switch len(parts) {
case 1:
ip, port, err := net.SplitHostPort(parts[0])
return "", ip, port, err
case 2:
ip, port, err := net.SplitHostPort(parts[1])
return parts[0], ip, port, err
default:
return "", "", "", fmt.Errorf("provided value is not in an address format : %s", address)
}
}
type zoneOverlap struct {
registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key
unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAddr is registered by its original key
}
func newOverlapZone() *zoneOverlap {
return &zoneOverlap{registeredAddr: make(map[zoneAddr]zoneAddr), unboundOverlap: make(map[zoneAddr]zoneAddr)}
}
// registerAndCheck adds a new zoneAddr for validation, it returns information about existing or overlapping with already registered
// we consider that an unbound address is overlapping all bound addresses for same zone, same port
func (zo *zoneOverlap) registerAndCheck(z zoneAddr) (existingZone *zoneAddr, overlappingZone *zoneAddr) {
existingZone, overlappingZone = zo.check(z)
if existingZone != nil || overlappingZone != nil {
return existingZone, overlappingZone
}
// there is no overlap, keep the current zoneAddr for future checks
zo.registeredAddr[z] = z
zo.unboundOverlap[z.unbound()] = z
return nil, nil
}
// check validates a zoneAddr for overlap without registering it
func (zo *zoneOverlap) check(z zoneAddr) (existingZone *zoneAddr, overlappingZone *zoneAddr) {
if exist, ok := zo.registeredAddr[z]; ok {
// exact same zone already registered
return &exist, nil
}
uz := z.unbound()
if already, ok := zo.unboundOverlap[uz]; ok {
if z.Address == "" {
// current is not bound to an address, but there is already another zone with a bind address registered
return nil, &already
}
if _, ok := zo.registeredAddr[uz]; ok {
// current zone is bound to an address, but there is already an overlapping zone+port with no bind address
return nil, &uz
}
}
// there is no overlap
return nil, nil
}
// unbound returns an unbound version of the zoneAddr
func (z zoneAddr) unbound() zoneAddr {
return zoneAddr{Zone: z.Zone, Address: "", Port: z.Port, Transport: z.Transport}
}
package dnsserver
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
)
// Config configuration for a single server.
type Config struct {
// The zone of the site.
Zone string
// one or several hostnames to bind the server to.
// defaults to a single empty string that denote the wildcard address
ListenHosts []string
// The port to listen on.
Port string
// The number of servers that will listen on one port.
// By default, one server will be running.
NumSockets int
// Root points to a base directory we find user defined "things".
// First consumer is the file plugin to looks for zone files in this place.
Root string
// Debug controls the panic/recover mechanism that is enabled by default.
Debug bool
// Stacktrace controls including stacktrace as part of log from recover mechanism, it is disabled by default.
Stacktrace bool
// The transport we implement, normally just "dns" over TCP/UDP, but could be
// DNS-over-TLS or DNS-over-gRPC.
Transport string
// If this function is not nil it will be used to inspect and validate
// HTTP requests. Although this isn't referenced in-tree, external plugins
// may depend on it.
HTTPRequestValidateFunc func(*http.Request) bool
// FilterFuncs is used to further filter access
// to this handler. E.g. to limit access to a reverse zone
// on a non-octet boundary, i.e. /17
FilterFuncs []FilterFunc
// ViewName is the name of the Viewer PLugin defined in the Config
ViewName string
// TLSConfig when listening for encrypted connections (gRPC, DNS-over-TLS).
TLSConfig *tls.Config
// MaxQUICStreams defines the maximum number of concurrent QUIC streams for a QUIC server.
// This is nil if not specified, allowing for a default to be used.
MaxQUICStreams *int
// MaxQUICWorkerPoolSize defines the size of the worker pool for processing QUIC streams.
// This is nil if not specified, allowing for a default to be used.
MaxQUICWorkerPoolSize *int
// Timeouts for TCP, TLS and HTTPS servers.
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
// TSIG secrets, [name]key.
TsigSecret map[string]string
// Plugin stack.
Plugin []plugin.Plugin
// Compiled plugin stack.
pluginChain plugin.Handler
// Plugin interested in announcing that they exist, so other plugin can call methods
// on them should register themselves here. The name should be the name as return by the
// Handler's Name method.
registry map[string]plugin.Handler
// firstConfigInBlock is used to reference the first config in a server block, for the
// purpose of sharing single instance of each plugin among all zones in a server block.
firstConfigInBlock *Config
// metaCollector references the first MetadataCollector plugin, if one exists
metaCollector MetadataCollector
}
// FilterFunc is a function that filters requests from the Config
type FilterFunc func(context.Context, *request.Request) bool
// keyForConfig builds a key for identifying the configs during setup time
func keyForConfig(blocIndex int, blocKeyIndex int) string {
return fmt.Sprintf("%d:%d", blocIndex, blocKeyIndex)
}
// GetConfig gets the Config that corresponds to c.
// If none exist nil is returned.
func GetConfig(c *caddy.Controller) *Config {
ctx := c.Context().(*dnsContext)
key := keyForConfig(c.ServerBlockIndex, c.ServerBlockKeyIndex)
if cfg, ok := ctx.keysToConfigs[key]; ok {
return cfg
}
// we should only get here during tests because directive
// actions typically skip the server blocks where we make
// the configs.
ctx.saveConfig(key, &Config{ListenHosts: []string{""}})
return GetConfig(c)
}
package dnsserver
import (
"net"
"net/http"
"github.com/miekg/dns"
)
// DoHWriter is a dns.ResponseWriter that adds more specific LocalAddr and RemoteAddr methods.
type DoHWriter struct {
// raddr is the remote's address. This can be optionally set.
raddr net.Addr
// laddr is our address. This can be optionally set.
laddr net.Addr
// request is the HTTP request we're currently handling.
request *http.Request
// Msg is a response to be written to the client.
Msg *dns.Msg
}
// WriteMsg stores the message to be written to the client.
func (d *DoHWriter) WriteMsg(m *dns.Msg) error {
d.Msg = m
return nil
}
// Write stores the message to be written to the client.
func (d *DoHWriter) Write(b []byte) (int, error) {
d.Msg = new(dns.Msg)
return len(b), d.Msg.Unpack(b)
}
// RemoteAddr returns the remote address.
func (d *DoHWriter) RemoteAddr() net.Addr {
return d.raddr
}
// LocalAddr returns the local address.
func (d *DoHWriter) LocalAddr() net.Addr {
return d.laddr
}
// Network no-op implementation.
func (d *DoHWriter) Network() string {
return ""
}
// Request returns the HTTP request.
func (d *DoHWriter) Request() *http.Request {
return d.request
}
// Close no-op implementation.
func (d *DoHWriter) Close() error {
return nil
}
// TsigStatus no-op implementation.
func (d *DoHWriter) TsigStatus() error {
return nil
}
// TsigTimersOnly no-op implementation.
func (d *DoHWriter) TsigTimersOnly(_ bool) {}
// Hijack no-op implementation.
func (d *DoHWriter) Hijack() {}
package dnsserver
import (
"fmt"
"regexp"
"sort"
"strings"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
)
// checkZoneSyntax() checks whether the given string match 1035 Preferred Syntax or not.
// The root zone, and all reverse zones always return true even though they technically don't meet 1035 Preferred Syntax
func checkZoneSyntax(zone string) bool {
if zone == "." || dnsutil.IsReverse(zone) != 0 {
return true
}
regex1035PreferredSyntax, _ := regexp.MatchString(`^(([A-Za-z]([A-Za-z0-9-]*[A-Za-z0-9])?)\.)+$`, zone)
return regex1035PreferredSyntax
}
// startUpZones creates the text that we show when starting up:
// grpc://example.com.:1055
// example.com.:1053 on 127.0.0.1
func startUpZones(protocol, addr string, zones map[string][]*Config) string {
keys := make([]string, len(zones))
i := 0
for k := range zones {
keys[i] = k
i++
}
sort.Strings(keys)
var sb strings.Builder
for _, zone := range keys {
if !checkZoneSyntax(zone) {
sb.WriteString(fmt.Sprintf("Warning: Domain %q does not follow RFC1035 preferred syntax\n", zone))
}
// split addr into protocol, IP and Port
_, ip, port, err := SplitProtocolHostPort(addr)
if err != nil {
// this should not happen, but we need to take care of it anyway
sb.WriteString(fmt.Sprintln(protocol + zone + ":" + addr))
continue
}
if ip == "" {
sb.WriteString(fmt.Sprintln(protocol + zone + ":" + port))
continue
}
// if the server is listening on a specific address let's make it visible in the log,
// so one can differentiate between all active listeners
sb.WriteString(fmt.Sprintln(protocol + zone + ":" + port + " on " + ip))
}
return sb.String()
}
package dnsserver
import (
"encoding/binary"
"errors"
"net"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
type DoQWriter struct {
localAddr net.Addr
remoteAddr net.Addr
stream *quic.Stream
Msg *dns.Msg
}
func (w *DoQWriter) Write(b []byte) (int, error) {
if w.stream == nil {
return 0, errors.New("stream is nil")
}
b = AddPrefix(b)
return w.stream.Write(b)
}
func (w *DoQWriter) WriteMsg(m *dns.Msg) error {
bytes, err := m.Pack()
if err != nil {
return err
}
_, err = w.Write(bytes)
if err != nil {
return err
}
return w.Close()
}
// Close sends the STREAM FIN signal.
// The server MUST send the response(s) on the same stream and MUST
// indicate, after the last response, through the STREAM FIN
// mechanism that no further data will be sent on that stream.
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.2-7
func (w *DoQWriter) Close() error {
if w.stream == nil {
return errors.New("stream is nil")
}
return w.stream.Close()
}
// AddPrefix adds a 2-byte prefix with the DNS message length.
func AddPrefix(b []byte) (m []byte) {
m = make([]byte, 2+len(b))
binary.BigEndian.PutUint16(m, uint16(len(b)))
copy(m[2:], b)
return m
}
// These methods implement the dns.ResponseWriter interface from Go DNS.
func (w *DoQWriter) TsigStatus() error { return nil }
func (w *DoQWriter) TsigTimersOnly(b bool) {}
func (w *DoQWriter) Hijack() {}
func (w *DoQWriter) LocalAddr() net.Addr { return w.localAddr }
func (w *DoQWriter) RemoteAddr() net.Addr { return w.remoteAddr }
func (w *DoQWriter) Network() string { return "" }
package dnsserver
import (
"fmt"
"net"
"time"
"github.com/coredns/caddy"
"github.com/coredns/caddy/caddyfile"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
const serverType = "dns"
func init() {
caddy.RegisterServerType(serverType, caddy.ServerType{
Directives: func() []string { return Directives },
DefaultInput: func() caddy.Input {
return caddy.CaddyfileInput{
Filepath: "Corefile",
Contents: []byte(".:" + Port + " {\nwhoami\nlog\n}\n"),
ServerTypeName: serverType,
}
},
NewContext: newContext,
})
}
func newContext(i *caddy.Instance) caddy.Context {
return &dnsContext{keysToConfigs: make(map[string]*Config)}
}
type dnsContext struct {
keysToConfigs map[string]*Config
// configs is the master list of all site configs.
configs []*Config
}
func (h *dnsContext) saveConfig(key string, cfg *Config) {
h.configs = append(h.configs, cfg)
h.keysToConfigs[key] = cfg
}
// Compile-time check to ensure dnsContext implements the caddy.Context interface
var _ caddy.Context = &dnsContext{}
// InspectServerBlocks make sure that everything checks out before
// executing directives and otherwise prepares the directives to
// be parsed and executed.
func (h *dnsContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
// Normalize and check all the zone names and check for duplicates
for ib, s := range serverBlocks {
// Walk the s.Keys and expand any reverse address in their proper DNS in-addr zones. If the expansions leads for
// more than one reverse zone, replace the current value and add the rest to s.Keys.
zoneAddrs := []zoneAddr{}
for ik, k := range s.Keys {
trans, k1 := parse.Transport(k) // get rid of any dns:// or other scheme.
hosts, port, err := plugin.SplitHostPort(k1)
// We need to make this a fully qualified domain name to catch all errors here and not later when
// plugin.Normalize is called again on these strings, with the prime difference being that the domain
// name is fully qualified. This was found by fuzzing where "ȶ" is deemed OK, but "ȶ." is not (might be a
// bug in miekg/dns actually). But here we were checking ȶ, which is OK, and later we barf in ȶ. leading to
// "index out of range".
for ih := range hosts {
_, _, err := plugin.SplitHostPort(dns.Fqdn(hosts[ih]))
if err != nil {
return nil, err
}
}
if err != nil {
return nil, err
}
if port == "" {
switch trans {
case transport.DNS:
port = Port
case transport.TLS:
port = transport.TLSPort
case transport.QUIC:
port = transport.QUICPort
case transport.GRPC:
port = transport.GRPCPort
case transport.HTTPS:
port = transport.HTTPSPort
}
}
if len(hosts) > 1 {
s.Keys[ik] = hosts[0] + ":" + port // replace for the first
for _, h := range hosts[1:] { // add the rest
s.Keys = append(s.Keys, h+":"+port)
}
}
for i := range hosts {
zoneAddrs = append(zoneAddrs, zoneAddr{Zone: dns.Fqdn(hosts[i]), Port: port, Transport: trans})
}
}
serverBlocks[ib].Keys = s.Keys // important to save back the new keys that are potentially created here.
var firstConfigInBlock *Config
for ik := range s.Keys {
za := zoneAddrs[ik]
s.Keys[ik] = za.String()
// Save the config to our master list, and key it for lookups.
cfg := &Config{
Zone: za.Zone,
ListenHosts: []string{""},
Port: za.Port,
Transport: za.Transport,
}
// Set reference to the first config in the current block.
// This is used later by MakeServers to share a single plugin list
// for all zones in a server block.
if ik == 0 {
firstConfigInBlock = cfg
}
cfg.firstConfigInBlock = firstConfigInBlock
keyConfig := keyForConfig(ib, ik)
h.saveConfig(keyConfig, cfg)
}
}
return serverBlocks, nil
}
// MakeServers uses the newly-created siteConfigs to create and return a list of server instances.
func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
// Copy parameters from first config in the block to all other config in the same block
propagateConfigParams(h.configs)
// we must map (group) each config to a bind address
groups, err := groupConfigsByListenAddr(h.configs)
if err != nil {
return nil, err
}
// then we create a server for each group
var servers []caddy.Server
for addr, group := range groups {
serversForGroup, err := makeServersForGroup(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, serversForGroup...)
}
// For each server config, check for View Filter plugins
for _, c := range h.configs {
// Add filters in the plugin.cfg order for consistent filter func evaluation order.
for _, d := range Directives {
if vf, ok := c.registry[d].(Viewer); ok {
if c.ViewName != "" {
return nil, fmt.Errorf("multiple views defined in server block")
}
c.ViewName = vf.ViewName()
c.FilterFuncs = append(c.FilterFuncs, vf.Filter)
}
}
}
// Verify that there is no overlap on the zones and listen addresses
// for unfiltered server configs
errValid := h.validateZonesAndListeningAddresses()
if errValid != nil {
return nil, errValid
}
return servers, nil
}
// AddPlugin adds a plugin to a site's plugin stack.
func (c *Config) AddPlugin(m plugin.Plugin) {
c.Plugin = append(c.Plugin, m)
}
// registerHandler adds a handler to a site's handler registration. Handlers
//
// use this to announce that they exist to other plugin.
func (c *Config) registerHandler(h plugin.Handler) {
if c.registry == nil {
c.registry = make(map[string]plugin.Handler)
}
// Just overwrite...
c.registry[h.Name()] = h
}
// Handler returns the plugin handler that has been added to the config under its name.
// This is useful to inspect if a certain plugin is active in this server.
// Note that this is order dependent and the order is defined in directives.go, i.e. if your plugin
// comes before the plugin you are checking; it will not be there (yet).
func (c *Config) Handler(name string) plugin.Handler {
if c.registry == nil {
return nil
}
if h, ok := c.registry[name]; ok {
return h
}
return nil
}
// Handlers returns a slice of plugins that have been registered. This can be used to
// inspect and interact with registered plugins but cannot be used to remove or add plugins.
// Note that this is order dependent and the order is defined in directives.go, i.e. if your plugin
// comes before the plugin you are checking; it will not be there (yet).
func (c *Config) Handlers() []plugin.Handler {
if c.registry == nil {
return nil
}
hs := make([]plugin.Handler, 0, len(c.registry))
for _, k := range Directives {
registry := c.Handler(k)
if registry != nil {
hs = append(hs, registry)
}
}
return hs
}
func (h *dnsContext) validateZonesAndListeningAddresses() error {
//Validate Zone and addresses
checker := newOverlapZone()
for _, conf := range h.configs {
for _, h := range conf.ListenHosts {
// Validate the overlapping of ZoneAddr
akey := zoneAddr{Transport: conf.Transport, Zone: conf.Zone, Address: h, Port: conf.Port}
var existZone, overlapZone *zoneAddr
if len(conf.FilterFuncs) > 0 {
// This config has filters. Check for overlap with other (unfiltered) configs.
existZone, overlapZone = checker.check(akey)
} else {
// This config has no filters. Check for overlap with other (unfiltered) configs,
// and register the zone to prevent subsequent zones from overlapping with it.
existZone, overlapZone = checker.registerAndCheck(akey)
}
if existZone != nil {
return fmt.Errorf("cannot serve %s - it is already defined", akey.String())
}
if overlapZone != nil {
return fmt.Errorf("cannot serve %s - zone overlap listener capacity with %v", akey.String(), overlapZone.String())
}
}
}
return nil
}
// propagateConfigParams copies the necessary parameters from first config in the block
// to all other config in the same block. Doing this results in zones
// sharing the same plugin instances and settings as other zones in
// the same block.
func propagateConfigParams(configs []*Config) {
for _, c := range configs {
c.Plugin = c.firstConfigInBlock.Plugin
c.ListenHosts = c.firstConfigInBlock.ListenHosts
c.Debug = c.firstConfigInBlock.Debug
c.Stacktrace = c.firstConfigInBlock.Stacktrace
c.NumSockets = c.firstConfigInBlock.NumSockets
// Fork TLSConfig for each encrypted connection
c.TLSConfig = c.firstConfigInBlock.TLSConfig.Clone()
c.ReadTimeout = c.firstConfigInBlock.ReadTimeout
c.WriteTimeout = c.firstConfigInBlock.WriteTimeout
c.IdleTimeout = c.firstConfigInBlock.IdleTimeout
c.TsigSecret = c.firstConfigInBlock.TsigSecret
}
}
// groupConfigsByListenAddr groups site configs by their listen
// (bind) address, so sites that use the same listener can be served
// on the same server instance. The return value maps the listen
// address (what you pass into net.Listen) to the list of site configs.
// This function does NOT vet the configs to ensure they are compatible.
func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
groups := make(map[string][]*Config)
for _, conf := range configs {
for _, h := range conf.ListenHosts {
addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(h, conf.Port))
if err != nil {
return nil, err
}
addrstr := conf.Transport + "://" + addr.String()
groups[addrstr] = append(groups[addrstr], conf)
}
}
return groups, nil
}
// makeServersForGroup creates servers for a specific transport and group.
// It creates as many servers as specified in the NumSockets configuration.
// If the NumSockets param is not specified, one server is created by default.
func makeServersForGroup(addr string, group []*Config) ([]caddy.Server, error) {
// that is impossible, but better to check
if len(group) == 0 {
return nil, fmt.Errorf("no configs for group defined")
}
// create one server by default if no NumSockets specified
numSockets := 1
if group[0].NumSockets > 0 {
numSockets = group[0].NumSockets
}
var servers []caddy.Server
for range numSockets {
// switch on addr
switch tr, _ := parse.Transport(addr); tr {
case transport.DNS:
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case transport.TLS:
s, err := NewServerTLS(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case transport.QUIC:
s, err := NewServerQUIC(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case transport.GRPC:
s, err := NewServergRPC(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case transport.HTTPS:
s, err := NewServerHTTPS(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
}
}
return servers, nil
}
// DefaultPort is the default port.
const DefaultPort = transport.Port
// These "soft defaults" are configurable by
// command line flags, etc.
var (
// Port is the port we listen on by default.
Port = DefaultPort
// GracefulTimeout is the maximum duration of a graceful shutdown.
GracefulTimeout time.Duration
)
// Package dnsserver implements all the interfaces from Caddy, so that CoreDNS can be a servertype plugin.
package dnsserver
import (
"context"
"fmt"
"maps"
"net"
"runtime/debug"
"strings"
"sync"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
"github.com/coredns/coredns/plugin/pkg/edns"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/trace"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
)
// Server represents an instance of a server, which serves
// DNS requests at a particular address (host and port). A
// server is capable of serving numerous zones on
// the same address and the listener may be stopped for
// graceful termination (POSIX only).
type Server struct {
Addr string // Address we listen on
IdleTimeout time.Duration // Idle timeout for TCP
ReadTimeout time.Duration // Read timeout for TCP
WriteTimeout time.Duration // Write timeout for TCP
server [2]*dns.Server // 0 is a net.Listener, 1 is a net.PacketConn (a *UDPConn) in our case.
m sync.Mutex // protects the servers
zones map[string][]*Config // zones keyed by their address
graceTimeout time.Duration // the maximum duration of a graceful shutdown
trace trace.Trace // the trace plugin for the server
debug bool // disable recover()
stacktrace bool // enable stacktrace in recover error log
classChaos bool // allow non-INET class queries
tsigSecret map[string]string
// Ensure Stop is idempotent when invoked concurrently (e.g., during reload and SIGTERM).
stopOnce sync.Once
stopErr error
}
// MetadataCollector is a plugin that can retrieve metadata functions from all metadata providing plugins
type MetadataCollector interface {
Collect(context.Context, request.Request) context.Context
}
// NewServer returns a new CoreDNS server and compiles all plugins in to it. By default CH class
// queries are blocked unless queries from enableChaos are loaded.
func NewServer(addr string, group []*Config) (*Server, error) {
s := &Server{
Addr: addr,
zones: make(map[string][]*Config),
graceTimeout: 5 * time.Second,
IdleTimeout: 10 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 5 * time.Second,
tsigSecret: make(map[string]string),
}
for _, site := range group {
if site.Debug {
s.debug = true
log.D.Set()
}
s.stacktrace = site.Stacktrace
// append the config to the zone's configs
s.zones[site.Zone] = append(s.zones[site.Zone], site)
// set timeouts
if site.ReadTimeout != 0 {
s.ReadTimeout = site.ReadTimeout
}
if site.WriteTimeout != 0 {
s.WriteTimeout = site.WriteTimeout
}
if site.IdleTimeout != 0 {
s.IdleTimeout = site.IdleTimeout
}
// copy tsig secrets
maps.Copy(s.tsigSecret, site.TsigSecret)
// compile custom plugin for everything
var stack plugin.Handler
for i := len(site.Plugin) - 1; i >= 0; i-- {
stack = site.Plugin[i](stack)
// register the *handler* also
site.registerHandler(stack)
// If the current plugin is a MetadataCollector, bookmark it for later use. This loop traverses the plugin
// list backwards, so the first MetadataCollector plugin wins.
if mdc, ok := stack.(MetadataCollector); ok {
site.metaCollector = mdc
}
if s.trace == nil && stack.Name() == "trace" {
// we have to stash away the plugin, not the
// Tracer object, because the Tracer won't be initialized yet
if t, ok := stack.(trace.Trace); ok {
s.trace = t
}
}
// Unblock CH class queries when any of these plugins are loaded.
if _, ok := EnableChaos[stack.Name()]; ok {
s.classChaos = true
}
}
site.pluginChain = stack
}
if !s.debug {
// When reloading we need to explicitly disable debug logging if it is now disabled.
log.D.Clear()
}
return s, nil
}
// Compile-time check to ensure Server implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &Server{}
// Serve starts the server with an existing listener. It blocks until the server stops.
// This implements caddy.TCPServer interface.
func (s *Server) Serve(l net.Listener) error {
s.m.Lock()
s.server[tcp] = &dns.Server{Listener: l,
Net: "tcp",
TsigSecret: s.tsigSecret,
MaxTCPQueries: tcpMaxQueries,
ReadTimeout: s.ReadTimeout,
WriteTimeout: s.WriteTimeout,
IdleTimeout: func() time.Duration {
return s.IdleTimeout
},
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
return s.server[tcp].ActivateAndServe()
}
// ServePacket starts the server with an existing packetconn. It blocks until the server stops.
// This implements caddy.UDPServer interface.
func (s *Server) ServePacket(p net.PacketConn) error {
s.m.Lock()
s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
}), TsigSecret: s.tsigSecret}
s.m.Unlock()
return s.server[udp].ActivateAndServe()
}
// Listen implements caddy.TCPServer interface.
func (s *Server) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.DNS+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// WrapListener Listen implements caddy.GracefulServer interface.
func (s *Server) WrapListener(ln net.Listener) net.Listener {
return ln
}
// ListenPacket implements caddy.UDPServer interface.
func (s *Server) ListenPacket() (net.PacketConn, error) {
p, err := reuseport.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):])
if err != nil {
return nil, err
}
return p, nil
}
// Stop attempts to gracefully stop the server.
// It waits until the server is stopped and its connections are closed,
// up to a max timeout of a few seconds. If unsuccessful, an error is returned.
//
// This implements Caddy.Stopper interface.
func (s *Server) Stop() error {
s.stopOnce.Do(func() {
ctx, cancelCtx := context.WithTimeout(context.Background(), s.graceTimeout)
defer cancelCtx()
var wg sync.WaitGroup
s.m.Lock()
for _, s1 := range s.server {
// We might not have started and initialized the full set of servers
if s1 == nil {
continue
}
wg.Add(1)
go func() {
s1.ShutdownContext(ctx)
wg.Done()
}()
}
s.m.Unlock()
wg.Wait()
s.stopErr = ctx.Err()
})
return s.stopErr
}
// Address together with Stop() implement caddy.GracefulServer.
func (s *Server) Address() string { return s.Addr }
// ServeDNS is the entry point for every request to the address that
// is bound to. It acts as a multiplexer for the requests zonename as
// defined in the request so that the correct zone
// (configuration and plugin stack) will handle the request.
func (s *Server) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
// The default dns.Mux checks the question section size, but we have our
// own mux here. Check if we have a question section. If not drop them here.
if r == nil || len(r.Question) == 0 {
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeServerFailure)
return
}
if !s.debug {
defer func() {
// In case the user doesn't enable error plugin, we still
// need to make sure that we stay alive up here
if rec := recover(); rec != nil {
if s.stacktrace {
log.Errorf("Recovered from panic in server: %q %v\n%s", s.Addr, rec, string(debug.Stack()))
} else {
log.Errorf("Recovered from panic in server: %q %v", s.Addr, rec)
}
vars.Panic.Inc()
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeServerFailure)
}
}()
}
if !s.classChaos && r.Question[0].Qclass != dns.ClassINET {
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
return
}
if m, err := edns.Version(r); err != nil { // Wrong EDNS version, return at once.
w.WriteMsg(m)
return
}
// Wrap the response writer in a ScrubWriter so we automatically make the reply fit in the client's buffer.
w = request.NewScrubWriter(r, w)
q := strings.ToLower(r.Question[0].Name)
var (
off int
end bool
dshandler *Config
)
for {
if z, ok := s.zones[q[off:]]; ok {
for _, h := range z {
if h.pluginChain == nil { // zone defined, but has not got any plugins
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
return
}
if h.metaCollector != nil {
// Collect metadata now, so it can be used before we send a request down the plugin chain.
ctx = h.metaCollector.Collect(ctx, request.Request{Req: r, W: w})
}
// If all filter funcs pass, use this config.
if passAllFilterFuncs(ctx, h.FilterFuncs, &request.Request{Req: r, W: w}) {
if h.ViewName != "" {
// if there was a view defined for this Config, set the view name in the context
ctx = context.WithValue(ctx, ViewKey{}, h.ViewName)
}
if r.Question[0].Qtype != dns.TypeDS {
rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
if !plugin.ClientWrite(rcode) {
errorFunc(s.Addr, w, r, rcode)
}
return
}
// The type is DS, keep the handler, but keep on searching as maybe we are serving
// the parent as well and the DS should be routed to it - this will probably *misroute* DS
// queries to a possibly grand parent, but there is no way for us to know at this point
// if there is an actual delegation from grandparent -> parent -> zone.
// In all fairness: direct DS queries should not be needed.
dshandler = h
}
}
}
off, end = dns.NextLabel(q, off)
if end {
break
}
}
if r.Question[0].Qtype == dns.TypeDS && dshandler != nil && dshandler.pluginChain != nil {
// DS request, and we found a zone, use the handler for the query.
rcode, _ := dshandler.pluginChain.ServeDNS(ctx, w, r)
if !plugin.ClientWrite(rcode) {
errorFunc(s.Addr, w, r, rcode)
}
return
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if z, ok := s.zones["."]; ok {
for _, h := range z {
if h.pluginChain == nil {
continue
}
if h.metaCollector != nil {
// Collect metadata now, so it can be used before we send a request down the plugin chain.
ctx = h.metaCollector.Collect(ctx, request.Request{Req: r, W: w})
}
// If all filter funcs pass, use this config.
if passAllFilterFuncs(ctx, h.FilterFuncs, &request.Request{Req: r, W: w}) {
if h.ViewName != "" {
// if there was a view defined for this Config, set the view name in the context
ctx = context.WithValue(ctx, ViewKey{}, h.ViewName)
}
rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
if !plugin.ClientWrite(rcode) {
errorFunc(s.Addr, w, r, rcode)
}
return
}
}
}
// Still here? Error out with REFUSED.
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
}
// passAllFilterFuncs returns true if all filter funcs evaluate to true for the given request
func passAllFilterFuncs(ctx context.Context, filterFuncs []FilterFunc, req *request.Request) bool {
for _, ff := range filterFuncs {
if !ff(ctx, req) {
return false
}
}
return true
}
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *Server) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones("", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Tracer returns the tracer in the server if defined.
func (s *Server) Tracer() ot.Tracer {
if s.trace == nil {
return nil
}
return s.trace.Tracer()
}
// errorFunc responds to an DNS request with an error.
func errorFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int) {
state := request.Request{W: w, Req: r}
answer := new(dns.Msg)
answer.SetRcode(r, rc)
state.SizeAndDo(answer)
w.WriteMsg(answer)
}
func errorAndMetricsFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int) {
state := request.Request{W: w, Req: r}
answer := new(dns.Msg)
answer.SetRcode(r, rc)
state.SizeAndDo(answer)
vars.Report(server, state, vars.Dropped, "", rcode.ToString(rc), "" /* plugin */, answer.Len(), time.Now())
w.WriteMsg(answer)
}
const (
tcp = 0
udp = 1
tcpMaxQueries = -1
)
type (
// Key is the context key for the current server added to the context.
Key struct{}
// LoopKey is the context key to detect server wide loops.
LoopKey struct{}
// ViewKey is the context key for the current view, if defined
ViewKey struct{}
)
// EnableChaos is a map with plugin names for which we should open CH class queries as we block these by default.
var EnableChaos = map[string]struct{}{
"chaos": {},
"forward": {},
"proxy": {},
}
// Quiet mode will not show any informative output on initialization.
var Quiet bool
package dnsserver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/coredns/caddy"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/miekg/dns"
"github.com/opentracing/opentracing-go"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)
// ServergRPC represents an instance of a DNS-over-gRPC server.
type ServergRPC struct {
*Server
*pb.UnimplementedDnsServiceServer
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
}
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
// http/2 is required when using gRPC. We need to specify it in next protos
// or the upgrade won't happen.
if tlsConfig != nil {
tlsConfig.NextProtos = []string{"h2"}
}
return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
}
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &ServergRPC{}
// Serve implements caddy.TCPServer interface.
func (s *ServergRPC) Serve(l net.Listener) error {
s.m.Lock()
s.listenAddr = l.Addr()
s.m.Unlock()
if s.Tracer() != nil {
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp any) bool {
return parentSpanCtx != nil
}
intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept))
} else {
s.grpcServer = grpc.NewServer()
}
pb.RegisterDnsServiceServer(s.grpcServer, s)
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
return s.grpcServer.Serve(l)
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServergRPC) ListenPacket() (net.PacketConn, error) { return nil, nil }
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServergRPC) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Stop stops the server. It blocks until the server is
// totally stopped.
func (s *ServergRPC) Stop() (err error) {
s.m.Lock()
defer s.m.Unlock()
if s.grpcServer != nil {
s.grpcServer.GracefulStop()
}
return
}
// Query is the main entry-point into the gRPC server. From here we call ServeDNS like
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
// back to the client as a protobuf.
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
msg := new(dns.Msg)
err := msg.Unpack(in.GetMsg())
if err != nil {
return nil, err
}
p, ok := peer.FromContext(ctx)
if !ok {
return nil, errors.New("no peer in gRPC context")
}
a, ok := p.Addr.(*net.TCPAddr)
if !ok {
return nil, fmt.Errorf("no TCP peer in gRPC context: %v", p.Addr)
}
w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg}
dnsCtx := context.WithValue(ctx, Key{}, s.Server)
dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0)
s.ServeDNS(dnsCtx, w, msg)
packed, err := w.Msg.Pack()
if err != nil {
return nil, err
}
return &pb.DnsPacket{Msg: packed}, nil
}
// Shutdown stops the server (non gracefully).
func (s *ServergRPC) Shutdown() error {
if s.grpcServer != nil {
s.grpcServer.Stop()
}
return nil
}
type gRPCresponse struct {
localAddr net.Addr
remoteAddr net.Addr
Msg *dns.Msg
}
// Write is the hack that makes this work. It does not actually write the message
// but returns the bytes we need to write in r. We can then pick this up in Query
// and write a proper protobuf back to the client.
func (r *gRPCresponse) Write(b []byte) (int, error) {
r.Msg = new(dns.Msg)
return len(b), r.Msg.Unpack(b)
}
// These methods implement the dns.ResponseWriter interface from Go DNS.
func (r *gRPCresponse) Close() error { return nil }
func (r *gRPCresponse) TsigStatus() error { return nil }
func (r *gRPCresponse) TsigTimersOnly(b bool) {}
func (r *gRPCresponse) Hijack() {}
func (r *gRPCresponse) LocalAddr() net.Addr { return r.localAddr }
func (r *gRPCresponse) RemoteAddr() net.Addr { return r.remoteAddr }
func (r *gRPCresponse) Network() string { return "" }
func (r *gRPCresponse) WriteMsg(m *dns.Msg) error { r.Msg = m; return nil }
package dnsserver
import (
"context"
"crypto/tls"
"fmt"
stdlog "log"
"net"
"net/http"
"strconv"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/metrics/vars"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/doh"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
)
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
type ServerHTTPS struct {
*Server
httpsServer *http.Server
listenAddr net.Addr
tlsConfig *tls.Config
validRequest func(*http.Request) bool
}
// loggerAdapter is a simple adapter around CoreDNS logger made to implement io.Writer in order to log errors from HTTP server
type loggerAdapter struct {
}
func (l *loggerAdapter) Write(p []byte) (n int, err error) {
clog.Debug(string(p))
return len(p), nil
}
// HTTPRequestKey is the context key for the HTTP request when processing DNS-over-HTTPS.
// Plugins can access the original HTTP request to retrieve headers, client IP, and metadata.
type HTTPRequestKey struct{}
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
// http/2 is recommended when using DoH. We need to specify it in next protos
// or the upgrade won't happen.
if tlsConfig != nil {
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
}
// Use a custom request validation func or use the standard DoH path check.
var validator func(*http.Request) bool
for _, z := range s.zones {
for _, conf := range z {
validator = conf.HTTPRequestValidateFunc
}
}
if validator == nil {
validator = func(r *http.Request) bool { return r.URL.Path == doh.Path }
}
srv := &http.Server{
ReadTimeout: s.ReadTimeout,
WriteTimeout: s.WriteTimeout,
IdleTimeout: s.IdleTimeout,
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
}
sh := &ServerHTTPS{
Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator,
}
sh.httpsServer.Handler = sh
return sh, nil
}
// Compile-time check to ensure ServerHTTPS implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &ServerHTTPS{}
// Serve implements caddy.TCPServer interface.
func (s *ServerHTTPS) Serve(l net.Listener) error {
s.m.Lock()
s.listenAddr = l.Addr()
s.m.Unlock()
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
return s.httpsServer.Serve(l)
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerHTTPS) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServerHTTPS) ListenPacket() (net.PacketConn, error) { return nil, nil }
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServerHTTPS) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Stop stops the server. It blocks until the server is totally stopped.
func (s *ServerHTTPS) Stop() error {
s.m.Lock()
defer s.m.Unlock()
if s.httpsServer != nil {
s.httpsServer.Shutdown(context.Background())
}
return nil
}
// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the plugin
// chain, converts it back and write it to the client.
func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !s.validRequest(r) {
http.Error(w, "", http.StatusNotFound)
s.countResponse(http.StatusNotFound)
return
}
msg, err := doh.RequestToMsg(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
s.countResponse(http.StatusBadRequest)
return
}
// Create a DoHWriter with the correct addresses in it.
h, p, _ := net.SplitHostPort(r.RemoteAddr)
port, _ := strconv.Atoi(p)
dw := &DoHWriter{
laddr: s.listenAddr,
raddr: &net.TCPAddr{IP: net.ParseIP(h), Port: port},
request: r,
}
// We just call the normal chain handler - all error handling is done there.
// We should expect a packet to be returned that we can send to the client.
// Propagate HTTP request context to DNS processing chain. This ensures that
// HTTP request timeouts, cancellations, and other context values are properly
// inherited by the DNS processing pipeline.
ctx := context.WithValue(r.Context(), Key{}, s.Server)
ctx = context.WithValue(ctx, LoopKey{}, 0)
ctx = context.WithValue(ctx, HTTPRequestKey{}, r)
s.ServeDNS(ctx, dw, msg)
// See section 4.2.1 of RFC 8484.
// We are using code 500 to indicate an unexpected situation when the chain
// handler has not provided any response message.
if dw.Msg == nil {
http.Error(w, "No response", http.StatusInternalServerError)
s.countResponse(http.StatusInternalServerError)
return
}
buf, _ := dw.Msg.Pack()
mt, _ := response.Typify(dw.Msg, time.Now().UTC())
age := dnsutil.MinimalTTL(dw.Msg, mt)
w.Header().Set("Content-Type", doh.MimeType)
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", uint32(age.Seconds())))
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
w.WriteHeader(http.StatusOK)
s.countResponse(http.StatusOK)
w.Write(buf)
}
func (s *ServerHTTPS) countResponse(status int) {
vars.HTTPSResponsesCount.WithLabelValues(s.Addr, strconv.Itoa(status)).Inc()
}
// Shutdown stops the server (non gracefully).
func (s *ServerHTTPS) Shutdown() error {
if s.httpsServer != nil {
s.httpsServer.Shutdown(context.Background())
}
return nil
}
package dnsserver
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"github.com/coredns/coredns/plugin/metrics/vars"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
const (
// DoQCodeNoError is used when the connection or stream needs to be
// closed, but there is no error to signal.
DoQCodeNoError quic.ApplicationErrorCode = 0
// DoQCodeInternalError signals that the DoQ implementation encountered
// an internal error and is incapable of pursuing the transaction or the
// connection.
DoQCodeInternalError quic.ApplicationErrorCode = 1
// DoQCodeProtocolError signals that the DoQ implementation encountered
// a protocol error and is forcibly aborting the connection.
DoQCodeProtocolError quic.ApplicationErrorCode = 2
// DefaultMaxQUICStreams is the default maximum number of concurrent QUIC streams
// on a per-connection basis. RFC 9250 (DNS-over-QUIC) does not require a high
// concurrent-stream limit; normal stub or recursive resolvers open only a handful
// of streams in parallel. This default (256) is a safe upper bound.
DefaultMaxQUICStreams = 256
// DefaultQUICStreamWorkers is the default number of workers for processing QUIC streams.
DefaultQUICStreamWorkers = 1024
)
// ServerQUIC represents an instance of a DNS-over-QUIC server.
type ServerQUIC struct {
*Server
listenAddr net.Addr
tlsConfig *tls.Config
quicConfig *quic.Config
quicListener *quic.Listener
maxStreams int
streamProcessPool chan struct{}
}
// NewServerQUIC returns a new CoreDNS QUIC server and compiles all plugin in to it.
func NewServerQUIC(addr string, group []*Config) (*ServerQUIC, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
if tlsConfig != nil {
tlsConfig.NextProtos = []string{"doq"}
}
maxStreams := DefaultMaxQUICStreams
if len(group) > 0 && group[0] != nil && group[0].MaxQUICStreams != nil {
maxStreams = *group[0].MaxQUICStreams
}
streamProcessPoolSize := DefaultQUICStreamWorkers
if len(group) > 0 && group[0] != nil && group[0].MaxQUICWorkerPoolSize != nil {
streamProcessPoolSize = *group[0].MaxQUICWorkerPoolSize
}
var quicConfig = &quic.Config{
MaxIdleTimeout: s.IdleTimeout,
MaxIncomingStreams: int64(maxStreams),
MaxIncomingUniStreams: int64(maxStreams),
// Enable 0-RTT by default for all connections on the server-side.
Allow0RTT: true,
}
return &ServerQUIC{
Server: s,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
maxStreams: maxStreams,
streamProcessPool: make(chan struct{}, streamProcessPoolSize),
}, nil
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServerQUIC) ServePacket(p net.PacketConn) error {
s.m.Lock()
s.listenAddr = s.quicListener.Addr()
s.m.Unlock()
return s.ServeQUIC()
}
// ServeQUIC listens for incoming QUIC packets.
func (s *ServerQUIC) ServeQUIC() error {
for {
conn, err := s.quicListener.Accept(context.Background())
if err != nil {
if s.isExpectedErr(err) {
s.closeQUICConn(conn, DoQCodeNoError)
return err
}
s.closeQUICConn(conn, DoQCodeInternalError)
return err
}
go s.serveQUICConnection(conn)
}
}
// serveQUICConnection handles a new QUIC connection. It waits for new streams
// and passes them to serveQUICStream.
func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) {
if conn == nil {
return
}
for {
// In DoQ, one query consumes one stream.
// The client MUST select the next available client-initiated bidirectional
// stream for each subsequent query on a QUIC connection.
stream, err := conn.AcceptStream(context.Background())
if err != nil {
if s.isExpectedErr(err) {
s.closeQUICConn(conn, DoQCodeNoError)
return
}
s.closeQUICConn(conn, DoQCodeInternalError)
return
}
// Use a bounded worker pool
s.streamProcessPool <- struct{}{} // Acquire a worker slot, may block
go func(st *quic.Stream, cn *quic.Conn) {
defer func() { <-s.streamProcessPool }() // Release worker slot
s.serveQUICStream(st, cn)
}(stream, conn)
}
}
func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) {
if conn == nil {
return
}
if stream == nil {
s.closeQUICConn(conn, DoQCodeInternalError)
return
}
buf, err := readDOQMessage(stream)
// io.EOF does not really mean that there's any error, it is just
// the STREAM FIN indicating that there will be no data to read
// anymore from this stream.
if err != nil && err != io.EOF {
s.closeQUICConn(conn, DoQCodeProtocolError)
return
}
req := &dns.Msg{}
err = req.Unpack(buf)
if err != nil {
clog.Debugf("unpacking quic packet: %s", err)
s.closeQUICConn(conn, DoQCodeProtocolError)
return
}
if !validRequest(req) {
// If a peer encounters such an error condition, it is considered a
// fatal error. It SHOULD forcibly abort the connection using QUIC's
// CONNECTION_CLOSE mechanism and SHOULD use the DoQ error code
// DOQ_PROTOCOL_ERROR.
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.3.3-3
s.closeQUICConn(conn, DoQCodeProtocolError)
return
}
w := &DoQWriter{
localAddr: conn.LocalAddr(),
remoteAddr: conn.RemoteAddr(),
stream: stream,
Msg: req,
}
dnsCtx := context.WithValue(stream.Context(), Key{}, s.Server)
dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0)
s.ServeDNS(dnsCtx, w, req)
s.countResponse(DoQCodeNoError)
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServerQUIC) ListenPacket() (net.PacketConn, error) {
p, err := reuseport.ListenPacket("udp", s.Addr[len(transport.QUIC+"://"):])
if err != nil {
return nil, err
}
s.m.Lock()
defer s.m.Unlock()
s.quicListener, err = quic.Listen(p, s.tlsConfig, s.quicConfig)
if err != nil {
return nil, err
}
return p, nil
}
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServerQUIC) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.QUIC+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Stop stops the server non-gracefully. It blocks until the server is totally stopped.
func (s *ServerQUIC) Stop() error {
s.m.Lock()
defer s.m.Unlock()
if s.quicListener != nil {
return s.quicListener.Close()
}
return nil
}
// Serve implements caddy.TCPServer interface.
func (s *ServerQUIC) Serve(l net.Listener) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerQUIC) Listen() (net.Listener, error) { return nil, nil }
// closeQUICConn quietly closes the QUIC connection.
func (s *ServerQUIC) closeQUICConn(conn *quic.Conn, code quic.ApplicationErrorCode) {
if conn == nil {
return
}
clog.Debugf("closing quic conn %s with code %d", conn.LocalAddr(), code)
err := conn.CloseWithError(code, "")
if err != nil {
clog.Debugf("failed to close quic connection with code %d: %s", code, err)
}
// DoQCodeNoError metrics are already registered after s.ServeDNS()
if code != DoQCodeNoError {
s.countResponse(code)
}
}
// validRequest checks for protocol errors in the unpacked DNS message.
// See https://www.rfc-editor.org/rfc/rfc9250.html#name-protocol-errors
func validRequest(req *dns.Msg) (ok bool) {
// 1. a client or server receives a message with a non-zero Message ID.
if req.Id != 0 {
return false
}
// 2. an implementation receives a message containing the edns-tcp-keepalive
// EDNS(0) Option [RFC7828].
if opt := req.IsEdns0(); opt != nil {
for _, option := range opt.Option {
if option.Option() == dns.EDNS0TCPKEEPALIVE {
clog.Debug("client sent EDNS0 TCP keepalive option")
return false
}
}
}
// 3. the client or server does not indicate the expected STREAM FIN after
// sending requests or responses.
//
// This is quite problematic to validate this case since this would imply
// we have to wait until STREAM FIN is arrived before we start processing
// the message. So we're consciously ignoring this case in this
// implementation.
// 4. a server receives a "replayable" transaction in 0-RTT data
//
// The information necessary to validate this is not exposed by quic-go.
return true
}
// readDOQMessage reads a DNS over QUIC (DOQ) message from the given stream
// and returns the message bytes.
// Drafts of the RFC9250 did not require the 2-byte prefixed message length.
// Thus, we are only supporting the official version (DoQ v1).
func readDOQMessage(r io.Reader) ([]byte, error) {
// All DNS messages (queries and responses) sent over DoQ connections MUST
// be encoded as a 2-octet length field followed by the message content as
// specified in [RFC1035].
// See https://www.rfc-editor.org/rfc/rfc9250.html#section-4.2-4
sizeBuf := make([]byte, 2)
_, err := io.ReadFull(r, sizeBuf)
if err != nil {
return nil, err
}
size := binary.BigEndian.Uint16(sizeBuf)
if size == 0 {
return nil, fmt.Errorf("message size is 0: probably unsupported DoQ version")
}
buf := make([]byte, size)
_, err = io.ReadFull(r, buf)
// A client or server receives a STREAM FIN before receiving all the bytes
// for a message indicated in the 2-octet length field.
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.3.3-2.2
if size != uint16(len(buf)) {
return nil, fmt.Errorf("message size does not match 2-byte prefix")
}
return buf, err
}
// isExpectedErr returns true if err is an expected error, likely related to
// the current implementation.
func (s *ServerQUIC) isExpectedErr(err error) bool {
if err == nil {
return false
}
// This error is returned when the QUIC listener was closed by us. As
// graceful shutdown is not implemented, the connection will be abruptly
// closed but there is no error to signal.
if errors.Is(err, quic.ErrServerClosed) {
return true
}
// This error happens when the connection was closed due to a DoQ
// protocol error but there's still something to read in the closed stream.
// For example, when the message was sent without the prefixed length.
var qAppErr *quic.ApplicationError
if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 2 {
return true
}
// When a connection hits the idle timeout, quic.AcceptStream() returns
// an IdleTimeoutError. In this, case, we should just drop the connection
// with DoQCodeNoError.
var qIdleErr *quic.IdleTimeoutError
return errors.As(err, &qIdleErr)
}
func (s *ServerQUIC) countResponse(code quic.ApplicationErrorCode) {
switch code {
case DoQCodeNoError:
vars.QUICResponsesCount.WithLabelValues(s.Addr, "0x0").Inc()
case DoQCodeInternalError:
vars.QUICResponsesCount.WithLabelValues(s.Addr, "0x1").Inc()
case DoQCodeProtocolError:
vars.QUICResponsesCount.WithLabelValues(s.Addr, "0x2").Inc()
}
}
package dnsserver
import (
"context"
"crypto/tls"
"fmt"
"net"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
// ServerTLS represents an instance of a TLS-over-DNS-server.
type ServerTLS struct {
*Server
tlsConfig *tls.Config
}
// NewServerTLS returns a new CoreDNS TLS server and compiles all plugin in to it.
func NewServerTLS(addr string, group []*Config) (*ServerTLS, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
return &ServerTLS{Server: s, tlsConfig: tlsConfig}, nil
}
// Compile-time check to ensure ServerTLS implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &ServerTLS{}
// Serve implements caddy.TCPServer interface.
func (s *ServerTLS) Serve(l net.Listener) error {
s.m.Lock()
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
// Only fill out the TCP server for this one.
s.server[tcp] = &dns.Server{Listener: l,
Net: "tcp-tls",
MaxTCPQueries: tlsMaxQueries,
ReadTimeout: s.ReadTimeout,
WriteTimeout: s.WriteTimeout,
IdleTimeout: func() time.Duration {
return s.IdleTimeout
},
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s.Server)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
return s.server[tcp].ActivateAndServe()
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerTLS) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.TLS+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServerTLS) ListenPacket() (net.PacketConn, error) { return nil, nil }
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServerTLS) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.TLS+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
const (
tlsMaxQueries = -1
)
// Package coremain contains the functions for starting CoreDNS.
package coremain
import (
"flag"
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"go.uber.org/automaxprocs/maxprocs"
)
func init() {
caddy.DefaultConfigFile = "Corefile"
caddy.Quiet = true // don't show init stuff from caddy
setVersion()
flag.StringVar(&conf, "conf", "", "Corefile to load (default \""+caddy.DefaultConfigFile+"\")")
flag.BoolVar(&plugins, "plugins", false, "List installed plugins")
flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file")
flag.BoolVar(&version, "version", false, "Show version")
flag.BoolVar(&dnsserver.Quiet, "quiet", false, "Quiet mode (no initialization output)")
caddy.RegisterCaddyfileLoader("flag", caddy.LoaderFunc(confLoader))
caddy.SetDefaultCaddyfileLoader("default", caddy.LoaderFunc(defaultLoader))
flag.StringVar(&dnsserver.Port, serverType+".port", dnsserver.DefaultPort, "Default port")
flag.StringVar(&dnsserver.Port, "p", dnsserver.DefaultPort, "Default port")
caddy.AppName = CoreName
caddy.AppVersion = CoreVersion
}
// Run is CoreDNS's main() function.
func Run() {
caddy.TrapSignals()
flag.Parse()
if len(flag.Args()) > 0 {
mustLogFatal(fmt.Errorf("extra command line arguments: %s", flag.Args()))
}
log.SetOutput(os.Stdout)
log.SetFlags(LogFlags)
if version {
showVersion()
os.Exit(0)
}
if plugins {
fmt.Println(caddy.DescribePlugins())
os.Exit(0)
}
_, err := maxprocs.Set(maxprocs.Logger(log.Printf))
if err != nil {
log.Println("[WARNING] Failed to set GOMAXPROCS:", err)
}
// Get Corefile input
corefile, err := caddy.LoadCaddyfile(serverType)
if err != nil {
mustLogFatal(err)
}
// Start your engines
instance, err := caddy.Start(corefile)
if err != nil {
mustLogFatal(err)
}
if !dnsserver.Quiet {
showVersion()
}
// Twiddle your thumbs
instance.Wait()
}
// mustLogFatal wraps log.Fatal() in a way that ensures the
// output is always printed to stderr so the user can see it
// if the user is still there, even if the process log was not
// enabled. If this process is an upgrade, however, and the user
// might not be there anymore, this just logs to the process
// log and exits.
func mustLogFatal(args ...any) {
if !caddy.IsUpgrade() {
log.SetOutput(os.Stderr)
}
log.Fatal(args...)
}
// confLoader loads the Caddyfile using the -conf flag.
func confLoader(serverType string) (caddy.Input, error) {
if conf == "" {
return nil, nil
}
if conf == "stdin" {
return caddy.CaddyfileFromPipe(os.Stdin, serverType)
}
contents, err := os.ReadFile(filepath.Clean(conf))
if err != nil {
return nil, err
}
return caddy.CaddyfileInput{
Contents: contents,
Filepath: conf,
ServerTypeName: serverType,
}, nil
}
// defaultLoader loads the Corefile from the current working directory.
func defaultLoader(serverType string) (caddy.Input, error) {
contents, err := os.ReadFile(caddy.DefaultConfigFile)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
return caddy.CaddyfileInput{
Contents: contents,
Filepath: caddy.DefaultConfigFile,
ServerTypeName: serverType,
}, nil
}
// showVersion prints the version that is starting.
func showVersion() {
fmt.Print(versionString())
fmt.Print(releaseString())
if devBuild && gitShortStat != "" {
fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified)
}
}
// versionString returns the CoreDNS version as a string.
func versionString() string {
return fmt.Sprintf("%s-%s\n", caddy.AppName, caddy.AppVersion)
}
// releaseString returns the release information related to CoreDNS version:
// <OS>/<ARCH>, <go version>, <commit>
// e.g.,
// linux/amd64, go1.8.3, a6d2d7b5
func releaseString() string {
return fmt.Sprintf("%s/%s, %s, %s\n", runtime.GOOS, runtime.GOARCH, runtime.Version(), GitCommit)
}
// setVersion figures out the version information
// based on variables set by -ldflags.
func setVersion() {
// A development build is one that's not at a tag or has uncommitted changes
devBuild = gitTag == "" || gitShortStat != ""
// Only set the appVersion if -ldflags was used
if gitNearestTag != "" || gitTag != "" {
if devBuild && gitNearestTag != "" {
appVersion = fmt.Sprintf("%s (+%s %s)", strings.TrimPrefix(gitNearestTag, "v"), GitCommit, buildDate)
} else if gitTag != "" {
appVersion = strings.TrimPrefix(gitTag, "v")
}
}
}
// Flags that control program flow or startup
var (
conf string
version bool
plugins bool
// LogFlags are initially set to 0 for no extra output
LogFlags int
)
// Build information obtained with the help of -ldflags
var (
appVersion = "(untracked dev build)" // inferred at startup
devBuild = true // inferred at startup
buildDate string // date -u
gitTag string // git describe --exact-match HEAD 2> /dev/null
gitNearestTag string // git describe --abbrev=0 --tags HEAD
gitShortStat string // git diff-index --shortstat
gitFilesModified string // git diff-index --name-only HEAD
// GitCommit contains the commit where we built CoreDNS from.
GitCommit string
)
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.19.4
// source: dns.proto
package pb
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type DnsPacket struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Msg []byte `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"`
}
func (x *DnsPacket) Reset() {
*x = DnsPacket{}
if protoimpl.UnsafeEnabled {
mi := &file_dns_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DnsPacket) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DnsPacket) ProtoMessage() {}
func (x *DnsPacket) ProtoReflect() protoreflect.Message {
mi := &file_dns_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DnsPacket.ProtoReflect.Descriptor instead.
func (*DnsPacket) Descriptor() ([]byte, []int) {
return file_dns_proto_rawDescGZIP(), []int{0}
}
func (x *DnsPacket) GetMsg() []byte {
if x != nil {
return x.Msg
}
return nil
}
var File_dns_proto protoreflect.FileDescriptor
var file_dns_proto_rawDesc = []byte{
0x0a, 0x09, 0x64, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0b, 0x63, 0x6f, 0x72,
0x65, 0x64, 0x6e, 0x73, 0x2e, 0x64, 0x6e, 0x73, 0x22, 0x1d, 0x0a, 0x09, 0x44, 0x6e, 0x73, 0x50,
0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x18, 0x01, 0x20, 0x01,
0x28, 0x0c, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x32, 0x45, 0x0a, 0x0a, 0x44, 0x6e, 0x73, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x37, 0x0a, 0x05, 0x51, 0x75, 0x65, 0x72, 0x79, 0x12, 0x16,
0x2e, 0x63, 0x6f, 0x72, 0x65, 0x64, 0x6e, 0x73, 0x2e, 0x64, 0x6e, 0x73, 0x2e, 0x44, 0x6e, 0x73,
0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x64, 0x6e, 0x73,
0x2e, 0x64, 0x6e, 0x73, 0x2e, 0x44, 0x6e, 0x73, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x42, 0x06,
0x5a, 0x04, 0x2e, 0x3b, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_dns_proto_rawDescOnce sync.Once
file_dns_proto_rawDescData = file_dns_proto_rawDesc
)
func file_dns_proto_rawDescGZIP() []byte {
file_dns_proto_rawDescOnce.Do(func() {
file_dns_proto_rawDescData = protoimpl.X.CompressGZIP(file_dns_proto_rawDescData)
})
return file_dns_proto_rawDescData
}
var file_dns_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_dns_proto_goTypes = []interface{}{
(*DnsPacket)(nil), // 0: coredns.dns.DnsPacket
}
var file_dns_proto_depIdxs = []int32{
0, // 0: coredns.dns.DnsService.Query:input_type -> coredns.dns.DnsPacket
0, // 1: coredns.dns.DnsService.Query:output_type -> coredns.dns.DnsPacket
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_dns_proto_init() }
func file_dns_proto_init() {
if File_dns_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_dns_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DnsPacket); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_dns_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_dns_proto_goTypes,
DependencyIndexes: file_dns_proto_depIdxs,
MessageInfos: file_dns_proto_msgTypes,
}.Build()
File_dns_proto = out.File
file_dns_proto_rawDesc = nil
file_dns_proto_goTypes = nil
file_dns_proto_depIdxs = nil
}
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.19.4
// source: dns.proto
package pb
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// DnsServiceClient is the client API for DnsService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type DnsServiceClient interface {
Query(ctx context.Context, in *DnsPacket, opts ...grpc.CallOption) (*DnsPacket, error)
}
type dnsServiceClient struct {
cc grpc.ClientConnInterface
}
func NewDnsServiceClient(cc grpc.ClientConnInterface) DnsServiceClient {
return &dnsServiceClient{cc}
}
func (c *dnsServiceClient) Query(ctx context.Context, in *DnsPacket, opts ...grpc.CallOption) (*DnsPacket, error) {
out := new(DnsPacket)
err := c.cc.Invoke(ctx, "/coredns.dns.DnsService/Query", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DnsServiceServer is the server API for DnsService service.
// All implementations must embed UnimplementedDnsServiceServer
// for forward compatibility
type DnsServiceServer interface {
Query(context.Context, *DnsPacket) (*DnsPacket, error)
mustEmbedUnimplementedDnsServiceServer()
}
// UnimplementedDnsServiceServer must be embedded to have forward compatible implementations.
type UnimplementedDnsServiceServer struct {
}
func (UnimplementedDnsServiceServer) Query(context.Context, *DnsPacket) (*DnsPacket, error) {
return nil, status.Errorf(codes.Unimplemented, "method Query not implemented")
}
func (UnimplementedDnsServiceServer) mustEmbedUnimplementedDnsServiceServer() {}
// UnsafeDnsServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to DnsServiceServer will
// result in compilation errors.
type UnsafeDnsServiceServer interface {
mustEmbedUnimplementedDnsServiceServer()
}
func RegisterDnsServiceServer(s grpc.ServiceRegistrar, srv DnsServiceServer) {
s.RegisterService(&DnsService_ServiceDesc, srv)
}
func _DnsService_Query_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DnsPacket)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DnsServiceServer).Query(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/coredns.dns.DnsService/Query",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DnsServiceServer).Query(ctx, req.(*DnsPacket))
}
return interceptor(ctx, in, info, handler)
}
// DnsService_ServiceDesc is the grpc.ServiceDesc for DnsService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var DnsService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "coredns.dns.DnsService",
HandlerType: (*DnsServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Query",
Handler: _DnsService_Query_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "dns.proto",
}
package acl
import (
"context"
"net"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/infobloxopen/go-trees/iptree"
"github.com/miekg/dns"
)
// ACL enforces access control policies on DNS queries.
type ACL struct {
Next plugin.Handler
Rules []rule
}
// rule defines a list of Zones and some ACL policies which will be
// enforced on them.
type rule struct {
zones []string
policies []policy
}
// action defines the action against queries.
type action int
// policy defines the ACL policy for DNS queries.
// A policy performs the specified action (block/allow) on all DNS queries
// matched by source IP or QTYPE.
type policy struct {
action action
qtypes map[uint16]struct{}
filter *iptree.Tree
}
const (
// actionNone does nothing on the queries.
actionNone = iota
// actionAllow allows authorized queries to recurse.
actionAllow
// actionBlock blocks unauthorized queries towards protected DNS zones.
actionBlock
// actionFilter returns empty sets for queries towards protected DNS zones.
actionFilter
// actionDrop does not respond for queries towards the protected DNS zones.
actionDrop
)
var log = clog.NewWithPlugin("acl")
// ServeDNS implements the plugin.Handler interface.
func (a ACL) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
RulesCheckLoop:
for _, rule := range a.Rules {
// check zone.
zone := plugin.Zones(rule.zones).Matches(state.Name())
if zone == "" {
continue
}
action := matchWithPolicies(rule.policies, w, r)
switch action {
case actionDrop:
{
RequestDropCount.WithLabelValues(metrics.WithServer(ctx), zone, metrics.WithView(ctx)).Inc()
return dns.RcodeSuccess, nil
}
case actionBlock:
{
m := new(dns.Msg).
SetRcode(r, dns.RcodeRefused).
SetEdns0(4096, true)
ede := dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeBlocked}
m.IsEdns0().Option = append(m.IsEdns0().Option, &ede)
w.WriteMsg(m)
RequestBlockCount.WithLabelValues(metrics.WithServer(ctx), zone, metrics.WithView(ctx)).Inc()
return dns.RcodeSuccess, nil
}
case actionAllow:
{
break RulesCheckLoop
}
case actionFilter:
{
m := new(dns.Msg).
SetRcode(r, dns.RcodeSuccess).
SetEdns0(4096, true)
ede := dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeFiltered}
m.IsEdns0().Option = append(m.IsEdns0().Option, &ede)
w.WriteMsg(m)
RequestFilterCount.WithLabelValues(metrics.WithServer(ctx), zone, metrics.WithView(ctx)).Inc()
return dns.RcodeSuccess, nil
}
}
}
RequestAllowCount.WithLabelValues(metrics.WithServer(ctx), metrics.WithView(ctx)).Inc()
return plugin.NextOrFailure(state.Name(), a.Next, ctx, w, r)
}
// matchWithPolicies matches the DNS query with a list of ACL polices and returns suitable
// action against the query.
func matchWithPolicies(policies []policy, w dns.ResponseWriter, r *dns.Msg) action {
state := request.Request{W: w, Req: r}
var ip net.IP
if idx := strings.IndexByte(state.IP(), '%'); idx >= 0 {
ip = net.ParseIP(state.IP()[:idx])
} else {
ip = net.ParseIP(state.IP())
}
// if the parsing did not return a proper response then we simply return 'actionBlock' to
// block the query
if ip == nil {
log.Errorf("Blocking request. Unable to parse source address: %v", state.IP())
return actionBlock
}
qtype := state.QType()
for _, policy := range policies {
// dns.TypeNone matches all query types.
_, matchAll := policy.qtypes[dns.TypeNone]
_, match := policy.qtypes[qtype]
if !matchAll && !match {
continue
}
_, contained := policy.filter.GetByIP(ip)
if !contained {
continue
}
// matched.
return policy.action
}
return actionNone
}
// Name implements the plugin.Handler interface.
func (a ACL) Name() string {
return "acl"
}
package acl
import (
"net"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/infobloxopen/go-trees/iptree"
"github.com/miekg/dns"
)
const pluginName = "acl"
func init() { plugin.Register(pluginName, setup) }
func newDefaultFilter() *iptree.Tree {
defaultFilter := iptree.NewTree()
_, IPv4All, _ := net.ParseCIDR("0.0.0.0/0")
_, IPv6All, _ := net.ParseCIDR("::/0")
defaultFilter.InplaceInsertNet(IPv4All, struct{}{})
defaultFilter.InplaceInsertNet(IPv6All, struct{}{})
return defaultFilter
}
func setup(c *caddy.Controller) error {
a, err := parse(c)
if err != nil {
return plugin.Error(pluginName, err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
a.Next = next
return a
})
return nil
}
func parse(c *caddy.Controller) (ACL, error) {
a := ACL{}
for c.Next() {
r := rule{}
args := c.RemainingArgs()
r.zones = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys)
for c.NextBlock() {
p := policy{}
action := strings.ToLower(c.Val())
switch action {
case "allow":
p.action = actionAllow
case "block":
p.action = actionBlock
case "filter":
p.action = actionFilter
case "drop":
p.action = actionDrop
default:
return a, c.Errf("unexpected token %q; expect 'allow', 'block', 'filter' or 'drop'", c.Val())
}
p.qtypes = make(map[uint16]struct{})
p.filter = iptree.NewTree()
hasTypeSection := false
hasNetSection := false
remainingTokens := c.RemainingArgs()
for len(remainingTokens) > 0 {
if !isPreservedIdentifier(remainingTokens[0]) {
return a, c.Errf("unexpected token %q; expect 'type | net'", remainingTokens[0])
}
section := strings.ToLower(remainingTokens[0])
i := 1
var tokens []string
for ; i < len(remainingTokens) && !isPreservedIdentifier(remainingTokens[i]); i++ {
tokens = append(tokens, remainingTokens[i])
}
remainingTokens = remainingTokens[i:]
if len(tokens) == 0 {
return a, c.Errf("no token specified in %q section", section)
}
switch section {
case "type":
hasTypeSection = true
for _, token := range tokens {
if token == "*" {
p.qtypes[dns.TypeNone] = struct{}{}
break
}
qtype, ok := dns.StringToType[token]
if !ok {
return a, c.Errf("unexpected token %q; expect legal QTYPE", token)
}
p.qtypes[qtype] = struct{}{}
}
case "net":
hasNetSection = true
for _, token := range tokens {
if token == "*" {
p.filter = newDefaultFilter()
break
}
token = normalize(token)
_, source, err := net.ParseCIDR(token)
if err != nil {
return a, c.Errf("illegal CIDR notation %q", token)
}
p.filter.InplaceInsertNet(source, struct{}{})
}
default:
return a, c.Errf("unexpected token %q; expect 'type | net'", section)
}
}
// optional `type` section means all record types.
if !hasTypeSection {
p.qtypes[dns.TypeNone] = struct{}{}
}
// optional `net` means all ip addresses.
if !hasNetSection {
p.filter = newDefaultFilter()
}
r.policies = append(r.policies, p)
}
a.Rules = append(a.Rules, r)
}
return a, nil
}
func isPreservedIdentifier(token string) bool {
identifier := strings.ToLower(token)
return identifier == "type" || identifier == "net"
}
// normalize appends '/32' for any single IPv4 address and '/128' for IPv6.
func normalize(rawNet string) string {
if idx := strings.IndexAny(rawNet, "/"); idx >= 0 {
return rawNet
}
if idx := strings.IndexAny(rawNet, ":"); idx >= 0 {
return rawNet + "/128"
}
return rawNet + "/32"
}
package any
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
// Any is a plugin that returns a HINFO reply to ANY queries.
type Any struct {
Next plugin.Handler
}
// ServeDNS implements the plugin.Handler interface.
func (a Any) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if r.Question[0].Qtype != dns.TypeANY {
return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r)
}
m := new(dns.Msg)
m.SetReply(r)
hdr := dns.RR_Header{Name: r.Question[0].Name, Ttl: 8482, Class: dns.ClassINET, Rrtype: dns.TypeHINFO}
m.Answer = []dns.RR{&dns.HINFO{Hdr: hdr, Cpu: "ANY obsoleted", Os: "See RFC 8482"}}
w.WriteMsg(m)
return 0, nil
}
// Name implements the Handler interface.
func (a Any) Name() string { return "any" }
package any
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("any", setup) }
func setup(c *caddy.Controller) error {
a := Any{}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
a.Next = next
return a
})
return nil
}
// Package auto implements an on-the-fly loading file backend.
package auto
import (
"context"
"regexp"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/file"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/plugin/transfer"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type (
// Auto holds the zones and the loader configuration for automatically loading zones.
Auto struct {
Next plugin.Handler
*Zones
metrics *metrics.Metrics
transfer *transfer.Transfer
loader
}
loader struct {
directory string
template string
re *regexp.Regexp
ReloadInterval time.Duration
upstream *upstream.Upstream // Upstream for looking up names during the resolution process.
}
)
// ServeDNS implements the plugin.Handler interface.
func (a Auto) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.Name()
// Precheck with the origins, i.e. are we allowed to look here?
zone := plugin.Zones(a.Zones.Origins()).Matches(qname)
if zone == "" {
return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r)
}
// Now the real zone.
zone = plugin.Zones(a.Zones.Names()).Matches(qname)
if zone == "" {
// If no next plugin is configured, it's more correct to return REFUSED as auto acts as an authoritative server
if a.Next == nil {
return dns.RcodeRefused, nil
}
return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r)
}
a.RLock()
z, ok := a.Z[zone]
a.RUnlock()
if !ok || z == nil {
return dns.RcodeServerFailure, nil
}
// If transfer is not loaded, we'll see these, answer with refused (no transfer allowed).
if state.QType() == dns.TypeAXFR || state.QType() == dns.TypeIXFR {
return dns.RcodeRefused, nil
}
answer, ns, extra, result := z.Lookup(ctx, state, qname)
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
m.Answer, m.Ns, m.Extra = answer, ns, extra
switch result {
case file.Success:
case file.NoData:
case file.NameError:
m.Rcode = dns.RcodeNameError
case file.Delegation:
m.Authoritative = false
case file.ServerFailure:
// If the result is SERVFAIL and the answer is non-empty, then the SERVFAIL came from an
// external CNAME lookup and the answer contains the CNAME with no target record. We should
// write the CNAME record to the client instead of sending an empty SERVFAIL response.
if len(m.Answer) == 0 {
return dns.RcodeServerFailure, nil
}
// The rcode in the response should be the rcode received from the target lookup. RFC 6604 section 3
m.Rcode = dns.RcodeServerFailure
}
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
// Name implements the Handler interface.
func (a Auto) Name() string { return "auto" }
package auto
import (
"strings"
)
// rewriteToExpand rewrites our template string to one that we can give to regexp.ExpandString. This basically
// involves prefixing any '{' with a '$'.
func rewriteToExpand(s string) string {
// Pretty dumb at the moment, every { will get a $ prefixed.
// Also wasteful as we build the string with +=. This is OKish
// as we do this during config parsing.
var copySb strings.Builder
for _, c := range s {
if c == '{' {
copySb.WriteString("$")
}
copySb.WriteString(string(c))
}
return copySb.String()
}
package auto
import (
"errors"
"os"
"path/filepath"
"regexp"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/plugin/transfer"
)
var log = clog.NewWithPlugin("auto")
func init() { plugin.Register("auto", setup) }
func setup(c *caddy.Controller) error {
a, err := autoParse(c)
if err != nil {
return plugin.Error("auto", err)
}
c.OnStartup(func() error {
m := dnsserver.GetConfig(c).Handler("prometheus")
if m != nil {
(&a).metrics = m.(*metrics.Metrics)
}
t := dnsserver.GetConfig(c).Handler("transfer")
if t != nil {
(&a).transfer = t.(*transfer.Transfer)
}
return nil
})
walkChan := make(chan bool)
c.OnStartup(func() error {
err := a.Walk()
if err != nil {
return err
}
if err := a.Notify(); err != nil {
log.Warning(err)
}
if a.ReloadInterval == 0 {
return nil
}
go func() {
ticker := time.NewTicker(a.ReloadInterval)
defer ticker.Stop()
for {
select {
case <-walkChan:
return
case <-ticker.C:
a.Walk()
if err := a.Notify(); err != nil {
log.Warning(err)
}
}
}
}()
return nil
})
c.OnShutdown(func() error {
close(walkChan)
for _, z := range a.Z {
z.Lock()
z.OnShutdown()
z.Unlock()
}
return nil
})
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
a.Next = next
return a
})
return nil
}
func autoParse(c *caddy.Controller) (Auto, error) {
nilInterval := -1 * time.Second
var a = Auto{
loader: loader{
template: "${1}",
re: regexp.MustCompile(`db\.(.*)`),
ReloadInterval: nilInterval,
},
Zones: &Zones{},
}
config := dnsserver.GetConfig(c)
for c.Next() {
// auto [ZONES...]
args := c.RemainingArgs()
a.origins = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys)
a.upstream = upstream.New()
for c.NextBlock() {
switch c.Val() {
case "directory": // directory DIR [REGEXP TEMPLATE]
if !c.NextArg() {
return a, c.ArgErr()
}
a.directory = c.Val()
if !filepath.IsAbs(a.directory) && config.Root != "" {
a.directory = filepath.Join(config.Root, a.directory)
}
_, err := os.Stat(a.directory)
if err != nil {
if !os.IsNotExist(err) {
return a, c.Errf("Unable to access root path '%s': %v", a.directory, err)
}
log.Warningf("Directory does not exist: %s", a.directory)
}
// regexp template
if c.NextArg() {
a.re, err = regexp.Compile(c.Val())
if err != nil {
return a, err
}
if a.re.NumSubexp() == 0 {
return a, c.Errf("Need at least one sub expression")
}
if !c.NextArg() {
return a, c.ArgErr()
}
a.template = rewriteToExpand(c.Val())
}
if c.NextArg() {
return Auto{}, c.ArgErr()
}
case "reload":
t := c.RemainingArgs()
if len(t) < 1 {
return a, errors.New("reload duration value is expected")
}
d, err := time.ParseDuration(t[0])
if d < 0 {
err = errors.New("invalid duration")
}
if err != nil {
return a, plugin.Error("file", err)
}
a.ReloadInterval = d
case "upstream":
// remove soon
c.RemainingArgs() // eat remaining args
default:
return Auto{}, c.Errf("unknown property '%s'", c.Val())
}
}
}
if a.ReloadInterval == nilInterval {
a.ReloadInterval = 60 * time.Second
}
return a, nil
}
package auto
import (
"os"
"path/filepath"
"regexp"
"github.com/coredns/coredns/plugin/file"
"github.com/miekg/dns"
)
// Walk will recursively walk of the file under l.directory and adds the one that match l.re.
func (a Auto) Walk() error {
// TODO(miek): should add something so that we don't stomp on each other.
toDelete := make(map[string]bool)
for _, n := range a.Names() {
toDelete[n] = true
}
filepath.Walk(a.directory, func(path string, info os.FileInfo, e error) error {
if e != nil {
log.Warningf("error reading %v: %v", path, e)
}
if info == nil || info.IsDir() {
return nil
}
match, origin := matches(a.re, info.Name(), a.template)
if !match {
return nil
}
if z, ok := a.Z[origin]; ok {
// we already have this zone
toDelete[origin] = false
z.SetFile(path)
return nil
}
reader, err := os.Open(filepath.Clean(path))
if err != nil {
log.Warningf("Opening %s failed: %s", path, err)
return nil
}
defer reader.Close()
// Serial for loading a zone is 0, because it is a new zone.
zo, err := file.Parse(reader, origin, path, 0)
if err != nil {
log.Warningf("Parse zone `%s': %v", origin, err)
return nil
}
zo.ReloadInterval = a.ReloadInterval
zo.Upstream = a.upstream
a.Add(zo, origin, a.transfer)
if a.metrics != nil {
a.metrics.AddZone(origin)
}
log.Infof("Inserting zone `%s' from: %s", origin, path)
toDelete[origin] = false
return nil
})
for origin, ok := range toDelete {
if !ok {
continue
}
if a.metrics != nil {
a.metrics.RemoveZone(origin)
}
a.Remove(origin)
log.Infof("Deleting zone `%s'", origin)
}
return nil
}
// matches re to filename, if it is a match, the subexpression will be used to expand
// template to an origin. When match is true that origin is returned. Origin is fully qualified.
func matches(re *regexp.Regexp, filename, template string) (match bool, origin string) {
base := filepath.Base(filename)
matches := re.FindStringSubmatchIndex(base)
if matches == nil {
return false, ""
}
by := re.ExpandString(nil, template, base, matches)
if by == nil {
return false, ""
}
origin = dns.Fqdn(string(by))
return true, origin
}
package auto
import (
"github.com/coredns/coredns/plugin/transfer"
"github.com/miekg/dns"
)
// Transfer implements the transfer.Transfer interface.
func (a Auto) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) {
a.RLock()
z, ok := a.Z[zone]
a.RUnlock()
if !ok || z == nil {
return nil, transfer.ErrNotAuthoritative
}
return z.Transfer(serial)
}
// Notify sends notifies for all zones with secondaries configured with the transfer plugin
func (a Auto) Notify() error {
var err error
for _, origin := range a.Names() {
e := a.transfer.Notify(origin)
if e != nil {
err = e
}
}
return err
}
package auto
import (
"sync"
"github.com/coredns/coredns/plugin/file"
"github.com/coredns/coredns/plugin/transfer"
)
// Zones maps zone names to a *Zone. This keeps track of what zones we have loaded at
// any one time.
type Zones struct {
Z map[string]*file.Zone // A map mapping zone (origin) to the Zone's data.
names []string // All the keys from the map Z as a string slice.
origins []string // Any origins from the server block.
sync.RWMutex
}
// Names returns the names from z.
func (z *Zones) Names() []string {
z.RLock()
n := z.names
z.RUnlock()
return n
}
// Origins returns the origins from z.
func (z *Zones) Origins() []string {
// doesn't need locking, because there aren't multiple Go routines accessing it.
return z.origins
}
// Zones returns a zone with origin name from z, nil when not found.
func (z *Zones) Zones(name string) *file.Zone {
z.RLock()
zo := z.Z[name]
z.RUnlock()
return zo
}
// Add adds a new zone into z. If z.ReloadInterval is not zero, the
// reload goroutine is started.
func (z *Zones) Add(zo *file.Zone, name string, t *transfer.Transfer) {
z.Lock()
if z.Z == nil {
z.Z = make(map[string]*file.Zone)
}
z.Z[name] = zo
z.names = append(z.names, name)
zo.Reload(t)
z.Unlock()
}
// Remove removes the zone named name from z. It also stops the zone's reload goroutine.
func (z *Zones) Remove(name string) {
z.Lock()
if zo, ok := z.Z[name]; ok {
zo.OnShutdown()
}
delete(z.Z, name)
// TODO(miek): just regenerate Names (might be bad if you have a lot of zones...)
z.names = []string{}
for n := range z.Z {
z.names = append(z.names, n)
}
z.Unlock()
}
/*
Package autopath implements autopathing. This is a hack; it shortcuts the
client's search path resolution by performing these lookups on the server...
The server has a copy (via AutoPathFunc) of the client's search path and on
receiving a query it first establishes if the suffix matches the FIRST configured
element. If no match can be found the query will be forwarded up the plugin
chain without interference (if, and only if, 'fallthrough' has been set).
If the query is deemed to fall in the search path the server will perform the
queries with each element of the search path appended in sequence until a
non-NXDOMAIN answer has been found. That reply will then be returned to the
client - with some CNAME hackery to let the client accept the reply.
If all queries return NXDOMAIN we return the original as-is and let the client
continue searching. The client will go to the next element in the search path,
but we won’t do any more autopathing. It means that in the failure case, you do
more work, since the server looks it up, then the client still needs to go
through the search path.
It is assume the search path ordering is identical between server and client.
Plugins implementing autopath, must have a function called `AutoPath` of type
autopath.Func. Note the searchpath must be ending with the empty string.
I.e:
func (m Plugins ) AutoPath(state request.Request) []string {
return []string{"first", "second", "last", ""}
}
*/
package autopath
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/nonwriter"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Func defines the function plugin should implement to return a search
// path to the autopath plugin. The last element of the slice must be the empty string.
// If Func returns a nil slice, no autopathing will be done.
type Func func(request.Request) []string
// AutoPather defines the interface that a plugin should implement in order to be
// used by AutoPath.
type AutoPather interface {
AutoPath(request.Request) []string
}
// AutoPath performs autopath: service side search path completion.
type AutoPath struct {
Next plugin.Handler
Zones []string
// Search always includes "" as the last element, so we try the base query with out any search paths added as well.
search []string
searchFunc Func
}
// ServeDNS implements the plugin.Handle interface.
func (a *AutoPath) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
zone := plugin.Zones(a.Zones).Matches(state.Name())
if zone == "" {
return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r)
}
// Check if autopath should be done, searchFunc takes precedence over the local configured search path.
var err error
searchpath := a.search
if a.searchFunc != nil {
searchpath = a.searchFunc(state)
}
if len(searchpath) == 0 {
return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r)
}
if !firstInSearchPath(state.Name(), searchpath) {
return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r)
}
origQName := state.QName()
// Establish base name of the query. I.e what was originally asked.
base, err := dnsutil.TrimZone(state.QName(), searchpath[0])
if err != nil {
return dns.RcodeServerFailure, err
}
firstReply := new(dns.Msg)
firstRcode := 0
var firstErr error
ar := r.Copy()
// Walk the search path and see if we can get a non-nxdomain - if they all fail we return the first
// query we've done and return that as-is. This means the client will do the search path walk again...
for i, s := range searchpath {
newQName := base + "." + s
ar.Question[0].Name = newQName
nw := nonwriter.New(w)
rcode, err := plugin.NextOrFailure(a.Name(), a.Next, ctx, nw, ar)
if err != nil {
// Return now - not sure if this is the best. We should also check if the write has happened.
return rcode, err
}
if i == 0 {
firstReply = nw.Msg
firstRcode = rcode
firstErr = err
}
if !plugin.ClientWrite(rcode) {
continue
}
if nw.Msg.Rcode == dns.RcodeNameError {
continue
}
msg := nw.Msg
cnamer(msg, origQName)
// Write whatever non-nxdomain answer we've found.
w.WriteMsg(msg)
autoPathCount.WithLabelValues(metrics.WithServer(ctx)).Add(1)
return rcode, err
}
if plugin.ClientWrite(firstRcode) {
w.WriteMsg(firstReply)
}
return firstRcode, firstErr
}
// Name implements the Handler interface.
func (a *AutoPath) Name() string { return "autopath" }
// firstInSearchPath checks if name is equal to are a sibling of the first element in the search path.
func firstInSearchPath(name string, searchpath []string) bool {
if name == searchpath[0] {
return true
}
if dns.IsSubDomain(searchpath[0], name) {
return true
}
return false
}
package autopath
import (
"strings"
"github.com/miekg/dns"
)
// cnamer will prefix the answer section with a cname that points from original qname to the
// name of the first RR. It will also update the question section and put original in there.
func cnamer(m *dns.Msg, original string) {
for _, a := range m.Answer {
if strings.EqualFold(original, a.Header().Name) {
continue
}
m.Answer = append(m.Answer, nil)
copy(m.Answer[1:], m.Answer)
m.Answer[0] = &dns.CNAME{
Hdr: dns.RR_Header{Name: original, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: a.Header().Ttl},
Target: a.Header().Name,
}
break
}
m.Question[0].Name = original
}
package autopath
import (
"fmt"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
func init() { plugin.Register("autopath", setup) }
func setup(c *caddy.Controller) error {
ap, mw, err := autoPathParse(c)
if err != nil {
return plugin.Error("autopath", err)
}
// Do this in OnStartup, so all plugin has been initialized.
c.OnStartup(func() error {
m := dnsserver.GetConfig(c).Handler(mw)
if m == nil {
return nil
}
if x, ok := m.(AutoPather); ok {
ap.searchFunc = x.AutoPath
} else {
return plugin.Error("autopath", fmt.Errorf("%s does not implement the AutoPather interface", mw))
}
return nil
})
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
ap.Next = next
return ap
})
return nil
}
func autoPathParse(c *caddy.Controller) (*AutoPath, string, error) {
ap := &AutoPath{}
mw := ""
for c.Next() {
zoneAndresolv := c.RemainingArgs()
if len(zoneAndresolv) < 1 {
return ap, "", fmt.Errorf("no resolv-conf specified")
}
resolv := zoneAndresolv[len(zoneAndresolv)-1]
if strings.HasPrefix(resolv, "@") {
mw = resolv[1:]
} else {
// assume file on disk
rc, err := dns.ClientConfigFromFile(resolv)
if err != nil {
return ap, "", fmt.Errorf("failed to parse %q: %v", resolv, err)
}
ap.search = rc.Search
plugin.Zones(ap.search).Normalize()
ap.search = append(ap.search, "") // sentinel value as demanded.
}
zones := zoneAndresolv[:len(zoneAndresolv)-1]
ap.Zones = plugin.OriginsFromArgsOrServerBlock(zones, c.ServerBlockKeys)
}
return ap, mw, nil
}
package azure
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/file"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
publicdns "github.com/Azure/azure-sdk-for-go/profiles/latest/dns/mgmt/dns"
privatedns "github.com/Azure/azure-sdk-for-go/profiles/latest/privatedns/mgmt/privatedns"
"github.com/miekg/dns"
)
type zone struct {
id string
z *file.Zone
zone string
private bool
}
type zones map[string][]*zone
// Azure is the core struct of the azure plugin.
type Azure struct {
zoneNames []string
publicClient publicdns.RecordSetsClient
privateClient privatedns.RecordSetsClient
upstream *upstream.Upstream
zMu sync.RWMutex
zones zones
Next plugin.Handler
Fall fall.F
}
// New validates the input DNS zones and initializes the Azure struct.
func New(ctx context.Context, publicClient publicdns.RecordSetsClient, privateClient privatedns.RecordSetsClient, keys map[string][]string, accessMap map[string]string) (*Azure, error) {
zones := make(map[string][]*zone, len(keys))
names := make([]string, len(keys))
var private bool
for resourceGroup, znames := range keys {
for _, name := range znames {
switch accessMap[resourceGroup+name] {
case "public":
if _, err := publicClient.ListAllByDNSZone(context.Background(), resourceGroup, name, nil, ""); err != nil {
return nil, err
}
private = false
case "private":
if _, err := privateClient.ListComplete(context.Background(), resourceGroup, name, nil, ""); err != nil {
return nil, err
}
private = true
}
fqdn := dns.Fqdn(name)
if _, ok := zones[fqdn]; !ok {
names = append(names, fqdn)
}
zones[fqdn] = append(zones[fqdn], &zone{id: resourceGroup, zone: name, private: private, z: file.NewZone(fqdn, "")})
}
}
return &Azure{
publicClient: publicClient,
privateClient: privateClient,
zones: zones,
zoneNames: names,
upstream: upstream.New(),
}, nil
}
// Run updates the zone from azure.
func (h *Azure) Run(ctx context.Context) error {
if err := h.updateZones(ctx); err != nil {
return err
}
go func() {
delay := 1 * time.Minute
timer := time.NewTimer(delay)
defer timer.Stop()
for {
timer.Reset(delay)
select {
case <-ctx.Done():
log.Debugf("Breaking out of Azure update loop for %v: %v", h.zoneNames, ctx.Err())
return
case <-timer.C:
if err := h.updateZones(ctx); err != nil && ctx.Err() == nil {
log.Errorf("Failed to update zones %v: %v", h.zoneNames, err)
}
}
}
}()
return nil
}
func (h *Azure) updateZones(ctx context.Context) error {
var err error
var publicSet publicdns.RecordSetListResultPage
var privateSet privatedns.RecordSetListResultPage
errs := make([]string, 0)
for zName, z := range h.zones {
for i, hostedZone := range z {
newZ := file.NewZone(zName, "")
if hostedZone.private {
for privateSet, err = h.privateClient.List(ctx, hostedZone.id, hostedZone.zone, nil, ""); privateSet.NotDone(); err = privateSet.NextWithContext(ctx) {
updateZoneFromPrivateResourceSet(privateSet, newZ)
}
} else {
for publicSet, err = h.publicClient.ListByDNSZone(ctx, hostedZone.id, hostedZone.zone, nil, ""); publicSet.NotDone(); err = publicSet.NextWithContext(ctx) {
updateZoneFromPublicResourceSet(publicSet, newZ)
}
}
if err != nil {
errs = append(errs, fmt.Sprintf("failed to list resource records for %v from azure: %v", hostedZone.zone, err))
}
newZ.Upstream = h.upstream
h.zMu.Lock()
(*z[i]).z = newZ
h.zMu.Unlock()
}
}
if len(errs) != 0 {
return fmt.Errorf("errors updating zones: %v", errs)
}
return nil
}
func updateZoneFromPublicResourceSet(recordSet publicdns.RecordSetListResultPage, newZ *file.Zone) {
for _, result := range *(recordSet.Response().Value) {
resultFqdn := *(result.Fqdn)
resultTTL := uint32(*(result.TTL))
if result.ARecords != nil {
for _, A := range *(result.ARecords) {
a := &dns.A{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: resultTTL},
A: net.ParseIP(*(A.Ipv4Address))}
newZ.Insert(a)
}
}
if result.AaaaRecords != nil {
for _, AAAA := range *(result.AaaaRecords) {
aaaa := &dns.AAAA{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: resultTTL},
AAAA: net.ParseIP(*(AAAA.Ipv6Address))}
newZ.Insert(aaaa)
}
}
if result.MxRecords != nil {
for _, MX := range *(result.MxRecords) {
mx := &dns.MX{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: resultTTL},
Preference: uint16(*(MX.Preference)),
Mx: dns.Fqdn(*(MX.Exchange))}
newZ.Insert(mx)
}
}
if result.PtrRecords != nil {
for _, PTR := range *(result.PtrRecords) {
ptr := &dns.PTR{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: resultTTL},
Ptr: dns.Fqdn(*(PTR.Ptrdname))}
newZ.Insert(ptr)
}
}
if result.SrvRecords != nil {
for _, SRV := range *(result.SrvRecords) {
srv := &dns.SRV{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: resultTTL},
Priority: uint16(*(SRV.Priority)),
Weight: uint16(*(SRV.Weight)),
Port: uint16(*(SRV.Port)),
Target: dns.Fqdn(*(SRV.Target))}
newZ.Insert(srv)
}
}
if result.TxtRecords != nil {
for _, TXT := range *(result.TxtRecords) {
txt := &dns.TXT{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: resultTTL},
Txt: *(TXT.Value)}
newZ.Insert(txt)
}
}
if result.NsRecords != nil {
for _, NS := range *(result.NsRecords) {
ns := &dns.NS{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: resultTTL},
Ns: *(NS.Nsdname)}
newZ.Insert(ns)
}
}
if result.SoaRecord != nil {
SOA := result.SoaRecord
soa := &dns.SOA{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: resultTTL},
Minttl: uint32(*(SOA.MinimumTTL)),
Expire: uint32(*(SOA.ExpireTime)),
Retry: uint32(*(SOA.RetryTime)),
Refresh: uint32(*(SOA.RefreshTime)),
Serial: uint32(*(SOA.SerialNumber)),
Mbox: dns.Fqdn(*(SOA.Email)),
Ns: *(SOA.Host)}
newZ.Insert(soa)
}
if result.CnameRecord != nil {
CNAME := result.CnameRecord.Cname
cname := &dns.CNAME{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: resultTTL},
Target: dns.Fqdn(*CNAME)}
newZ.Insert(cname)
}
}
}
func updateZoneFromPrivateResourceSet(recordSet privatedns.RecordSetListResultPage, newZ *file.Zone) {
for _, result := range *(recordSet.Response().Value) {
resultFqdn := *(result.Fqdn)
resultTTL := uint32(*(result.TTL))
if result.ARecords != nil {
for _, A := range *(result.ARecords) {
a := &dns.A{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: resultTTL},
A: net.ParseIP(*(A.Ipv4Address))}
newZ.Insert(a)
}
}
if result.AaaaRecords != nil {
for _, AAAA := range *(result.AaaaRecords) {
aaaa := &dns.AAAA{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: resultTTL},
AAAA: net.ParseIP(*(AAAA.Ipv6Address))}
newZ.Insert(aaaa)
}
}
if result.MxRecords != nil {
for _, MX := range *(result.MxRecords) {
mx := &dns.MX{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: resultTTL},
Preference: uint16(*(MX.Preference)),
Mx: dns.Fqdn(*(MX.Exchange))}
newZ.Insert(mx)
}
}
if result.PtrRecords != nil {
for _, PTR := range *(result.PtrRecords) {
ptr := &dns.PTR{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: resultTTL},
Ptr: dns.Fqdn(*(PTR.Ptrdname))}
newZ.Insert(ptr)
}
}
if result.SrvRecords != nil {
for _, SRV := range *(result.SrvRecords) {
srv := &dns.SRV{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: resultTTL},
Priority: uint16(*(SRV.Priority)),
Weight: uint16(*(SRV.Weight)),
Port: uint16(*(SRV.Port)),
Target: dns.Fqdn(*(SRV.Target))}
newZ.Insert(srv)
}
}
if result.TxtRecords != nil {
for _, TXT := range *(result.TxtRecords) {
txt := &dns.TXT{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: resultTTL},
Txt: *(TXT.Value)}
newZ.Insert(txt)
}
}
if result.SoaRecord != nil {
SOA := result.SoaRecord
soa := &dns.SOA{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: resultTTL},
Minttl: uint32(*(SOA.MinimumTTL)),
Expire: uint32(*(SOA.ExpireTime)),
Retry: uint32(*(SOA.RetryTime)),
Refresh: uint32(*(SOA.RefreshTime)),
Serial: uint32(*(SOA.SerialNumber)),
Mbox: dns.Fqdn(*(SOA.Email)),
Ns: dns.Fqdn(*(SOA.Host))}
newZ.Insert(soa)
}
if result.CnameRecord != nil {
CNAME := result.CnameRecord.Cname
cname := &dns.CNAME{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: resultTTL},
Target: dns.Fqdn(*CNAME)}
newZ.Insert(cname)
}
}
}
// ServeDNS implements the plugin.Handler interface.
func (h *Azure) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.Name()
zone := plugin.Zones(h.zoneNames).Matches(qname)
if zone == "" {
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
zones, ok := h.zones[zone] // ok true if we are authoritative for the zone.
if !ok || zones == nil {
return dns.RcodeServerFailure, nil
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
var result file.Result
for _, z := range zones {
h.zMu.RLock()
m.Answer, m.Ns, m.Extra, result = z.z.Lookup(ctx, state, qname)
h.zMu.RUnlock()
// record type exists for this name (NODATA).
if len(m.Answer) != 0 || result == file.NoData {
break
}
}
if len(m.Answer) == 0 && result != file.NoData && h.Fall.Through(qname) {
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
switch result {
case file.Success:
case file.NoData:
case file.NameError:
m.Rcode = dns.RcodeNameError
case file.Delegation:
m.Authoritative = false
case file.ServerFailure:
return dns.RcodeServerFailure, nil
}
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
// Name implements plugin.Handler.Name.
func (h *Azure) Name() string { return "azure" }
package azure
import (
"context"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/fall"
clog "github.com/coredns/coredns/plugin/pkg/log"
publicAzureDNS "github.com/Azure/azure-sdk-for-go/profiles/latest/dns/mgmt/dns"
privateAzureDNS "github.com/Azure/azure-sdk-for-go/profiles/latest/privatedns/mgmt/privatedns"
azurerest "github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/azure/auth"
)
var log = clog.NewWithPlugin("azure")
func init() { plugin.Register("azure", setup) }
func setup(c *caddy.Controller) error {
env, keys, accessMap, fall, err := parse(c)
if err != nil {
return plugin.Error("azure", err)
}
ctx, cancel := context.WithCancel(context.Background())
publicDNSClient := publicAzureDNS.NewRecordSetsClient(env.Values[auth.SubscriptionID])
if publicDNSClient.Authorizer, err = env.GetAuthorizer(); err != nil {
cancel()
return plugin.Error("azure", err)
}
privateDNSClient := privateAzureDNS.NewRecordSetsClient(env.Values[auth.SubscriptionID])
if privateDNSClient.Authorizer, err = env.GetAuthorizer(); err != nil {
cancel()
return plugin.Error("azure", err)
}
h, err := New(ctx, publicDNSClient, privateDNSClient, keys, accessMap)
if err != nil {
cancel()
return plugin.Error("azure", err)
}
h.Fall = fall
if err := h.Run(ctx); err != nil {
cancel()
return plugin.Error("azure", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
h.Next = next
return h
})
c.OnShutdown(func() error { cancel(); return nil })
return nil
}
func parse(c *caddy.Controller) (auth.EnvironmentSettings, map[string][]string, map[string]string, fall.F, error) {
resourceGroupMapping := map[string][]string{}
accessMap := map[string]string{}
resourceGroupSet := map[string]struct{}{}
azureEnv := azurerest.PublicCloud
env := auth.EnvironmentSettings{Values: map[string]string{}}
var fall fall.F
var access string
var resourceGroup string
var zoneName string
for c.Next() {
args := c.RemainingArgs()
for i := range args {
parts := strings.SplitN(args[i], ":", 2)
if len(parts) != 2 {
return env, resourceGroupMapping, accessMap, fall, c.Errf("invalid resource group/zone: %q", args[i])
}
resourceGroup, zoneName = parts[0], parts[1]
if resourceGroup == "" || zoneName == "" {
return env, resourceGroupMapping, accessMap, fall, c.Errf("invalid resource group/zone: %q", args[i])
}
if _, ok := resourceGroupSet[resourceGroup+zoneName]; ok {
return env, resourceGroupMapping, accessMap, fall, c.Errf("conflicting zone: %q", args[i])
}
resourceGroupSet[resourceGroup+zoneName] = struct{}{}
accessMap[resourceGroup+zoneName] = "public"
resourceGroupMapping[resourceGroup] = append(resourceGroupMapping[resourceGroup], zoneName)
}
for c.NextBlock() {
switch c.Val() {
case "subscription":
if !c.NextArg() {
return env, resourceGroupMapping, accessMap, fall, c.ArgErr()
}
env.Values[auth.SubscriptionID] = c.Val()
case "tenant":
if !c.NextArg() {
return env, resourceGroupMapping, accessMap, fall, c.ArgErr()
}
env.Values[auth.TenantID] = c.Val()
case "client":
if !c.NextArg() {
return env, resourceGroupMapping, accessMap, fall, c.ArgErr()
}
env.Values[auth.ClientID] = c.Val()
case "secret":
if !c.NextArg() {
return env, resourceGroupMapping, accessMap, fall, c.ArgErr()
}
env.Values[auth.ClientSecret] = c.Val()
case "environment":
if !c.NextArg() {
return env, resourceGroupMapping, accessMap, fall, c.ArgErr()
}
var err error
if azureEnv, err = azurerest.EnvironmentFromName(c.Val()); err != nil {
return env, resourceGroupMapping, accessMap, fall, c.Errf("cannot set azure environment: %q", err.Error())
}
case "fallthrough":
fall.SetZonesFromArgs(c.RemainingArgs())
case "access":
if !c.NextArg() {
return env, resourceGroupMapping, accessMap, fall, c.ArgErr()
}
access = c.Val()
if access != "public" && access != "private" {
return env, resourceGroupMapping, accessMap, fall, c.Errf("invalid access value: can be public/private, found: %s", access)
}
accessMap[resourceGroup+zoneName] = access
default:
return env, resourceGroupMapping, accessMap, fall, c.Errf("unknown property: %q", c.Val())
}
}
}
env.Values[auth.Resource] = azureEnv.ResourceManagerEndpoint
env.Environment = azureEnv
return env, resourceGroupMapping, accessMap, fall, nil
}
package plugin
import (
"context"
"fmt"
"math"
"net"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
const maxCnameChainLength = 10
// A returns A records from Backend or an error.
func A(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) {
services, err := checkForApex(ctx, b, zone, state, opt)
if err != nil {
return nil, false, err
}
dup := make(map[string]struct{})
for _, serv := range services {
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) {
// x CNAME x is a direct loop, don't add those
// in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results
continue
}
newRecord := serv.NewCNAME(state.QName(), serv.Host)
if len(previousRecords) > maxCnameChainLength {
// don't add it, and just continue
continue
}
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) {
state1 := state.NewWithQuestion(serv.Host, state.QType())
state1.Zone = zone
nextRecords, tc, err := A(ctx, b, zone, state1, append(previousRecords, newRecord), opt)
if err == nil {
// Not only have we found something we should add the CNAME and the IP addresses.
if len(nextRecords) > 0 {
records = append(records, newRecord)
records = append(records, nextRecords...)
}
}
if tc {
truncated = true
}
continue
}
// This means we can not complete the CNAME, try to look else where.
target := newRecord.Target
// Lookup
m1, e1 := b.Lookup(ctx, state, target, state.QType())
if e1 != nil || m1 == nil {
continue
}
if m1.Truncated {
truncated = true
}
// Len(m1.Answer) > 0 here is well?
records = append(records, newRecord)
records = append(records, m1.Answer...)
continue
case dns.TypeA:
if _, ok := dup[serv.Host]; !ok {
dup[serv.Host] = struct{}{}
records = append(records, serv.NewA(state.QName(), ip))
}
case dns.TypeAAAA:
// nada
}
}
return records, truncated, nil
}
// AAAA returns AAAA records from Backend or an error.
func AAAA(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) {
services, err := checkForApex(ctx, b, zone, state, opt)
if err != nil {
return nil, false, err
}
dup := make(map[string]struct{})
for _, serv := range services {
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
// Try to resolve as CNAME if it's not an IP, but only if we don't create loops.
if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) {
// x CNAME x is a direct loop, don't add those
// in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results
continue
}
newRecord := serv.NewCNAME(state.QName(), serv.Host)
if len(previousRecords) > maxCnameChainLength {
// don't add it, and just continue
continue
}
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) {
state1 := state.NewWithQuestion(serv.Host, state.QType())
state1.Zone = zone
nextRecords, tc, err := AAAA(ctx, b, zone, state1, append(previousRecords, newRecord), opt)
if err == nil {
// Not only have we found something we should add the CNAME and the IP addresses.
if len(nextRecords) > 0 {
records = append(records, newRecord)
records = append(records, nextRecords...)
}
}
if tc {
truncated = true
}
continue
}
// This means we can not complete the CNAME, try to look else where.
target := newRecord.Target
m1, e1 := b.Lookup(ctx, state, target, state.QType())
if e1 != nil || m1 == nil {
continue
}
if m1.Truncated {
truncated = true
}
// Len(m1.Answer) > 0 here is well?
records = append(records, newRecord)
records = append(records, m1.Answer...)
continue
// both here again
case dns.TypeA:
// nada
case dns.TypeAAAA:
if _, ok := dup[serv.Host]; !ok {
dup[serv.Host] = struct{}{}
records = append(records, serv.NewAAAA(state.QName(), ip))
}
}
}
return records, truncated, nil
}
// SRV returns SRV records from the Backend.
// If the Target is not a name but an IP address, a name is created on the fly.
func SRV(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
services, err := b.Services(ctx, state, false, opt)
if err != nil {
return nil, nil, err
}
dup := make(map[item]struct{})
lookup := make(map[string]struct{})
// Looping twice to get the right weight vs priority. This might break because we may drop duplicate SRV records latter on.
w := make(map[int]int)
for _, serv := range services {
weight := 100
if serv.Weight != 0 {
weight = serv.Weight
}
if _, ok := w[serv.Priority]; !ok {
w[serv.Priority] = weight
continue
}
w[serv.Priority] += weight
}
for _, serv := range services {
// Don't add the entry if the port is -1 (invalid). The kubernetes plugin uses port -1 when a service/endpoint
// does not have any declared ports.
if serv.Port == -1 {
continue
}
w1 := 100.0 / float64(w[serv.Priority])
if serv.Weight == 0 {
w1 *= 100
} else {
w1 *= float64(serv.Weight)
}
weight := uint16(math.Floor(w1))
// weight should be at least 1
if weight == 0 {
weight = 1
}
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
srv := serv.NewSRV(state.QName(), weight)
records = append(records, srv)
if _, ok := lookup[srv.Target]; ok {
break
}
lookup[srv.Target] = struct{}{}
if !dns.IsSubDomain(zone, srv.Target) {
m1, e1 := b.Lookup(ctx, state, srv.Target, dns.TypeA)
if e1 == nil && m1 != nil {
extra = append(extra, m1.Answer...)
}
m1, e1 = b.Lookup(ctx, state, srv.Target, dns.TypeAAAA)
if e1 == nil && m1 != nil {
// If we have seen CNAME's we *assume* that they are already added.
for _, a := range m1.Answer {
if _, ok := a.(*dns.CNAME); !ok {
extra = append(extra, a)
}
}
}
break
}
// Internal name, we should have some info on them, either v4 or v6
// Clients expect a complete answer, because we are a recursor in their view.
state1 := state.NewWithQuestion(srv.Target, dns.TypeA)
addr, _, e1 := A(ctx, b, zone, state1, nil, opt)
if e1 == nil {
extra = append(extra, addr...)
}
// TODO(miek): AAAA as well here.
case dns.TypeA, dns.TypeAAAA:
addr := serv.Host
serv.Host = msg.Domain(serv.Key)
srv := serv.NewSRV(state.QName(), weight)
if ok := isDuplicate(dup, srv.Target, "", srv.Port); !ok {
records = append(records, srv)
}
if ok := isDuplicate(dup, srv.Target, addr, 0); !ok {
extra = append(extra, newAddress(serv, srv.Target, ip, what))
}
}
}
return records, extra, nil
}
// MX returns MX records from the Backend. If the Target is not a name but an IP address, a name is created on the fly.
func MX(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
services, err := b.Services(ctx, state, false, opt)
if err != nil {
return nil, nil, err
}
dup := make(map[item]struct{})
lookup := make(map[string]struct{})
for _, serv := range services {
if !serv.Mail {
continue
}
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
mx := serv.NewMX(state.QName())
records = append(records, mx)
if _, ok := lookup[mx.Mx]; ok {
break
}
lookup[mx.Mx] = struct{}{}
if !dns.IsSubDomain(zone, mx.Mx) {
m1, e1 := b.Lookup(ctx, state, mx.Mx, dns.TypeA)
if e1 == nil && m1 != nil {
extra = append(extra, m1.Answer...)
}
m1, e1 = b.Lookup(ctx, state, mx.Mx, dns.TypeAAAA)
if e1 == nil && m1 != nil {
// If we have seen CNAME's we *assume* that they are already added.
for _, a := range m1.Answer {
if _, ok := a.(*dns.CNAME); !ok {
extra = append(extra, a)
}
}
}
break
}
// Internal name
state1 := state.NewWithQuestion(mx.Mx, dns.TypeA)
addr, _, e1 := A(ctx, b, zone, state1, nil, opt)
if e1 == nil {
extra = append(extra, addr...)
}
// TODO(miek): AAAA as well here.
case dns.TypeA, dns.TypeAAAA:
addr := serv.Host
serv.Host = msg.Domain(serv.Key)
mx := serv.NewMX(state.QName())
if ok := isDuplicate(dup, mx.Mx, "", mx.Preference); !ok {
records = append(records, mx)
}
// Fake port to be 0 for address...
if ok := isDuplicate(dup, serv.Host, addr, 0); !ok {
extra = append(extra, newAddress(serv, serv.Host, ip, what))
}
}
}
return records, extra, nil
}
// CNAME returns CNAME records from the backend or an error.
func CNAME(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) {
services, err := b.Services(ctx, state, true, opt)
if err != nil {
return nil, err
}
if len(services) > 0 {
serv := services[0]
if ip := net.ParseIP(serv.Host); ip == nil {
records = append(records, serv.NewCNAME(state.QName(), serv.Host))
}
}
return records, nil
}
// TXT returns TXT records from Backend or an error.
func TXT(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) {
services, err := b.Services(ctx, state, false, opt)
if err != nil {
return nil, false, err
}
dup := make(map[string]struct{})
for _, serv := range services {
what, _ := serv.HostType()
switch what {
case dns.TypeCNAME:
if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) {
// x CNAME x is a direct loop, don't add those
// in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results
continue
}
newRecord := serv.NewCNAME(state.QName(), serv.Host)
if len(previousRecords) > maxCnameChainLength {
// don't add it, and just continue
continue
}
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) {
state1 := state.NewWithQuestion(serv.Host, state.QType())
state1.Zone = zone
nextRecords, tc, err := TXT(ctx, b, zone, state1, append(previousRecords, newRecord), opt)
if tc {
truncated = true
}
if err == nil {
// Not only have we found something we should add the CNAME and the IP addresses.
if len(nextRecords) > 0 {
records = append(records, newRecord)
records = append(records, nextRecords...)
}
}
continue
}
// This means we can not complete the CNAME, try to look else where.
target := newRecord.Target
// Lookup
m1, e1 := b.Lookup(ctx, state, target, state.QType())
if e1 != nil || m1 == nil {
continue
}
// Len(m1.Answer) > 0 here is well?
records = append(records, newRecord)
records = append(records, m1.Answer...)
continue
case dns.TypeTXT:
if _, ok := dup[serv.Text]; !ok {
dup[serv.Text] = struct{}{}
records = append(records, serv.NewTXT(state.QName()))
}
}
}
return records, truncated, nil
}
// PTR returns the PTR records from the backend, only services that have a domain name as host are included.
func PTR(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) {
services, err := b.Reverse(ctx, state, true, opt)
if err != nil {
return nil, err
}
dup := make(map[string]struct{})
for _, serv := range services {
if ip := net.ParseIP(serv.Host); ip == nil {
if _, ok := dup[serv.Host]; !ok {
dup[serv.Host] = struct{}{}
records = append(records, serv.NewPTR(state.QName(), serv.Host))
}
}
}
return records, nil
}
// NS returns NS records from the backend
func NS(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
// NS record for this zone lives in a special place, ns.dns.<zone>. Fake our lookup.
// Only a tad bit fishy...
old := state.QName()
state.Clear()
state.Req.Question[0].Name = dnsutil.Join("ns.dns.", zone)
services, err := b.Services(ctx, state, false, opt)
// reset the query name to the original
state.Req.Question[0].Name = old
if err != nil {
return nil, nil, err
}
seen := map[string]bool{}
for _, serv := range services {
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
return nil, nil, fmt.Errorf("NS record must be an IP address: %s", serv.Host)
case dns.TypeA, dns.TypeAAAA:
serv.Host = msg.Domain(serv.Key)
ns := serv.NewNS(state.QName())
extra = append(extra, newAddress(serv, ns.Ns, ip, what))
if _, ok := seen[ns.Ns]; ok {
continue
}
seen[ns.Ns] = true
records = append(records, ns)
}
}
return records, extra, nil
}
// SOA returns a SOA record from the backend.
func SOA(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) ([]dns.RR, error) {
minTTL := b.MinTTL(state)
ttl := min(minTTL, uint32(300))
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: ttl, Class: dns.ClassINET}
Mbox := dnsutil.Join(hostmaster, zone)
Ns := dnsutil.Join("ns.dns", zone)
soa := &dns.SOA{Hdr: header,
Mbox: Mbox,
Ns: Ns,
Serial: b.Serial(state),
Refresh: 7200,
Retry: 1800,
Expire: 86400,
Minttl: minTTL,
}
return []dns.RR{soa}, nil
}
// BackendError writes an error response to the client.
func BackendError(ctx context.Context, b ServiceBackend, zone string, rcode int, state request.Request, err error, opt Options) (int, error) {
m := new(dns.Msg)
m.SetRcode(state.Req, rcode)
m.Authoritative = true
m.Ns, _ = SOA(ctx, b, zone, state, opt)
state.W.WriteMsg(m)
// Return success as the rcode to signal we have written to the client.
return dns.RcodeSuccess, err
}
func newAddress(s msg.Service, name string, ip net.IP, what uint16) dns.RR {
hdr := dns.RR_Header{Name: name, Rrtype: what, Class: dns.ClassINET, Ttl: s.TTL}
if what == dns.TypeA {
return &dns.A{Hdr: hdr, A: ip}
}
// Should always be dns.TypeAAAA
return &dns.AAAA{Hdr: hdr, AAAA: ip}
}
// checkForApex checks the special apex.dns directory for records that will be returned as A or AAAA.
func checkForApex(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) ([]msg.Service, error) {
if state.Name() != zone {
return b.Services(ctx, state, false, opt)
}
// If the zone name itself is queried we fake the query to search for a special entry
// this is equivalent to the NS search code.
old := state.QName()
state.Clear()
state.Req.Question[0].Name = dnsutil.Join("apex.dns", zone)
services, err := b.Services(ctx, state, false, opt)
if err == nil {
state.Req.Question[0].Name = old
return services, err
}
state.Req.Question[0].Name = old
return b.Services(ctx, state, false, opt)
}
// item holds records.
type item struct {
name string // name of the record (either owner or something else unique).
port uint16 // port of the record (used for address records, A and AAAA).
addr string // address of the record (A and AAAA).
}
// isDuplicate uses m to see if the combo (name, addr, port) already exists. If it does
// not exist already IsDuplicate will also add the record to the map.
func isDuplicate(m map[item]struct{}, name, addr string, port uint16) bool {
if addr != "" {
_, ok := m[item{name, 0, addr}]
if !ok {
m[item{name, 0, addr}] = struct{}{}
}
return ok
}
_, ok := m[item{name, port, ""}]
if !ok {
m[item{name, port, ""}] = struct{}{}
}
return ok
}
const hostmaster = "hostmaster"
// Package bind allows binding to a specific interface instead of bind to all of them.
package bind
import (
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("bind", setup) }
type bind struct {
Next plugin.Handler
addrs []string
except []string
}
// Name implements plugin.Handler.
func (b *bind) Name() string { return "bind" }
package bind
import (
"errors"
"fmt"
"net"
"slices"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/log"
)
func setup(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
// addresses will be consolidated over all BIND directives available in that BlocServer
all := []string{}
ifaces, err := net.Interfaces()
if err != nil {
log.Warning(plugin.Error("bind", fmt.Errorf("failed to get interfaces list, cannot bind by interface name: %s", err)))
}
for c.Next() {
b, err := parse(c)
if err != nil {
return plugin.Error("bind", err)
}
ips, err := listIP(b.addrs, ifaces)
if err != nil {
return plugin.Error("bind", err)
}
except, err := listIP(b.except, ifaces)
if err != nil {
return plugin.Error("bind", err)
}
for _, ip := range ips {
if !slices.Contains(except, ip) {
all = append(all, ip)
}
}
}
config.ListenHosts = all
return nil
}
func parse(c *caddy.Controller) (*bind, error) {
b := &bind{}
b.addrs = c.RemainingArgs()
if len(b.addrs) == 0 {
return nil, errors.New("at least one address or interface name is expected")
}
for c.NextBlock() {
switch c.Val() {
case "except":
b.except = c.RemainingArgs()
if len(b.except) == 0 {
return nil, errors.New("at least one address or interface must be given to except subdirective")
}
default:
return nil, fmt.Errorf("invalid option %q", c.Val())
}
}
return b, nil
}
// listIP returns a list of IP addresses from a list of arguments which can be either IP-Address or Interface-Name.
func listIP(args []string, ifaces []net.Interface) ([]string, error) {
all := []string{}
var isIface bool
for _, a := range args {
isIface = false
for _, iface := range ifaces {
if a == iface.Name {
isIface = true
addrs, err := iface.Addrs()
if err != nil {
return nil, fmt.Errorf("failed to get the IP addresses of the interface: %q", a)
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok {
ipa, err := net.ResolveIPAddr("ip", ipnet.IP.String())
if err == nil {
if ipnet.IP.To4() == nil &&
(ipnet.IP.IsLinkLocalMulticast() || ipnet.IP.IsLinkLocalUnicast()) {
if ipa.Zone == "" {
ipa.Zone = iface.Name
}
}
all = append(all, ipa.String())
}
}
}
}
}
if !isIface {
if net.ParseIP(a) == nil {
return nil, fmt.Errorf("not a valid IP address or interface name: %q", a)
}
all = append(all, a)
}
}
return all, nil
}
// Package bufsize implements a plugin that clamps EDNS0 buffer size preventing packet fragmentation.
package bufsize
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
// Bufsize implements bufsize plugin.
type Bufsize struct {
Next plugin.Handler
Size int
}
// ServeDNS implements the plugin.Handler interface.
func (buf Bufsize) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if option := r.IsEdns0(); option != nil && int(option.UDPSize()) > buf.Size {
option.SetUDPSize(uint16(buf.Size))
}
return plugin.NextOrFailure(buf.Name(), buf.Next, ctx, w, r)
}
// Name implements the Handler interface.
func (buf Bufsize) Name() string { return "bufsize" }
package bufsize
import (
"strconv"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("bufsize", setup) }
func setup(c *caddy.Controller) error {
bufsize, err := parse(c)
if err != nil {
return plugin.Error("bufsize", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Bufsize{Next: next, Size: bufsize}
})
return nil
}
func parse(c *caddy.Controller) (int, error) {
// value from http://www.dnsflagday.net/2020/
const defaultBufSize = 1232
for c.Next() {
args := c.RemainingArgs()
switch len(args) {
case 0:
// Nothing specified; use defaultBufSize
return defaultBufSize, nil
case 1:
// Specified value is needed to verify
bufsize, err := strconv.Atoi(args[0])
if err != nil {
return -1, plugin.Error("bufsize", c.ArgErr())
}
// Follows RFC 6891
if bufsize < 512 || bufsize > 4096 {
return -1, plugin.Error("bufsize", c.ArgErr())
}
return bufsize, nil
default:
// Only 1 argument is acceptable
return -1, plugin.Error("bufsize", c.ArgErr())
}
}
return -1, plugin.Error("bufsize", c.ArgErr())
}
// Package cache implements a cache.
package cache
import (
"hash/fnv"
"net"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/cache"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Cache is a plugin that looks up responses in a cache and caches replies.
// It has a success and a denial of existence cache.
type Cache struct {
Next plugin.Handler
Zones []string
zonesMetricLabel string
viewMetricLabel string
ncache *cache.Cache
ncap int
nttl time.Duration
minnttl time.Duration
pcache *cache.Cache
pcap int
pttl time.Duration
minpttl time.Duration
failttl time.Duration // TTL for caching SERVFAIL responses
// Prefetch.
prefetch int
duration time.Duration
percentage int
// Stale serve
staleUpTo time.Duration
verifyStale bool
// Positive/negative zone exceptions
pexcept []string
nexcept []string
// Keep ttl option
keepttl bool
// Testing.
now func() time.Time
}
// New returns an initialized Cache with default settings. It's up to the
// caller to set the Next handler.
func New() *Cache {
return &Cache{
Zones: []string{"."},
pcap: defaultCap,
pcache: cache.New(defaultCap),
pttl: maxTTL,
minpttl: minTTL,
ncap: defaultCap,
ncache: cache.New(defaultCap),
nttl: maxNTTL,
minnttl: minNTTL,
failttl: minNTTL,
prefetch: 0,
duration: 1 * time.Minute,
percentage: 10,
now: time.Now,
}
}
// key returns key under which we store the item, -1 will be returned if we don't store the message.
// Currently we do not cache Truncated, errors zone transfers or dynamic update messages.
// qname holds the already lowercased qname.
func key(qname string, m *dns.Msg, t response.Type, do, cd bool) (bool, uint64) {
// We don't store truncated responses.
if m.Truncated {
return false, 0
}
// Nor errors or Meta or Update.
if t == response.OtherError || t == response.Meta || t == response.Update {
return false, 0
}
return true, hash(qname, m.Question[0].Qtype, do, cd)
}
var one = []byte("1")
var zero = []byte("0")
func hash(qname string, qtype uint16, do, cd bool) uint64 {
h := fnv.New64()
if do {
h.Write(one)
} else {
h.Write(zero)
}
if cd {
h.Write(one)
} else {
h.Write(zero)
}
h.Write([]byte{byte(qtype >> 8)})
h.Write([]byte{byte(qtype)})
h.Write([]byte(qname))
return h.Sum64()
}
func computeTTL(msgTTL, minTTL, maxTTL time.Duration) time.Duration {
ttl := min(max(msgTTL, minTTL), maxTTL)
return ttl
}
// ResponseWriter is a response writer that caches the reply message.
type ResponseWriter struct {
dns.ResponseWriter
*Cache
state request.Request
server string // Server handling the request.
do bool // When true the original request had the DO bit set.
cd bool // When true the original request had the CD bit set.
ad bool // When true the original request had the AD bit set.
prefetch bool // When true write nothing back to the client.
remoteAddr net.Addr
wildcardFunc func() string // function to retrieve wildcard name that synthesized the result.
pexcept []string // positive zone exceptions
nexcept []string // negative zone exceptions
}
// newPrefetchResponseWriter returns a Cache ResponseWriter to be used in
// prefetch requests. It ensures RemoteAddr() can be called even after the
// original connection has already been closed.
func newPrefetchResponseWriter(server string, state request.Request, c *Cache) *ResponseWriter {
// Resolve the address now, the connection might be already closed when the
// actual prefetch request is made.
addr := state.W.RemoteAddr()
// The protocol of the client triggering a cache prefetch doesn't matter.
// The address type is used by request.Proto to determine the response size,
// and using TCP ensures the message isn't unnecessarily truncated.
if u, ok := addr.(*net.UDPAddr); ok {
addr = &net.TCPAddr{IP: u.IP, Port: u.Port, Zone: u.Zone}
}
return &ResponseWriter{
ResponseWriter: state.W,
Cache: c,
state: state,
server: server,
do: state.Do(),
cd: state.Req.CheckingDisabled,
prefetch: true,
remoteAddr: addr,
}
}
// RemoteAddr implements the dns.ResponseWriter interface.
func (w *ResponseWriter) RemoteAddr() net.Addr {
if w.remoteAddr != nil {
return w.remoteAddr
}
return w.ResponseWriter.RemoteAddr()
}
// WriteMsg implements the dns.ResponseWriter interface.
func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
res = res.Copy()
mt, _ := response.Typify(res, w.now().UTC())
// key returns empty string for anything we don't want to cache.
hasKey, key := key(w.state.Name(), res, mt, w.do, w.cd)
msgTTL := dnsutil.MinimalTTL(res, mt)
var duration time.Duration
switch mt {
case response.NameError, response.NoData:
duration = computeTTL(msgTTL, w.minnttl, w.nttl)
case response.ServerError:
duration = w.failttl
default:
duration = computeTTL(msgTTL, w.minpttl, w.pttl)
}
// Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.)
ttl := uint32(duration.Seconds())
res.Answer = filterRRSlice(res.Answer, ttl, false)
res.Ns = filterRRSlice(res.Ns, ttl, false)
res.Extra = filterRRSlice(res.Extra, ttl, false)
if !w.do && !w.ad {
// unset AD bit if requester is not OK with DNSSEC
// But retain AD bit if requester set the AD bit in the request, per RFC6840 5.7-5.8
res.AuthenticatedData = false
}
if hasKey && duration > 0 {
if w.state.Match(res) {
w.set(res, key, mt, duration)
cacheSize.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.pcache.Len()))
cacheSize.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.ncache.Len()))
} else {
// Don't log it, but increment counter
cacheDrops.WithLabelValues(w.server, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
}
if w.prefetch {
return nil
}
return w.ResponseWriter.WriteMsg(res)
}
func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) {
// duration is expected > 0
// and key is valid
switch mt {
case response.NoError, response.Delegation:
if plugin.Zones(w.pexcept).Matches(m.Question[0].Name) != "" {
// zone is in exception list, do not cache
return
}
i := newItem(m, w.now(), duration)
if w.wildcardFunc != nil {
i.wildcard = w.wildcardFunc()
}
if w.pcache.Add(key, i) {
evictions.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
// when pre-fetching, remove the negative cache entry if it exists
if w.prefetch {
w.ncache.Remove(key)
}
case response.NameError, response.NoData, response.ServerError:
if plugin.Zones(w.nexcept).Matches(m.Question[0].Name) != "" {
// zone is in exception list, do not cache
return
}
i := newItem(m, w.now(), duration)
if w.wildcardFunc != nil {
i.wildcard = w.wildcardFunc()
}
if w.ncache.Add(key, i) {
evictions.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
case response.OtherError:
// don't cache these
default:
log.Warningf("Caching called with unknown classification: %d", mt)
}
}
// Write implements the dns.ResponseWriter interface.
func (w *ResponseWriter) Write(buf []byte) (int, error) {
log.Warning("Caching called with Write: not caching reply")
if w.prefetch {
return 0, nil
}
n, err := w.ResponseWriter.Write(buf)
return n, err
}
// verifyStaleResponseWriter is a response writer that only writes messages if they should replace a
// stale cache entry, and otherwise discards them.
type verifyStaleResponseWriter struct {
*ResponseWriter
refreshed bool // set to true if the last WriteMsg wrote to ResponseWriter, false otherwise.
}
// newVerifyStaleResponseWriter returns a ResponseWriter to be used when verifying stale cache
// entries. It only forward writes if an entry was successfully refreshed according to RFC8767,
// section 4 (response is NoError or NXDomain), and ignores any other response.
func newVerifyStaleResponseWriter(w *ResponseWriter) *verifyStaleResponseWriter {
return &verifyStaleResponseWriter{
w,
false,
}
}
// WriteMsg implements the dns.ResponseWriter interface.
func (w *verifyStaleResponseWriter) WriteMsg(res *dns.Msg) error {
w.refreshed = false
if res.Rcode == dns.RcodeSuccess || res.Rcode == dns.RcodeNameError {
w.refreshed = true
return w.ResponseWriter.WriteMsg(res) // stores to the cache and send to client
}
return nil // else discard
}
const (
maxTTL = dnsutil.MaximumDefaulTTL
minTTL = dnsutil.MinimalDefaultTTL
maxNTTL = dnsutil.MaximumDefaulTTL / 2
minNTTL = dnsutil.MinimalDefaultTTL
defaultCap = 10000 // default capacity of the cache.
// Success is the class for caching positive caching.
Success = "success"
// Denial is the class defined for negative caching.
Denial = "denial"
)
package cache
import "github.com/miekg/dns"
// filterRRSlice filters out OPT RRs, and sets all RR TTLs to ttl.
// If dup is true the RRs in rrs are _copied_ before adjusting their
// TTL and the slice of copied RRs is returned.
func filterRRSlice(rrs []dns.RR, ttl uint32, dup bool) []dns.RR {
j := 0
rs := make([]dns.RR, len(rrs))
for _, r := range rrs {
if r.Header().Rrtype == dns.TypeOPT {
continue
}
if dup {
rs[j] = dns.Copy(r)
} else {
rs[j] = r
}
rs[j].Header().Ttl = ttl
j++
}
return rs[:j]
}
// Package freq keeps track of last X seen events. The events themselves are not stored
// here. So the Freq type should be added next to the thing it is tracking.
package freq
import (
"sync"
"time"
)
// Freq tracks the frequencies of things.
type Freq struct {
// Last time we saw a query for this element.
last time.Time
// Number of this in the last time slice.
hits int
sync.RWMutex
}
// New returns a new initialized Freq.
func New(t time.Time) *Freq {
return &Freq{last: t, hits: 0}
}
// Update updates the number of hits. Last time seen will be set to now.
// If the last time we've seen this entity is within now - d, we increment hits, otherwise
// we reset hits to 1. It returns the number of hits.
func (f *Freq) Update(d time.Duration, now time.Time) int {
earliest := now.Add(-1 * d)
f.Lock()
defer f.Unlock()
if f.last.Before(earliest) {
f.last = now
f.hits = 1
return f.hits
}
f.last = now
f.hits++
return f.hits
}
// Hits returns the number of hits that we have seen, according to the updates we have done to f.
func (f *Freq) Hits() int {
f.RLock()
defer f.RUnlock()
return f.hits
}
// Reset resets f to time t and hits to hits.
func (f *Freq) Reset(t time.Time, hits int) {
f.Lock()
defer f.Unlock()
f.last = t
f.hits = hits
}
//go:build gofuzz
package cache
import (
"github.com/coredns/coredns/plugin/pkg/fuzz"
)
// Fuzz fuzzes cache.
func Fuzz(data []byte) int {
return fuzz.Do(New(), data)
}
package cache
import (
"context"
"math"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// ServeDNS implements the plugin.Handler interface.
func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rc := r.Copy() // We potentially modify r, to prevent other plugins from seeing this (r is a pointer), copy r into rc.
state := request.Request{W: w, Req: rc}
do := state.Do()
cd := r.CheckingDisabled
ad := r.AuthenticatedData
zone := plugin.Zones(c.Zones).Matches(state.Name())
if zone == "" {
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, rc)
}
now := c.now().UTC()
server := metrics.WithServer(ctx)
// On cache refresh, we will just use the DO bit from the incoming query for the refresh since we key our cache
// with the query DO bit. That means two separate cache items for the query DO bit true or false. In the situation
// in which upstream doesn't support DNSSEC, the two cache items will effectively be the same. Regardless, any
// DNSSEC RRs in the response are written to cache with the response.
i := c.getIgnoreTTL(now, state, server)
if i == nil {
crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do, ad: ad, cd: cd,
nexcept: c.nexcept, pexcept: c.pexcept, wildcardFunc: wildcardFunc(ctx)}
return c.doRefresh(ctx, state, crr)
}
ttl := i.ttl(now)
if ttl < 0 {
// serve stale behavior
if c.verifyStale {
crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do, cd: cd}
cw := newVerifyStaleResponseWriter(crr)
ret, err := c.doRefresh(ctx, state, cw)
if cw.refreshed {
return ret, err
}
}
// Adjust the time to get a 0 TTL in the reply built from a stale item.
now = now.Add(time.Duration(ttl) * time.Second)
if !c.verifyStale {
cw := newPrefetchResponseWriter(server, state, c)
go c.doPrefetch(ctx, state, cw, i, now)
}
servedStale.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc()
} else if c.shouldPrefetch(i, now) {
cw := newPrefetchResponseWriter(server, state, c)
go c.doPrefetch(ctx, state, cw, i, now)
}
if i.wildcard != "" {
// Set wildcard source record name to metadata
metadata.SetValueFunc(ctx, "zone/wildcard", func() string {
return i.wildcard
})
}
if c.keepttl {
// If keepttl is enabled we fake the current time to the stored
// one so that we always get the original TTL
now = i.stored
}
resp := i.toMsg(r, now, do, ad)
w.WriteMsg(resp)
return dns.RcodeSuccess, nil
}
func wildcardFunc(ctx context.Context) func() string {
return func() string {
// Get wildcard source record name from metadata
if f := metadata.ValueFunc(ctx, "zone/wildcard"); f != nil {
return f()
}
return ""
}
}
func (c *Cache) doPrefetch(ctx context.Context, state request.Request, cw *ResponseWriter, i *item, now time.Time) {
// Use a fresh metadata map to avoid concurrent writes to the original request's metadata.
ctx = metadata.ContextWithMetadata(ctx)
cachePrefetches.WithLabelValues(cw.server, c.zonesMetricLabel, c.viewMetricLabel).Inc()
c.doRefresh(ctx, state, cw)
// When prefetching we loose the item i, and with it the frequency
// that we've gathered sofar. See we copy the frequencies info back
// into the new item that was stored in the cache.
if i1 := c.exists(state); i1 != nil {
i1.Reset(now, i.Hits())
}
}
func (c *Cache) doRefresh(ctx context.Context, state request.Request, cw dns.ResponseWriter) (int, error) {
return plugin.NextOrFailure(c.Name(), c.Next, ctx, cw, state.Req)
}
func (c *Cache) shouldPrefetch(i *item, now time.Time) bool {
if c.prefetch <= 0 {
return false
}
i.Update(c.duration, now)
threshold := int(math.Ceil(float64(c.percentage) / 100 * float64(i.origTTL)))
return i.Hits() >= c.prefetch && i.ttl(now) <= threshold
}
// Name implements the Handler interface.
func (c *Cache) Name() string { return "cache" }
// getIgnoreTTL unconditionally returns an item if it exists in the cache.
func (c *Cache) getIgnoreTTL(now time.Time, state request.Request, server string) *item {
k := hash(state.Name(), state.QType(), state.Do(), state.Req.CheckingDisabled)
cacheRequests.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc()
if i, ok := c.ncache.Get(k); ok {
itm := i.(*item)
ttl := itm.ttl(now)
if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) {
cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel, c.viewMetricLabel).Inc()
return i.(*item)
}
}
if i, ok := c.pcache.Get(k); ok {
itm := i.(*item)
ttl := itm.ttl(now)
if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) {
cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel, c.viewMetricLabel).Inc()
return i.(*item)
}
}
cacheMisses.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc()
return nil
}
func (c *Cache) exists(state request.Request) *item {
k := hash(state.Name(), state.QType(), state.Do(), state.Req.CheckingDisabled)
if i, ok := c.ncache.Get(k); ok {
return i.(*item)
}
if i, ok := c.pcache.Get(k); ok {
return i.(*item)
}
return nil
}
package cache
import (
"strings"
"time"
"github.com/coredns/coredns/plugin/cache/freq"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type item struct {
Name string
QType uint16
Rcode int
AuthenticatedData bool
RecursionAvailable bool
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
wildcard string
origTTL uint32
stored time.Time
*freq.Freq
}
func newItem(m *dns.Msg, now time.Time, d time.Duration) *item {
i := new(item)
if len(m.Question) != 0 {
i.Name = m.Question[0].Name
i.QType = m.Question[0].Qtype
}
i.Rcode = m.Rcode
i.AuthenticatedData = m.AuthenticatedData
i.RecursionAvailable = m.RecursionAvailable
i.Answer = m.Answer
i.Ns = m.Ns
i.Extra = make([]dns.RR, len(m.Extra))
// Don't copy OPT records as these are hop-by-hop.
j := 0
for _, e := range m.Extra {
if e.Header().Rrtype == dns.TypeOPT {
continue
}
i.Extra[j] = e
j++
}
i.Extra = i.Extra[:j]
i.origTTL = uint32(d.Seconds())
i.stored = now.UTC()
i.Freq = new(freq.Freq)
return i
}
// toMsg turns i into a message, it tailors the reply to m.
// The Authoritative bit should be set to 0, but some client stub resolver implementations, most notably,
// on some legacy systems(e.g. ubuntu 14.04 with glib version 2.20), low-level glibc function `getaddrinfo`
// useb by Python/Ruby/etc.. will discard answers that do not have this bit set.
// So we're forced to always set this to 1; regardless if the answer came from the cache or not.
// On newer systems(e.g. ubuntu 16.04 with glib version 2.23), this issue is resolved.
// So we may set this bit back to 0 in the future ?
func (i *item) toMsg(m *dns.Msg, now time.Time, do bool, ad bool) *dns.Msg {
m1 := new(dns.Msg)
m1.SetReply(m)
// Set this to true as some DNS clients discard the *entire* packet when it's non-authoritative.
// This is probably not according to spec, but the bit itself is not super useful as this point, so
// just set it to true.
m1.Authoritative = true
m1.AuthenticatedData = i.AuthenticatedData
if !do && !ad {
// When DNSSEC was not wanted, it can't be authenticated data.
// However, retain the AD bit if the requester set the AD bit, per RFC6840 5.7-5.8
m1.AuthenticatedData = false
}
m1.RecursionAvailable = i.RecursionAvailable
m1.Rcode = i.Rcode
m1.Answer = make([]dns.RR, len(i.Answer))
m1.Ns = make([]dns.RR, len(i.Ns))
m1.Extra = make([]dns.RR, len(i.Extra))
ttl := uint32(i.ttl(now))
m1.Answer = filterRRSlice(i.Answer, ttl, true)
m1.Ns = filterRRSlice(i.Ns, ttl, true)
m1.Extra = filterRRSlice(i.Extra, ttl, true)
return m1
}
func (i *item) ttl(now time.Time) int {
ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
return ttl
}
func (i *item) matches(state request.Request) bool {
if state.QType() == i.QType && strings.EqualFold(state.QName(), i.Name) {
return true
}
return false
}
package cache
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/cache"
clog "github.com/coredns/coredns/plugin/pkg/log"
)
var log = clog.NewWithPlugin("cache")
func init() { plugin.Register("cache", setup) }
func setup(c *caddy.Controller) error {
ca, err := cacheParse(c)
if err != nil {
return plugin.Error("cache", err)
}
c.OnStartup(func() error {
ca.viewMetricLabel = dnsserver.GetConfig(c).ViewName
return nil
})
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
ca.Next = next
return ca
})
return nil
}
func cacheParse(c *caddy.Controller) (*Cache, error) {
ca := New()
j := 0
for c.Next() {
if j > 0 {
return nil, plugin.ErrOnce
}
j++
// cache [ttl] [zones..]
args := c.RemainingArgs()
if len(args) > 0 {
// first args may be just a number, then it is the ttl, if not it is a zone
ttl, err := strconv.Atoi(args[0])
if err == nil {
// Reserve 0 (and smaller for future things)
if ttl <= 0 {
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", ttl)
}
ca.pttl = time.Duration(ttl) * time.Second
ca.nttl = time.Duration(ttl) * time.Second
args = args[1:]
}
}
origins := plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys)
// Refinements? In an extra block.
for c.NextBlock() {
switch c.Val() {
// first number is cap, second is an new ttl
case Success:
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
pcap, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
ca.pcap = pcap
if len(args) > 1 {
pttl, err := strconv.Atoi(args[1])
if err != nil {
return nil, err
}
// Reserve 0 (and smaller for future things)
if pttl <= 0 {
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", pttl)
}
ca.pttl = time.Duration(pttl) * time.Second
if len(args) > 2 {
minpttl, err := strconv.Atoi(args[2])
if err != nil {
return nil, err
}
// Reserve < 0
if minpttl < 0 {
return nil, fmt.Errorf("cache min TTL can not be negative: %d", minpttl)
}
ca.minpttl = time.Duration(minpttl) * time.Second
}
}
case Denial:
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
ncap, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
ca.ncap = ncap
if len(args) > 1 {
nttl, err := strconv.Atoi(args[1])
if err != nil {
return nil, err
}
// Reserve 0 (and smaller for future things)
if nttl <= 0 {
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", nttl)
}
ca.nttl = time.Duration(nttl) * time.Second
if len(args) > 2 {
minnttl, err := strconv.Atoi(args[2])
if err != nil {
return nil, err
}
// Reserve < 0
if minnttl < 0 {
return nil, fmt.Errorf("cache min TTL can not be negative: %d", minnttl)
}
ca.minnttl = time.Duration(minnttl) * time.Second
}
}
case "prefetch":
args := c.RemainingArgs()
if len(args) == 0 || len(args) > 3 {
return nil, c.ArgErr()
}
amount, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
if amount < 0 {
return nil, fmt.Errorf("prefetch amount should be positive: %d", amount)
}
ca.prefetch = amount
if len(args) > 1 {
dur, err := time.ParseDuration(args[1])
if err != nil {
return nil, err
}
ca.duration = dur
}
if len(args) > 2 {
pct := args[2]
if x := pct[len(pct)-1]; x != '%' {
return nil, fmt.Errorf("last character of percentage should be `%%`, but is: %q", x)
}
pct = pct[:len(pct)-1]
num, err := strconv.Atoi(pct)
if err != nil {
return nil, err
}
if num < 10 || num > 90 {
return nil, fmt.Errorf("percentage should fall in range [10, 90]: %d", num)
}
ca.percentage = num
}
case "serve_stale":
args := c.RemainingArgs()
if len(args) > 2 {
return nil, c.ArgErr()
}
ca.staleUpTo = 1 * time.Hour
if len(args) > 0 {
d, err := time.ParseDuration(args[0])
if err != nil {
return nil, err
}
if d < 0 {
return nil, errors.New("invalid negative duration for serve_stale")
}
ca.staleUpTo = d
}
ca.verifyStale = false
if len(args) > 1 {
mode := strings.ToLower(args[1])
if mode != "immediate" && mode != "verify" {
return nil, fmt.Errorf("invalid value for serve_stale refresh mode: %s", mode)
}
ca.verifyStale = mode == "verify"
}
case "servfail":
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
d, err := time.ParseDuration(args[0])
if err != nil {
return nil, err
}
if d < 0 {
return nil, errors.New("invalid negative ttl for servfail")
}
if d > 5*time.Minute {
// RFC 2308 prohibits caching SERVFAIL longer than 5 minutes
return nil, errors.New("caching SERVFAIL responses over 5 minutes is not permitted")
}
ca.failttl = d
case "disable":
// disable [success|denial] [zones]...
args := c.RemainingArgs()
if len(args) < 1 {
return nil, c.ArgErr()
}
var zones []string
if len(args) > 1 {
for _, z := range args[1:] { // args[1:] define the list of zones to disable
nz := plugin.Name(z).Normalize()
if nz == "" {
return nil, fmt.Errorf("invalid disabled zone: %s", z)
}
zones = append(zones, nz)
}
} else {
// if no zones specified, default to root
zones = []string{"."}
}
switch args[0] { // args[0] defines which cache to disable
case Denial:
ca.nexcept = zones
case Success:
ca.pexcept = zones
default:
return nil, fmt.Errorf("cache type for disable must be %q or %q", Success, Denial)
}
case "keepttl":
args := c.RemainingArgs()
if len(args) != 0 {
return nil, c.ArgErr()
}
ca.keepttl = true
default:
return nil, c.ArgErr()
}
}
ca.Zones = origins
ca.zonesMetricLabel = strings.Join(origins, ",")
ca.pcache = cache.New(ca.pcap)
ca.ncache = cache.New(ca.ncap)
}
return ca, nil
}
// Package cancel implements a plugin adds a canceling context to each request.
package cancel
import (
"context"
"fmt"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
func init() { plugin.Register("cancel", setup) }
func setup(c *caddy.Controller) error {
ca := Cancel{}
for c.Next() {
args := c.RemainingArgs()
switch len(args) {
case 0:
ca.timeout = 5001 * time.Millisecond
case 1:
dur, err := time.ParseDuration(args[0])
if err != nil {
return plugin.Error("cancel", fmt.Errorf("invalid duration: %q", args[0]))
}
if dur <= 0 {
return plugin.Error("cancel", fmt.Errorf("invalid negative duration: %q", args[0]))
}
ca.timeout = dur
default:
return plugin.Error("cancel", c.ArgErr())
}
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
ca.Next = next
return ca
})
return nil
}
// Cancel is a plugin that adds a canceling context to each request's context.
type Cancel struct {
timeout time.Duration
Next plugin.Handler
}
// ServeDNS implements the plugin.Handler interface.
func (c Cancel) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
ctx, cancel := context.WithTimeout(ctx, c.timeout)
code, err := plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r)
cancel()
return code, err
}
// Name implements the Handler interface.
func (c Cancel) Name() string { return "cancel" }
// Package chaos implements a plugin that answer to 'CH version.bind TXT' type queries.
package chaos
import (
"context"
"math/rand"
"os"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Chaos allows CoreDNS to reply to CH TXT queries and return author or
// version information.
type Chaos struct {
Next plugin.Handler
Version string
Authors []string
}
// ServeDNS implements the plugin.Handler interface.
func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT {
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r)
}
m := new(dns.Msg)
m.SetReply(r)
hdr := dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0}
switch state.Name() {
default:
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r)
case "authors.bind.":
rnd := rand.New(rand.NewSource(time.Now().Unix()))
for _, i := range rnd.Perm(len(c.Authors)) {
m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{c.Authors[i]}})
}
case "version.bind.", "version.server.":
m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{c.Version}}}
case "hostname.bind.", "id.server.":
hostname, err := os.Hostname()
if err != nil {
hostname = "localhost"
}
m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{trim(hostname)}}}
}
w.WriteMsg(m)
return 0, nil
}
// Name implements the Handler interface.
func (c Chaos) Name() string { return "chaos" }
//go:build gofuzz
package chaos
import (
"github.com/coredns/coredns/plugin/pkg/fuzz"
)
// Fuzz fuzzes cache.
func Fuzz(data []byte) int {
c := Chaos{}
return fuzz.Do(c, data)
}
//go:generate go run owners_generate.go
package chaos
import (
"sort"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("chaos", setup) }
func setup(c *caddy.Controller) error {
version, authors, err := parse(c)
if err != nil {
return plugin.Error("chaos", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Chaos{Next: next, Version: version, Authors: authors}
})
return nil
}
func parse(c *caddy.Controller) (string, []string, error) {
// Set here so we pick up AppName and AppVersion that get set in coremain's init().
chaosVersion = caddy.AppName + "-" + caddy.AppVersion
version := ""
if c.Next() {
args := c.RemainingArgs()
if len(args) == 0 {
return trim(chaosVersion), Owners, nil
}
if len(args) == 1 {
return trim(args[0]), Owners, nil
}
version = args[0]
authors := make(map[string]struct{})
for _, a := range args[1:] {
authors[a] = struct{}{}
}
list := []string{}
for k := range authors {
k = trim(k) // limit size to 255 chars
list = append(list, k)
}
sort.Strings(list)
return version, list, nil
}
return version, Owners, nil
}
func trim(s string) string {
if len(s) < 256 {
return s
}
return s[:255]
}
var chaosVersion string
// Package clouddns implements a plugin that returns resource records
// from GCP Cloud DNS.
package clouddns
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/file"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
gcp "google.golang.org/api/dns/v1"
)
// CloudDNS is a plugin that returns RR from GCP Cloud DNS.
type CloudDNS struct {
Next plugin.Handler
Fall fall.F
zoneNames []string
client gcpDNS
upstream *upstream.Upstream
zMu sync.RWMutex
zones zones
}
type zone struct {
projectName string
zoneName string
z *file.Zone
dns string
}
type zones map[string][]*zone
// New reads from the keys map which uses domain names as its key and a colon separated
// string of project name and hosted zone name lists as its values, validates
// that each domain name/zone id pair does exist, and returns a new *CloudDNS.
// In addition to this, upstream is passed for doing recursive queries against CNAMEs.
// Returns error if it cannot verify any given domain name/zone id pair.
func New(ctx context.Context, c gcpDNS, keys map[string][]string, up *upstream.Upstream) (*CloudDNS, error) {
zones := make(map[string][]*zone, len(keys))
zoneNames := make([]string, 0, len(keys))
for dnsName, hostedZoneDetails := range keys {
for _, hostedZone := range hostedZoneDetails {
ss := strings.SplitN(hostedZone, ":", 2)
if len(ss) != 2 {
return nil, errors.New("either project or zone name missing")
}
err := c.zoneExists(ss[0], ss[1])
if err != nil {
return nil, err
}
fqdnDNSName := dns.Fqdn(dnsName)
if _, ok := zones[fqdnDNSName]; !ok {
zoneNames = append(zoneNames, fqdnDNSName)
}
zones[fqdnDNSName] = append(zones[fqdnDNSName], &zone{projectName: ss[0], zoneName: ss[1], dns: fqdnDNSName, z: file.NewZone(fqdnDNSName, "")})
}
}
return &CloudDNS{
client: c,
zoneNames: zoneNames,
zones: zones,
upstream: up,
}, nil
}
// Run executes first update, spins up an update forever-loop.
// Returns error if first update fails.
func (h *CloudDNS) Run(ctx context.Context) error {
if err := h.updateZones(ctx); err != nil {
return err
}
go func() {
delay := 1 * time.Minute
timer := time.NewTimer(delay)
defer timer.Stop()
for {
timer.Reset(delay)
select {
case <-ctx.Done():
log.Debugf("Breaking out of CloudDNS update loop for %v: %v", h.zoneNames, ctx.Err())
return
case <-timer.C:
if err := h.updateZones(ctx); err != nil && ctx.Err() == nil /* Don't log error if ctx expired. */ {
log.Errorf("Failed to update zones %v: %v", h.zoneNames, err)
}
}
}
}()
return nil
}
// ServeDNS implements the plugin.Handler interface.
func (h *CloudDNS) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.Name()
zName := plugin.Zones(h.zoneNames).Matches(qname)
if zName == "" {
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
z, ok := h.zones[zName] // ok true if we are authoritative for the zone
if !ok || z == nil {
return dns.RcodeServerFailure, nil
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
var result file.Result
for _, hostedZone := range z {
h.zMu.RLock()
m.Answer, m.Ns, m.Extra, result = hostedZone.z.Lookup(ctx, state, qname)
h.zMu.RUnlock()
// Take the answer if it's non-empty OR if there is another
// record type exists for this name (NODATA).
if len(m.Answer) != 0 || result == file.NoData {
break
}
}
if len(m.Answer) == 0 && result != file.NoData && h.Fall.Through(qname) {
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
switch result {
case file.Success:
case file.NoData:
case file.NameError:
m.Rcode = dns.RcodeNameError
case file.Delegation:
m.Authoritative = false
case file.ServerFailure:
return dns.RcodeServerFailure, nil
}
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
func updateZoneFromRRS(rrs *gcp.ResourceRecordSetsListResponse, z *file.Zone) error {
for _, rr := range rrs.Rrsets {
var rfc1035 string
var r dns.RR
var err error
for _, value := range rr.Rrdatas {
if rr.Type == "CNAME" || rr.Type == "PTR" {
value = dns.Fqdn(value)
}
// Assemble RFC 1035 conforming record to pass into dns scanner.
rfc1035 = fmt.Sprintf("%s %d IN %s %s", dns.Fqdn(rr.Name), rr.Ttl, rr.Type, value)
r, err = dns.NewRR(rfc1035)
if err != nil {
return fmt.Errorf("failed to parse resource record: %v", err)
}
err = z.Insert(r)
if err != nil {
return fmt.Errorf("failed to insert record: %v", err)
}
}
}
return nil
}
// updateZones re-queries resource record sets for each zone and updates the
// zone object.
// Returns error if any zones error'ed out, but waits for other zones to
// complete first.
func (h *CloudDNS) updateZones(ctx context.Context) error {
errc := make(chan error)
defer close(errc)
for zName, z := range h.zones {
go func(zName string, z []*zone) {
var err error
var rrListResponse *gcp.ResourceRecordSetsListResponse
defer func() {
errc <- err
}()
for i, hostedZone := range z {
newZ := file.NewZone(zName, "")
newZ.Upstream = h.upstream
rrListResponse, err = h.client.listRRSets(ctx, hostedZone.projectName, hostedZone.zoneName)
if err != nil {
err = fmt.Errorf("failed to list resource records for %v:%v:%v from gcp: %v", zName, hostedZone.projectName, hostedZone.zoneName, err)
return
}
updateZoneFromRRS(rrListResponse, newZ)
h.zMu.Lock()
(*z[i]).z = newZ
h.zMu.Unlock()
}
}(zName, z)
}
// Collect errors (if any). This will also sync on all zones updates
// completion.
var errs []string
for range len(h.zones) {
err := <-errc
if err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) != 0 {
return fmt.Errorf("errors updating zones: %v", errs)
}
return nil
}
// Name implements the Handler interface.
func (h *CloudDNS) Name() string { return "clouddns" }
package clouddns
import (
"context"
gcp "google.golang.org/api/dns/v1"
)
type gcpDNS interface {
zoneExists(projectName, hostedZoneName string) error
listRRSets(ctx context.Context, projectName, hostedZoneName string) (*gcp.ResourceRecordSetsListResponse, error)
}
type gcpClient struct {
*gcp.Service
}
// zoneExists is a wrapper method around `gcp.Service.ManagedZones.Get`
// it checks if the provided zone name for a given project exists.
func (c gcpClient) zoneExists(projectName, hostedZoneName string) error {
_, err := c.ManagedZones.Get(projectName, hostedZoneName).Do()
if err != nil {
return err
}
return nil
}
// listRRSets is a wrapper method around `gcp.Service.ResourceRecordSets.List`
// it fetches and returns the record sets for a hosted zone.
func (c gcpClient) listRRSets(ctx context.Context, projectName, hostedZoneName string) (*gcp.ResourceRecordSetsListResponse, error) {
req := c.ResourceRecordSets.List(projectName, hostedZoneName)
var rs []*gcp.ResourceRecordSet
if err := req.Pages(ctx, func(page *gcp.ResourceRecordSetsListResponse) error {
rs = append(rs, page.Rrsets...)
return nil
}); err != nil {
return nil, err
}
return &gcp.ResourceRecordSetsListResponse{Rrsets: rs}, nil
}
package clouddns
import (
"context"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/fall"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/upstream"
gcp "google.golang.org/api/dns/v1"
"google.golang.org/api/option"
)
var log = clog.NewWithPlugin("clouddns")
func init() { plugin.Register("clouddns", setup) }
// exposed for testing
var f = func(ctx context.Context, opt option.ClientOption) (gcpDNS, error) {
var err error
var client *gcp.Service
if opt != nil {
client, err = gcp.NewService(ctx, opt)
} else {
// if credentials file is not provided in the Corefile
// authenticate the client using env variables
client, err = gcp.NewService(ctx)
}
return gcpClient{client}, err
}
func setup(c *caddy.Controller) error {
for c.Next() {
keyPairs := map[string]struct{}{}
keys := map[string][]string{}
var fall fall.F
up := upstream.New()
args := c.RemainingArgs()
for i := range args {
parts := strings.SplitN(args[i], ":", 3)
if len(parts) != 3 {
return plugin.Error("clouddns", c.Errf("invalid zone %q", args[i]))
}
dnsName, projectName, hostedZone := parts[0], parts[1], parts[2]
if dnsName == "" || projectName == "" || hostedZone == "" {
return plugin.Error("clouddns", c.Errf("invalid zone %q", args[i]))
}
if _, ok := keyPairs[args[i]]; ok {
return plugin.Error("clouddns", c.Errf("conflict zone %q", args[i]))
}
keyPairs[args[i]] = struct{}{}
keys[dnsName] = append(keys[dnsName], projectName+":"+hostedZone)
}
var opt option.ClientOption
for c.NextBlock() {
switch c.Val() {
case "upstream":
c.RemainingArgs()
case "credentials":
if c.NextArg() {
opt = option.WithCredentialsFile(c.Val())
} else {
return plugin.Error("clouddns", c.ArgErr())
}
case "fallthrough":
fall.SetZonesFromArgs(c.RemainingArgs())
default:
return plugin.Error("clouddns", c.Errf("unknown property %q", c.Val()))
}
}
ctx, cancel := context.WithCancel(context.Background())
client, err := f(ctx, opt)
if err != nil {
cancel()
return err
}
h, err := New(ctx, client, keys, up)
if err != nil {
cancel()
return plugin.Error("clouddns", c.Errf("failed to create plugin: %v", err))
}
h.Fall = fall
if err := h.Run(ctx); err != nil {
cancel()
return plugin.Error("clouddns", c.Errf("failed to initialize plugin: %v", err))
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
h.Next = next
return h
})
c.OnShutdown(func() error { cancel(); return nil })
}
return nil
}
package debug
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("debug", setup) }
func setup(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
for c.Next() {
if c.NextArg() {
return plugin.Error("debug", c.ArgErr())
}
config.Debug = true
}
return nil
}
package debug
import (
"bytes"
"fmt"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/miekg/dns"
)
// Hexdump converts the dns message m to a hex dump Wireshark can import.
// See https://www.wireshark.org/docs/man-pages/text2pcap.html.
// This output looks like this:
//
// 00000 dc bd 01 00 00 01 00 00 00 00 00 01 07 65 78 61
// 000010 6d 70 6c 65 05 6c 6f 63 61 6c 00 00 01 00 01 00
// 000020 00 29 10 00 00 00 80 00 00 00
// 00002a
//
// Hexdump will use log.Debug to write the dump to the log, each line
// is prefixed with 'debug: ' so the data can be easily extracted.
//
// msg will prefix the pcap dump.
func Hexdump(m *dns.Msg, v ...any) {
if !log.D.Value() {
return
}
buf, _ := m.Pack()
if len(buf) == 0 {
return
}
out := "\n" + string(hexdump(buf))
v = append(v, out)
log.Debug(v...)
}
// Hexdumpf dumps a DNS message as Hexdump, but allows a format string.
func Hexdumpf(m *dns.Msg, format string, v ...any) {
if !log.D.Value() {
return
}
buf, _ := m.Pack()
if len(buf) == 0 {
return
}
format += "\n%s"
v = append(v, hexdump(buf))
log.Debugf(format, v...)
}
func hexdump(data []byte) []byte {
b := new(bytes.Buffer)
newline := ""
for i := range data {
if i%16 == 0 {
fmt.Fprintf(b, "%s%s%06x", newline, prefix, i)
newline = "\n"
}
fmt.Fprintf(b, " %02x", data[i])
}
fmt.Fprintf(b, "\n%s%06x", prefix, len(data))
return b.Bytes()
}
const prefix = "debug: "
// Package dns64 implements a plugin that performs DNS64.
//
// See: RFC 6147 (https://tools.ietf.org/html/rfc6147)
package dns64
import (
"context"
"errors"
"net"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/nonwriter"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// UpstreamInt wraps the Upstream API for dependency injection during testing
type UpstreamInt interface {
Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error)
}
// DNS64 performs DNS64.
type DNS64 struct {
Next plugin.Handler
Prefix *net.IPNet
TranslateAll bool // Not comply with 5.1.1
AllowIPv4 bool
Upstream UpstreamInt
}
// ServeDNS implements the plugin.Handler interface.
func (d *DNS64) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
// Don't proxy if we don't need to.
if !d.requestShouldIntercept(&request.Request{W: w, Req: r}) {
return plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r)
}
// Pass the request to the next plugin in the chain, but intercept the response.
nw := nonwriter.New(w)
origRc, origErr := d.Next.ServeDNS(ctx, nw, r)
if nw.Msg == nil { // somehow we didn't get a response (or raw bytes were written)
return origRc, origErr
}
// If the response doesn't need DNS64, short-circuit.
if !d.responseShouldDNS64(nw.Msg) {
w.WriteMsg(nw.Msg)
return origRc, origErr
}
// otherwise do the actual DNS64 request and response synthesis
msg, err := d.DoDNS64(ctx, w, r, nw.Msg)
if err != nil {
// err means we weren't able to even issue the A request
// to CoreDNS upstream
return dns.RcodeServerFailure, err
}
RequestsTranslatedCount.WithLabelValues(metrics.WithServer(ctx)).Inc()
w.WriteMsg(msg)
return msg.Rcode, nil
}
// Name implements the Handler interface.
func (d *DNS64) Name() string { return "dns64" }
// requestShouldIntercept returns true if the request represents one that is eligible
// for DNS64 rewriting:
// 1. The request came in over IPv6 or the 'allow_ipv4' option is set
// 2. The request is of type AAAA
// 3. The request is of class INET
func (d *DNS64) requestShouldIntercept(req *request.Request) bool {
// Make sure that request came in over IPv4 unless AllowIPv4 option is enabled.
// Translating requests without taking into consideration client (source) IP might be problematic in dual-stack networks.
if !d.AllowIPv4 && req.Family() == 1 {
return false
}
// Do not modify if question is not AAAA or not of class IN. See RFC 6147 5.1
return req.QType() == dns.TypeAAAA && req.QClass() == dns.ClassINET
}
// responseShouldDNS64 returns true if the response indicates we should attempt
// DNS64 rewriting:
// 1. The response has no valid (RFC 5.1.4) AAAA records (RFC 5.1.1)
// 2. The response code (RCODE) is not 3 (Name Error) (RFC 5.1.2)
//
// Note that requestShouldIntercept must also have been true, so the request
// is known to be of type AAAA.
func (d *DNS64) responseShouldDNS64(origResponse *dns.Msg) bool {
ty, _ := response.Typify(origResponse, time.Now().UTC())
// Handle NameError normally. See RFC 6147 5.1.2
// All other error types are "equivalent" to empty response
if ty == response.NameError {
return false
}
// If we've configured to always translate, well, then always translate.
if d.TranslateAll {
return true
}
// if response includes AAAA record, no need to rewrite
for _, rr := range origResponse.Answer {
if rr.Header().Rrtype == dns.TypeAAAA {
return false
}
}
return true
}
// DoDNS64 takes an (empty) response to an AAAA question, issues the A request,
// and synthesizes the answer. Returns the response message, or error on internal failure.
func (d *DNS64) DoDNS64(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, origResponse *dns.Msg) (*dns.Msg, error) {
req := request.Request{W: w, Req: r} // req is unused
resp, err := d.Upstream.Lookup(ctx, req, req.Name(), dns.TypeA)
if err != nil {
return nil, err
}
out := d.Synthesize(r, origResponse, resp)
return out, nil
}
// Synthesize merges the AAAA response and the records from the A response
func (d *DNS64) Synthesize(origReq, origResponse, resp *dns.Msg) *dns.Msg {
ret := dns.Msg{}
ret.SetReply(origReq)
// persist truncated state of AAAA response
ret.Truncated = resp.Truncated
// 5.3.2: DNS64 MUST pass the additional section unchanged
ret.Extra = resp.Extra
ret.Ns = resp.Ns
// 5.1.7: The TTL is the minimum of the A RR and the SOA RR. If SOA is
// unknown, then the TTL is the minimum of A TTL and 600
SOATtl := uint32(600) // Default NS record TTL
for _, ns := range origResponse.Ns {
if ns.Header().Rrtype == dns.TypeSOA {
SOATtl = ns.Header().Ttl
}
}
ret.Answer = make([]dns.RR, 0, len(resp.Answer))
// convert A records to AAAA records
for _, rr := range resp.Answer {
header := rr.Header()
// 5.3.3: All other RR's MUST be returned unchanged
if header.Rrtype != dns.TypeA {
ret.Answer = append(ret.Answer, rr)
continue
}
aaaa, _ := to6(d.Prefix, rr.(*dns.A).A)
// ttl is min of SOA TTL and A TTL
ttl := min(rr.Header().Ttl, SOATtl)
// Replace A answer with a DNS64 AAAA answer
ret.Answer = append(ret.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: header.Name,
Rrtype: dns.TypeAAAA,
Class: header.Class,
Ttl: ttl,
},
AAAA: aaaa,
})
}
return &ret
}
// to6 takes a prefix and IPv4 address and returns an IPv6 address according to RFC 6052.
func to6(prefix *net.IPNet, addr net.IP) (net.IP, error) {
addr = addr.To4()
if addr == nil {
return nil, errors.New("not a valid IPv4 address")
}
n, _ := prefix.Mask.Size()
// Assumes prefix has been validated during setup
v6 := make([]byte, 16)
i, j := 0, 0
for ; i < n/8; i++ {
v6[i] = prefix.IP[i]
}
for ; i < 8; i, j = i+1, j+1 {
v6[i] = addr[j]
}
if i == 8 {
i++
}
for ; j < 4; i, j = i+1, j+1 {
v6[i] = addr[j]
}
return v6, nil
}
package dns64
import (
"net"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/upstream"
)
const pluginName = "dns64"
func init() { plugin.Register(pluginName, setup) }
func setup(c *caddy.Controller) error {
dns64, err := dns64Parse(c)
if err != nil {
return plugin.Error(pluginName, err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
dns64.Next = next
return dns64
})
return nil
}
func dns64Parse(c *caddy.Controller) (*DNS64, error) {
_, defaultPref, _ := net.ParseCIDR("64:ff9b::/96")
dns64 := &DNS64{
Upstream: upstream.New(),
Prefix: defaultPref,
}
for c.Next() {
args := c.RemainingArgs()
if len(args) == 1 {
pref, err := parsePrefix(c, args[0])
if err != nil {
return nil, err
}
dns64.Prefix = pref
continue
}
if len(args) > 0 {
return nil, c.ArgErr()
}
for c.NextBlock() {
switch c.Val() {
case "prefix":
if !c.NextArg() {
return nil, c.ArgErr()
}
pref, err := parsePrefix(c, c.Val())
if err != nil {
return nil, err
}
dns64.Prefix = pref
case "translate_all":
dns64.TranslateAll = true
case "allow_ipv4":
dns64.AllowIPv4 = true
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
}
return dns64, nil
}
func parsePrefix(c *caddy.Controller, addr string) (*net.IPNet, error) {
_, pref, err := net.ParseCIDR(addr)
if err != nil {
return nil, err
}
// Test for valid prefix
n, total := pref.Mask.Size()
if total != 128 {
return nil, c.Errf("invalid netmask %d IPv6 address: %q", total, pref)
}
if n%8 != 0 || n < 32 || n > 96 {
return nil, c.Errf("invalid prefix length %q", pref)
}
return pref, nil
}
package dnssec
import (
"strings"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// nsec returns an NSEC useful for NXDOMAIN responses.
// See https://tools.ietf.org/html/draft-valsorda-dnsop-black-lies-00
// For example, a request for the non-existing name a.example.com would
// cause the following NSEC record to be generated:
//
// a.example.com. 3600 IN NSEC \000.a.example.com. ( RRSIG NSEC ... )
//
// This inturn makes every NXDOMAIN answer a NODATA one, don't forget to flip
// the header rcode to NOERROR.
func (d Dnssec) nsec(state request.Request, mt response.Type, ttl, incep, expir uint32, server string) ([]dns.RR, error) {
nsec := &dns.NSEC{}
nsec.Hdr = dns.RR_Header{Name: state.QName(), Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeNSEC}
nsec.NextDomain = "\\000." + state.QName()
if state.QName() == "." {
nsec.NextDomain = "\\000." // If You want to play as root server
}
if state.Name() == state.Zone {
nsec.TypeBitMap = filter18(state.QType(), apexBitmap, mt)
} else if mt == response.Delegation || state.QType() == dns.TypeDS {
nsec.TypeBitMap = delegationBitmap[:]
if mt == response.Delegation {
labels := dns.SplitDomainName(state.QName())
labels[0] += "\\000"
nsec.NextDomain = strings.Join(labels, ".") + "."
}
} else {
nsec.TypeBitMap = filter14(state.QType(), zoneBitmap, mt)
}
sigs, err := d.sign([]dns.RR{nsec}, state.Zone, ttl, incep, expir, server)
if err != nil {
return nil, err
}
return append(sigs, nsec), nil
}
// The NSEC bit maps we return.
var (
delegationBitmap = [...]uint16{dns.TypeA, dns.TypeNS, dns.TypeHINFO, dns.TypeTXT, dns.TypeAAAA, dns.TypeLOC, dns.TypeSRV, dns.TypeCERT, dns.TypeSSHFP, dns.TypeRRSIG, dns.TypeNSEC, dns.TypeTLSA, dns.TypeHIP, dns.TypeOPENPGPKEY, dns.TypeSPF}
zoneBitmap = [...]uint16{dns.TypeA, dns.TypeHINFO, dns.TypeTXT, dns.TypeAAAA, dns.TypeLOC, dns.TypeSRV, dns.TypeCERT, dns.TypeSSHFP, dns.TypeRRSIG, dns.TypeNSEC, dns.TypeTLSA, dns.TypeHIP, dns.TypeOPENPGPKEY, dns.TypeSPF}
apexBitmap = [...]uint16{dns.TypeA, dns.TypeNS, dns.TypeSOA, dns.TypeHINFO, dns.TypeMX, dns.TypeTXT, dns.TypeAAAA, dns.TypeLOC, dns.TypeSRV, dns.TypeCERT, dns.TypeSSHFP, dns.TypeRRSIG, dns.TypeNSEC, dns.TypeDNSKEY, dns.TypeTLSA, dns.TypeHIP, dns.TypeOPENPGPKEY, dns.TypeSPF}
)
// filter14 filters out t from bitmap (if it exists). If mt is not an NODATA response, just return the entire bitmap.
func filter14(t uint16, bitmap [14]uint16, mt response.Type) []uint16 {
if mt != response.NoData && mt != response.NameError || t == dns.TypeNSEC {
return zoneBitmap[:]
}
for i := range bitmap {
if bitmap[i] == t {
return append(bitmap[:i], bitmap[i+1:]...)
}
}
return zoneBitmap[:] // make a slice
}
func filter18(t uint16, bitmap [18]uint16, mt response.Type) []uint16 {
if mt != response.NoData && mt != response.NameError || t == dns.TypeNSEC {
return apexBitmap[:]
}
for i := range bitmap {
if bitmap[i] == t {
return append(bitmap[:i], bitmap[i+1:]...)
}
}
return apexBitmap[:] // make a slice
}
package dnssec
import (
"hash/fnv"
"io"
"time"
"github.com/coredns/coredns/plugin/pkg/cache"
"github.com/miekg/dns"
)
// hash serializes the RRset and returns a signature cache key.
func hash(rrs []dns.RR) uint64 {
h := fnv.New64()
// we need to hash the entire RRset to pick the correct sig, if the rrset
// changes for whatever reason we should resign.
// We could use wirefmt, or the string format, both create garbage when creating
// the hash key. And of course is a uint64 big enough?
for _, rr := range rrs {
io.WriteString(h, rr.String())
}
return h.Sum64()
}
func periodicClean(c *cache.Cache, stop <-chan struct{}) {
tick := time.NewTicker(8 * time.Hour)
defer tick.Stop()
for {
select {
case <-tick.C:
// we sign for 8 days, check if a signature in the cache reached 75% of that (i.e. 6), if found delete
// the signature
is75 := time.Now().UTC().Add(twoDays)
c.Walk(func(items map[uint64]any, key uint64) bool {
for _, rr := range items[key].([]dns.RR) {
if !rr.(*dns.RRSIG).ValidityPeriod(is75) {
delete(items, key)
}
}
return true
})
case <-stop:
return
}
}
}
package dnssec
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"encoding/json"
"errors"
"os"
"path/filepath"
"strings"
"time"
"github.com/coredns/coredns/request"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
"github.com/miekg/dns"
"golang.org/x/crypto/ed25519"
)
// DNSKEY holds a DNSSEC public and private key used for on-the-fly signing.
type DNSKEY struct {
K *dns.DNSKEY
D *dns.DS
s crypto.Signer
tag uint16
}
// SecretKeyData represents the structure of the DNS keys stored in AWS Secrets Manager.
type SecretKeyData struct {
Key string `json:"key"`
Private string `json:"private"`
}
// ParseKeyFile read a DNSSEC keyfile as generated by dnssec-keygen or other
// utilities. It adds ".key" for the public key and ".private" for the private key.
func ParseKeyFile(pubFile, privFile string) (*DNSKEY, error) {
f, e := os.Open(filepath.Clean(pubFile))
if e != nil {
return nil, e
}
defer f.Close()
k, e := dns.ReadRR(f, pubFile)
if e != nil {
return nil, e
}
f, e = os.Open(filepath.Clean(privFile))
if e != nil {
return nil, e
}
defer f.Close()
dk, ok := k.(*dns.DNSKEY)
if !ok {
return nil, errors.New("no public key found")
}
p, e := dk.ReadPrivateKey(f, privFile)
if e != nil {
return nil, e
}
if s, ok := p.(*rsa.PrivateKey); ok {
return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: s, tag: dk.KeyTag()}, nil
}
if s, ok := p.(*ecdsa.PrivateKey); ok {
return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: s, tag: dk.KeyTag()}, nil
}
if s, ok := p.(ed25519.PrivateKey); ok {
return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: s, tag: dk.KeyTag()}, nil
}
return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: nil, tag: 0}, errors.New("no private key found")
}
// ParseKeyFromAWSSecretsManager retrieves and parses a DNSSEC key pair from AWS Secrets Manager.
func ParseKeyFromAWSSecretsManager(secretID string) (*DNSKEY, error) {
// Load the AWS SDK configuration
cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
return nil, err
}
// Create a Secrets Manager client
client := secretsmanager.NewFromConfig(cfg)
// Retrieve the secret value
input := &secretsmanager.GetSecretValueInput{
SecretId: &secretID,
}
result, err := client.GetSecretValue(context.TODO(), input)
if err != nil {
return nil, err
}
// Parse the secret string into SecretKeyData
var secretData SecretKeyData
err = json.Unmarshal([]byte(*result.SecretString), &secretData)
if err != nil {
return nil, err
}
// Parse the public key
rr, err := dns.NewRR(secretData.Key)
if err != nil {
return nil, err
}
dk, ok := rr.(*dns.DNSKEY)
if !ok {
return nil, errors.New("invalid public key format")
}
// Parse the private key
p, err := dk.ReadPrivateKey(strings.NewReader(secretData.Private), secretID)
if err != nil {
return nil, err
}
// Create the DNSKEY structure
var s crypto.Signer
var tag uint16
switch key := p.(type) {
case *rsa.PrivateKey:
s = key
tag = dk.KeyTag()
case *ecdsa.PrivateKey:
s = key
tag = dk.KeyTag()
case ed25519.PrivateKey:
s = key
tag = dk.KeyTag()
default:
return nil, errors.New("unsupported key type")
}
return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: s, tag: tag}, nil
}
// getDNSKEY returns the correct DNSKEY to the client. Signatures are added when do is true.
func (d Dnssec) getDNSKEY(state request.Request, zone string, do bool, server string) *dns.Msg {
keys := make([]dns.RR, len(d.keys))
for i, k := range d.keys {
keys[i] = dns.Copy(k.K)
keys[i].Header().Name = zone
}
m := new(dns.Msg)
m.SetReply(state.Req)
m.Answer = keys
if !do {
return m
}
incep, expir := incepExpir(time.Now().UTC())
if sigs, err := d.sign(keys, zone, 3600, incep, expir, server); err == nil {
m.Answer = append(m.Answer, sigs...)
}
return m
}
// Return true if, and only if, this is a zone key with the SEP bit unset. This implies a ZSK (rfc4034 2.1.1).
func (k DNSKEY) isZSK() bool {
return k.K.Flags&(1<<8) == (1<<8) && k.K.Flags&1 == 0
}
// Return true if, and only if, this is a zone key with the SEP bit set. This implies a KSK (rfc4034 2.1.1).
func (k DNSKEY) isKSK() bool {
return k.K.Flags&(1<<8) == (1<<8) && k.K.Flags&1 == 1
}
// Package dnssec implements a plugin that signs responses on-the-fly using
// NSEC black lies.
package dnssec
import (
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/cache"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/singleflight"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Dnssec signs the reply on-the-fly.
type Dnssec struct {
Next plugin.Handler
zones []string
keys []*DNSKEY
splitkeys bool
inflight *singleflight.Group
cache *cache.Cache
}
// New returns a new Dnssec.
func New(zones []string, keys []*DNSKEY, splitkeys bool, next plugin.Handler, c *cache.Cache) Dnssec {
return Dnssec{Next: next,
zones: zones,
keys: keys,
splitkeys: splitkeys,
cache: c,
inflight: new(singleflight.Group),
}
}
// Sign signs the message in state. it takes care of negative or nodata responses. It
// uses NSEC black lies for authenticated denial of existence. For delegations it
// will insert DS records and sign those.
// Signatures will be cached for a short while. By default we sign for 8 days,
// starting 3 hours ago.
func (d Dnssec) Sign(state request.Request, now time.Time, server string) *dns.Msg {
req := state.Req
incep, expir := incepExpir(now)
mt, _ := response.Typify(req, time.Now().UTC()) // TODO(miek): need opt record here?
if mt == response.Delegation {
// We either sign DS or NSEC of DS.
ttl := req.Ns[0].Header().Ttl
ds := []dns.RR{}
for i := range req.Ns {
if req.Ns[i].Header().Rrtype == dns.TypeDS {
ds = append(ds, req.Ns[i])
}
}
if len(ds) == 0 {
if sigs, err := d.nsec(state, mt, ttl, incep, expir, server); err == nil {
req.Ns = append(req.Ns, sigs...)
}
} else if sigs, err := d.sign(ds, state.Zone, ttl, incep, expir, server); err == nil {
req.Ns = append(req.Ns, sigs...)
}
return req
}
if mt == response.NameError || mt == response.NoData {
if req.Ns[0].Header().Rrtype != dns.TypeSOA || len(req.Ns) > 1 {
return req
}
ttl := req.Ns[0].Header().Ttl
if sigs, err := d.sign(req.Ns, state.Zone, ttl, incep, expir, server); err == nil {
req.Ns = append(req.Ns, sigs...)
}
if sigs, err := d.nsec(state, mt, ttl, incep, expir, server); err == nil {
req.Ns = append(req.Ns, sigs...)
}
if len(req.Ns) > 1 { // actually added nsec and sigs, reset the rcode
req.Rcode = dns.RcodeSuccess
if state.QType() == dns.TypeNSEC { // If original query was NSEC move Ns to Answer without SOA
req.Answer = req.Ns[len(req.Ns)-2 : len(req.Ns)]
req.Ns = nil
}
}
return req
}
for _, r := range rrSets(req.Answer) {
ttl := r[0].Header().Ttl
if sigs, err := d.sign(r, state.Zone, ttl, incep, expir, server); err == nil {
req.Answer = append(req.Answer, sigs...)
}
}
for _, r := range rrSets(req.Ns) {
ttl := r[0].Header().Ttl
if sigs, err := d.sign(r, state.Zone, ttl, incep, expir, server); err == nil {
req.Ns = append(req.Ns, sigs...)
}
}
for _, r := range rrSets(req.Extra) {
ttl := r[0].Header().Ttl
if sigs, err := d.sign(r, state.Zone, ttl, incep, expir, server); err == nil {
req.Extra = append(req.Extra, sigs...)
}
}
return req
}
func (d Dnssec) sign(rrs []dns.RR, signerName string, ttl, incep, expir uint32, server string) ([]dns.RR, error) {
k := hash(rrs)
sgs, ok := d.get(k, server)
if ok {
return sgs, nil
}
sigs, err := d.inflight.Do(k, func() (any, error) {
var sigs []dns.RR
for _, k := range d.keys {
if d.splitkeys {
if len(rrs) > 0 && rrs[0].Header().Rrtype == dns.TypeDNSKEY {
// We are signing a DNSKEY RRSet. With split keys, we need to use a KSK here.
if !k.isKSK() {
continue
}
} else {
// For non-DNSKEY RRSets, we want to use a ZSK.
if !k.isZSK() {
continue
}
}
}
sig := k.newRRSIG(signerName, ttl, incep, expir)
if e := sig.Sign(k.s, rrs); e != nil {
return sigs, e
}
sigs = append(sigs, sig)
}
d.set(k, sigs)
return sigs, nil
})
return sigs.([]dns.RR), err
}
func (d Dnssec) set(key uint64, sigs []dns.RR) { d.cache.Add(key, sigs) }
func (d Dnssec) get(key uint64, server string) ([]dns.RR, bool) {
if s, ok := d.cache.Get(key); ok {
// we sign for 8 days, check if a signature in the cache reached 3/4 of that
is75 := time.Now().UTC().Add(twoDays)
for _, rr := range s.([]dns.RR) {
if !rr.(*dns.RRSIG).ValidityPeriod(is75) {
cacheMisses.WithLabelValues(server).Inc()
return nil, false
}
}
cacheHits.WithLabelValues(server).Inc()
return s.([]dns.RR), true
}
cacheMisses.WithLabelValues(server).Inc()
return nil, false
}
func incepExpir(now time.Time) (uint32, uint32) {
incep := uint32(now.Add(-3 * time.Hour).Unix()) // -(2+1) hours, be sure to catch daylight saving time and such
expir := uint32(now.Add(eightDays).Unix()) // sign for 8 days
return incep, expir
}
const (
eightDays = 8 * 24 * time.Hour
twoDays = 2 * 24 * time.Hour
defaultCap = 10000 // default capacity of the cache.
)
package dnssec
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// ServeDNS implements the plugin.Handler interface.
func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
do := state.Do()
qname := state.Name()
qtype := state.QType()
zone := plugin.Zones(d.zones).Matches(qname)
if zone == "" {
return plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r)
}
state.Zone = zone
server := metrics.WithServer(ctx)
// Intercept queries for DNSKEY, but only if one of the zones matches the qname, otherwise we let
// the query through.
if qtype == dns.TypeDNSKEY {
for _, z := range d.zones {
if qname == z {
resp := d.getDNSKEY(state, z, do, server)
resp.Authoritative = true
w.WriteMsg(resp)
return dns.RcodeSuccess, nil
}
}
}
if do {
drr := &ResponseWriter{w, d, server}
return plugin.NextOrFailure(d.Name(), d.Next, ctx, drr, r)
}
return plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r)
}
// Name implements the Handler interface.
func (d Dnssec) Name() string { return "dnssec" }
package dnssec
import (
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// ResponseWriter signs the response on the fly.
type ResponseWriter struct {
dns.ResponseWriter
d Dnssec
server string // server label for metrics.
}
// WriteMsg implements the dns.ResponseWriter interface.
func (d *ResponseWriter) WriteMsg(res *dns.Msg) error {
// By definition we should sign anything that comes back, we should still figure out for
// which zone it should be.
state := request.Request{W: d.ResponseWriter, Req: res}
zone := plugin.Zones(d.d.zones).Matches(state.Name())
if zone == "" {
return d.ResponseWriter.WriteMsg(res)
}
state.Zone = zone
res = d.d.Sign(state, time.Now().UTC(), d.server)
cacheSize.WithLabelValues(d.server, "signature").Set(float64(d.d.cache.Len()))
// No need for EDNS0 trickery, as that is handled by the server.
return d.ResponseWriter.WriteMsg(res)
}
// Write implements the dns.ResponseWriter interface.
func (d *ResponseWriter) Write(buf []byte) (int, error) {
log.Warning("Dnssec called with Write: not signing reply")
n, err := d.ResponseWriter.Write(buf)
return n, err
}
package dnssec
import "github.com/miekg/dns"
// newRRSIG returns a new RRSIG, with all fields filled out, except the signed data.
func (k *DNSKEY) newRRSIG(signerName string, ttl, incep, expir uint32) *dns.RRSIG {
sig := new(dns.RRSIG)
sig.Hdr.Rrtype = dns.TypeRRSIG
sig.Algorithm = k.K.Algorithm
sig.KeyTag = k.tag
sig.SignerName = signerName
sig.Hdr.Ttl = ttl
sig.OrigTtl = origTTL
sig.Inception = incep
sig.Expiration = expir
return sig
}
type rrset struct {
qname string
qtype uint16
}
// rrSets returns rrs as a map of RRsets. It skips RRSIG and OPT records as those don't need to be signed.
func rrSets(rrs []dns.RR) map[rrset][]dns.RR {
m := make(map[rrset][]dns.RR)
for _, r := range rrs {
if r.Header().Rrtype == dns.TypeRRSIG || r.Header().Rrtype == dns.TypeOPT {
continue
}
if s, ok := m[rrset{r.Header().Name, r.Header().Rrtype}]; ok {
s = append(s, r)
m[rrset{r.Header().Name, r.Header().Rrtype}] = s
continue
}
s := make([]dns.RR, 1, 3)
s[0] = r
m[rrset{r.Header().Name, r.Header().Rrtype}] = s
}
if len(m) > 0 {
return m
}
return nil
}
const origTTL = 3600
package dnssec
import (
"fmt"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/cache"
clog "github.com/coredns/coredns/plugin/pkg/log"
)
var log = clog.NewWithPlugin("dnssec")
func init() { plugin.Register("dnssec", setup) }
func setup(c *caddy.Controller) error {
zones, keys, capacity, splitkeys, err := dnssecParse(c)
if err != nil {
return plugin.Error("dnssec", err)
}
ca := cache.New(capacity)
stop := make(chan struct{})
c.OnShutdown(func() error {
close(stop)
return nil
})
c.OnStartup(func() error {
go periodicClean(ca, stop)
return nil
})
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return New(zones, keys, splitkeys, next, ca)
})
return nil
}
func dnssecParse(c *caddy.Controller) ([]string, []*DNSKEY, int, bool, error) {
zones := []string{}
keys := []*DNSKEY{}
capacity := defaultCap
i := 0
for c.Next() {
if i > 0 {
return nil, nil, 0, false, plugin.ErrOnce
}
i++
// dnssec [zones...]
zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
for c.NextBlock() {
switch x := c.Val(); x {
case "key":
k, e := keyParse(c)
if e != nil {
return nil, nil, 0, false, e
}
keys = append(keys, k...)
case "cache_capacity":
if !c.NextArg() {
return nil, nil, 0, false, c.ArgErr()
}
value := c.Val()
cacheCap, err := strconv.Atoi(value)
if err != nil {
return nil, nil, 0, false, err
}
capacity = cacheCap
default:
return nil, nil, 0, false, c.Errf("unknown property '%s'", x)
}
}
}
// Check if we have both KSKs and ZSKs.
zsk, ksk := 0, 0
for _, k := range keys {
if k.isKSK() {
ksk++
} else if k.isZSK() {
zsk++
}
}
splitkeys := zsk > 0 && ksk > 0
// Check if each keys owner name can actually sign the zones we want them to sign.
for _, k := range keys {
kname := plugin.Name(k.K.Header().Name)
ok := slices.ContainsFunc(zones, kname.Matches)
if !ok {
return zones, keys, capacity, splitkeys, fmt.Errorf("key %s (keyid: %d) can not sign any of the zones", string(kname), k.tag)
}
}
return zones, keys, capacity, splitkeys, nil
}
func keyParse(c *caddy.Controller) ([]*DNSKEY, error) {
keys := []*DNSKEY{}
config := dnsserver.GetConfig(c)
if !c.NextArg() {
return nil, c.ArgErr()
}
value := c.Val()
switch value {
case "file":
ks := c.RemainingArgs()
if len(ks) == 0 {
return nil, c.ArgErr()
}
for _, k := range ks {
base := k
// Kmiek.nl.+013+26205.key, handle .private or without extension: Kmiek.nl.+013+26205
if strings.HasSuffix(k, ".key") {
base = k[:len(k)-4]
}
if strings.HasSuffix(k, ".private") {
base = k[:len(k)-8]
}
if !filepath.IsAbs(base) && config.Root != "" {
base = filepath.Join(config.Root, base)
}
k, err := ParseKeyFile(base+".key", base+".private")
if err != nil {
return nil, err
}
keys = append(keys, k)
}
case "aws_secretsmanager":
ks := c.RemainingArgs()
if len(ks) == 0 {
return nil, c.ArgErr()
}
for _, k := range ks {
k, err := ParseKeyFromAWSSecretsManager(k)
if err != nil {
return nil, err
}
keys = append(keys, k)
}
}
return keys, nil
}
package dnstap
import (
"io"
"time"
tap "github.com/dnstap/golang-dnstap"
fs "github.com/farsightsec/golang-framestream"
"google.golang.org/protobuf/proto"
)
// encoder wraps a golang-framestream.Encoder.
type encoder struct {
fs *fs.Encoder
}
func newEncoder(w io.Writer, timeout time.Duration) (*encoder, error) {
fs, err := fs.NewEncoder(w, &fs.EncoderOptions{
ContentType: []byte("protobuf:dnstap.Dnstap"),
Bidirectional: true,
Timeout: timeout,
})
if err != nil {
return nil, err
}
return &encoder{fs}, nil
}
func (e *encoder) writeMsg(msg *tap.Dnstap) error {
buf, err := proto.Marshal(msg)
if err != nil {
return err
}
_, err = e.fs.Write(buf) // n < len(buf) should return an error?
return err
}
func (e *encoder) flush() error { return e.fs.Flush() }
func (e *encoder) close() error { return e.fs.Close() }
package dnstap
import (
"context"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/dnstap/msg"
"github.com/coredns/coredns/plugin/pkg/replacer"
"github.com/coredns/coredns/request"
tap "github.com/dnstap/golang-dnstap"
"github.com/miekg/dns"
)
// Dnstap is the dnstap handler.
type Dnstap struct {
Next plugin.Handler
io tapper
repl replacer.Replacer
// IncludeRawMessage will include the raw DNS message into the dnstap messages if true.
IncludeRawMessage bool
Identity []byte
Version []byte
ExtraFormat string
MultipleTcpWriteBuf int // *Mb
MultipleQueue int // *10000
}
// TapMessage sends the message m to the dnstap interface, without populating "Extra" field.
func (h *Dnstap) TapMessage(m *tap.Message) {
if h.ExtraFormat == "" {
h.tapWithExtra(m, nil)
} else {
h.tapWithExtra(m, []byte(h.ExtraFormat))
}
}
// TapMessageWithMetadata sends the message m to the dnstap interface, with "Extra" field being populated.
func (h *Dnstap) TapMessageWithMetadata(ctx context.Context, m *tap.Message, state request.Request) {
if h.ExtraFormat == "" {
h.tapWithExtra(m, nil)
return
}
extraStr := h.repl.Replace(ctx, state, nil, h.ExtraFormat)
h.tapWithExtra(m, []byte(extraStr))
}
func (h *Dnstap) tapWithExtra(m *tap.Message, extra []byte) {
t := tap.Dnstap_MESSAGE
h.io.Dnstap(&tap.Dnstap{Type: &t, Message: m, Identity: h.Identity, Version: h.Version, Extra: extra})
}
func (h *Dnstap) tapQuery(ctx context.Context, w dns.ResponseWriter, query *dns.Msg, queryTime time.Time) {
q := new(tap.Message)
msg.SetQueryTime(q, queryTime)
msg.SetQueryAddress(q, w.RemoteAddr())
if h.IncludeRawMessage {
buf, _ := query.Pack()
q.QueryMessage = buf
}
msg.SetType(q, tap.Message_CLIENT_QUERY)
state := request.Request{W: w, Req: query}
h.TapMessageWithMetadata(ctx, q, state)
}
// ServeDNS logs the client query and response to dnstap and passes the dnstap Context.
func (h *Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rw := &ResponseWriter{
ResponseWriter: w,
Dnstap: h,
query: r,
ctx: ctx,
queryTime: time.Now(),
}
// The query tap message should be sent before sending the query to the
// forwarder. Otherwise, the tap messages will come out out of order.
h.tapQuery(ctx, w, r, rw.queryTime)
return plugin.NextOrFailure(h.Name(), h.Next, ctx, rw, r)
}
// Name implements the plugin.Plugin interface.
func (h *Dnstap) Name() string { return "dnstap" }
package dnstap
import (
"crypto/tls"
"errors"
"net"
"sync/atomic"
"time"
tap "github.com/dnstap/golang-dnstap"
)
const (
tcpWriteBufSize = 1024 * 1024 // there is no good explanation for why this number has this value.
queueSize = 10000 // idem.
tcpTimeout = 4 * time.Second
flushTimeout = 1 * time.Second
errorCheckInterval = 10 * time.Second
skipVerify = false // by default, every tls connection is verified to be secure
)
// tapper interface is used in testing to mock the Dnstap method.
type tapper interface {
Dnstap(*tap.Dnstap)
}
type WarnLogger interface {
Warningf(format string, v ...any)
}
// dio implements the Tapper interface.
type dio struct {
endpoint string
proto string
enc *encoder
queue chan *tap.Dnstap
dropped uint32
quit chan struct{}
flushTimeout time.Duration
tcpTimeout time.Duration
skipVerify bool
tcpWriteBufSize int
logger WarnLogger
errorCheckInterval time.Duration
}
var errNoOutput = errors.New("dnstap not connected to output socket")
// newIO returns a new and initialized pointer to a dio.
func newIO(proto, endpoint string, multipleQueue int, multipleTcpWriteBuf int) *dio {
return &dio{
endpoint: endpoint,
proto: proto,
queue: make(chan *tap.Dnstap, multipleQueue*queueSize),
quit: make(chan struct{}),
flushTimeout: flushTimeout,
tcpTimeout: tcpTimeout,
skipVerify: skipVerify,
tcpWriteBufSize: multipleTcpWriteBuf * tcpWriteBufSize,
logger: log,
errorCheckInterval: errorCheckInterval,
}
}
func (d *dio) dial() error {
var conn net.Conn
var err error
if d.proto == "tls" {
config := &tls.Config{
InsecureSkipVerify: d.skipVerify,
}
dialer := &net.Dialer{
Timeout: d.tcpTimeout,
}
conn, err = tls.DialWithDialer(dialer, "tcp", d.endpoint, config)
if err != nil {
return err
}
} else {
conn, err = net.DialTimeout(d.proto, d.endpoint, d.tcpTimeout)
if err != nil {
return err
}
}
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetWriteBuffer(d.tcpWriteBufSize)
tcpConn.SetNoDelay(false)
}
d.enc, err = newEncoder(conn, d.tcpTimeout)
return err
}
// Connect connects to the dnstap endpoint.
func (d *dio) connect() error {
err := d.dial()
go d.serve()
return err
}
// Dnstap enqueues the payload for log.
func (d *dio) Dnstap(payload *tap.Dnstap) {
select {
case d.queue <- payload:
default:
atomic.AddUint32(&d.dropped, 1)
}
}
// close waits until the I/O routine is finished to return.
func (d *dio) close() { close(d.quit) }
func (d *dio) write(payload *tap.Dnstap) error {
if d.enc == nil {
return errNoOutput
}
if err := d.enc.writeMsg(payload); err != nil {
return err
}
return nil
}
func (d *dio) serve() {
flushTicker := time.NewTicker(d.flushTimeout)
errorCheckTicker := time.NewTicker(d.errorCheckInterval)
defer flushTicker.Stop()
defer errorCheckTicker.Stop()
for {
select {
case <-d.quit:
if d.enc == nil {
return
}
d.enc.flush()
d.enc.close()
return
case payload := <-d.queue:
if err := d.write(payload); err != nil {
atomic.AddUint32(&d.dropped, 1)
if !errors.Is(err, errNoOutput) {
// Redial immediately if it's not an output connection error
d.dial()
}
}
case <-flushTicker.C:
if d.enc != nil {
d.enc.flush()
}
case <-errorCheckTicker.C:
if dropped := atomic.SwapUint32(&d.dropped, 0); dropped > 0 {
d.logger.Warningf("Dropped dnstap messages: %d\n", dropped)
}
if d.enc == nil {
d.dial()
}
}
}
}
package msg
import (
"fmt"
"net"
"time"
tap "github.com/dnstap/golang-dnstap"
)
var (
protoUDP = tap.SocketProtocol_UDP
protoTCP = tap.SocketProtocol_TCP
familyINET = tap.SocketFamily_INET
familyINET6 = tap.SocketFamily_INET6
)
// SetQueryAddress adds the query address to the message. This also sets the SocketFamily and SocketProtocol.
func SetQueryAddress(t *tap.Message, addr net.Addr) error {
t.SocketFamily = &familyINET
switch a := addr.(type) {
case *net.TCPAddr:
t.SocketProtocol = &protoTCP
t.QueryAddress = a.IP
p := uint32(a.Port)
t.QueryPort = &p
if a.IP.To4() == nil {
t.SocketFamily = &familyINET6
}
return nil
case *net.UDPAddr:
t.SocketProtocol = &protoUDP
t.QueryAddress = a.IP
p := uint32(a.Port)
t.QueryPort = &p
if a.IP.To4() == nil {
t.SocketFamily = &familyINET6
}
return nil
default:
return fmt.Errorf("unknown address type: %T", a)
}
}
// SetResponseAddress the response address to the message. This also sets the SocketFamily and SocketProtocol.
func SetResponseAddress(t *tap.Message, addr net.Addr) error {
t.SocketFamily = &familyINET
switch a := addr.(type) {
case *net.TCPAddr:
t.SocketProtocol = &protoTCP
t.ResponseAddress = a.IP
p := uint32(a.Port)
t.ResponsePort = &p
if a.IP.To4() == nil {
t.SocketFamily = &familyINET6
}
return nil
case *net.UDPAddr:
t.SocketProtocol = &protoUDP
t.ResponseAddress = a.IP
p := uint32(a.Port)
t.ResponsePort = &p
if a.IP.To4() == nil {
t.SocketFamily = &familyINET6
}
return nil
default:
return fmt.Errorf("unknown address type: %T", a)
}
}
// SetQueryTime sets the time of the query in t.
func SetQueryTime(t *tap.Message, ti time.Time) {
qts := uint64(ti.Unix())
qtn := uint32(ti.Nanosecond())
t.QueryTimeSec = &qts
t.QueryTimeNsec = &qtn
}
// SetResponseTime sets the time of the response in t.
func SetResponseTime(t *tap.Message, ti time.Time) {
rts := uint64(ti.Unix())
rtn := uint32(ti.Nanosecond())
t.ResponseTimeSec = &rts
t.ResponseTimeNsec = &rtn
}
// SetType sets the type in t.
func SetType(t *tap.Message, typ tap.Message_Type) { t.Type = &typ }
package dnstap
import (
"net/url"
"os"
"strconv"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/replacer"
)
var log = clog.NewWithPlugin("dnstap")
func init() { plugin.Register("dnstap", setup) }
const (
// Upper bounds chosen to keep memory use and kernel socket buffer requests reasonable
// while allowing large configurations. Write buffer multiple is in MiB units; queue
// multiple is applied to 10,000 messages. See plugin README for parameter semantics.
maxMultipleTcpWriteBuf = 1024 // up to 1 GiB write buffer per TCP connection
maxMultipleQueue = 4096 // up to 40,960,000 enqueued messages
)
func parseConfig(c *caddy.Controller) ([]*Dnstap, error) {
dnstaps := []*Dnstap{}
for c.Next() { // directive name
d := Dnstap{
MultipleTcpWriteBuf: 1,
MultipleQueue: 1,
}
d.repl = replacer.New()
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
endpoint := args[0]
if len(args) >= 3 {
tcpWriteBuf := args[2]
if v, err := strconv.Atoi(tcpWriteBuf); err == nil {
if v < 1 || v > maxMultipleTcpWriteBuf {
return nil, c.Errf("dnstap: MultipleTcpWriteBuf must be between 1 and %d (MiB units): %d", maxMultipleTcpWriteBuf, v)
}
d.MultipleTcpWriteBuf = v
} else {
return nil, c.Errf("dnstap: invalid MultipleTcpWriteBuf %q: %v", tcpWriteBuf, err)
}
}
if len(args) >= 4 {
qSize := args[3]
if v, err := strconv.Atoi(qSize); err == nil {
if v < 1 || v > maxMultipleQueue {
return nil, c.Errf("dnstap: MultipleQueue must be between 1 and %d (x10k messages): %d", maxMultipleQueue, v)
}
d.MultipleQueue = v
} else {
return nil, c.Errf("dnstap: invalid MultipleQueue %q: %v", qSize, err)
}
}
var dio *dio
if strings.HasPrefix(endpoint, "tls://") {
// remote network endpoint
endpointURL, err := url.Parse(endpoint)
if err != nil {
return nil, c.ArgErr()
}
dio = newIO("tls", endpointURL.Host, d.MultipleQueue, d.MultipleTcpWriteBuf)
d.io = dio
} else if strings.HasPrefix(endpoint, "tcp://") {
// remote network endpoint
endpointURL, err := url.Parse(endpoint)
if err != nil {
return nil, c.ArgErr()
}
dio = newIO("tcp", endpointURL.Host, d.MultipleQueue, d.MultipleTcpWriteBuf)
d.io = dio
} else {
endpoint = strings.TrimPrefix(endpoint, "unix://")
dio = newIO("unix", endpoint, d.MultipleQueue, d.MultipleTcpWriteBuf)
d.io = dio
}
d.IncludeRawMessage = len(args) >= 2 && args[1] == "full"
hostname, _ := os.Hostname()
d.Identity = []byte(hostname)
d.Version = []byte(caddy.AppName + "-" + caddy.AppVersion)
for c.NextBlock() {
switch c.Val() {
case "skipverify":
{
dio.skipVerify = true
}
case "identity":
{
if !c.NextArg() {
return nil, c.ArgErr()
}
d.Identity = []byte(c.Val())
}
case "version":
{
if !c.NextArg() {
return nil, c.ArgErr()
}
d.Version = []byte(c.Val())
}
case "extra":
{
if !c.NextArg() {
return nil, c.ArgErr()
}
d.ExtraFormat = c.Val()
}
}
}
dnstaps = append(dnstaps, &d)
}
return dnstaps, nil
}
func setup(c *caddy.Controller) error {
dnstaps, err := parseConfig(c)
if err != nil {
return plugin.Error("dnstap", err)
}
for i := range dnstaps {
dnstap := dnstaps[i]
c.OnStartup(func() error {
if err := dnstap.io.(*dio).connect(); err != nil {
log.Errorf("No connection to dnstap endpoint: %s", err)
}
return nil
})
c.OnRestart(func() error {
dnstap.io.(*dio).close()
return nil
})
c.OnFinalShutdown(func() error {
dnstap.io.(*dio).close()
return nil
})
if i == len(dnstaps)-1 {
// last dnstap plugin in block: point next to next plugin
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
dnstap.Next = next
return dnstap
})
} else {
// not last dnstap plugin in block: point next to next dnstap
nextDnstap := dnstaps[i+1]
dnsserver.GetConfig(c).AddPlugin(func(plugin.Handler) plugin.Handler {
dnstap.Next = nextDnstap
return dnstap
})
}
}
return nil
}
package dnstap
import (
"context"
"time"
"github.com/coredns/coredns/plugin/dnstap/msg"
"github.com/coredns/coredns/request"
tap "github.com/dnstap/golang-dnstap"
"github.com/miekg/dns"
)
// ResponseWriter captures the client response and logs the query to dnstap.
type ResponseWriter struct {
queryTime time.Time
query *dns.Msg
ctx context.Context
dns.ResponseWriter
*Dnstap
}
// WriteMsg writes back the response to the client and THEN works on logging the request and response to dnstap.
func (w *ResponseWriter) WriteMsg(resp *dns.Msg) error {
err := w.ResponseWriter.WriteMsg(resp)
if err != nil {
return err
}
r := new(tap.Message)
msg.SetQueryTime(r, w.queryTime)
msg.SetResponseTime(r, time.Now())
msg.SetQueryAddress(r, w.RemoteAddr())
if w.IncludeRawMessage {
buf, _ := resp.Pack()
r.ResponseMessage = buf
}
msg.SetType(r, tap.Message_CLIENT_RESPONSE)
state := request.Request{W: w.ResponseWriter, Req: w.query}
w.TapMessageWithMetadata(w.ctx, r, state)
return nil
}
package plugin
import "context"
// Done is a non-blocking function that returns true if the context has been canceled.
func Done(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}
package erratic
import "github.com/coredns/coredns/request"
// AutoPath implements the AutoPathFunc call from the autopath plugin.
func (e *Erratic) AutoPath(state request.Request) []string {
return []string{"a.example.org.", "b.example.org.", ""}
}
// Package erratic implements a plugin that returns erratic answers (delayed, dropped).
package erratic
import (
"context"
"sync/atomic"
"time"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Erratic is a plugin that returns erratic responses to each client.
type Erratic struct {
q uint64 // counter of queries
drop uint64
delay uint64
truncate uint64
duration time.Duration
large bool // undocumented feature; return large responses for A request (>512B, to test compression).
}
// ServeDNS implements the plugin.Handler interface.
func (e *Erratic) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
drop := false
delay := false
trunc := false
queryNr := atomic.LoadUint64(&e.q)
atomic.AddUint64(&e.q, 1)
if e.drop > 0 && queryNr%e.drop == 0 {
drop = true
}
if e.delay > 0 && queryNr%e.delay == 0 {
delay = true
}
if e.truncate > 0 && queryNr&e.truncate == 0 {
trunc = true
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
if trunc {
m.Truncated = true
}
// small dance to copy rrA or rrAAAA into a non-pointer var that allows us to overwrite the ownername
// in a non-racy way.
switch state.QType() {
case dns.TypeA:
rr := *(rrA.(*dns.A))
rr.Header().Name = state.QName()
m.Answer = append(m.Answer, &rr)
if e.large {
for range 29 {
m.Answer = append(m.Answer, &rr)
}
}
case dns.TypeAAAA:
rr := *(rrAAAA.(*dns.AAAA))
rr.Header().Name = state.QName()
m.Answer = append(m.Answer, &rr)
case dns.TypeAXFR:
if drop {
return 0, nil
}
if delay {
time.Sleep(e.duration)
}
xfr(state, trunc)
return 0, nil
default:
if drop {
return 0, nil
}
if delay {
time.Sleep(e.duration)
}
// coredns will return error.
return dns.RcodeServerFailure, nil
}
if drop {
return 0, nil
}
if delay {
time.Sleep(e.duration)
}
w.WriteMsg(m)
return 0, nil
}
// Name implements the Handler interface.
func (e *Erratic) Name() string { return "erratic" }
var (
rrA, _ = dns.NewRR(". IN 0 A 192.0.2.53")
rrAAAA, _ = dns.NewRR(". IN 0 AAAA 2001:DB8::53")
)
package erratic
import "sync/atomic"
// Ready returns true if the number of received queries is in the range [3, 5). All other values return false.
// To aid in testing we want to this flip between ready and not ready.
func (e *Erratic) Ready() bool {
q := atomic.LoadUint64(&e.q)
if q >= 3 && q < 5 {
return true
}
return false
}
package erratic
import (
"fmt"
"strconv"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("erratic", setup) }
func setup(c *caddy.Controller) error {
e, err := parseErratic(c)
if err != nil {
return plugin.Error("erratic", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return e
})
return nil
}
func parseErratic(c *caddy.Controller) (*Erratic, error) {
e := &Erratic{drop: 2}
drop := false // true if we've seen the drop keyword
for c.Next() { // 'erratic'
for c.NextBlock() {
switch c.Val() {
case "drop":
args := c.RemainingArgs()
if len(args) > 1 {
return nil, c.ArgErr()
}
if len(args) == 0 {
continue
}
amount, err := strconv.ParseInt(args[0], 10, 32)
if err != nil {
return nil, err
}
if amount < 0 {
return nil, fmt.Errorf("illegal amount value given %q", args[0])
}
e.drop = uint64(amount)
drop = true
case "delay":
args := c.RemainingArgs()
if len(args) > 2 {
return nil, c.ArgErr()
}
// Defaults.
e.delay = 2
e.duration = 100 * time.Millisecond
if len(args) == 0 {
continue
}
amount, err := strconv.ParseInt(args[0], 10, 32)
if err != nil {
return nil, err
}
if amount < 0 {
return nil, fmt.Errorf("illegal amount value given %q", args[0])
}
e.delay = uint64(amount)
if len(args) > 1 {
duration, err := time.ParseDuration(args[1])
if err != nil {
return nil, err
}
e.duration = duration
}
case "truncate":
args := c.RemainingArgs()
if len(args) > 1 {
return nil, c.ArgErr()
}
if len(args) == 0 {
continue
}
amount, err := strconv.ParseInt(args[0], 10, 32)
if err != nil {
return nil, err
}
if amount < 0 {
return nil, fmt.Errorf("illegal amount value given %q", args[0])
}
e.truncate = uint64(amount)
case "large":
e.large = true
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
}
if (e.delay > 0 || e.truncate > 0) && !drop { // delay is set, but we've haven't seen a drop keyword, remove default drop stuff
e.drop = 0
}
return e, nil
}
package erratic
import (
"strings"
"sync"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// allRecords returns a small zone file. The first RR must be a SOA.
func allRecords(name string) []dns.RR {
var rrs = []dns.RR{
test.SOA("xx. 0 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2018050825 7200 3600 1209600 3600"),
test.NS("xx. 0 IN NS b.xx."),
test.NS("xx. 0 IN NS a.xx."),
test.AAAA("a.xx. 0 IN AAAA 2001:bd8::53"),
test.AAAA("b.xx. 0 IN AAAA 2001:500::54"),
}
for _, r := range rrs {
r.Header().Name = strings.Replace(r.Header().Name, "xx.", name, 1)
if n, ok := r.(*dns.NS); ok {
n.Ns = strings.Replace(n.Ns, "xx.", name, 1)
}
}
return rrs
}
func xfr(state request.Request, truncate bool) {
rrs := allRecords(state.QName())
ch := make(chan *dns.Envelope)
tr := new(dns.Transfer)
go func() {
// So the rrs we have don't have a closing SOA, only add that when truncate is false,
// so we send an incomplete AXFR.
if !truncate {
rrs = append(rrs, rrs[0])
}
ch <- &dns.Envelope{RR: rrs}
close(ch)
}()
wg := new(sync.WaitGroup)
wg.Add(1)
go func() {
tr.Out(state.W, state.Req, ch)
wg.Done()
}()
wg.Wait()
}
// Package errors implements an error handling plugin.
package errors
import (
"context"
"regexp"
"sync/atomic"
"time"
"unsafe"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("errors")
type pattern struct {
ptimer unsafe.Pointer
count uint32
period time.Duration
pattern *regexp.Regexp
logCallback func(format string, v ...any)
}
func (p *pattern) timer() *time.Timer {
return (*time.Timer)(atomic.LoadPointer(&p.ptimer))
}
func (p *pattern) setTimer(t *time.Timer) {
atomic.StorePointer(&p.ptimer, unsafe.Pointer(t))
}
// errorHandler handles DNS errors (and errors from other plugin).
type errorHandler struct {
patterns []*pattern
stopFlag uint32
Next plugin.Handler
}
func newErrorHandler() *errorHandler {
return &errorHandler{}
}
func (h *errorHandler) logPattern(i int) {
cnt := atomic.SwapUint32(&h.patterns[i].count, 0)
if cnt > 0 {
h.patterns[i].logCallback("%d errors like '%s' occurred in last %s",
cnt, h.patterns[i].pattern.String(), h.patterns[i].period)
}
}
func (h *errorHandler) inc(i int) bool {
if atomic.LoadUint32(&h.stopFlag) > 0 {
return false
}
if atomic.AddUint32(&h.patterns[i].count, 1) == 1 {
ind := i
t := time.AfterFunc(h.patterns[ind].period, func() {
h.logPattern(ind)
})
h.patterns[ind].setTimer(t)
if atomic.LoadUint32(&h.stopFlag) > 0 && t.Stop() {
h.logPattern(ind)
}
}
return true
}
func (h *errorHandler) stop() {
atomic.StoreUint32(&h.stopFlag, 1)
for i := range h.patterns {
t := h.patterns[i].timer()
if t != nil && t.Stop() {
h.logPattern(i)
}
}
}
// ServeDNS implements the plugin.Handler interface.
func (h *errorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rcode, err := plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
if err != nil {
strErr := err.Error()
for i := range h.patterns {
if h.patterns[i].pattern.MatchString(strErr) {
if h.inc(i) {
return rcode, err
}
break
}
}
state := request.Request{W: w, Req: r}
log.Errorf("%d %s %s: %s", rcode, state.Name(), state.Type(), strErr)
}
return rcode, err
}
// Name implements the plugin.Handler interface.
func (h *errorHandler) Name() string { return "errors" }
package errors
import (
"regexp"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("errors", setup) }
func setup(c *caddy.Controller) error {
handler, err := errorsParse(c)
if err != nil {
return plugin.Error("errors", err)
}
c.OnShutdown(func() error {
handler.stop()
return nil
})
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
handler.Next = next
return handler
})
return nil
}
func errorsParse(c *caddy.Controller) (*errorHandler, error) {
handler := newErrorHandler()
i := 0
for c.Next() {
if i > 0 {
return nil, plugin.ErrOnce
}
i++
args := c.RemainingArgs()
switch len(args) {
case 0:
case 1:
if args[0] != "stdout" {
return nil, c.Errf("invalid log file: %s", args[0])
}
default:
return nil, c.ArgErr()
}
for c.NextBlock() {
switch c.Val() {
case "stacktrace":
dnsserver.GetConfig(c).Stacktrace = true
case "consolidate":
pattern, err := parseConsolidate(c)
if err != nil {
return nil, err
}
handler.patterns = append(handler.patterns, pattern)
default:
return handler, c.SyntaxErr("Unknown field " + c.Val())
}
}
}
return handler, nil
}
func parseConsolidate(c *caddy.Controller) (*pattern, error) {
args := c.RemainingArgs()
if len(args) < 2 || len(args) > 3 {
return nil, c.ArgErr()
}
p, err := time.ParseDuration(args[0])
if err != nil {
return nil, c.Err(err.Error())
}
re, err := regexp.Compile(args[1])
if err != nil {
return nil, c.Err(err.Error())
}
lc, err := parseLogLevel(c, args)
if err != nil {
return nil, err
}
return &pattern{period: p, pattern: re, logCallback: lc}, nil
}
func parseLogLevel(c *caddy.Controller, args []string) (func(format string, v ...any), error) {
if len(args) != 3 {
return log.Errorf, nil
}
switch args[2] {
case "warning":
return log.Warningf, nil
case "error":
return log.Errorf, nil
case "info":
return log.Infof, nil
case "debug":
return log.Debugf, nil
default:
return nil, c.Errf("unknown log level argument in consolidate: %s", args[2])
}
}
// Package etcd provides the etcd version 3 backend plugin.
package etcd
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
"go.etcd.io/etcd/api/v3/mvccpb"
etcdcv3 "go.etcd.io/etcd/client/v3"
)
const (
defaultPriority = 10 // default priority when nothing is set
defaultTTL = 300 // default ttl when nothing is set
defaultLeaseMinTTL = 30 // default minimum TTL for lease-based records
defaultLeaseMaxTTL = 86400 // default maximum TTL for lease-based records
etcdTimeout = 5 * time.Second
)
var errKeyNotFound = errors.New("key not found")
// Etcd is a plugin talks to an etcd cluster.
type Etcd struct {
Next plugin.Handler
Fall fall.F
Zones []string
PathPrefix string
Upstream *upstream.Upstream
Client *etcdcv3.Client
MinLeaseTTL uint32 // minimum TTL for lease-based records
MaxLeaseTTL uint32 // maximum TTL for lease-based records
endpoints []string // Stored here as well, to aid in testing.
}
// Services implements the ServiceBackend interface.
func (e *Etcd) Services(ctx context.Context, state request.Request, exact bool, opt plugin.Options) (services []msg.Service, err error) {
services, err = e.Records(ctx, state, exact)
if err != nil {
return
}
services = msg.Group(services)
return
}
// Reverse implements the ServiceBackend interface.
func (e *Etcd) Reverse(ctx context.Context, state request.Request, exact bool, opt plugin.Options) (services []msg.Service, err error) {
return e.Services(ctx, state, exact, opt)
}
// Lookup implements the ServiceBackend interface.
func (e *Etcd) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) {
return e.Upstream.Lookup(ctx, state, name, typ)
}
// IsNameError implements the ServiceBackend interface.
func (e *Etcd) IsNameError(err error) bool {
return err == errKeyNotFound
}
// Records looks up records in etcd. If exact is true, it will lookup just this
// name. This is used when find matches when completing SRV lookups for instance.
func (e *Etcd) Records(ctx context.Context, state request.Request, exact bool) ([]msg.Service, error) {
name := state.Name()
path, star := msg.PathWithWildcard(name, e.PathPrefix)
r, err := e.get(ctx, path, !exact)
if err != nil {
return nil, err
}
segments := strings.Split(msg.Path(name, e.PathPrefix), "/")
return e.loopNodes(r.Kvs, segments, star, state.QType())
}
func (e *Etcd) get(ctx context.Context, path string, recursive bool) (*etcdcv3.GetResponse, error) {
ctx, cancel := context.WithTimeout(ctx, etcdTimeout)
defer cancel()
if recursive {
if !strings.HasSuffix(path, "/") {
path = path + "/"
}
r, err := e.Client.Get(ctx, path, etcdcv3.WithPrefix())
if err != nil {
return nil, err
}
if r.Count == 0 {
path = strings.TrimSuffix(path, "/")
r, err = e.Client.Get(ctx, path)
if err != nil {
return nil, err
}
if r.Count == 0 {
return nil, errKeyNotFound
}
}
return r, nil
}
r, err := e.Client.Get(ctx, path)
if err != nil {
return nil, err
}
if r.Count == 0 {
return nil, errKeyNotFound
}
return r, nil
}
func (e *Etcd) loopNodes(kv []*mvccpb.KeyValue, nameParts []string, star bool, qType uint16) (sx []msg.Service, err error) {
bx := make(map[msg.Service]struct{})
Nodes:
for _, n := range kv {
if star {
s := string(n.Key)
keyParts := strings.Split(s, "/")
for i, n := range nameParts {
if i > len(keyParts)-1 {
// name is longer than key
continue Nodes
}
if n == "*" || n == "any" {
continue
}
if keyParts[i] != n {
continue Nodes
}
}
}
serv := new(msg.Service)
if err := json.Unmarshal(n.Value, serv); err != nil {
return nil, fmt.Errorf("%s: %s", n.Key, err.Error())
}
serv.Key = string(n.Key)
if _, ok := bx[*serv]; ok {
continue
}
bx[*serv] = struct{}{}
serv.TTL = e.TTL(n, serv)
if serv.Priority == 0 {
serv.Priority = defaultPriority
}
if shouldInclude(serv, qType) {
sx = append(sx, *serv)
}
}
return sx, nil
}
// TTL returns the smaller of the etcd TTL and the service's
// TTL. If neither of these are set (have a zero value), a default is used.
func (e *Etcd) TTL(kv *mvccpb.KeyValue, serv *msg.Service) uint32 {
var etcdTTL uint32
// Get actual lease TTL from etcd if lease exists and client is available
if kv.Lease != 0 && e.Client != nil {
if resp, err := e.Client.TimeToLive(context.Background(), etcdcv3.LeaseID(kv.Lease)); err == nil && resp.TTL > 0 {
leaseTTL := resp.TTL
// Get bounds with defaults
minTTL := e.MinLeaseTTL
if minTTL == 0 {
minTTL = defaultLeaseMinTTL
}
maxTTL := e.MaxLeaseTTL
if maxTTL == 0 {
maxTTL = defaultLeaseMaxTTL
}
// Clamp lease TTL to configured bounds
minTTL64 := int64(minTTL)
maxTTL64 := int64(maxTTL)
if leaseTTL < minTTL64 {
leaseTTL = minTTL64
} else if leaseTTL > maxTTL64 {
leaseTTL = maxTTL64
}
etcdTTL = uint32(leaseTTL)
}
}
if etcdTTL == 0 && serv.TTL == 0 {
return defaultTTL
}
if etcdTTL == 0 {
return serv.TTL
}
if serv.TTL == 0 {
return etcdTTL
}
if etcdTTL < serv.TTL {
return etcdTTL
}
return serv.TTL
}
// shouldInclude returns true if the service should be included in a list of records, given the qType. For all the
// currently supported lookup types, the only one to allow for an empty Host field in the service are TXT records
// which resolve directly. If a TXT record is being resolved by CNAME, then we expect the Host field to have a
// value while the TXT field will be empty.
func shouldInclude(serv *msg.Service, qType uint16) bool {
return (qType == dns.TypeTXT && serv.Text != "") || serv.Host != ""
}
// OnShutdown shuts down etcd client when caddy instance restart
func (e *Etcd) OnShutdown() error {
if e.Client != nil {
e.Client.Close()
}
return nil
}
package etcd
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// ServeDNS implements the plugin.Handler interface.
func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
opt := plugin.Options{}
state := request.Request{W: w, Req: r}
zone := plugin.Zones(e.Zones).Matches(state.Name())
if zone == "" {
return plugin.NextOrFailure(e.Name(), e.Next, ctx, w, r)
}
var (
records, extra []dns.RR
truncated bool
err error
)
switch state.QType() {
case dns.TypeA:
records, truncated, err = plugin.A(ctx, e, zone, state, nil, opt)
case dns.TypeAAAA:
records, truncated, err = plugin.AAAA(ctx, e, zone, state, nil, opt)
case dns.TypeTXT:
records, truncated, err = plugin.TXT(ctx, e, zone, state, nil, opt)
case dns.TypeCNAME:
records, err = plugin.CNAME(ctx, e, zone, state, opt)
case dns.TypePTR:
records, err = plugin.PTR(ctx, e, zone, state, opt)
case dns.TypeMX:
records, extra, err = plugin.MX(ctx, e, zone, state, opt)
case dns.TypeSRV:
records, extra, err = plugin.SRV(ctx, e, zone, state, opt)
case dns.TypeSOA:
records, err = plugin.SOA(ctx, e, zone, state, opt)
case dns.TypeNS:
if state.Name() == zone {
records, extra, err = plugin.NS(ctx, e, zone, state, opt)
break
}
fallthrough
default:
// Do a fake A lookup, so we can distinguish between NODATA and NXDOMAIN
_, _, err = plugin.A(ctx, e, zone, state, nil, opt)
}
if err != nil && e.IsNameError(err) {
if e.Fall.Through(state.Name()) {
return plugin.NextOrFailure(e.Name(), e.Next, ctx, w, r)
}
// Make err nil when returning here, so we don't log spam for NXDOMAIN.
return plugin.BackendError(ctx, e, zone, dns.RcodeNameError, state, nil /* err */, opt)
}
if err != nil {
return plugin.BackendError(ctx, e, zone, dns.RcodeServerFailure, state, err, opt)
}
if len(records) == 0 {
return plugin.BackendError(ctx, e, zone, dns.RcodeSuccess, state, err, opt)
}
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = truncated
m.Authoritative = true
m.Answer = records
m.Extra = extra
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
// Name implements the Handler interface.
func (e *Etcd) Name() string { return "etcd" }
package msg
import (
"path"
"strings"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/miekg/dns"
)
// Path converts a domainname to an etcd path. If s looks like service.staging.skydns.local.,
// the resulting key will be /skydns/local/skydns/staging/service .
func Path(s, prefix string) string {
l := dns.SplitDomainName(s)
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
return path.Join(append([]string{"/" + prefix + "/"}, l...)...)
}
// Domain is the opposite of Path.
func Domain(s string) string {
l := strings.Split(s, "/")
if l[len(l)-1] == "" {
l = l[:len(l)-1]
}
// start with 1, to strip /skydns
for i, j := 1, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
return dnsutil.Join(l[1 : len(l)-1]...)
}
// PathWithWildcard acts as Path, but if a name contains wildcards (* or any), the name will be
// chopped of before the (first) wildcard, and we do a higher level search and
// later find the matching names. So service.*.skydns.local, will look for all
// services under skydns.local and will later check for names that match
// service.*.skydns.local. If a wildcard is found the returned bool is true.
func PathWithWildcard(s, prefix string) (string, bool) {
l := dns.SplitDomainName(s)
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
for i, k := range l {
if k == "*" || k == "any" {
return path.Join(append([]string{"/" + prefix + "/"}, l[:i]...)...), true
}
}
return path.Join(append([]string{"/" + prefix + "/"}, l...)...), false
}
// Package msg defines the Service structure which is used for service discovery.
package msg
import (
"net"
"strings"
"github.com/miekg/dns"
)
// Service defines a discoverable service in etcd. It is the rdata from a SRV
// record, but with a twist. Host (Target in SRV) must be a domain name, but
// if it looks like an IP address (4/6), we will treat it like an IP address.
type Service struct {
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
Priority int `json:"priority,omitempty"`
Weight int `json:"weight,omitempty"`
Text string `json:"text,omitempty"`
Mail bool `json:"mail,omitempty"` // Be an MX record. Priority becomes Preference.
TTL uint32 `json:"ttl,omitempty"`
// When a SRV record with a "Host: IP-address" is added, we synthesize
// a srv.Target domain name. Normally we convert the full Key where
// the record lives to a DNS name and use this as the srv.Target. When
// TargetStrip > 0 we strip the left most TargetStrip labels from the
// DNS name.
TargetStrip int `json:"targetstrip,omitempty"`
// Group is used to group (or *not* to group) different services
// together. Services with an identical Group are returned in the same
// answer.
Group string `json:"group,omitempty"`
// Etcd key where we found this service and ignored from json un-/marshalling
Key string `json:"-"`
}
// NewSRV returns a new SRV record based on the Service.
func (s *Service) NewSRV(name string, weight uint16) *dns.SRV {
host := dns.Fqdn(s.Host)
if s.TargetStrip > 0 {
host = targetStrip(host, s.TargetStrip)
}
return &dns.SRV{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: s.TTL},
Priority: uint16(s.Priority), Weight: weight, Port: uint16(s.Port), Target: host}
}
// NewMX returns a new MX record based on the Service.
func (s *Service) NewMX(name string) *dns.MX {
host := dns.Fqdn(s.Host)
if s.TargetStrip > 0 {
host = targetStrip(host, s.TargetStrip)
}
return &dns.MX{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: s.TTL},
Preference: uint16(s.Priority), Mx: host}
}
// NewA returns a new A record based on the Service.
func (s *Service) NewA(name string, ip net.IP) *dns.A {
return &dns.A{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: s.TTL}, A: ip}
}
// NewAAAA returns a new AAAA record based on the Service.
func (s *Service) NewAAAA(name string, ip net.IP) *dns.AAAA {
return &dns.AAAA{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: s.TTL}, AAAA: ip}
}
// NewCNAME returns a new CNAME record based on the Service.
func (s *Service) NewCNAME(name string, target string) *dns.CNAME {
return &dns.CNAME{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: s.TTL}, Target: dns.Fqdn(target)}
}
// NewTXT returns a new TXT record based on the Service.
func (s *Service) NewTXT(name string) *dns.TXT {
return &dns.TXT{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: s.TTL}, Txt: split255(s.Text)}
}
// NewPTR returns a new PTR record based on the Service.
func (s *Service) NewPTR(name string, target string) *dns.PTR {
return &dns.PTR{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: s.TTL}, Ptr: dns.Fqdn(target)}
}
// NewNS returns a new NS record based on the Service.
func (s *Service) NewNS(name string) *dns.NS {
host := dns.Fqdn(s.Host)
if s.TargetStrip > 0 {
host = targetStrip(host, s.TargetStrip)
}
return &dns.NS{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: s.TTL}, Ns: host}
}
// Group checks the services in sx, it looks for a Group attribute on the shortest
// keys. If there are multiple shortest keys *and* the group attribute disagrees (and
// is not empty), we don't consider it a group.
// If a group is found, only services with *that* group (or no group) will be returned.
func Group(sx []Service) []Service {
if len(sx) == 0 {
return sx
}
// Shortest key with group attribute sets the group for this set.
group := sx[0].Group
slashes := strings.Count(sx[0].Key, "/")
length := make([]int, len(sx))
for i, s := range sx {
x := strings.Count(s.Key, "/")
length[i] = x
if x < slashes {
if s.Group == "" {
break
}
slashes = x
group = s.Group
}
}
if group == "" {
return sx
}
ret := []Service{} // with slice-tricks in sx we can prolly save this allocation (TODO)
for i, s := range sx {
if s.Group == "" {
ret = append(ret, s)
continue
}
// Disagreement on the same level
if length[i] == slashes && s.Group != group {
return sx
}
if s.Group == group {
ret = append(ret, s)
}
}
return ret
}
// Split255 splits a string into 255 byte chunks.
func split255(s string) []string {
if len(s) < 255 {
return []string{s}
}
sx := []string{}
p, i := 0, 255
for {
if i > len(s) {
sx = append(sx, s[p:])
break
}
sx = append(sx, s[p:i])
p, i = p+255, i+255
}
return sx
}
// targetStrip strips "targetstrip" labels from the left side of the fully qualified name.
func targetStrip(name string, targetStrip int) string {
offset, end := 0, false
for range targetStrip {
offset, end = dns.NextLabel(name, offset)
}
if end {
// We overshot the name, use the original one.
offset = 0
}
name = name[offset:]
return name
}
package msg
import (
"net"
"github.com/miekg/dns"
)
// HostType returns the DNS type of what is encoded in the Service Host field. We're reusing
// dns.TypeXXX to not reinvent a new set of identifiers.
//
// dns.TypeA: the service's Host field contains an A record.
// dns.TypeAAAA: the service's Host field contains an AAAA record.
// dns.TypeCNAME: the service's Host field contains a name.
//
// Note that a service can double/triple as a TXT record or MX record.
func (s *Service) HostType() (what uint16, normalized net.IP) {
ip := net.ParseIP(s.Host)
switch {
case ip == nil:
if len(s.Text) == 0 {
return dns.TypeCNAME, nil
}
return dns.TypeTXT, nil
case ip.To4() != nil:
return dns.TypeA, ip.To4()
case ip.To4() == nil:
return dns.TypeAAAA, ip.To16()
}
// This should never be reached.
return dns.TypeNone, nil
}
package etcd
import (
"crypto/tls"
"errors"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
mwtls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/coredns/coredns/plugin/pkg/upstream"
etcdcv3 "go.etcd.io/etcd/client/v3"
)
func init() { plugin.Register("etcd", setup) }
func setup(c *caddy.Controller) error {
e, err := etcdParse(c)
if err != nil {
return plugin.Error("etcd", err)
}
c.OnShutdown(e.OnShutdown)
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
e.Next = next
return e
})
return nil
}
func etcdParse(c *caddy.Controller) (*Etcd, error) {
config := dnsserver.GetConfig(c)
etc := Etcd{
PathPrefix: "skydns",
MinLeaseTTL: defaultLeaseMinTTL,
MaxLeaseTTL: defaultLeaseMaxTTL,
}
var (
tlsConfig *tls.Config
err error
endpoints = []string{defaultEndpoint}
username string
password string
)
etc.Upstream = upstream.New()
if c.Next() {
etc.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
for c.NextBlock() {
switch c.Val() {
case "stubzones":
// ignored, remove later.
case "fallthrough":
etc.Fall.SetZonesFromArgs(c.RemainingArgs())
case "debug":
/* it is a noop now */
case "path":
if !c.NextArg() {
return &Etcd{}, c.ArgErr()
}
etc.PathPrefix = c.Val()
case "endpoint":
args := c.RemainingArgs()
if len(args) == 0 {
return &Etcd{}, c.ArgErr()
}
endpoints = args
case "upstream":
// remove soon
c.RemainingArgs()
case "tls": // cert key cacertfile
args := c.RemainingArgs()
for i := range args {
if !filepath.IsAbs(args[i]) && config.Root != "" {
args[i] = filepath.Join(config.Root, args[i])
}
}
tlsConfig, err = mwtls.NewTLSConfigFromArgs(args...)
if err != nil {
return &Etcd{}, err
}
case "credentials":
args := c.RemainingArgs()
if len(args) == 0 {
return &Etcd{}, c.ArgErr()
}
if len(args) != 2 {
return &Etcd{}, c.Errf("credentials requires 2 arguments, username and password")
}
username, password = args[0], args[1]
case "min-lease-ttl":
if !c.NextArg() {
return &Etcd{}, c.ArgErr()
}
minLeaseTTL, err := parseTTL(c.Val())
if err != nil {
return &Etcd{}, c.Errf("invalid min-lease-ttl value: %v", err)
}
etc.MinLeaseTTL = minLeaseTTL
case "max-lease-ttl":
if !c.NextArg() {
return &Etcd{}, c.ArgErr()
}
maxLeaseTTL, err := parseTTL(c.Val())
if err != nil {
return &Etcd{}, c.Errf("invalid max-lease-ttl value: %v", err)
}
etc.MaxLeaseTTL = maxLeaseTTL
default:
if c.Val() != "}" {
return &Etcd{}, c.Errf("unknown property '%s'", c.Val())
}
}
}
client, err := newEtcdClient(endpoints, tlsConfig, username, password)
if err != nil {
return &Etcd{}, err
}
etc.Client = client
etc.endpoints = endpoints
return &etc, nil
}
return &Etcd{}, nil
}
func newEtcdClient(endpoints []string, cc *tls.Config, username, password string) (*etcdcv3.Client, error) {
etcdCfg := etcdcv3.Config{
Endpoints: endpoints,
TLS: cc,
DialKeepAliveTime: etcdTimeout,
}
if username != "" && password != "" {
etcdCfg.Username = username
etcdCfg.Password = password
}
cli, err := etcdcv3.New(etcdCfg)
if err != nil {
return nil, err
}
return cli, nil
}
const defaultEndpoint = "http://localhost:2379"
// parseTTL parses a TTL value with flexible time units using Go's standard duration parsing.
// Supports formats like: "30", "30s", "5m", "1h", "90s", "2h30m", etc.
func parseTTL(s string) (uint32, error) {
s = strings.TrimSpace(s)
if s == "" {
return 0, nil
}
// Handle plain numbers (assume seconds)
if _, err := strconv.ParseUint(s, 10, 64); err == nil {
// If it's just a number, append "s" for seconds
s += "s"
}
// Use Go's standard time.ParseDuration for robust parsing
duration, err := time.ParseDuration(s)
if err != nil {
return 0, errors.New("invalid TTL format, use format like '30', '30s', '5m', '1h', or '2h30m'")
}
// Convert to seconds and check bounds
seconds := duration.Seconds()
if seconds < 0 {
return 0, errors.New("TTL must be non-negative")
}
if seconds > 4294967295 { // uint32 max value
return 0, errors.New("TTL too large, maximum is 4294967295 seconds")
}
return uint32(seconds), nil
}
package etcd
import (
"time"
"github.com/coredns/coredns/request"
)
// Serial returns the serial number to use.
func (e *Etcd) Serial(state request.Request) uint32 {
return uint32(time.Now().Unix())
}
// MinTTL returns the minimal TTL.
func (e *Etcd) MinTTL(state request.Request) uint32 {
return 30
}
package file
import (
"github.com/coredns/coredns/plugin/file/tree"
"github.com/miekg/dns"
)
// ClosestEncloser returns the closest encloser for qname.
func (z *Zone) ClosestEncloser(qname string) (*tree.Elem, bool) {
offset, end := dns.NextLabel(qname, 0)
for !end {
elem, _ := z.Search(qname)
if elem != nil {
return elem, true
}
qname = qname[offset:]
offset, end = dns.NextLabel(qname, 0)
}
return z.Search(z.origin)
}
package file
import (
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/miekg/dns"
)
// substituteDNAME performs the DNAME substitution defined by RFC 6672,
// assuming the QTYPE of the query is not DNAME. It returns an empty
// string if there is no match.
func substituteDNAME(qname, owner, target string) string {
if dns.IsSubDomain(owner, qname) && qname != owner {
labels := dns.SplitDomainName(qname)
labels = append(labels[0:len(labels)-dns.CountLabel(owner)], dns.SplitDomainName(target)...)
return dnsutil.Join(labels...)
}
return ""
}
// synthesizeCNAME returns a CNAME RR pointing to the resulting name of
// the DNAME substitution. The owner name of the CNAME is the QNAME of
// the query and the TTL is the same as the corresponding DNAME RR.
//
// It returns nil if the DNAME substitution has no match.
func synthesizeCNAME(qname string, d *dns.DNAME) *dns.CNAME {
target := substituteDNAME(qname, d.Header().Name, d.Target)
if target == "" {
return nil
}
r := new(dns.CNAME)
r.Hdr = dns.RR_Header{
Name: qname,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: d.Header().Ttl,
}
r.Target = target
return r
}
// Package file implements a file backend.
package file
import (
"context"
"fmt"
"io"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/fall"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/transfer"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("file")
type (
// File is the plugin that reads zone data from disk.
File struct {
Next plugin.Handler
Zones
transfer *transfer.Transfer
Fall fall.F
}
// Zones maps zone names to a *Zone.
Zones struct {
Z map[string]*Zone // A map mapping zone (origin) to the Zone's data
Names []string // All the keys from the map Z as a string slice.
}
)
// ServeDNS implements the plugin.Handle interface.
func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.Name()
// TODO(miek): match the qname better in the map
zone := plugin.Zones(f.Zones.Names).Matches(qname)
if zone == "" {
// If no next plugin is configured, it's more correct to return REFUSED as file acts as an authoritative server
if f.Next == nil {
return dns.RcodeRefused, nil
}
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
z, ok := f.Z[zone]
if !ok || z == nil {
return dns.RcodeServerFailure, nil
}
// If transfer is not loaded, we'll see these, answer with refused (no transfer allowed).
if state.QType() == dns.TypeAXFR || state.QType() == dns.TypeIXFR {
return dns.RcodeRefused, nil
}
// This is only for when we are a secondary zones.
if r.Opcode == dns.OpcodeNotify {
if z.isNotify(state) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
w.WriteMsg(m)
log.Infof("Notify from %s for %s: checking transfer", state.IP(), zone)
ok, err := z.shouldTransfer()
if ok {
z.TransferIn()
} else {
log.Infof("Notify from %s for %s: no SOA serial increase seen", state.IP(), zone)
}
if err != nil {
log.Warningf("Notify from %s for %s: failed primary check: %s", state.IP(), zone, err)
}
return dns.RcodeSuccess, nil
}
log.Infof("Dropping notify from %s for %s", state.IP(), zone)
return dns.RcodeSuccess, nil
}
z.RLock()
exp := z.Expired
z.RUnlock()
if exp {
log.Errorf("Zone %s is expired", zone)
return dns.RcodeServerFailure, nil
}
answer, ns, extra, result := z.Lookup(ctx, state, qname)
// Only on NXDOMAIN we will fallthrough.
// `z.Lookup` can also return NOERROR for NXDOMAIN see comment see comment "Hacky way to get around empty-non-terminals" inside `Zone.Lookup`.
// It's safe to fallthrough with `result` Sucess (NOERROR) since all other return points in Lookup with Success have answer(s).
if len(answer) == 0 && (result == NameError || result == Success) && f.Fall.Through(qname) {
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
m.Answer, m.Ns, m.Extra = answer, ns, extra
switch result {
case Success:
case NoData:
case NameError:
m.Rcode = dns.RcodeNameError
case Delegation:
m.Authoritative = false
case ServerFailure:
// If the result is SERVFAIL and the answer is non-empty, then the SERVFAIL came from an
// external CNAME lookup and the answer contains the CNAME with no target record. We should
// write the CNAME record to the client instead of sending an empty SERVFAIL response.
if len(m.Answer) == 0 {
return dns.RcodeServerFailure, nil
}
// The rcode in the response should be the rcode received from the target lookup. RFC 6604 section 3
m.Rcode = dns.RcodeServerFailure
}
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
// Name implements the Handler interface.
func (f File) Name() string { return "file" }
type serialErr struct {
err string
zone string
origin string
serial int64
}
func (s *serialErr) Error() string {
return fmt.Sprintf("%s for origin %s in file %s, with %d SOA serial", s.err, s.origin, s.zone, s.serial)
}
// Parse parses the zone in filename and returns a new Zone or an error.
// If serial >= 0 it will reload the zone, if the SOA hasn't changed
// it returns an error indicating nothing was read.
func Parse(f io.Reader, origin, fileName string, serial int64) (*Zone, error) {
zp := dns.NewZoneParser(f, dns.Fqdn(origin), fileName)
zp.SetIncludeAllowed(true)
z := NewZone(origin, fileName)
seenSOA := false
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
if !seenSOA {
if s, ok := rr.(*dns.SOA); ok {
seenSOA = true
// -1 is valid serial is we failed to load the file on startup.
if serial >= 0 && s.Serial == uint32(serial) { // same serial
return nil, &serialErr{err: "no change in SOA serial", origin: origin, zone: fileName, serial: serial}
}
}
}
if err := z.Insert(rr); err != nil {
return nil, err
}
}
if !seenSOA {
return nil, fmt.Errorf("file %q has no SOA record for origin %s", fileName, origin)
}
if zp.Err() != nil {
return nil, fmt.Errorf("failed to parse file %q for origin %s with error %v", fileName, origin, zp.Err())
}
if err := zp.Err(); err != nil {
return nil, err
}
return z, nil
}
//go:build gofuzz
package file
import (
"strings"
"github.com/coredns/coredns/plugin/pkg/fuzz"
"github.com/coredns/coredns/plugin/test"
)
// Fuzz fuzzes file.
func Fuzz(data []byte) int {
name := "miek.nl."
zone, _ := Parse(strings.NewReader(fuzzMiekNL), name, "stdin", 0)
f := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{name: zone}, Names: []string{name}}}
return fuzz.Do(f, data)
}
const fuzzMiekNL = `
$TTL 30M
$ORIGIN miek.nl.
@ IN SOA linode.atoom.net. miek.miek.nl. (
1282630057 ; Serial
4H ; Refresh
1H ; Retry
7D ; Expire
4H ) ; Negative Cache TTL
IN NS linode.atoom.net.
IN NS ns-ext.nlnetlabs.nl.
IN NS omval.tednet.nl.
IN NS ext.ns.whyscream.net.
IN MX 1 aspmx.l.google.com.
IN MX 5 alt1.aspmx.l.google.com.
IN MX 5 alt2.aspmx.l.google.com.
IN MX 10 aspmx2.googlemail.com.
IN MX 10 aspmx3.googlemail.com.
IN A 139.162.196.78
IN AAAA 2a01:7e00::f03c:91ff:fef1:6735
a IN A 139.162.196.78
IN AAAA 2a01:7e00::f03c:91ff:fef1:6735
www IN CNAME a
archive IN CNAME a
srv IN SRV 10 10 8080 a.miek.nl.
mx IN MX 10 a.miek.nl.`
package file
import (
"context"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin/file/rrutil"
"github.com/coredns/coredns/plugin/file/tree"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Result is the result of a Lookup
type Result int
const (
// Success is a successful lookup.
Success Result = iota
// NameError indicates a nameerror
NameError
// Delegation indicates the lookup resulted in a delegation.
Delegation
// NoData indicates the lookup resulted in a NODATA.
NoData
// ServerFailure indicates a server failure during the lookup.
ServerFailure
)
// Lookup looks up qname and qtype in the zone. When do is true DNSSEC records are included.
// Three sets of records are returned, one for the answer, one for authority and one for the additional section.
func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) ([]dns.RR, []dns.RR, []dns.RR, Result) {
qtype := state.QType()
do := state.Do()
// If z is a secondary zone we might not have transferred it, meaning we have
// all zone context setup, except the actual record. This means (for one thing) the apex
// is empty and we don't have a SOA record.
z.RLock()
ap := z.Apex
tr := z.Tree
z.RUnlock()
if ap.SOA == nil {
return nil, nil, nil, ServerFailure
}
if qname == z.origin {
switch qtype {
case dns.TypeSOA:
return ap.soa(do), ap.ns(do), nil, Success
case dns.TypeNS:
nsrrs := ap.ns(do)
glue := tr.Glue(nsrrs, do) // technically this isn't glue
return nsrrs, nil, glue, Success
}
}
var (
found, shot bool
parts string
i int
elem, wildElem *tree.Elem
)
loop, _ := ctx.Value(dnsserver.LoopKey{}).(int)
if loop > 8 {
// We're back here for the 9th time; we have a loop and need to bail out.
// Note the answer we're returning will be incomplete (more cnames to be followed) or
// illegal (wildcard cname with multiple identical records). For now it's more important
// to protect ourselves then to give the client a valid answer. We return with an error
// to let the server handle what to do.
return nil, nil, nil, ServerFailure
}
// Lookup:
// * Per label from the right, look if it exists. We do this to find potential
// delegation records.
// * If the per-label search finds nothing, we will look for the wildcard at the
// level. If found we keep it around. If we don't find the complete name we will
// use the wildcard.
//
// Main for-loop handles delegation and finding or not finding the qname.
// If found we check if it is a CNAME/DNAME and do CNAME processing
// We also check if we have type and do a nodata response.
//
// If not found, we check the potential wildcard, and use that for further processing.
// If not found and no wildcard we will process this as an NXDOMAIN response.
for {
parts, shot = z.nameFromRight(qname, i)
// We overshot the name, break and check if we previously found something.
if shot {
break
}
elem, found = tr.Search(parts)
if !found {
// Apex will always be found, when we are here we can search for a wildcard
// and save the result of that search. So when nothing match, but we have a
// wildcard we should expand the wildcard.
wildcard := replaceWithAsteriskLabel(parts)
if wild, found := tr.Search(wildcard); found {
wildElem = wild
}
// Keep on searching, because maybe we hit an empty-non-terminal (which aren't
// stored in the tree. Only when we have match the full qname (and possible wildcard
// we can be confident that we didn't find anything.
i++
continue
}
// If we see DNAME records, we should return those.
if dnamerrs := elem.Type(dns.TypeDNAME); dnamerrs != nil {
// Only one DNAME is allowed per name. We just pick the first one to synthesize from.
dname := dnamerrs[0]
if cname := synthesizeCNAME(state.Name(), dname.(*dns.DNAME)); cname != nil {
var (
answer, ns, extra []dns.RR
rcode Result
)
// We don't need to chase CNAME chain for synthesized CNAME
if qtype == dns.TypeCNAME {
answer = []dns.RR{cname}
ns = ap.ns(do)
extra = nil
rcode = Success
} else {
ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1)
answer, ns, extra, rcode = z.externalLookup(ctx, state, elem, []dns.RR{cname})
}
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeDNAME)
dnamerrs = append(dnamerrs, sigs...)
}
// The relevant DNAME RR should be included in the answer section,
// if the DNAME is being employed as a substitution instruction.
answer = append(dnamerrs, answer...)
return answer, ns, extra, rcode
}
// The domain name that owns a DNAME record is allowed to have other RR types
// at that domain name, except those have restrictions on what they can coexist
// with (e.g. another DNAME). So there is nothing special left here.
}
// If we see NS records, it means the name as been delegated, and we should return the delegation.
if nsrrs := elem.Type(dns.TypeNS); nsrrs != nil {
// If the query is specifically for DS and the qname matches the delegated name, we should
// return the DS in the answer section and leave the rest empty, i.e. just continue the loop
// and continue searching.
if qtype == dns.TypeDS && elem.Name() == qname {
i++
continue
}
glue := tr.Glue(nsrrs, do)
if do {
dss := typeFromElem(elem, dns.TypeDS, do)
nsrrs = append(nsrrs, dss...)
}
return nil, nsrrs, glue, Delegation
}
i++
}
// What does found and !shot mean - do we ever hit it?
if found && !shot {
return nil, nil, nil, ServerFailure
}
// Found entire name.
if found && shot {
if rrs := elem.Type(dns.TypeCNAME); len(rrs) > 0 && qtype != dns.TypeCNAME {
ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1)
return z.externalLookup(ctx, state, elem, rrs)
}
rrs := elem.Type(qtype)
// NODATA
if len(rrs) == 0 {
ret := ap.soa(do)
if do {
nsec := typeFromElem(elem, dns.TypeNSEC, do)
ret = append(ret, nsec...)
}
return nil, ret, nil, NoData
}
// Additional section processing for MX, SRV. Check response and see if any of the names are in bailiwick -
// if so add IP addresses to the additional section.
additional := z.additionalProcessing(rrs, do)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, qtype)
rrs = append(rrs, sigs...)
}
return rrs, ap.ns(do), additional, Success
}
// Haven't found the original name.
// Found wildcard.
if wildElem != nil {
// set metadata value for the wildcard record that synthesized the result
metadata.SetValueFunc(ctx, "zone/wildcard", func() string {
return wildElem.Name()
})
if rrs := wildElem.TypeForWildcard(dns.TypeCNAME, qname); len(rrs) > 0 && qtype != dns.TypeCNAME {
ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1)
return z.externalLookup(ctx, state, wildElem, rrs)
}
rrs := wildElem.TypeForWildcard(qtype, qname)
// NODATA response.
if len(rrs) == 0 {
ret := ap.soa(do)
if do {
nsec := typeFromElem(wildElem, dns.TypeNSEC, do)
ret = append(ret, nsec...)
}
return nil, ret, nil, NoData
}
auth := ap.ns(do)
if do {
// An NSEC is needed to say no longer name exists under this wildcard.
if deny, found := tr.Prev(qname); found {
nsec := typeFromElem(deny, dns.TypeNSEC, do)
auth = append(auth, nsec...)
}
sigs := wildElem.TypeForWildcard(dns.TypeRRSIG, qname)
sigs = rrutil.SubTypeSignature(sigs, qtype)
rrs = append(rrs, sigs...)
}
return rrs, auth, nil, Success
}
rcode := NameError
// Hacky way to get around empty-non-terminals. If a longer name does exist, but this qname, does not, it
// must be an empty-non-terminal. If so, we do the proper NXDOMAIN handling, but set the rcode to be success.
if x, found := tr.Next(qname); found {
if dns.IsSubDomain(qname, x.Name()) {
rcode = Success
}
}
ret := ap.soa(do)
if do {
deny, found := tr.Prev(qname)
if !found {
goto Out
}
nsec := typeFromElem(deny, dns.TypeNSEC, do)
ret = append(ret, nsec...)
if rcode != NameError {
goto Out
}
ce, found := z.ClosestEncloser(qname)
// wildcard denial only for NXDOMAIN
if found {
// wildcard denial
wildcard := "*." + ce.Name()
if ss, found := tr.Prev(wildcard); found {
// Only add this nsec if it is different than the one already added
if ss.Name() != deny.Name() {
nsec := typeFromElem(ss, dns.TypeNSEC, do)
ret = append(ret, nsec...)
}
}
}
}
Out:
return nil, ret, nil, rcode
}
// typeFromElem returns the type tp from e and adds signatures (if they exist) and do is true.
func typeFromElem(elem *tree.Elem, tp uint16, do bool) []dns.RR {
rrs := elem.Type(tp)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, tp)
rrs = append(rrs, sigs...)
}
return rrs
}
func (a Apex) soa(do bool) []dns.RR {
if do {
ret := append([]dns.RR{a.SOA}, a.SIGSOA...)
return ret
}
return []dns.RR{a.SOA}
}
func (a Apex) ns(do bool) []dns.RR {
if do {
ret := append(a.NS, a.SIGNS...)
return ret
}
return a.NS
}
// externalLookup adds signatures and tries to resolve CNAMEs that point to external names.
func (z *Zone) externalLookup(ctx context.Context, state request.Request, elem *tree.Elem, rrs []dns.RR) ([]dns.RR, []dns.RR, []dns.RR, Result) {
qtype := state.QType()
do := state.Do()
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeCNAME)
rrs = append(rrs, sigs...)
}
targetName := rrs[0].(*dns.CNAME).Target
elem, _ = z.Search(targetName)
if elem == nil {
lookupRRs, result := z.doLookup(ctx, state, targetName, qtype)
rrs = append(rrs, lookupRRs...)
return rrs, z.ns(do), nil, result
}
i := 0
Redo:
cname := elem.Type(dns.TypeCNAME)
if len(cname) > 0 {
rrs = append(rrs, cname...)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeCNAME)
rrs = append(rrs, sigs...)
}
targetName := cname[0].(*dns.CNAME).Target
elem, _ = z.Search(targetName)
if elem == nil {
lookupRRs, result := z.doLookup(ctx, state, targetName, qtype)
rrs = append(rrs, lookupRRs...)
return rrs, z.ns(do), nil, result
}
i++
if i > 8 {
return rrs, z.ns(do), nil, Success
}
goto Redo
}
targets := elem.Type(qtype)
if len(targets) > 0 {
rrs = append(rrs, targets...)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, qtype)
rrs = append(rrs, sigs...)
}
}
return rrs, z.ns(do), nil, Success
}
func (z *Zone) doLookup(ctx context.Context, state request.Request, target string, qtype uint16) ([]dns.RR, Result) {
m, e := z.Upstream.Lookup(ctx, state, target, qtype)
if e != nil {
return nil, ServerFailure
}
if m == nil {
return nil, Success
}
if m.Rcode == dns.RcodeNameError {
return m.Answer, NameError
}
if m.Rcode == dns.RcodeServerFailure {
return m.Answer, ServerFailure
}
if m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0 {
return m.Answer, NoData
}
return m.Answer, Success
}
// additionalProcessing checks the current answer section and retrieves A or AAAA records
// (and possible SIGs) to need to be put in the additional section.
func (z *Zone) additionalProcessing(answer []dns.RR, do bool) (extra []dns.RR) {
for _, rr := range answer {
name := ""
switch x := rr.(type) {
case *dns.SRV:
name = x.Target
case *dns.MX:
name = x.Mx
}
if len(name) == 0 || !dns.IsSubDomain(z.origin, name) {
continue
}
elem, _ := z.Search(name)
if elem == nil {
continue
}
sigs := elem.Type(dns.TypeRRSIG)
for _, addr := range []uint16{dns.TypeA, dns.TypeAAAA} {
if a := elem.Type(addr); a != nil {
extra = append(extra, a...)
if do {
sig := rrutil.SubTypeSignature(sigs, addr)
extra = append(extra, sig...)
}
}
}
}
return extra
}
package file
import (
"net"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// isNotify checks if state is a notify message and if so, will *also* check if it
// is from one of the configured masters. If not it will not be a valid notify
// message. If the zone z is not a secondary zone the message will also be ignored.
func (z *Zone) isNotify(state request.Request) bool {
if state.Req.Opcode != dns.OpcodeNotify {
return false
}
if len(z.TransferFrom) == 0 {
return false
}
// If remote IP matches we accept.
remote := state.IP()
for _, f := range z.TransferFrom {
from, _, err := net.SplitHostPort(f)
if err != nil {
continue
}
if from == remote {
return true
}
}
return false
}
package file
import (
"os"
"path/filepath"
"time"
"github.com/coredns/coredns/plugin/transfer"
)
// Reload reloads a zone when it is changed on disk. If z.ReloadInterval is zero, no reloading will be done.
func (z *Zone) Reload(t *transfer.Transfer) error {
if z.ReloadInterval == 0 {
return nil
}
tick := time.NewTicker(z.ReloadInterval)
go func() {
for {
select {
case <-tick.C:
zFile := z.File()
reader, err := os.Open(filepath.Clean(zFile))
if err != nil {
log.Errorf("Failed to open zone %q in %q: %v", z.origin, zFile, err)
continue
}
serial := z.SOASerialIfDefined()
zone, err := Parse(reader, z.origin, zFile, serial)
reader.Close()
if err != nil {
if _, ok := err.(*serialErr); !ok {
log.Errorf("Parsing zone %q: %v", z.origin, err)
}
continue
}
// copy elements we need
z.Lock()
z.Apex = zone.Apex
z.Tree = zone.Tree
z.Unlock()
log.Infof("Successfully reloaded zone %q in %q with %d SOA serial", z.origin, zFile, z.SOA.Serial)
if t != nil {
if err := t.Notify(z.origin); err != nil {
log.Warningf("Failed sending notifies: %s", err)
}
}
case <-z.reloadShutdown:
tick.Stop()
return
}
}
}()
return nil
}
// SOASerialIfDefined returns the SOA's serial if the zone has a SOA record in the Apex, or -1 otherwise.
func (z *Zone) SOASerialIfDefined() int64 {
z.RLock()
defer z.RUnlock()
if z.SOA != nil {
return int64(z.SOA.Serial)
}
return -1
}
// Package rrutil provides function to find certain RRs in slices.
package rrutil
import "github.com/miekg/dns"
// SubTypeSignature returns the RRSIG for the subtype.
func SubTypeSignature(rrs []dns.RR, subtype uint16) []dns.RR {
sigs := []dns.RR{}
// there may be multiple keys that have signed this subtype
for _, sig := range rrs {
if s, ok := sig.(*dns.RRSIG); ok {
if s.TypeCovered == subtype {
sigs = append(sigs, s)
}
}
}
return sigs
}
package file
import (
"math/rand"
"time"
"github.com/miekg/dns"
)
// TransferIn retrieves the zone from the masters, parses it and sets it live.
func (z *Zone) TransferIn() error {
if len(z.TransferFrom) == 0 {
return nil
}
m := new(dns.Msg)
m.SetAxfr(z.origin)
z1 := z.CopyWithoutApex()
var (
Err error
tr string
)
Transfer:
for _, tr = range z.TransferFrom {
t := new(dns.Transfer)
c, err := t.In(m, tr)
if err != nil {
log.Errorf("Failed to setup transfer `%s' with `%q': %v", z.origin, tr, err)
Err = err
continue Transfer
}
for env := range c {
if env.Error != nil {
log.Errorf("Failed to transfer `%s' from %q: %v", z.origin, tr, env.Error)
Err = env.Error
continue Transfer
}
for _, rr := range env.RR {
if err := z1.Insert(rr); err != nil {
log.Errorf("Failed to parse transfer `%s' from: %q: %v", z.origin, tr, err)
Err = err
continue Transfer
}
}
}
Err = nil
break
}
if Err != nil {
return Err
}
z.Lock()
z.Tree = z1.Tree
z.Apex = z1.Apex
z.Expired = false
z.Unlock()
log.Infof("Transferred: %s from %s", z.origin, tr)
return nil
}
// shouldTransfer checks the primaries of zone, retrieves the SOA record, checks the current serial
// and the remote serial and will return true if the remote one is higher than the locally configured one.
func (z *Zone) shouldTransfer() (bool, error) {
c := new(dns.Client)
c.Net = "tcp" // do this query over TCP to minimize spoofing
m := new(dns.Msg)
m.SetQuestion(z.origin, dns.TypeSOA)
var Err error
serial := -1
Transfer:
for _, tr := range z.TransferFrom {
Err = nil
ret, _, err := c.Exchange(m, tr)
if err != nil || ret.Rcode != dns.RcodeSuccess {
Err = err
continue
}
for _, a := range ret.Answer {
if a.Header().Rrtype == dns.TypeSOA {
serial = int(a.(*dns.SOA).Serial)
break Transfer
}
}
}
if serial == -1 {
return false, Err
}
if z.SOA == nil {
return true, Err
}
return less(z.SOA.Serial, uint32(serial)), Err
}
// less returns true of a is smaller than b when taking RFC 1982 serial arithmetic into account.
func less(a, b uint32) bool {
if a < b {
return (b - a) <= MaxSerialIncrement
}
return (a - b) > MaxSerialIncrement
}
// Update updates the secondary zone according to its SOA. It will run for the life time of the server
// and uses the SOA parameters. Every refresh it will check for a new SOA number. If that fails (for all
// server) it will retry every retry interval. If the zone failed to transfer before the expire, the zone
// will be marked expired.
func (z *Zone) Update() error {
// If we don't have a SOA, we don't have a zone, wait for it to appear.
for z.SOA == nil {
time.Sleep(1 * time.Second)
}
retryActive := false
Restart:
refresh := time.Second * time.Duration(z.SOA.Refresh)
retry := time.Second * time.Duration(z.SOA.Retry)
expire := time.Second * time.Duration(z.SOA.Expire)
refreshTicker := time.NewTicker(refresh)
retryTicker := time.NewTicker(retry)
expireTicker := time.NewTicker(expire)
for {
select {
case <-expireTicker.C:
if !retryActive {
break
}
z.Expired = true
case <-retryTicker.C:
if !retryActive {
break
}
time.Sleep(jitter(2000)) // 2s randomize
ok, err := z.shouldTransfer()
if err != nil {
log.Warningf("Failed retry check %s", err)
continue
}
if ok {
if err := z.TransferIn(); err != nil {
// transfer failed, leave retryActive true
break
}
}
// no errors, stop timers and restart
retryActive = false
refreshTicker.Stop()
retryTicker.Stop()
expireTicker.Stop()
goto Restart
case <-refreshTicker.C:
time.Sleep(jitter(5000)) // 5s randomize
ok, err := z.shouldTransfer()
if err != nil {
log.Warningf("Failed refresh check %s", err)
retryActive = true
continue
}
if ok {
if err := z.TransferIn(); err != nil {
// transfer failed
retryActive = true
break
}
}
// no errors, stop timers and restart
retryActive = false
refreshTicker.Stop()
retryTicker.Stop()
expireTicker.Stop()
goto Restart
}
}
}
// jitter returns a random duration between [0,n) * time.Millisecond
func jitter(n int) time.Duration {
r := rand.Intn(n)
return time.Duration(r) * time.Millisecond
}
// MaxSerialIncrement is the maximum difference between two serial numbers. If the difference between
// two serials is greater than this number, the smaller one is considered greater.
const MaxSerialIncrement uint32 = 2147483647
package file
import (
"errors"
"os"
"path/filepath"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/plugin/transfer"
)
func init() { plugin.Register("file", setup) }
func setup(c *caddy.Controller) error {
zones, fall, err := fileParse(c)
if err != nil {
return plugin.Error("file", err)
}
f := File{Zones: zones, Fall: fall}
// get the transfer plugin, so we can send notifies and send notifies on startup as well.
c.OnStartup(func() error {
t := dnsserver.GetConfig(c).Handler("transfer")
if t == nil {
return nil
}
f.transfer = t.(*transfer.Transfer) // if found this must be OK.
go func() {
for _, n := range zones.Names {
f.transfer.Notify(n)
}
}()
return nil
})
c.OnRestartFailed(func() error {
t := dnsserver.GetConfig(c).Handler("transfer")
if t == nil {
return nil
}
go func() {
for _, n := range zones.Names {
f.transfer.Notify(n)
}
}()
return nil
})
for _, n := range zones.Names {
z := zones.Z[n]
c.OnShutdown(z.OnShutdown)
c.OnStartup(func() error {
z.StartupOnce.Do(func() { z.Reload(f.transfer) })
return nil
})
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
f.Next = next
return f
})
return nil
}
func fileParse(c *caddy.Controller) (Zones, fall.F, error) {
z := make(map[string]*Zone)
names := []string{}
fall := fall.F{}
config := dnsserver.GetConfig(c)
var openErr error
reload := 1 * time.Minute
for c.Next() {
// file db.file [zones...]
if !c.NextArg() {
return Zones{}, fall, c.ArgErr()
}
fileName := c.Val()
origins := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
if !filepath.IsAbs(fileName) && config.Root != "" {
fileName = filepath.Join(config.Root, fileName)
}
reader, err := os.Open(filepath.Clean(fileName))
if err != nil {
openErr = err
}
err = func() error {
defer reader.Close()
for i := range origins {
z[origins[i]] = NewZone(origins[i], fileName)
if openErr == nil {
reader.Seek(0, 0)
zone, err := Parse(reader, origins[i], fileName, 0)
if err != nil {
return err
}
z[origins[i]] = zone
}
names = append(names, origins[i])
}
return nil
}()
if err != nil {
return Zones{}, fall, err
}
for c.NextBlock() {
switch c.Val() {
case "fallthrough":
fall.SetZonesFromArgs(c.RemainingArgs())
case "reload":
t := c.RemainingArgs()
if len(t) < 1 {
return Zones{}, fall, errors.New("reload duration value is expected")
}
d, err := time.ParseDuration(t[0])
if err != nil {
return Zones{}, fall, plugin.Error("file", err)
}
reload = d
case "upstream":
// remove soon
c.RemainingArgs()
default:
return Zones{}, fall, c.Errf("unknown property '%s'", c.Val())
}
}
for i := range origins {
z[origins[i]].ReloadInterval = reload
z[origins[i]].Upstream = upstream.New()
}
}
if openErr != nil {
if reload == 0 {
// reload hasn't been set make this a fatal error
return Zones{}, fall, plugin.Error("file", openErr)
}
log.Warningf("Failed to open %q: trying again in %s", openErr, reload)
}
return Zones{Z: z, Names: names}, fall, nil
}
package file
// OnShutdown shuts down any running go-routines for this zone.
func (z *Zone) OnShutdown() error {
if 0 < z.ReloadInterval {
z.reloadShutdown <- true
}
return nil
}
package tree
// All traverses tree and returns all elements.
func (t *Tree) All() []*Elem {
if t.Root == nil {
return nil
}
found := t.Root.all(nil)
return found
}
func (n *Node) all(found []*Elem) []*Elem {
if n.Left != nil {
found = n.Left.all(found)
}
found = append(found, n.Elem)
if n.Right != nil {
found = n.Right.all(found)
}
return found
}
package tree
import (
"github.com/miekg/dns"
)
// AuthWalk performs fn on all authoritative values stored in the tree in
// pre-order depth first. If a non-nil error is returned the AuthWalk was interrupted
// by an fn returning that error. If fn alters stored values' sort
// relationships, future tree operation behaviors are undefined.
//
// The fn function will be called with 3 arguments, the current element, a map containing all
// the RRs for this element and a boolean if this name is considered authoritative.
func (t *Tree) AuthWalk(fn func(*Elem, map[uint16][]dns.RR, bool) error) error {
if t.Root == nil {
return nil
}
return t.Root.authwalk(make(map[string]struct{}), fn)
}
func (n *Node) authwalk(ns map[string]struct{}, fn func(*Elem, map[uint16][]dns.RR, bool) error) error {
if n.Left != nil {
if err := n.Left.authwalk(ns, fn); err != nil {
return err
}
}
// Check if the current name is a subdomain of *any* of the delegated names we've seen, if so, skip this name.
// The ordering of the tree and how we walk if guarantees we see parents first.
if n.Elem.Type(dns.TypeNS) != nil {
ns[n.Elem.Name()] = struct{}{}
}
auth := true
i := 0
for {
j, end := dns.NextLabel(n.Elem.Name(), i)
if end {
break
}
if _, ok := ns[n.Elem.Name()[j:]]; ok {
auth = false
break
}
i++
}
if err := fn(n.Elem, n.Elem.m, auth); err != nil {
return err
}
if n.Right != nil {
if err := n.Right.authwalk(ns, fn); err != nil {
return err
}
}
return nil
}
package tree
import "github.com/miekg/dns"
// Elem is an element in the tree.
type Elem struct {
m map[uint16][]dns.RR
name string // owner name
}
// newElem returns a new elem.
func newElem(rr dns.RR) *Elem {
e := Elem{m: make(map[uint16][]dns.RR)}
e.m[rr.Header().Rrtype] = []dns.RR{rr}
// Eagerly set the cached owner name to avoid racy lazy writes later.
e.name = rr.Header().Name
return &e
}
// Types returns the types of the records in e. The returned list is not sorted.
func (e *Elem) Types() []uint16 {
t := make([]uint16, len(e.m))
i := 0
for ty := range e.m {
t[i] = ty
i++
}
return t
}
// Type returns the RRs with type qtype from e.
func (e *Elem) Type(qtype uint16) []dns.RR { return e.m[qtype] }
// TypeForWildcard returns the RRs with type qtype from e. The ownername returned is set to qname.
func (e *Elem) TypeForWildcard(qtype uint16, qname string) []dns.RR {
rrs := e.m[qtype]
if rrs == nil {
return nil
}
copied := make([]dns.RR, len(rrs))
for i := range rrs {
copied[i] = dns.Copy(rrs[i])
copied[i].Header().Name = qname
}
return copied
}
// All returns all RRs from e, regardless of type.
func (e *Elem) All() []dns.RR {
list := []dns.RR{}
for _, rrs := range e.m {
list = append(list, rrs...)
}
return list
}
// Name returns the name for this node.
func (e *Elem) Name() string {
// Read-only: name is eagerly set in newElem and should not be mutated here.
if e.name != "" {
return e.name
}
for _, rrs := range e.m {
return rrs[0].Header().Name
}
return ""
}
// Empty returns true is e does not contain any RRs, i.e. is an empty-non-terminal.
func (e *Elem) Empty() bool { return len(e.m) == 0 }
// Insert inserts rr into e. If rr is equal to existing RRs, the RR will be added anyway.
func (e *Elem) Insert(rr dns.RR) {
t := rr.Header().Rrtype
if e.m == nil {
e.m = make(map[uint16][]dns.RR)
e.m[t] = []dns.RR{rr}
return
}
rrs, ok := e.m[t]
if !ok {
e.m[t] = []dns.RR{rr}
return
}
rrs = append(rrs, rr)
e.m[t] = rrs
}
// Delete removes all RRs of type rr.Header().Rrtype from e.
func (e *Elem) Delete(rr dns.RR) {
if e.m == nil {
return
}
t := rr.Header().Rrtype
delete(e.m, t)
}
// Less is a tree helper function that calls less.
func Less(a *Elem, name string) int { return less(name, a.Name()) }
package tree
import (
"github.com/coredns/coredns/plugin/file/rrutil"
"github.com/miekg/dns"
)
// Glue returns any potential glue records for nsrrs.
func (t *Tree) Glue(nsrrs []dns.RR, do bool) []dns.RR {
glue := []dns.RR{}
for _, rr := range nsrrs {
if ns, ok := rr.(*dns.NS); ok && dns.IsSubDomain(ns.Header().Name, ns.Ns) {
glue = append(glue, t.searchGlue(ns.Ns, do)...)
}
}
return glue
}
// searchGlue looks up A and AAAA for name.
func (t *Tree) searchGlue(name string, do bool) []dns.RR {
glue := []dns.RR{}
// A
if elem, found := t.Search(name); found {
glue = append(glue, elem.Type(dns.TypeA)...)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeA)
glue = append(glue, sigs...)
}
}
// AAAA
if elem, found := t.Search(name); found {
glue = append(glue, elem.Type(dns.TypeAAAA)...)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeAAAA)
glue = append(glue, sigs...)
}
}
return glue
}
package tree
import (
"bytes"
"strings"
"github.com/miekg/dns"
)
// less returns <0 when a is less than b, 0 when they are equal and
// >0 when a is larger than b.
// The function orders names in DNSSEC canonical order: RFC 4034s section-6.1
//
// See https://bert-hubert.blogspot.co.uk/2015/10/how-to-do-fast-canonical-ordering-of.html
// for a blog article on this implementation, although here we still go label by label.
//
// The values of a and b are *not* lowercased before the comparison!
func less(a, b string) int {
aj := len(a)
bj := len(b)
for {
ai, oka := dns.PrevLabel(a[:aj], 1)
bi, okb := dns.PrevLabel(b[:bj], 1)
if oka && okb {
return 0
}
// sadly this []byte will allocate... TODO(miek): check if this is needed
// for a name, otherwise compare the strings.
ab := []byte(strings.ToLower(a[ai:aj]))
bb := []byte(strings.ToLower(b[bi:bj]))
doDDD(ab)
doDDD(bb)
res := bytes.Compare(ab, bb)
if res != 0 {
return res
}
aj, bj = ai, bi
}
}
func doDDD(b []byte) {
lb := len(b)
for i := 0; i < lb; i++ {
if i+3 < lb && b[i] == '\\' && isDigit(b[i+1]) && isDigit(b[i+2]) && isDigit(b[i+3]) {
b[i] = dddToByte(b[i:])
for j := i + 1; j < lb-3; j++ {
b[j] = b[j+3]
}
lb -= 3
}
}
}
func isDigit(b byte) bool { return b >= '0' && b <= '9' }
func dddToByte(s []byte) byte { return (s[1]-'0')*100 + (s[2]-'0')*10 + (s[3] - '0') }
package tree
import "fmt"
// Print prints a Tree. Main use is to aid in debugging.
func (t *Tree) Print() {
if t.Root == nil {
fmt.Println("<nil>")
}
t.Root.print()
}
func (n *Node) print() {
q := newQueue()
q.push(n)
nodesInCurrentLevel := 1
nodesInNextLevel := 0
for !q.empty() {
do := q.pop()
nodesInCurrentLevel--
if do != nil {
fmt.Print(do.Elem.Name(), " ")
q.push(do.Left)
q.push(do.Right)
nodesInNextLevel += 2
}
if nodesInCurrentLevel == 0 {
fmt.Println()
nodesInCurrentLevel = nodesInNextLevel
nodesInNextLevel = 0
}
}
fmt.Println()
}
type queue []*Node
// newQueue returns a new queue.
func newQueue() queue {
q := queue([]*Node{})
return q
}
// push pushes n to the end of the queue.
func (q *queue) push(n *Node) {
*q = append(*q, n)
}
// pop pops the first element off the queue.
func (q *queue) pop() *Node {
n := (*q)[0]
*q = (*q)[1:]
return n
}
// empty returns true when the queue contains zero nodes.
func (q *queue) empty() bool {
return len(*q) == 0
}
// Copyright ©2012 The bíogo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found at the end of this file.
// Package tree implements Left-Leaning Red Black trees as described by Robert Sedgewick.
//
// More details relating to the implementation are available at the following locations:
//
// http://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf
// http://www.cs.princeton.edu/~rs/talks/LLRB/Java/RedBlackBST.java
// http://www.teachsolaisgames.com/articles/balanced_left_leaning.html
//
// Heavily modified by Miek Gieben for use in DNS zones.
package tree
import "github.com/miekg/dns"
const (
td234 = iota
bu23
)
// Operation mode of the LLRB tree.
const mode = bu23
func init() {
if mode != td234 && mode != bu23 {
panic("tree: unknown mode")
}
}
// A Color represents the color of a Node.
type Color bool
const (
// Red as false give us the defined behaviour that new nodes are red. Although this
// is incorrect for the root node, that is resolved on the first insertion.
red Color = false
black Color = true
)
// A Node represents a node in the LLRB tree.
type Node struct {
Elem *Elem
Left, Right *Node
Color Color
}
// A Tree manages the root node of an LLRB tree. Public methods are exposed through this type.
type Tree struct {
Root *Node // Root node of the tree.
Count int // Number of elements stored.
}
// Helper methods
// color returns the effect color of a Node. A nil node returns black.
func (n *Node) color() Color {
if n == nil {
return black
}
return n.Color
}
// (a,c)b -rotL-> ((a,)b,)c
func (n *Node) rotateLeft() (root *Node) {
// Assumes: n has two children.
root = n.Right
n.Right = root.Left
root.Left = n
root.Color = n.Color
n.Color = red
return
}
// (a,c)b -rotR-> (,(,c)b)a
func (n *Node) rotateRight() (root *Node) {
// Assumes: n has two children.
root = n.Left
n.Left = root.Right
root.Right = n
root.Color = n.Color
n.Color = red
return
}
// (aR,cR)bB -flipC-> (aB,cB)bR | (aB,cB)bR -flipC-> (aR,cR)bB
func (n *Node) flipColors() {
// Assumes: n has two children.
n.Color = !n.Color
n.Left.Color = !n.Left.Color
n.Right.Color = !n.Right.Color
}
// fixUp ensures that black link balance is correct, that red nodes lean left,
// and that 4 nodes are split in the case of BU23 and properly balanced in TD234.
func (n *Node) fixUp() *Node {
if n.Right.color() == red {
if mode == td234 && n.Right.Left.color() == red {
n.Right = n.Right.rotateRight()
}
n = n.rotateLeft()
}
if n.Left.color() == red && n.Left.Left.color() == red {
n = n.rotateRight()
}
if mode == bu23 && n.Left.color() == red && n.Right.color() == red {
n.flipColors()
}
return n
}
func (n *Node) moveRedLeft() *Node {
n.flipColors()
if n.Right.Left.color() == red {
n.Right = n.Right.rotateRight()
n = n.rotateLeft()
n.flipColors()
if mode == td234 && n.Right.Right.color() == red {
n.Right = n.Right.rotateLeft()
}
}
return n
}
func (n *Node) moveRedRight() *Node {
n.flipColors()
if n.Left.Left.color() == red {
n = n.rotateRight()
n.flipColors()
}
return n
}
// Len returns the number of elements stored in the Tree.
func (t *Tree) Len() int {
return t.Count
}
// Search returns the first match of qname in the Tree.
func (t *Tree) Search(qname string) (*Elem, bool) {
if t.Root == nil {
return nil, false
}
n, res := t.Root.search(qname)
if n == nil {
return nil, res
}
return n.Elem, res
}
// search searches the tree for qname and type.
func (n *Node) search(qname string) (*Node, bool) {
for n != nil {
switch c := Less(n.Elem, qname); {
case c == 0:
return n, true
case c < 0:
n = n.Left
default:
n = n.Right
}
}
return n, false
}
// Insert inserts rr into the Tree at the first match found
// with e or when a nil node is reached.
func (t *Tree) Insert(rr dns.RR) {
var d int
t.Root, d = t.Root.insert(rr)
t.Count += d
t.Root.Color = black
}
// insert inserts rr in to the tree.
func (n *Node) insert(rr dns.RR) (root *Node, d int) {
if n == nil {
return &Node{Elem: newElem(rr)}, 1
} else if n.Elem == nil {
n.Elem = newElem(rr)
return n, 1
}
if mode == td234 {
if n.Left.color() == red && n.Right.color() == red {
n.flipColors()
}
}
switch c := Less(n.Elem, rr.Header().Name); {
case c == 0:
n.Elem.Insert(rr)
case c < 0:
n.Left, d = n.Left.insert(rr)
default:
n.Right, d = n.Right.insert(rr)
}
if n.Right.color() == red && n.Left.color() == black {
n = n.rotateLeft()
}
if n.Left.color() == red && n.Left.Left.color() == red {
n = n.rotateRight()
}
if mode == bu23 {
if n.Left.color() == red && n.Right.color() == red {
n.flipColors()
}
}
root = n
return root, d
}
// DeleteMin deletes the node with the minimum value in the tree.
func (t *Tree) DeleteMin() {
if t.Root == nil {
return
}
var d int
t.Root, d = t.Root.deleteMin()
t.Count += d
if t.Root == nil {
return
}
t.Root.Color = black
}
func (n *Node) deleteMin() (root *Node, d int) {
if n.Left == nil {
return nil, -1
}
if n.Left.color() == black && n.Left.Left.color() == black {
n = n.moveRedLeft()
}
n.Left, d = n.Left.deleteMin()
root = n.fixUp()
return
}
// DeleteMax deletes the node with the maximum value in the tree.
func (t *Tree) DeleteMax() {
if t.Root == nil {
return
}
var d int
t.Root, d = t.Root.deleteMax()
t.Count += d
if t.Root == nil {
return
}
t.Root.Color = black
}
func (n *Node) deleteMax() (root *Node, d int) {
if n.Left != nil && n.Left.color() == red {
n = n.rotateRight()
}
if n.Right == nil {
return nil, -1
}
if n.Right.color() == black && n.Right.Left.color() == black {
n = n.moveRedRight()
}
n.Right, d = n.Right.deleteMax()
root = n.fixUp()
return
}
// Delete removes all RRs of type rr.Header().Rrtype from e. If after the deletion of rr the node is empty the
// entire node is deleted.
func (t *Tree) Delete(rr dns.RR) {
if t.Root == nil {
return
}
el, _ := t.Search(rr.Header().Name)
if el == nil {
return
}
el.Delete(rr)
if el.Empty() {
t.deleteNode(rr)
}
}
// DeleteNode deletes the node that matches rr according to Less().
func (t *Tree) deleteNode(rr dns.RR) {
if t.Root == nil {
return
}
var d int
t.Root, d = t.Root.delete(rr)
t.Count += d
if t.Root == nil {
return
}
t.Root.Color = black
}
func (n *Node) delete(rr dns.RR) (root *Node, d int) {
if Less(n.Elem, rr.Header().Name) < 0 {
if n.Left != nil {
if n.Left.color() == black && n.Left.Left.color() == black {
n = n.moveRedLeft()
}
n.Left, d = n.Left.delete(rr)
}
} else {
if n.Left.color() == red {
n = n.rotateRight()
}
if n.Right == nil && Less(n.Elem, rr.Header().Name) == 0 {
return nil, -1
}
if n.Right != nil {
if n.Right.color() == black && n.Right.Left.color() == black {
n = n.moveRedRight()
}
if Less(n.Elem, rr.Header().Name) == 0 {
n.Elem = n.Right.min().Elem
n.Right, d = n.Right.deleteMin()
} else {
n.Right, d = n.Right.delete(rr)
}
}
}
root = n.fixUp()
return
}
// Min returns the minimum value stored in the tree.
func (t *Tree) Min() *Elem {
if t.Root == nil {
return nil
}
return t.Root.min().Elem
}
func (n *Node) min() *Node {
for ; n.Left != nil; n = n.Left {
}
return n
}
// Max returns the maximum value stored in the tree.
func (t *Tree) Max() *Elem {
if t.Root == nil {
return nil
}
return t.Root.max().Elem
}
func (n *Node) max() *Node {
for ; n.Right != nil; n = n.Right {
}
return n
}
// Prev returns the greatest value equal to or less than the qname according to Less().
func (t *Tree) Prev(qname string) (*Elem, bool) {
if t.Root == nil {
return nil, false
}
n := t.Root.floor(qname)
if n == nil {
return nil, false
}
return n.Elem, true
}
func (n *Node) floor(qname string) *Node {
if n == nil {
return nil
}
switch c := Less(n.Elem, qname); {
case c == 0:
return n
case c <= 0:
return n.Left.floor(qname)
default:
if r := n.Right.floor(qname); r != nil {
return r
}
}
return n
}
// Next returns the smallest value equal to or greater than the qname according to Less().
func (t *Tree) Next(qname string) (*Elem, bool) {
if t.Root == nil {
return nil, false
}
n := t.Root.ceil(qname)
if n == nil {
return nil, false
}
return n.Elem, true
}
func (n *Node) ceil(qname string) *Node {
if n == nil {
return nil
}
switch c := Less(n.Elem, qname); {
case c == 0:
return n
case c > 0:
return n.Right.ceil(qname)
default:
if l := n.Left.ceil(qname); l != nil {
return l
}
}
return n
}
/*
Copyright ©2012 The bíogo Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the bíogo project nor the names of its authors and
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package tree
import "github.com/miekg/dns"
// Walk performs fn on all authoritative values stored in the tree in
// in-order depth first. If a non-nil error is returned the Walk was interrupted
// by an fn returning that error. If fn alters stored values' sort
// relationships, future tree operation behaviors are undefined.
func (t *Tree) Walk(fn func(*Elem, map[uint16][]dns.RR) error) error {
if t.Root == nil {
return nil
}
return t.Root.walk(fn)
}
func (n *Node) walk(fn func(*Elem, map[uint16][]dns.RR) error) error {
if n.Left != nil {
if err := n.Left.walk(fn); err != nil {
return err
}
}
if err := fn(n.Elem, n.Elem.m); err != nil {
return err
}
if n.Right != nil {
if err := n.Right.walk(fn); err != nil {
return err
}
}
return nil
}
package file
import "github.com/miekg/dns"
// replaceWithAsteriskLabel replaces the left most label with '*'.
func replaceWithAsteriskLabel(qname string) (wildcard string) {
i, shot := dns.NextLabel(qname, 0)
if shot {
return ""
}
return "*." + qname[i:]
}
package file
import (
"github.com/coredns/coredns/plugin/file/tree"
"github.com/coredns/coredns/plugin/transfer"
"github.com/miekg/dns"
)
// Transfer implements the transfer.Transfer interface.
func (f File) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) {
z, ok := f.Z[zone]
if !ok || z == nil {
return nil, transfer.ErrNotAuthoritative
}
return z.Transfer(serial)
}
// Transfer transfers a zone with serial in the returned channel and implements IXFR fallback, by just
// sending a single SOA record.
func (z *Zone) Transfer(serial uint32) (<-chan []dns.RR, error) {
// get soa and apex
apex, err := z.ApexIfDefined()
if err != nil {
return nil, err
}
ch := make(chan []dns.RR)
go func() {
if serial != 0 && apex[0].(*dns.SOA).Serial == serial { // ixfr fallback, only send SOA
ch <- []dns.RR{apex[0]}
close(ch)
return
}
ch <- apex
z.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error { ch <- e.All(); return nil })
ch <- []dns.RR{apex[0]}
close(ch)
}()
return ch, nil
}
package file
import (
"fmt"
"path/filepath"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin/file/tree"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/miekg/dns"
)
// Zone is a structure that contains all data related to a DNS zone.
type Zone struct {
origin string
origLen int
file string
*tree.Tree
Apex
Expired bool
sync.RWMutex
StartupOnce sync.Once
TransferFrom []string
ReloadInterval time.Duration
reloadShutdown chan bool
Upstream *upstream.Upstream // Upstream for looking up external names during the resolution process.
}
// Apex contains the apex records of a zone: SOA, NS and their potential signatures.
type Apex struct {
SOA *dns.SOA
NS []dns.RR
SIGSOA []dns.RR
SIGNS []dns.RR
}
// NewZone returns a new zone.
func NewZone(name, file string) *Zone {
return &Zone{
origin: dns.Fqdn(name),
origLen: dns.CountLabel(dns.Fqdn(name)),
file: filepath.Clean(file),
Tree: &tree.Tree{},
reloadShutdown: make(chan bool),
}
}
// Copy copies a zone.
func (z *Zone) Copy() *Zone {
z1 := NewZone(z.origin, z.file)
z1.TransferFrom = z.TransferFrom
z1.Expired = z.Expired
z1.Apex = z.Apex
return z1
}
// CopyWithoutApex copies zone z without the Apex records.
func (z *Zone) CopyWithoutApex() *Zone {
z1 := NewZone(z.origin, z.file)
z1.TransferFrom = z.TransferFrom
z1.Expired = z.Expired
return z1
}
// Insert inserts r into z.
func (z *Zone) Insert(r dns.RR) error {
// r.Header().Name = strings.ToLower(r.Header().Name)
if r.Header().Rrtype != dns.TypeSRV {
r.Header().Name = strings.ToLower(r.Header().Name)
}
switch h := r.Header().Rrtype; h {
case dns.TypeNS:
r.(*dns.NS).Ns = strings.ToLower(r.(*dns.NS).Ns)
if r.Header().Name == z.origin {
z.NS = append(z.NS, r)
return nil
}
case dns.TypeSOA:
r.(*dns.SOA).Ns = strings.ToLower(r.(*dns.SOA).Ns)
r.(*dns.SOA).Mbox = strings.ToLower(r.(*dns.SOA).Mbox)
z.SOA = r.(*dns.SOA)
return nil
case dns.TypeNSEC3, dns.TypeNSEC3PARAM:
return fmt.Errorf("NSEC3 zone is not supported, dropping RR: %s for zone: %s", r.Header().Name, z.origin)
case dns.TypeRRSIG:
x := r.(*dns.RRSIG)
switch x.TypeCovered {
case dns.TypeSOA:
z.SIGSOA = append(z.SIGSOA, x)
return nil
case dns.TypeNS:
if r.Header().Name == z.origin {
z.SIGNS = append(z.SIGNS, x)
return nil
}
}
case dns.TypeCNAME:
r.(*dns.CNAME).Target = strings.ToLower(r.(*dns.CNAME).Target)
case dns.TypeMX:
r.(*dns.MX).Mx = strings.ToLower(r.(*dns.MX).Mx)
case dns.TypeSRV:
// r.(*dns.SRV).Target = strings.ToLower(r.(*dns.SRV).Target)
}
z.Tree.Insert(r)
return nil
}
// File retrieves the file path in a safe way.
func (z *Zone) File() string {
z.RLock()
defer z.RUnlock()
return z.file
}
// SetFile updates the file path in a safe way.
func (z *Zone) SetFile(path string) {
z.Lock()
z.file = path
z.Unlock()
}
// ApexIfDefined returns the apex nodes from z. The SOA record is the first record, if it does not exist, an error is returned.
func (z *Zone) ApexIfDefined() ([]dns.RR, error) {
z.RLock()
defer z.RUnlock()
if z.SOA == nil {
return nil, fmt.Errorf("no SOA")
}
rrs := []dns.RR{z.SOA}
if len(z.SIGSOA) > 0 {
rrs = append(rrs, z.SIGSOA...)
}
if len(z.NS) > 0 {
rrs = append(rrs, z.NS...)
}
if len(z.SIGNS) > 0 {
rrs = append(rrs, z.SIGNS...)
}
return rrs, nil
}
// NameFromRight returns the labels from the right, staring with the
// origin and then i labels extra. When we are overshooting the name
// the returned boolean is set to true.
func (z *Zone) nameFromRight(qname string, i int) (string, bool) {
if i <= 0 {
return z.origin, false
}
n := len(qname)
for j := 1; j <= z.origLen; j++ {
if m, shot := dns.PrevLabel(qname[:n], 1); shot {
return qname, shot
} else {
n = m
}
}
for j := 1; j <= i; j++ {
m, shot := dns.PrevLabel(qname[:n], 1)
if shot {
return qname, shot
} else {
n = m
}
}
return qname[n:], false
}
package forward
import (
"context"
"net"
"net/netip"
"time"
"github.com/coredns/coredns/plugin/dnstap/msg"
"github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/coredns/coredns/request"
tap "github.com/dnstap/golang-dnstap"
"github.com/miekg/dns"
)
// toDnstap will send the forward and received message to the dnstap plugin.
func toDnstap(ctx context.Context, f *Forward, host string, state request.Request, opts proxy.Options, reply *dns.Msg, start time.Time) {
ap, _ := netip.ParseAddrPort(host) // this is preparsed and can't err here
ip := net.IP(ap.Addr().AsSlice())
port := int(ap.Port())
var ta net.Addr = &net.UDPAddr{
IP: ip,
Port: port,
}
t := state.Proto()
switch {
case opts.ForceTCP:
t = "tcp"
case opts.PreferUDP:
t = "udp"
}
if t == "tcp" {
ta = &net.TCPAddr{IP: ip, Port: port}
}
for _, t := range f.tapPlugins {
// Query
q := new(tap.Message)
msg.SetQueryTime(q, start)
// Forwarder dnstap messages are from the perspective of the downstream server
// (upstream is the forward server)
msg.SetQueryAddress(q, state.W.RemoteAddr())
msg.SetResponseAddress(q, ta)
if t.IncludeRawMessage {
buf, _ := state.Req.Pack()
q.QueryMessage = buf
}
msg.SetType(q, tap.Message_FORWARDER_QUERY)
t.TapMessageWithMetadata(ctx, q, state)
// Response
if reply != nil {
r := new(tap.Message)
if t.IncludeRawMessage {
buf, _ := reply.Pack()
r.ResponseMessage = buf
}
msg.SetQueryTime(r, start)
msg.SetQueryAddress(r, state.W.RemoteAddr())
msg.SetResponseAddress(r, ta)
msg.SetResponseTime(r, time.Now())
msg.SetType(r, tap.Message_FORWARDER_RESPONSE)
t.TapMessageWithMetadata(ctx, r, state)
}
}
}
// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses
// inband healthchecking.
package forward
import (
"context"
"crypto/tls"
"errors"
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/debug"
"github.com/coredns/coredns/plugin/dnstap"
"github.com/coredns/coredns/plugin/metadata"
clog "github.com/coredns/coredns/plugin/pkg/log"
proxyPkg "github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
otext "github.com/opentracing/opentracing-go/ext"
)
var log = clog.NewWithPlugin("forward")
const (
defaultExpire = 10 * time.Second
hcInterval = 500 * time.Millisecond
)
// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list
// of proxies each representing one upstream proxy.
type Forward struct {
concurrent int64 // atomic counters need to be first in struct for proper alignment
proxies []*proxyPkg.Proxy
p Policy
hcInterval time.Duration
from string
ignored []string
nextAlternateRcodes []int
tlsConfig *tls.Config
tlsServerName string
maxfails uint32
expire time.Duration
maxConcurrent int64
failfastUnhealthyUpstreams bool
failoverRcodes []int
opts proxyPkg.Options // also here for testing
// ErrLimitExceeded indicates that a query was rejected because the number of concurrent queries has exceeded
// the maximum allowed (maxConcurrent)
ErrLimitExceeded error
tapPlugins []*dnstap.Dnstap // when dnstap plugins are loaded, we use to this to send messages out.
Next plugin.Handler
}
// New returns a new Forward.
func New() *Forward {
f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, opts: proxyPkg.Options{ForceTCP: false, PreferUDP: false, HCRecursionDesired: true, HCDomain: "."}}
return f
}
// SetProxy appends p to the proxy list and starts healthchecking.
func (f *Forward) SetProxy(p *proxyPkg.Proxy) {
f.proxies = append(f.proxies, p)
p.Start(f.hcInterval)
}
// SetProxyOptions setup proxy options
func (f *Forward) SetProxyOptions(opts proxyPkg.Options) {
f.opts = opts
}
// SetTapPlugin appends one or more dnstap plugins to the tap plugin list.
func (f *Forward) SetTapPlugin(tapPlugin *dnstap.Dnstap) {
f.tapPlugins = append(f.tapPlugins, tapPlugin)
if nextPlugin, ok := tapPlugin.Next.(*dnstap.Dnstap); ok {
f.SetTapPlugin(nextPlugin)
}
}
// Len returns the number of configured proxies.
func (f *Forward) Len() int { return len(f.proxies) }
// Name implements plugin.Handler.
func (f *Forward) Name() string { return "forward" }
// ServeDNS implements plugin.Handler.
func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if !f.match(state) {
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
if f.maxConcurrent > 0 {
count := atomic.AddInt64(&(f.concurrent), 1)
defer atomic.AddInt64(&(f.concurrent), -1)
if count > f.maxConcurrent {
maxConcurrentRejectCount.Add(1)
return dns.RcodeRefused, f.ErrLimitExceeded
}
}
fails := 0
var span, child ot.Span
var upstreamErr error
span = ot.SpanFromContext(ctx)
i := 0
list := f.List()
deadline := time.Now().Add(defaultTimeout)
start := time.Now()
for time.Now().Before(deadline) && ctx.Err() == nil {
if i >= len(list) {
// reached the end of list, reset to begin
i = 0
fails = 0
}
proxy := list[i]
i++
if proxy.Down(f.maxfails) {
fails++
if fails < len(f.proxies) {
continue
}
healthcheckBrokenCount.Add(1)
// All upstreams are dead, return servfail if all upstreams are down
if f.failfastUnhealthyUpstreams {
break
}
// assume healthcheck is completely broken and randomly
// select an upstream to connect to.
r := new(random)
proxy = r.List(f.proxies)[0]
}
if span != nil {
child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context()))
otext.PeerAddress.Set(child, proxy.Addr())
ctx = ot.ContextWithSpan(ctx, child)
}
metadata.SetValueFunc(ctx, "forward/upstream", func() string {
return proxy.Addr()
})
var (
ret *dns.Msg
err error
)
opts := f.opts
for {
ret, err = proxy.Connect(ctx, state, opts)
if err == proxyPkg.ErrCachedClosed { // Remote side closed conn, can only happen with TCP.
continue
}
// Retry with TCP if truncated and prefer_udp configured.
if ret != nil && ret.Truncated && !opts.ForceTCP && opts.PreferUDP {
opts.ForceTCP = true
continue
}
break
}
if child != nil {
child.Finish()
}
if len(f.tapPlugins) != 0 {
toDnstap(ctx, f, proxy.Addr(), state, opts, ret, start)
}
upstreamErr = err
if err != nil {
// Kick off health check to see if *our* upstream is broken.
if f.maxfails != 0 {
proxy.Healthcheck()
}
if fails < len(f.proxies) {
continue
}
break
}
// Check if the reply is correct; if not return FormErr.
if !state.Match(ret) {
debug.Hexdumpf(ret, "Wrong reply for id: %d, %s %d", ret.Id, state.QName(), state.QType())
formerr := new(dns.Msg)
formerr.SetRcode(state.Req, dns.RcodeFormatError)
w.WriteMsg(formerr)
return 0, nil
}
// Check if we have a failover Rcode defined, check if we match on the code
tryNext := false
for _, failoverRcode := range f.failoverRcodes {
// if we match, we continue to the next upstream in the list
if failoverRcode == ret.Rcode {
if fails < len(f.proxies) {
tryNext = true
}
}
}
if tryNext {
fails++
continue
}
// Check if we have an alternate Rcode defined, check if we match on the code
for _, alternateRcode := range f.nextAlternateRcodes {
if alternateRcode == ret.Rcode && f.Next != nil { // In case we do not have a Next handler, just continue normally
if _, ok := f.Next.(*Forward); ok { // Only continue if the next forwarder is also a Forworder
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
}
}
w.WriteMsg(ret)
return 0, nil
}
if upstreamErr != nil {
return dns.RcodeServerFailure, upstreamErr
}
return dns.RcodeServerFailure, ErrNoHealthy
}
func (f *Forward) match(state request.Request) bool {
if !plugin.Name(f.from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) {
return false
}
return true
}
func (f *Forward) isAllowedDomain(name string) bool {
if dns.Name(name) == dns.Name(f.from) {
return true
}
for _, ignore := range f.ignored {
if plugin.Name(ignore).Matches(name) {
return false
}
}
return true
}
// ForceTCP returns if TCP is forced to be used even when the request comes in over UDP.
func (f *Forward) ForceTCP() bool { return f.opts.ForceTCP }
// PreferUDP returns if UDP is preferred to be used even when the request comes in over TCP.
func (f *Forward) PreferUDP() bool { return f.opts.PreferUDP }
// List returns a set of proxies to be used for this client depending on the policy in f.
func (f *Forward) List() []*proxyPkg.Proxy { return f.p.List(f.proxies) }
var (
// ErrNoHealthy means no healthy proxies left.
ErrNoHealthy = errors.New("no healthy proxies")
// ErrNoForward means no forwarder defined.
ErrNoForward = errors.New("no forwarder defined")
// ErrCachedClosed means cached connection was closed by peer.
ErrCachedClosed = errors.New("cached connection was closed by peer")
)
// Options holds various Options that can be set.
type Options struct {
// ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag
ForceTCP bool
// PreferUDP use UDP protocol for upstream DNS request.
PreferUDP bool
// HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests
HCRecursionDesired bool
// HCDomain sets domain for Proxy healthcheck requests
HCDomain string
}
var defaultTimeout = 5 * time.Second
//go:build gofuzz
package forward
import (
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/fuzz"
"github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/miekg/dns"
)
var f *Forward
// abuse init to setup an environment to test against. This start another server to that will
// reflect responses.
func init() {
f = New()
s := dnstest.NewServer(r{}.reflectHandler)
f.SetProxy(proxy.NewProxy("FuzzForwardPlugin1", s.Addr, "tcp"))
f.SetProxy(proxy.NewProxy("FuzzForwardPlugin2", s.Addr, "udp"))
}
// Fuzz fuzzes forward.
func Fuzz(data []byte) int {
return fuzz.Do(f, data)
}
type r struct{}
func (r r) reflectHandler(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetReply(req)
w.WriteMsg(m)
}
package forward
import (
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/coredns/coredns/plugin/pkg/rand"
)
// Policy defines a policy we use for selecting upstreams.
type Policy interface {
List([]*proxy.Proxy) []*proxy.Proxy
String() string
}
// random is a policy that implements random upstream selection.
type random struct{}
func (r *random) String() string { return "random" }
func (r *random) List(p []*proxy.Proxy) []*proxy.Proxy {
switch len(p) {
case 1:
return p
case 2:
if rn.Int()%2 == 0 {
return []*proxy.Proxy{p[1], p[0]} // swap
}
return p
}
perms := rn.Perm(len(p))
rnd := make([]*proxy.Proxy, len(p))
for i, p1 := range perms {
rnd[i] = p[p1]
}
return rnd
}
// roundRobin is a policy that selects hosts based on round robin ordering.
type roundRobin struct {
robin uint32
}
func (r *roundRobin) String() string { return "round_robin" }
func (r *roundRobin) List(p []*proxy.Proxy) []*proxy.Proxy {
poolLen := uint32(len(p))
i := atomic.AddUint32(&r.robin, 1) % poolLen
robin := []*proxy.Proxy{p[i]}
robin = append(robin, p[:i]...)
robin = append(robin, p[i+1:]...)
return robin
}
// sequential is a policy that selects hosts based on sequential ordering.
type sequential struct{}
func (r *sequential) String() string { return "sequential" }
func (r *sequential) List(p []*proxy.Proxy) []*proxy.Proxy {
return p
}
var rn = rand.New(time.Now().UnixNano())
package forward
import (
"crypto/tls"
"errors"
"fmt"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/dnstap"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/coredns/coredns/plugin/pkg/proxy"
pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
func init() {
plugin.Register("forward", setup)
}
func setup(c *caddy.Controller) error {
fs, err := parseForward(c)
if err != nil {
return plugin.Error("forward", err)
}
for i := range fs {
f := fs[i]
if f.Len() > max {
return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len()))
}
if i == len(fs)-1 {
// last forward: point next to next plugin
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
f.Next = next
return f
})
} else {
// middle forward: point next to next forward
nextForward := fs[i+1]
dnsserver.GetConfig(c).AddPlugin(func(plugin.Handler) plugin.Handler {
f.Next = nextForward
return f
})
}
c.OnStartup(func() error {
return f.OnStartup()
})
c.OnStartup(func() error {
if taph := dnsserver.GetConfig(c).Handler("dnstap"); taph != nil {
f.SetTapPlugin(taph.(*dnstap.Dnstap))
}
return nil
})
c.OnShutdown(func() error {
return f.OnShutdown()
})
}
return nil
}
// OnStartup starts a goroutines for all proxies.
func (f *Forward) OnStartup() (err error) {
for _, p := range f.proxies {
p.Start(f.hcInterval)
}
return nil
}
// OnShutdown stops all configured proxies.
func (f *Forward) OnShutdown() error {
for _, p := range f.proxies {
p.Stop()
}
return nil
}
func parseForward(c *caddy.Controller) ([]*Forward, error) {
var fs = []*Forward{}
for c.Next() {
f, err := parseStanza(c)
if err != nil {
return nil, err
}
fs = append(fs, f)
}
return fs, nil
}
// Splits the zone, preserving any port that comes after the zone
func splitZone(host string) (newHost string, zone string) {
newHost = host
if strings.Contains(host, "%") {
lastPercent := strings.LastIndex(host, "%")
newHost = host[:lastPercent]
zone = host[lastPercent+1:]
if strings.Contains(zone, ":") {
lastColon := strings.LastIndex(zone, ":")
newHost += zone[lastColon:]
zone = zone[:lastColon]
}
}
return
}
func parseStanza(c *caddy.Controller) (*Forward, error) {
f := New()
if !c.Args(&f.from) {
return f, c.ArgErr()
}
origFrom := f.from
zones := plugin.Host(f.from).NormalizeExact()
if len(zones) == 0 {
return f, fmt.Errorf("unable to normalize '%s'", f.from)
}
f.from = zones[0] // there can only be one here, won't work with non-octet reverse
if len(zones) > 1 {
log.Warningf("Unsupported CIDR notation: '%s' expands to multiple zones. Using only '%s'.", origFrom, f.from)
}
to := c.RemainingArgs()
if len(to) == 0 {
return f, c.ArgErr()
}
toHosts, err := parse.HostPortOrFile(to...)
if err != nil {
return f, err
}
for c.NextBlock() {
if err := parseBlock(c, f); err != nil {
return f, err
}
}
tlsServerNames := make([]string, len(toHosts))
perServerNameProxyCount := make(map[string]int)
transports := make([]string, len(toHosts))
allowedTrans := map[string]bool{"dns": true, "tls": true}
for i, hostWithZone := range toHosts {
host, serverName := splitZone(hostWithZone)
trans, h := parse.Transport(host)
if !allowedTrans[trans] {
return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host)
}
if trans == transport.TLS && serverName != "" {
if f.tlsServerName != "" {
return f, fmt.Errorf("both forward ('%s') and proxy level ('%s') TLS servernames are set for upstream proxy '%s'", f.tlsServerName, serverName, host)
}
tlsServerNames[i] = serverName
perServerNameProxyCount[serverName]++
}
p := proxy.NewProxy("forward", h, trans)
f.proxies = append(f.proxies, p)
transports[i] = trans
}
perServerNameTlsConfig := make(map[string]*tls.Config)
if f.tlsServerName != "" {
f.tlsConfig.ServerName = f.tlsServerName
} else {
for serverName, proxyCount := range perServerNameProxyCount {
tlsConfig := f.tlsConfig.Clone()
tlsConfig.ServerName = serverName
tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(proxyCount)
perServerNameTlsConfig[serverName] = tlsConfig
}
}
// Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake
// in upcoming connections to the same TLS server.
f.tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(len(f.proxies))
for i := range f.proxies {
// Only set this for proxies that need it.
if transports[i] == transport.TLS {
if tlsConfig, ok := perServerNameTlsConfig[tlsServerNames[i]]; ok {
f.proxies[i].SetTLSConfig(tlsConfig)
} else {
f.proxies[i].SetTLSConfig(f.tlsConfig)
}
}
f.proxies[i].SetExpire(f.expire)
f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired)
// when TLS is used, checks are set to tcp-tls
if f.opts.ForceTCP && transports[i] != transport.TLS {
f.proxies[i].GetHealthchecker().SetTCPTransport()
}
f.proxies[i].GetHealthchecker().SetDomain(f.opts.HCDomain)
}
return f, nil
}
func parseBlock(c *caddy.Controller, f *Forward) error {
config := dnsserver.GetConfig(c)
switch c.Val() {
case "except":
ignore := c.RemainingArgs()
if len(ignore) == 0 {
return c.ArgErr()
}
for i := range ignore {
f.ignored = append(f.ignored, plugin.Host(ignore[i]).NormalizeExact()...)
}
case "max_fails":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.ParseUint(c.Val(), 10, 32)
if err != nil {
return err
}
f.maxfails = uint32(n)
case "health_check":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
if dur < 0 {
return fmt.Errorf("health_check can't be negative: %d", dur)
}
f.hcInterval = dur
f.opts.HCDomain = "."
for c.NextArg() {
switch hcOpts := c.Val(); hcOpts {
case "no_rec":
f.opts.HCRecursionDesired = false
case "domain":
if !c.NextArg() {
return c.ArgErr()
}
hcDomain := c.Val()
if _, ok := dns.IsDomainName(hcDomain); !ok {
return fmt.Errorf("health_check: invalid domain name %s", hcDomain)
}
f.opts.HCDomain = plugin.Name(hcDomain).Normalize()
default:
return fmt.Errorf("health_check: unknown option %s", hcOpts)
}
}
case "force_tcp":
if c.NextArg() {
return c.ArgErr()
}
f.opts.ForceTCP = true
case "prefer_udp":
if c.NextArg() {
return c.ArgErr()
}
f.opts.PreferUDP = true
case "tls":
args := c.RemainingArgs()
if len(args) > 3 {
return c.ArgErr()
}
for i := range args {
if !filepath.IsAbs(args[i]) && config.Root != "" {
args[i] = filepath.Join(config.Root, args[i])
}
}
tlsConfig, err := pkgtls.NewTLSConfigFromArgs(args...)
if err != nil {
return err
}
f.tlsConfig = tlsConfig
case "tls_servername":
if !c.NextArg() {
return c.ArgErr()
}
f.tlsServerName = c.Val()
case "expire":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
if dur < 0 {
return fmt.Errorf("expire can't be negative: %s", dur)
}
f.expire = dur
case "policy":
if !c.NextArg() {
return c.ArgErr()
}
switch x := c.Val(); x {
case "random":
f.p = &random{}
case "round_robin":
f.p = &roundRobin{}
case "sequential":
f.p = &sequential{}
default:
return c.Errf("unknown policy '%s'", x)
}
case "max_concurrent":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
if n < 0 {
return fmt.Errorf("max_concurrent can't be negative: %d", n)
}
f.ErrLimitExceeded = errors.New("concurrent queries exceeded maximum " + c.Val())
f.maxConcurrent = int64(n)
case "next":
args := c.RemainingArgs()
if len(args) == 0 {
return c.ArgErr()
}
for _, rcode := range args {
var rc int
var ok bool
if rc, ok = dns.StringToRcode[strings.ToUpper(rcode)]; !ok {
return fmt.Errorf("%s is not a valid rcode", rcode)
}
f.nextAlternateRcodes = append(f.nextAlternateRcodes, rc)
}
case "failfast_all_unhealthy_upstreams":
args := c.RemainingArgs()
if len(args) != 0 {
return c.ArgErr()
}
f.failfastUnhealthyUpstreams = true
case "failover":
args := c.RemainingArgs()
if len(args) == 0 {
return c.ArgErr()
}
toRcode := dns.StringToRcode
for _, rcode := range args {
rc, ok := toRcode[strings.ToUpper(rcode)]
if !ok {
return fmt.Errorf("%s is not a valid rcode", rcode)
}
if rc == dns.RcodeSuccess {
return fmt.Errorf("NoError cannot be used in failover")
}
f.failoverRcodes = append(f.failoverRcodes, rc)
}
default:
return c.Errf("unknown property '%s'", c.Val())
}
return nil
}
const max = 15 // Maximum number of upstreams.
package geoip
import (
"context"
"strconv"
"github.com/coredns/coredns/plugin/metadata"
"github.com/oschwald/geoip2-golang"
)
const defaultLang = "en"
func (g GeoIP) setCityMetadata(ctx context.Context, data *geoip2.City) {
// Set labels for city, country and continent names.
cityName := data.City.Names[defaultLang]
metadata.SetValueFunc(ctx, pluginName+"/city/name", func() string {
return cityName
})
countryName := data.Country.Names[defaultLang]
metadata.SetValueFunc(ctx, pluginName+"/country/name", func() string {
return countryName
})
continentName := data.Continent.Names[defaultLang]
metadata.SetValueFunc(ctx, pluginName+"/continent/name", func() string {
return continentName
})
countryCode := data.Country.IsoCode
metadata.SetValueFunc(ctx, pluginName+"/country/code", func() string {
return countryCode
})
isInEurope := strconv.FormatBool(data.Country.IsInEuropeanUnion)
metadata.SetValueFunc(ctx, pluginName+"/country/is_in_european_union", func() string {
return isInEurope
})
continentCode := data.Continent.Code
metadata.SetValueFunc(ctx, pluginName+"/continent/code", func() string {
return continentCode
})
latitude := strconv.FormatFloat(data.Location.Latitude, 'f', -1, 64)
metadata.SetValueFunc(ctx, pluginName+"/latitude", func() string {
return latitude
})
longitude := strconv.FormatFloat(data.Location.Longitude, 'f', -1, 64)
metadata.SetValueFunc(ctx, pluginName+"/longitude", func() string {
return longitude
})
timeZone := data.Location.TimeZone
metadata.SetValueFunc(ctx, pluginName+"/timezone", func() string {
return timeZone
})
postalCode := data.Postal.Code
metadata.SetValueFunc(ctx, pluginName+"/postalcode", func() string {
return postalCode
})
}
// Package geoip implements a max mind database plugin.
package geoip
import (
"context"
"fmt"
"net"
"path/filepath"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
"github.com/oschwald/geoip2-golang"
)
var log = clog.NewWithPlugin(pluginName)
// GeoIP is a plugin that add geo location data to the request context by looking up a maxmind
// geoIP2 database, and which data can be later consumed by other middlewares.
type GeoIP struct {
Next plugin.Handler
db db
edns0 bool
}
type db struct {
*geoip2.Reader
// provides defines the schemas that can be obtained by querying this database, by using
// bitwise operations.
provides int
}
const (
city = 1 << iota
)
var probingIP = net.ParseIP("127.0.0.1")
func newGeoIP(dbPath string, edns0 bool) (*GeoIP, error) {
reader, err := geoip2.Open(dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database file: %v", err)
}
db := db{Reader: reader}
schemas := []struct {
provides int
name string
validate func() error
}{
{name: "city", provides: city, validate: func() error { _, err := reader.City(probingIP); return err }},
}
// Query the database to figure out the database type.
for _, schema := range schemas {
if err := schema.validate(); err != nil {
// If we get an InvalidMethodError then we know this database does not provide that schema.
if _, ok := err.(geoip2.InvalidMethodError); !ok {
return nil, fmt.Errorf("unexpected failure looking up database %q schema %q: %v", filepath.Base(dbPath), schema.name, err)
}
} else {
db.provides |= schema.provides
}
}
if db.provides&city == 0 {
return nil, fmt.Errorf("database does not provide city schema")
}
return &GeoIP{db: db, edns0: edns0}, nil
}
// ServeDNS implements the plugin.Handler interface.
func (g GeoIP) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return plugin.NextOrFailure(pluginName, g.Next, ctx, w, r)
}
// Metadata implements the metadata.Provider Interface in the metadata plugin, and is used to store
// the data associated with the source IP of every request.
func (g GeoIP) Metadata(ctx context.Context, state request.Request) context.Context {
srcIP := net.ParseIP(state.IP())
if g.edns0 {
if o := state.Req.IsEdns0(); o != nil {
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_SUBNET); ok {
srcIP = e.Address
break
}
}
}
}
switch g.db.provides & city {
case city:
data, err := g.db.City(srcIP)
if err != nil {
log.Debugf("Setting up metadata failed due to database lookup error: %v", err)
return ctx
}
g.setCityMetadata(ctx, data)
}
return ctx
}
// Name implements the Handler interface.
func (g GeoIP) Name() string { return pluginName }
package geoip
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
const pluginName = "geoip"
func init() { plugin.Register(pluginName, setup) }
func setup(c *caddy.Controller) error {
geoip, err := geoipParse(c)
if err != nil {
return plugin.Error(pluginName, err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
geoip.Next = next
return geoip
})
return nil
}
func geoipParse(c *caddy.Controller) (*GeoIP, error) {
var dbPath string
var edns0 bool
for c.Next() {
if !c.NextArg() {
return nil, c.ArgErr()
}
if dbPath != "" {
return nil, c.Errf("configuring multiple databases is not supported")
}
dbPath = c.Val()
// There shouldn't be any more arguments.
if len(c.RemainingArgs()) != 0 {
return nil, c.ArgErr()
}
for c.NextBlock() {
if c.Val() != "edns-subnet" {
return nil, c.Errf("unknown property %q", c.Val())
}
edns0 = true
}
}
geoIP, err := newGeoIP(dbPath, edns0)
if err != nil {
return geoIP, c.Err(err.Error())
}
return geoIP, nil
}
//go:build gofuzz
package grpc
import (
"context"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/fuzz"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
grpcgo "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// fakeClient implements pb.DnsServiceClient without doing any network I/O.
// Its behavior is controlled by the mode field.
type fakeClient struct {
mode byte
idx int
}
func (f *fakeClient) Query(_ context.Context, in *pb.DnsPacket, _ ...grpcgo.CallOption) (*pb.DnsPacket, error) {
// Derive mode deterministically from request bytes to vary behavior per call.
m := f.mode
if len(in.GetMsg()) > 0 {
b := in.GetMsg()[f.idx%len(in.GetMsg())]
f.idx++
m = b
}
switch m % 12 {
case 0:
// Success echo: return the same bytes.
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
case 1:
// Return NotFound to exercise NXDOMAIN conversion and optional fallthrough.
return nil, status.Error(codes.NotFound, "not found")
case 2:
// Return a transient error to trigger retry/rotation.
return nil, status.Error(codes.Unavailable, "unavailable")
case 3:
// Corrupt response that fails dns.Msg Unpack.
return &pb.DnsPacket{Msg: []byte{0x00, 0x01, 0x02}}, nil
case 4:
// Valid DNS message with mismatched ID/qname to trigger formerr path in ServeDNS.
var req dns.Msg
if err := req.Unpack(in.GetMsg()); err != nil {
// If input isn't a DNS message, just echo to avoid blocking fuzzing.
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
}
resp := new(dns.Msg)
resp.SetReply(&req)
resp.Id = req.Id + 1
// Alter question name if present.
if len(req.Question) > 0 {
resp.Question[0].Name = "example.net."
}
packed, err := resp.Pack()
if err != nil {
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
}
return &pb.DnsPacket{Msg: packed}, nil
case 5:
// Success with EDNS and larger answer to stress flags and sizes.
var req dns.Msg
if err := req.Unpack(in.GetMsg()); err != nil {
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
}
resp := new(dns.Msg)
resp.SetReply(&req)
// Set EDNS0 with varying UDP size and DO bit based on m.
size := uint16(512)
if (m>>1)&1 == 1 {
size = 1232
}
if (m>>2)&1 == 1 {
size = 4096
}
do := ((m>>3)&1 == 1)
resp.SetEdns0(size, do)
// Optionally set TC bit to exercise truncation handling.
if (m>>4)&1 == 1 {
resp.Truncated = true
}
// Add a few TXT records to grow the payload.
name := "."
if len(req.Question) > 0 {
name = req.Question[0].Name
}
n := int(1 + (m % 16))
for range n {
resp.Answer = append(resp.Answer, &dns.TXT{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}, Txt: []string{"aaaaaaaaaaaaaaaaaaaaaaaa", "bbbbbbbbbbbbbbbbbbbbbbbb"}})
}
packed, err := resp.Pack()
if err != nil {
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
}
return &pb.DnsPacket{Msg: packed}, nil
case 6:
return nil, status.Error(codes.DeadlineExceeded, "timeout")
case 7:
return nil, status.Error(codes.Internal, "internal")
case 8:
return nil, status.Error(codes.ResourceExhausted, "quota")
case 9:
return nil, status.Error(codes.PermissionDenied, "denied")
case 10:
// NODATA: NOERROR with empty Answer and SOA in Authority.
var req dns.Msg
if err := req.Unpack(in.GetMsg()); err != nil {
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
}
resp := new(dns.Msg)
resp.SetRcode(&req, dns.RcodeSuccess)
name := "."
if len(req.Question) > 0 {
name = req.Question[0].Name
}
resp.Ns = append(resp.Ns, &dns.SOA{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 60}, Ns: "ns.example.", Mbox: "hostmaster.example.", Serial: 1, Refresh: 3600, Retry: 600, Expire: 86400, Minttl: 60})
packed, err := resp.Pack()
if err != nil {
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
}
return &pb.DnsPacket{Msg: packed}, nil
case 11:
// TC-only: truncated response without answers.
var req dns.Msg
if err := req.Unpack(in.GetMsg()); err != nil {
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
}
resp := new(dns.Msg)
resp.SetReply(&req)
resp.Truncated = true
packed, err := resp.Pack()
if err != nil {
return &pb.DnsPacket{Msg: in.GetMsg()}, nil
}
return &pb.DnsPacket{Msg: packed}, nil
default:
// Empty/zero-length response to exercise unpack error path.
return &pb.DnsPacket{Msg: nil}, nil
}
}
// Fuzz exercises the grpc plugin using a fake client and the shared fuzz harness.
func Fuzz(data []byte) int {
if len(data) == 0 {
return 0
}
cfg := data[0]
rest := data[1:]
g := &GRPC{
from: ".",
Next: test.ErrorHandler(),
}
// Select policy based on cfg bits to vary list() ordering.
switch cfg % 3 {
case 0:
g.p = &random{}
case 1:
g.p = &roundRobin{}
default:
g.p = &sequential{}
}
// Optionally enable fallthrough; choose scope based on input bit.
if cfg&0x80 != 0 {
if cfg&0x01 != 0 {
g.Fall.SetZonesFromArgs([]string{"."})
} else {
g.Fall.SetZonesFromArgs([]string{g.from})
}
}
// Create 0–3 fake proxies with varied behaviors.
numProxies := int((cfg >> 4) & 0x03)
if numProxies == 0 {
if _, is := g.p.(*roundRobin); is {
// Avoid divide-by-zero in roundRobin policy when pool is empty.
g.p = &sequential{}
}
}
for i := range numProxies {
mode := byte(i)
if len(rest) > 0 {
mode = rest[i%len(rest)]
}
p := &Proxy{addr: "fake"}
p.client = &fakeClient{mode: mode}
g.proxies = append(g.proxies, p)
}
// Deterministically set a narrow from to miss match and hit Next/SERVFAIL paths.
if cfg&0x20 != 0 {
g.from = "_not_matching_."
}
// Optionally construct a tiny deterministic query to vary RD/CD flags.
if cfg&0x08 != 0 {
var rq dns.Msg
rq.SetQuestion("example.org.", dns.TypeA)
rq.RecursionDesired = (cfg&0x04 != 0)
rq.CheckingDisabled = (cfg&0x02 != 0)
if packed, err := rq.Pack(); err == nil {
rest = packed
}
}
return fuzz.Do(g, rest)
}
package grpc
import (
"context"
"crypto/tls"
"errors"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/debug"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
)
// GRPC represents a plugin instance that can proxy requests to another (DNS) server via gRPC protocol.
// It has a list of proxies each representing one upstream proxy.
type GRPC struct {
proxies []*Proxy
p Policy
from string
ignored []string
tlsConfig *tls.Config
tlsServerName string
Fall fall.F
Next plugin.Handler
}
// ServeDNS implements the plugin.Handler interface.
func (g *GRPC) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if !g.match(state) {
if g.Next != nil {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
}
// No next plugin, return SERVFAIL
return dns.RcodeServerFailure, nil
}
var (
span ot.Span
ret *dns.Msg
err error
i int
)
span = ot.SpanFromContext(ctx)
list := g.list()
deadline := time.Now().Add(defaultTimeout)
for time.Now().Before(deadline) {
if i >= len(list) {
// reached the end of list without any answer
if ret != nil {
// write empty response and finish
w.WriteMsg(ret)
}
break
}
proxy := list[i]
i++
callCtx := ctx
var child ot.Span
if span != nil {
child, callCtx = ot.StartSpanFromContext(callCtx, "query")
}
var cancel context.CancelFunc
callCtx, cancel = context.WithDeadline(callCtx, deadline)
ret, err = proxy.query(callCtx, r)
cancel()
if child != nil {
child.Finish()
}
if err != nil {
// Continue with the next proxy
continue
}
// Check if the reply is correct; if not return FormErr.
if !state.Match(ret) {
debug.Hexdumpf(ret, "Wrong reply for id: %d, %s %d", ret.Id, state.QName(), state.QType())
formerr := new(dns.Msg)
formerr.SetRcode(state.Req, dns.RcodeFormatError)
w.WriteMsg(formerr)
return 0, nil
}
// Check if we should fallthrough on NXDOMAIN responses
if ret.Rcode == dns.RcodeNameError && g.Fall.Through(state.Name()) {
if g.Next != nil {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
}
// No next plugin to fallthrough to, return the NXDOMAIN response
}
w.WriteMsg(ret)
return 0, nil
}
// SERVFAIL if all healthy proxys returned errors.
if err != nil {
// If fallthrough is enabled, try the next plugin instead of returning SERVFAIL
if g.Fall.Through(state.Name()) && g.Next != nil {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
}
// just return the last error received
return dns.RcodeServerFailure, err
}
// If fallthrough is enabled, try the next plugin instead of returning SERVFAIL
if g.Fall.Through(state.Name()) && g.Next != nil {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
}
return dns.RcodeServerFailure, ErrNoHealthy
}
// NewGRPC returns a new GRPC.
func newGRPC() *GRPC {
g := &GRPC{
p: new(random),
}
return g
}
// Name implements the Handler interface.
func (g *GRPC) Name() string { return "grpc" }
// Len returns the number of configured proxies.
func (g *GRPC) len() int { return len(g.proxies) }
func (g *GRPC) match(state request.Request) bool {
if !plugin.Name(g.from).Matches(state.Name()) || !g.isAllowedDomain(state.Name()) {
return false
}
return true
}
func (g *GRPC) isAllowedDomain(name string) bool {
if dns.Name(name) == dns.Name(g.from) {
return true
}
for _, ignore := range g.ignored {
if plugin.Name(ignore).Matches(name) {
return false
}
}
return true
}
// List returns a set of proxies to be used for this client depending on the policy in p.
func (g *GRPC) list() []*Proxy { return g.p.List(g.proxies) }
const defaultTimeout = 5 * time.Second
var (
// ErrNoHealthy means no healthy proxies left.
ErrNoHealthy = errors.New("no healthy gRPC proxies")
)
package grpc
import (
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/pkg/rand"
)
// Policy defines a policy we use for selecting upstreams.
type Policy interface {
List([]*Proxy) []*Proxy
String() string
}
// random is a policy that implements random upstream selection.
type random struct{}
func (r *random) String() string { return "random" }
func (r *random) List(p []*Proxy) []*Proxy {
switch len(p) {
case 0:
return nil
case 1:
return p
case 2:
if rn.Int()%2 == 0 {
return []*Proxy{p[1], p[0]} // swap
}
return p
}
perms := rn.Perm(len(p))
rnd := make([]*Proxy, len(p))
for i, p1 := range perms {
rnd[i] = p[p1]
}
return rnd
}
// roundRobin is a policy that selects hosts based on round robin ordering.
type roundRobin struct {
robin uint32
}
func (r *roundRobin) String() string { return "round_robin" }
func (r *roundRobin) List(p []*Proxy) []*Proxy {
if len(p) == 0 {
return nil
}
poolLen := uint32(len(p))
i := atomic.AddUint32(&r.robin, 1) % poolLen
robin := []*Proxy{p[i]}
robin = append(robin, p[:i]...)
robin = append(robin, p[i+1:]...)
return robin
}
// sequential is a policy that selects hosts based on sequential ordering.
type sequential struct{}
func (r *sequential) String() string { return "sequential" }
func (r *sequential) List(p []*Proxy) []*Proxy {
return p
}
var rn = rand.New(time.Now().UnixNano())
package grpc
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
"time"
"github.com/coredns/coredns/pb"
"github.com/miekg/dns"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
const (
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
maxDNSMessageBytes = dns.MaxMsgSize
// maxProtobufPayloadBytes accounts for protobuf overhead.
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
)
var (
// ErrDNSMessageTooLarge is returned when a DNS message exceeds the maximum allowed size.
ErrDNSMessageTooLarge = errors.New("dns message exceeds size limit")
)
// Proxy defines an upstream host.
type Proxy struct {
addr string
// connection
client pb.DnsServiceClient
dialOpts []grpc.DialOption
}
// newProxy returns a new proxy.
func newProxy(addr string, tlsConfig *tls.Config) (*Proxy, error) {
p := &Proxy{
addr: addr,
}
if tlsConfig != nil {
p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
} else {
p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
// Cap send/recv sizes to avoid oversized messages.
// Note: gRPC size limits apply to the serialized protobuf message size.
p.dialOpts = append(p.dialOpts,
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(maxProtobufPayloadBytes),
grpc.MaxCallSendMsgSize(maxProtobufPayloadBytes),
),
)
conn, err := grpc.NewClient(p.addr, p.dialOpts...)
if err != nil {
return nil, err
}
p.client = pb.NewDnsServiceClient(conn)
return p, nil
}
// query sends the request and waits for a response.
func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
start := time.Now()
msg, err := req.Pack()
if err != nil {
return nil, err
}
if err := validateDNSSize(msg); err != nil {
return nil, err
}
reply, err := p.client.Query(ctx, &pb.DnsPacket{Msg: msg})
if err != nil {
// if not found message, return empty message with NXDomain code
if status.Code(err) == codes.NotFound {
m := new(dns.Msg).SetRcode(req, dns.RcodeNameError)
return m, nil
}
return nil, err
}
wire := reply.GetMsg()
if err := validateDNSSize(wire); err != nil {
return nil, err
}
ret := new(dns.Msg)
if err := ret.Unpack(wire); err != nil {
return nil, err
}
rc, ok := dns.RcodeToString[ret.Rcode]
if !ok {
rc = strconv.Itoa(ret.Rcode)
}
RequestCount.WithLabelValues(p.addr).Add(1)
RcodeCount.WithLabelValues(rc, p.addr).Add(1)
RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds())
return ret, nil
}
func validateDNSSize(data []byte) error {
l := len(data)
if l > maxDNSMessageBytes {
return fmt.Errorf("%w: %d bytes (limit %d)", ErrDNSMessageTooLarge, l, maxDNSMessageBytes)
}
return nil
}
package grpc
import (
"crypto/tls"
"fmt"
"path/filepath"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/parse"
pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
)
func init() { plugin.Register("grpc", setup) }
func setup(c *caddy.Controller) error {
g, err := parseGRPC(c)
if err != nil {
return plugin.Error("grpc", err)
}
if g.len() > max {
return plugin.Error("grpc", fmt.Errorf("more than %d TOs configured: %d", max, g.len()))
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
g.Next = next // Set the Next field, so the plugin chaining works.
return g
})
return nil
}
func parseGRPC(c *caddy.Controller) (*GRPC, error) {
var (
g *GRPC
err error
i int
)
for c.Next() {
if i > 0 {
return nil, plugin.ErrOnce
}
i++
g, err = parseStanza(c)
if err != nil {
return nil, err
}
}
return g, nil
}
func parseStanza(c *caddy.Controller) (*GRPC, error) {
g := newGRPC()
if !c.Args(&g.from) {
return g, c.ArgErr()
}
normalized := plugin.Host(g.from).NormalizeExact()
if len(normalized) == 0 {
return g, fmt.Errorf("unable to normalize '%s'", g.from)
}
g.from = normalized[0] // only the first is used.
to := c.RemainingArgs()
if len(to) == 0 {
return g, c.ArgErr()
}
toHosts, err := parse.HostPortOrFile(to...)
if err != nil {
return g, err
}
for c.NextBlock() {
if err := parseBlock(c, g); err != nil {
return g, err
}
}
if g.tlsServerName != "" {
if g.tlsConfig == nil {
g.tlsConfig = new(tls.Config)
}
g.tlsConfig.ServerName = g.tlsServerName
}
for _, host := range toHosts {
pr, err := newProxy(host, g.tlsConfig)
if err != nil {
return nil, err
}
g.proxies = append(g.proxies, pr)
}
return g, nil
}
func parseBlock(c *caddy.Controller, g *GRPC) error {
switch c.Val() {
case "except":
ignore := c.RemainingArgs()
if len(ignore) == 0 {
return c.ArgErr()
}
for i := range ignore {
g.ignored = append(g.ignored, plugin.Host(ignore[i]).NormalizeExact()...)
}
case "tls":
args := c.RemainingArgs()
if len(args) > 3 {
return c.ArgErr()
}
for i := range args {
if !filepath.IsAbs(args[i]) && dnsserver.GetConfig(c).Root != "" {
args[i] = filepath.Join(dnsserver.GetConfig(c).Root, args[i])
}
}
tlsConfig, err := pkgtls.NewTLSConfigFromArgs(args...)
if err != nil {
return err
}
g.tlsConfig = tlsConfig
case "tls_servername":
if !c.NextArg() {
return c.ArgErr()
}
g.tlsServerName = c.Val()
case "policy":
if !c.NextArg() {
return c.ArgErr()
}
switch x := c.Val(); x {
case "random":
g.p = &random{}
case "round_robin":
g.p = &roundRobin{}
case "sequential":
g.p = &sequential{}
default:
return c.Errf("unknown policy '%s'", x)
}
case "fallthrough":
g.Fall.SetZonesFromArgs(c.RemainingArgs())
default:
if c.Val() != "}" {
return c.Errf("unknown property '%s'", c.Val())
}
}
return nil
}
const max = 15 // Maximum number of upstreams.
package header
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
// Header modifies flags of dns.MsgHdr in queries and / or responses
type Header struct {
QueryRules []Rule
ResponseRules []Rule
Next plugin.Handler
}
// ServeDNS implements the plugin.Handler interface.
func (h Header) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
applyRules(r, h.QueryRules)
wr := ResponseHeaderWriter{ResponseWriter: w, Rules: h.ResponseRules}
return plugin.NextOrFailure(h.Name(), h.Next, ctx, &wr, r)
}
// Name implements the plugin.Handler interface.
func (h Header) Name() string { return "header" }
package header
import (
"fmt"
"strings"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/miekg/dns"
)
// Supported flags
const (
authoritative = "aa"
recursionAvailable = "ra"
recursionDesired = "rd"
)
var log = clog.NewWithPlugin("header")
// ResponseHeaderWriter is a response writer that allows modifying dns.MsgHdr
type ResponseHeaderWriter struct {
dns.ResponseWriter
Rules []Rule
}
// WriteMsg implements the dns.ResponseWriter interface.
func (r *ResponseHeaderWriter) WriteMsg(res *dns.Msg) error {
applyRules(res, r.Rules)
return r.ResponseWriter.WriteMsg(res)
}
// Write implements the dns.ResponseWriter interface.
func (r *ResponseHeaderWriter) Write(buf []byte) (int, error) {
log.Warning("ResponseHeaderWriter called with Write: not ensuring headers")
n, err := r.ResponseWriter.Write(buf)
return n, err
}
// Rule is used to set/clear Flag in dns.MsgHdr
type Rule struct {
Flag string
State bool
}
func newRules(key string, args []string) ([]Rule, error) {
if key == "" {
return nil, fmt.Errorf("no flag action provided")
}
if len(args) < 1 {
return nil, fmt.Errorf("invalid length for flags, at least one should be provided")
}
var state bool
action := strings.ToLower(key)
switch action {
case "set":
state = true
case "clear":
state = false
default:
return nil, fmt.Errorf("unknown flag action=%s, should be set or clear", action)
}
rules := make([]Rule, 0, len(args))
for _, arg := range args {
flag := strings.ToLower(arg)
switch flag {
case authoritative:
case recursionAvailable:
case recursionDesired:
default:
return nil, fmt.Errorf("unknown/unsupported flag=%s", flag)
}
rule := Rule{Flag: flag, State: state}
rules = append(rules, rule)
}
return rules, nil
}
func applyRules(res *dns.Msg, rules []Rule) {
// handle all supported flags
for _, rule := range rules {
switch rule.Flag {
case authoritative:
res.Authoritative = rule.State
case recursionAvailable:
res.RecursionAvailable = rule.State
case recursionDesired:
res.RecursionDesired = rule.State
}
}
}
package header
import (
"fmt"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("header", setup) }
func setup(c *caddy.Controller) error {
queryRules, responseRules, err := parse(c)
if err != nil {
return plugin.Error("header", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Header{
QueryRules: queryRules,
ResponseRules: responseRules,
Next: next,
}
})
return nil
}
func parse(c *caddy.Controller) ([]Rule, []Rule, error) {
for c.Next() {
var queryRules []Rule
var responseRules []Rule
for c.NextBlock() {
selector := strings.ToLower(c.Val())
var action string
switch selector {
case "query", "response":
if c.NextArg() {
action = c.Val()
}
default:
return nil, nil, fmt.Errorf("setting up rule: invalid selector=%s should be query or response", selector)
}
args := c.RemainingArgs()
rules, err := newRules(action, args)
if err != nil {
return nil, nil, fmt.Errorf("setting up rule: %w", err)
}
if selector == "response" {
responseRules = append(responseRules, rules...)
} else {
queryRules = append(queryRules, rules...)
}
}
if len(queryRules) > 0 || len(responseRules) > 0 {
return queryRules, responseRules, nil
}
}
return nil, nil, c.ArgErr()
}
// Package health implements an HTTP handler that responds to health checks.
package health
import (
"context"
"io"
"net"
"net/http"
"net/url"
"time"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/reuseport"
)
var log = clog.NewWithPlugin("health")
// Health implements healthchecks by exporting a HTTP endpoint.
type health struct {
Addr string
lameduck time.Duration
healthURI *url.URL
ln net.Listener
nlSetup bool
mux *http.ServeMux
stop context.CancelFunc
}
func (h *health) OnStartup() error {
if h.Addr == "" {
h.Addr = ":8080"
}
var err error
h.healthURI, err = url.Parse("http://" + h.Addr)
if err != nil {
return err
}
h.healthURI.Path = "/health"
if h.healthURI.Host == "" {
// while we can listen on multiple network interfaces, we need to pick one to poll
h.healthURI.Host = "localhost"
}
ln, err := reuseport.Listen("tcp", h.Addr)
if err != nil {
return err
}
h.ln = ln
h.mux = http.NewServeMux()
h.nlSetup = true
h.mux.HandleFunc(h.healthURI.Path, func(w http.ResponseWriter, r *http.Request) {
// We're always healthy.
w.WriteHeader(http.StatusOK)
io.WriteString(w, http.StatusText(http.StatusOK))
})
ctx := context.Background()
ctx, h.stop = context.WithCancel(ctx)
go func() { http.Serve(h.ln, h.mux) }()
go func() { h.overloaded(ctx) }()
return nil
}
func (h *health) OnFinalShutdown() error {
if !h.nlSetup {
return nil
}
if h.lameduck > 0 {
log.Infof("Going into lameduck mode for %s", h.lameduck)
time.Sleep(h.lameduck)
}
h.stop()
h.ln.Close()
h.nlSetup = false
return nil
}
func (h *health) OnReload() error {
if !h.nlSetup {
return nil
}
h.stop()
h.ln.Close()
h.nlSetup = false
return nil
}
package health
import (
"context"
"net"
"net/http"
"time"
"github.com/coredns/coredns/plugin"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
// overloaded queries the health end point and updates a metrics showing how long it took.
func (h *health) overloaded(ctx context.Context) {
bypassProxy := &http.Transport{
Proxy: nil,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
timeout := 3 * time.Second
client := http.Client{
Timeout: timeout,
Transport: bypassProxy,
}
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, h.healthURI.String(), nil)
tick := time.NewTicker(1 * time.Second)
defer tick.Stop()
for {
select {
case <-tick.C:
start := time.Now()
resp, err := client.Do(req)
if err != nil && ctx.Err() == context.Canceled {
// request was cancelled by parent goroutine
return
}
if err != nil {
HealthDuration.Observe(time.Since(start).Seconds())
HealthFailures.Inc()
log.Warningf("Local health request to %q failed: %s", req.URL.String(), err)
continue
}
resp.Body.Close()
elapsed := time.Since(start)
HealthDuration.Observe(elapsed.Seconds())
if elapsed > time.Second { // 1s is pretty random, but a *local* scrape taking that long isn't good
log.Warningf("Local health request to %q took more than 1s: %s", req.URL.String(), elapsed)
}
case <-ctx.Done():
return
}
}
}
var (
// HealthDuration is the metric used for exporting how fast we can retrieve the /health endpoint.
HealthDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: plugin.Namespace,
Subsystem: "health",
Name: "request_duration_seconds",
Buckets: plugin.SlimTimeBuckets,
NativeHistogramBucketFactor: plugin.NativeHistogramBucketFactor,
Help: "Histogram of the time (in seconds) each request took.",
})
// HealthFailures is the metric used to count how many times the health request failed
HealthFailures = promauto.NewCounter(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "health",
Name: "request_failures_total",
Help: "The number of times the health check failed.",
})
)
package health
import (
"fmt"
"net"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("health", setup) }
func setup(c *caddy.Controller) error {
addr, lame, err := parse(c)
if err != nil {
return plugin.Error("health", err)
}
h := &health{Addr: addr, lameduck: lame}
c.OnStartup(h.OnStartup)
c.OnRestart(h.OnReload)
c.OnFinalShutdown(h.OnFinalShutdown)
c.OnRestartFailed(h.OnStartup)
// Don't do AddPlugin, as health is not *really* a plugin just a separate webserver running.
return nil
}
func parse(c *caddy.Controller) (string, time.Duration, error) {
addr := ""
dur := time.Duration(0)
for c.Next() {
args := c.RemainingArgs()
switch len(args) {
case 0:
case 1:
addr = args[0]
if _, _, e := net.SplitHostPort(addr); e != nil {
return "", 0, e
}
default:
return "", 0, c.ArgErr()
}
for c.NextBlock() {
switch c.Val() {
case "lameduck":
args := c.RemainingArgs()
if len(args) != 1 {
return "", 0, c.ArgErr()
}
l, err := time.ParseDuration(args[0])
if err != nil {
return "", 0, fmt.Errorf("unable to parse lameduck duration value: '%v' : %v", args[0], err)
}
dur = l
default:
return "", 0, c.ArgErr()
}
}
}
return addr, dur, nil
}
package hosts
import (
"context"
"net"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Hosts is the plugin handler
type Hosts struct {
Next plugin.Handler
*Hostsfile
Fall fall.F
}
// ServeDNS implements the plugin.Handle interface.
func (h Hosts) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.Name()
answers := []dns.RR{}
zone := plugin.Zones(h.Origins).Matches(qname)
if zone == "" {
// PTR zones don't need to be specified in Origins.
if state.QType() != dns.TypePTR {
// if this doesn't match we need to fall through regardless of h.Fallthrough
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
}
switch state.QType() {
case dns.TypePTR:
names := h.LookupStaticAddr(dnsutil.ExtractAddressFromReverse(qname))
if len(names) == 0 {
// If this doesn't match we need to fall through regardless of h.Fallthrough
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
answers = h.ptr(qname, h.options.ttl, names)
case dns.TypeA:
ips := h.LookupStaticHostV4(qname)
answers = a(qname, h.options.ttl, ips)
case dns.TypeAAAA:
ips := h.LookupStaticHostV6(qname)
answers = aaaa(qname, h.options.ttl, ips)
}
// Only on NXDOMAIN we will fallthrough.
if len(answers) == 0 && !h.otherRecordsExist(qname) {
if h.Fall.Through(qname) {
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
// We want to send an NXDOMAIN, but because of /etc/hosts' setup we don't have a SOA, so we make it SERVFAIL
// to at least give an answer back to signals we're having problems resolving this.
return dns.RcodeServerFailure, nil
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
m.Answer = answers
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
func (h Hosts) otherRecordsExist(qname string) bool {
if len(h.LookupStaticHostV4(qname)) > 0 {
return true
}
if len(h.LookupStaticHostV6(qname)) > 0 {
return true
}
return false
}
// Name implements the plugin.Handle interface.
func (h Hosts) Name() string { return "hosts" }
// a takes a slice of net.IPs and returns a slice of A RRs.
func a(zone string, ttl uint32, ips []net.IP) []dns.RR {
answers := make([]dns.RR, len(ips))
for i, ip := range ips {
r := new(dns.A)
r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}
r.A = ip
answers[i] = r
}
return answers
}
// aaaa takes a slice of net.IPs and returns a slice of AAAA RRs.
func aaaa(zone string, ttl uint32, ips []net.IP) []dns.RR {
answers := make([]dns.RR, len(ips))
for i, ip := range ips {
r := new(dns.AAAA)
r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}
r.AAAA = ip
answers[i] = r
}
return answers
}
// ptr takes a slice of host names and filters out the ones that aren't in Origins, if specified, and returns a slice of PTR RRs.
func (h *Hosts) ptr(zone string, ttl uint32, names []string) []dns.RR {
answers := make([]dns.RR, len(names))
for i, n := range names {
r := new(dns.PTR)
r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: ttl}
r.Ptr = dns.Fqdn(n)
answers[i] = r
}
return answers
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is a modified version of net/hosts.go from the golang repo
package hosts
import (
"bufio"
"bytes"
"io"
"net"
"os"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin"
)
// parseIP calls discards any v6 zone info, before calling net.ParseIP.
func parseIP(addr string) net.IP {
if i := strings.Index(addr, "%"); i >= 0 {
// discard ipv6 zone
addr = addr[0:i]
}
return net.ParseIP(addr)
}
type options struct {
// automatically generate IP to Hostname PTR entries
// for host entries we parse
autoReverse bool
// The TTL of the record we generate
ttl uint32
// The time between two reload of the configuration
reload time.Duration
}
func newOptions() *options {
return &options{
autoReverse: true,
ttl: 3600,
reload: 5 * time.Second,
}
}
// Map contains the IPv4/IPv6 and reverse mapping.
type Map struct {
// Key for the list of literal IP addresses must be a FQDN lowercased host name.
name4 map[string][]net.IP
name6 map[string][]net.IP
// Key for the list of host names must be a literal IP address
// including IPv6 address without zone identifier.
// We don't support old-classful IP address notation.
addr map[string][]string
}
func newMap() *Map {
return &Map{
name4: make(map[string][]net.IP),
name6: make(map[string][]net.IP),
addr: make(map[string][]string),
}
}
// Len returns the total number of addresses in the hostmap, this includes V4/V6 and any reverse addresses.
func (h *Map) Len() int {
l := 0
for _, v4 := range h.name4 {
l += len(v4)
}
for _, v6 := range h.name6 {
l += len(v6)
}
for _, a := range h.addr {
l += len(a)
}
return l
}
// Hostsfile contains known host entries.
type Hostsfile struct {
sync.RWMutex
// list of zones we are authoritative for
Origins []string
// hosts maps for lookups
hmap *Map
// inline saves the hosts file that is inlined in a Corefile.
inline *Map
// path to the hosts file
path string
// mtime and size are only read and modified by a single goroutine
mtime time.Time
size int64
options *options
}
// readHosts determines if the cached data needs to be updated based on the size and modification time of the hostsfile.
func (h *Hostsfile) readHosts() {
file, err := os.Open(h.path)
if err != nil {
// We already log a warning if the file doesn't exist or can't be opened on setup. No need to return the error here.
return
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return
}
h.RLock()
size := h.size
h.RUnlock()
if h.mtime.Equal(stat.ModTime()) && size == stat.Size() {
return
}
newMap := h.parse(file)
log.Debugf("Parsed hosts file into %d entries", newMap.Len())
h.Lock()
h.hmap = newMap
// Update the data cache.
h.mtime = stat.ModTime()
h.size = stat.Size()
hostsEntries.WithLabelValues(h.path).Set(float64(h.inline.Len() + h.hmap.Len()))
hostsReloadTime.Set(float64(stat.ModTime().UnixNano()) / 1e9)
h.Unlock()
}
func (h *Hostsfile) initInline(inline []string) {
if len(inline) == 0 {
return
}
h.inline = h.parse(strings.NewReader(strings.Join(inline, "\n")))
}
// Parse reads the hostsfile and populates the byName and addr maps.
func (h *Hostsfile) parse(r io.Reader) *Map {
hmap := newMap()
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Bytes()
if i := bytes.Index(line, []byte{'#'}); i >= 0 {
// Discard comments.
line = line[0:i]
}
f := bytes.Fields(line)
if len(f) < 2 {
continue
}
addr := parseIP(string(f[0]))
if addr == nil {
continue
}
var family int
if addr.To4() != nil {
family = 1
} else {
family = 2
}
for i := 1; i < len(f); i++ {
name := plugin.Name(string(f[i])).Normalize()
if plugin.Zones(h.Origins).Matches(name) == "" {
// name is not in Origins
continue
}
switch family {
case 1:
hmap.name4[name] = append(hmap.name4[name], addr)
case 2:
hmap.name6[name] = append(hmap.name6[name], addr)
default:
continue
}
if !h.options.autoReverse {
continue
}
hmap.addr[addr.String()] = append(hmap.addr[addr.String()], name)
}
}
return hmap
}
// lookupStaticHost looks up the IP addresses for the given host from the hosts file.
func (h *Hostsfile) lookupStaticHost(m map[string][]net.IP, host string) []net.IP {
h.RLock()
defer h.RUnlock()
if len(m) == 0 {
return nil
}
ips, ok := m[host]
if !ok {
return nil
}
ipsCp := make([]net.IP, len(ips))
copy(ipsCp, ips)
return ipsCp
}
// LookupStaticHostV4 looks up the IPv4 addresses for the given host from the hosts file.
func (h *Hostsfile) LookupStaticHostV4(host string) []net.IP {
host = strings.ToLower(host)
ip1 := h.lookupStaticHost(h.hmap.name4, host)
ip2 := h.lookupStaticHost(h.inline.name4, host)
return append(ip1, ip2...)
}
// LookupStaticHostV6 looks up the IPv6 addresses for the given host from the hosts file.
func (h *Hostsfile) LookupStaticHostV6(host string) []net.IP {
host = strings.ToLower(host)
ip1 := h.lookupStaticHost(h.hmap.name6, host)
ip2 := h.lookupStaticHost(h.inline.name6, host)
return append(ip1, ip2...)
}
// LookupStaticAddr looks up the hosts for the given address from the hosts file.
func (h *Hostsfile) LookupStaticAddr(addr string) []string {
addr = parseIP(addr).String()
if addr == "" {
return nil
}
h.RLock()
defer h.RUnlock()
hosts1 := h.hmap.addr[addr]
hosts2 := h.inline.addr[addr]
if len(hosts1) == 0 && len(hosts2) == 0 {
return nil
}
hostsCp := make([]string, len(hosts1)+len(hosts2))
copy(hostsCp, hosts1)
copy(hostsCp[len(hosts1):], hosts2)
return hostsCp
}
package hosts
import (
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
)
var log = clog.NewWithPlugin("hosts")
func init() { plugin.Register("hosts", setup) }
func periodicHostsUpdate(h *Hosts) chan bool {
parseChan := make(chan bool)
if h.options.reload == 0 {
return parseChan
}
go func() {
ticker := time.NewTicker(h.options.reload)
defer ticker.Stop()
for {
select {
case <-parseChan:
return
case <-ticker.C:
h.readHosts()
}
}
}()
return parseChan
}
func setup(c *caddy.Controller) error {
h, err := hostsParse(c)
if err != nil {
return plugin.Error("hosts", err)
}
parseChan := periodicHostsUpdate(&h)
c.OnStartup(func() error {
h.readHosts()
return nil
})
c.OnShutdown(func() error {
close(parseChan)
return nil
})
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
h.Next = next
return h
})
return nil
}
func hostsParse(c *caddy.Controller) (Hosts, error) {
config := dnsserver.GetConfig(c)
h := Hosts{
Hostsfile: &Hostsfile{
path: "/etc/hosts",
hmap: newMap(),
inline: newMap(),
options: newOptions(),
},
}
inline := []string{}
i := 0
for c.Next() {
if i > 0 {
return h, plugin.ErrOnce
}
i++
args := c.RemainingArgs()
if len(args) >= 1 {
h.path = args[0]
args = args[1:]
if !filepath.IsAbs(h.path) && config.Root != "" {
h.path = filepath.Join(config.Root, h.path)
}
s, err := os.Stat(h.path)
if err != nil {
if !os.IsNotExist(err) {
return h, c.Errf("unable to access hosts file '%s': %v", h.path, err)
}
log.Warningf("File does not exist: %s", h.path)
}
if s != nil && s.IsDir() {
log.Warningf("Hosts file %q is a directory", h.path)
}
}
h.Origins = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys)
for c.NextBlock() {
switch c.Val() {
case "fallthrough":
h.Fall.SetZonesFromArgs(c.RemainingArgs())
case "no_reverse":
h.options.autoReverse = false
case "ttl":
remaining := c.RemainingArgs()
if len(remaining) < 1 {
return h, c.Errf("ttl needs a time in second")
}
ttl, err := strconv.Atoi(remaining[0])
if err != nil {
return h, c.Errf("ttl needs a number of second")
}
if ttl <= 0 || ttl > 65535 {
return h, c.Errf("ttl provided is invalid")
}
h.options.ttl = uint32(ttl)
case "reload":
remaining := c.RemainingArgs()
if len(remaining) != 1 {
return h, c.Errf("reload needs a duration (zero seconds to disable)")
}
reload, err := time.ParseDuration(remaining[0])
if err != nil {
return h, c.Errf("invalid duration for reload '%s'", remaining[0])
}
if reload < 0 {
return h, c.Errf("invalid negative duration for reload '%s'", remaining[0])
}
h.options.reload = reload
default:
if len(h.Fall.Zones) == 0 {
line := strings.Join(append([]string{c.Val()}, c.RemainingArgs()...), " ")
inline = append(inline, line)
continue
}
return h, c.Errf("unknown property '%s'", c.Val())
}
}
}
h.initInline(inline)
return h, nil
}
package external
import (
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// serveApex serves request that hit the zone' apex. A reply is written back to the client.
func (e *External) serveApex(state request.Request) (int, error) {
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
switch state.QType() {
case dns.TypeSOA:
m.Answer = []dns.RR{e.soa(state)}
case dns.TypeNS:
m.Answer = []dns.RR{e.ns(state)}
addr := e.externalAddrFunc(state, e.headless)
for _, rr := range addr {
rr.Header().Ttl = e.ttl
rr.Header().Name = dnsutil.Join("ns1", e.apex, state.QName())
m.Extra = append(m.Extra, rr)
}
default:
m.Ns = []dns.RR{e.soa(state)}
}
state.W.WriteMsg(m)
return 0, nil
}
// serveSubApex serves requests that hit the zones fake 'dns' subdomain where our nameservers live.
func (e *External) serveSubApex(state request.Request) (int, error) {
base, _ := dnsutil.TrimZone(state.Name(), state.Zone)
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
// base is either dns. of ns1.dns (or another name), if it's longer return nxdomain
switch labels := dns.CountLabel(base); labels {
default:
m.SetRcode(m, dns.RcodeNameError)
m.Ns = []dns.RR{e.soa(state)}
state.W.WriteMsg(m)
return 0, nil
case 2:
nl, _ := dns.NextLabel(base, 0)
ns := base[:nl]
if ns != "ns1." {
// nxdomain
m.SetRcode(m, dns.RcodeNameError)
m.Ns = []dns.RR{e.soa(state)}
state.W.WriteMsg(m)
return 0, nil
}
addr := e.externalAddrFunc(state, e.headless)
for _, rr := range addr {
rr.Header().Ttl = e.ttl
rr.Header().Name = state.QName()
switch state.QType() {
case dns.TypeA:
if rr.Header().Rrtype == dns.TypeA {
m.Answer = append(m.Answer, rr)
}
case dns.TypeAAAA:
if rr.Header().Rrtype == dns.TypeAAAA {
m.Answer = append(m.Answer, rr)
}
}
}
if len(m.Answer) == 0 {
m.Ns = []dns.RR{e.soa(state)}
}
state.W.WriteMsg(m)
return 0, nil
case 1:
// nodata for the dns empty non-terminal
m.Ns = []dns.RR{e.soa(state)}
state.W.WriteMsg(m)
return 0, nil
}
}
func (e *External) soa(state request.Request) *dns.SOA {
header := dns.RR_Header{Name: state.Zone, Rrtype: dns.TypeSOA, Ttl: e.ttl, Class: dns.ClassINET}
soa := &dns.SOA{Hdr: header,
Mbox: dnsutil.Join(e.hostmaster, e.apex, state.Zone),
Ns: dnsutil.Join("ns1", e.apex, state.Zone),
Serial: e.externalSerialFunc(state.Zone),
Refresh: 7200,
Retry: 1800,
Expire: 86400,
Minttl: e.ttl,
}
return soa
}
func (e *External) ns(state request.Request) *dns.NS {
header := dns.RR_Header{Name: state.Zone, Rrtype: dns.TypeNS, Ttl: e.ttl, Class: dns.ClassINET}
ns := &dns.NS{Hdr: header, Ns: dnsutil.Join("ns1", e.apex, state.Zone)}
return ns
}
/*
Package external implements external names for kubernetes clusters.
This plugin only handles three qtypes (except the apex queries, because those are handled
differently). We support A, AAAA and SRV request, for all other types we return NODATA or
NXDOMAIN depending on the state of the cluster.
A plugin willing to provide these services must implement the Externaler interface, although it
likely only makes sense for the *kubernetes* plugin.
*/
package external
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Externaler defines the interface that a plugin should implement in order to be used by External.
type Externaler interface {
// External returns a slice of msg.Services that are looked up in the backend and match
// the request.
External(request.Request, bool) ([]msg.Service, int)
// ExternalAddress should return a string slice of addresses for the nameserving endpoint.
ExternalAddress(state request.Request, headless bool) []dns.RR
// ExternalServices returns all services in the given zone as a slice of msg.Service and if enabled, headless services as a map of services.
ExternalServices(zone string, headless bool) ([]msg.Service, map[string][]msg.Service)
// ExternalSerial gets the current serial.
ExternalSerial(string) uint32
}
// External serves records for External IPs and Loadbalance IPs of Services in Kubernetes clusters.
type External struct {
Next plugin.Handler
Zones []string
Fall fall.F
hostmaster string
apex string
ttl uint32
headless bool
upstream *upstream.Upstream
externalFunc func(request.Request, bool) ([]msg.Service, int)
externalAddrFunc func(request.Request, bool) []dns.RR
externalSerialFunc func(string) uint32
externalServicesFunc func(string, bool) ([]msg.Service, map[string][]msg.Service)
}
// New returns a new and initialized *External.
func New() *External {
e := &External{hostmaster: "hostmaster", ttl: 5, apex: "dns"}
return e
}
// ServeDNS implements the plugin.Handle interface.
func (e *External) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
zone := plugin.Zones(e.Zones).Matches(state.Name())
if zone == "" {
return plugin.NextOrFailure(e.Name(), e.Next, ctx, w, r)
}
state.Zone = zone
for _, z := range e.Zones {
// TODO(miek): save this in the External struct.
if state.Name() == z { // apex query
ret, err := e.serveApex(state)
return ret, err
}
if dns.IsSubDomain(e.apex+"."+z, state.Name()) {
// dns subdomain test for ns. and dns. queries
ret, err := e.serveSubApex(state)
return ret, err
}
}
svc, rcode := e.externalFunc(state, e.headless)
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative = true
if len(svc) == 0 {
if e.Fall.Through(state.Name()) && rcode == dns.RcodeNameError {
return plugin.NextOrFailure(e.Name(), e.Next, ctx, w, r)
}
m.Rcode = rcode
m.Ns = []dns.RR{e.soa(state)}
w.WriteMsg(m)
return 0, nil
}
switch state.QType() {
case dns.TypeA:
m.Answer, m.Truncated = e.a(ctx, svc, state)
case dns.TypeAAAA:
m.Answer, m.Truncated = e.aaaa(ctx, svc, state)
case dns.TypeSRV:
m.Answer, m.Extra = e.srv(ctx, svc, state)
case dns.TypePTR:
m.Answer = e.ptr(svc, state)
default:
m.Ns = []dns.RR{e.soa(state)}
}
// If we did have records, but queried for the wrong qtype return a nodata response.
if len(m.Answer) == 0 {
m.Ns = []dns.RR{e.soa(state)}
}
w.WriteMsg(m)
return 0, nil
}
// Name implements the Handler interface.
func (e *External) Name() string { return "k8s_external" }
package external
import (
"context"
"math"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
func (e *External) a(ctx context.Context, services []msg.Service, state request.Request) (records []dns.RR, truncated bool) {
dup := make(map[string]struct{})
for _, s := range services {
what, ip := s.HostType()
switch what {
case dns.TypeCNAME:
rr := s.NewCNAME(state.QName(), s.Host)
records = append(records, rr)
if resp, err := e.upstream.Lookup(ctx, state, dns.Fqdn(s.Host), dns.TypeA); err == nil {
records = append(records, resp.Answer...)
if resp.Truncated {
truncated = true
}
}
case dns.TypeA:
if _, ok := dup[s.Host]; !ok {
dup[s.Host] = struct{}{}
rr := s.NewA(state.QName(), ip)
rr.Hdr.Ttl = e.ttl
records = append(records, rr)
}
case dns.TypeAAAA:
// nada
}
}
return records, truncated
}
func (e *External) aaaa(ctx context.Context, services []msg.Service, state request.Request) (records []dns.RR, truncated bool) {
dup := make(map[string]struct{})
for _, s := range services {
what, ip := s.HostType()
switch what {
case dns.TypeCNAME:
rr := s.NewCNAME(state.QName(), s.Host)
records = append(records, rr)
if resp, err := e.upstream.Lookup(ctx, state, dns.Fqdn(s.Host), dns.TypeAAAA); err == nil {
records = append(records, resp.Answer...)
if resp.Truncated {
truncated = true
}
}
case dns.TypeA:
// nada
case dns.TypeAAAA:
if _, ok := dup[s.Host]; !ok {
dup[s.Host] = struct{}{}
rr := s.NewAAAA(state.QName(), ip)
rr.Hdr.Ttl = e.ttl
records = append(records, rr)
}
}
}
return records, truncated
}
func (e *External) ptr(services []msg.Service, state request.Request) (records []dns.RR) {
dup := make(map[string]struct{})
for _, s := range services {
if _, ok := dup[s.Host]; !ok {
dup[s.Host] = struct{}{}
rr := s.NewPTR(state.QName(), dnsutil.Join(s.Host, e.Zones[0]))
rr.Hdr.Ttl = e.ttl
records = append(records, rr)
}
}
return records
}
func (e *External) srv(ctx context.Context, services []msg.Service, state request.Request) (records, extra []dns.RR) {
dup := make(map[item]struct{})
// Looping twice to get the right weight vs priority. This might break because we may drop duplicate SRV records latter on.
w := make(map[int]int)
for _, s := range services {
weight := 100
if s.Weight != 0 {
weight = s.Weight
}
if _, ok := w[s.Priority]; !ok {
w[s.Priority] = weight
continue
}
w[s.Priority] += weight
}
for _, s := range services {
// Don't add the entry if the port is -1 (invalid). The kubernetes plugin uses port -1 when a service/endpoint
// does not have any declared ports.
if s.Port == -1 {
continue
}
w1 := 100.0 / float64(w[s.Priority])
if s.Weight == 0 {
w1 *= 100
} else {
w1 *= float64(s.Weight)
}
weight := uint16(math.Floor(w1))
// weight should be at least 1
if weight == 0 {
weight = 1
}
what, ip := s.HostType()
switch what {
case dns.TypeCNAME:
addr := dns.Fqdn(s.Host)
srv := s.NewSRV(state.QName(), weight)
if ok := isDuplicate(dup, srv.Target, "", srv.Port); !ok {
records = append(records, srv)
}
if ok := isDuplicate(dup, srv.Target, addr, 0); !ok {
if resp, err := e.upstream.Lookup(ctx, state, addr, dns.TypeA); err == nil {
extra = append(extra, resp.Answer...)
}
if resp, err := e.upstream.Lookup(ctx, state, addr, dns.TypeAAAA); err == nil {
extra = append(extra, resp.Answer...)
}
}
case dns.TypeA, dns.TypeAAAA:
addr := s.Host
s.Host = msg.Domain(s.Key)
srv := s.NewSRV(state.QName(), weight)
if ok := isDuplicate(dup, srv.Target, "", srv.Port); !ok {
records = append(records, srv)
}
if ok := isDuplicate(dup, srv.Target, addr, 0); !ok {
hdr := dns.RR_Header{Name: srv.Target, Rrtype: what, Class: dns.ClassINET, Ttl: e.ttl}
switch what {
case dns.TypeA:
extra = append(extra, &dns.A{Hdr: hdr, A: ip})
case dns.TypeAAAA:
extra = append(extra, &dns.AAAA{Hdr: hdr, AAAA: ip})
}
}
}
}
return records, extra
}
// not sure if this is even needed.
// item holds records.
type item struct {
name string // name of the record (either owner or something else unique).
port uint16 // port of the record (used for address records, A and AAAA).
addr string // address of the record (A and AAAA).
}
// isDuplicate uses m to see if the combo (name, addr, port) already exists. If it does
// not exist already IsDuplicate will also add the record to the map.
func isDuplicate(m map[item]struct{}, name, addr string, port uint16) bool {
if addr != "" {
_, ok := m[item{name, 0, addr}]
if !ok {
m[item{name, 0, addr}] = struct{}{}
}
return ok
}
_, ok := m[item{name, port, ""}]
if !ok {
m[item{name, port, ""}] = struct{}{}
}
return ok
}
package external
import (
"errors"
"strconv"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/upstream"
)
const pluginName = "k8s_external"
func init() { plugin.Register(pluginName, setup) }
func setup(c *caddy.Controller) error {
e, err := parse(c)
if err != nil {
return plugin.Error("k8s_external", err)
}
// Do this in OnStartup, so all plugins have been initialized.
c.OnStartup(func() error {
m := dnsserver.GetConfig(c).Handler("kubernetes")
if m == nil {
return plugin.Error(pluginName, errors.New("kubernetes plugin not loaded"))
}
x, ok := m.(Externaler)
if !ok {
return plugin.Error(pluginName, errors.New("kubernetes plugin does not implement the Externaler interface"))
}
e.externalFunc = x.External
e.externalAddrFunc = x.ExternalAddress
e.externalServicesFunc = x.ExternalServices
e.externalSerialFunc = x.ExternalSerial
return nil
})
e.upstream = upstream.New()
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
e.Next = next
return e
})
return nil
}
func parse(c *caddy.Controller) (*External, error) {
e := New()
for c.Next() { // external
e.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
for c.NextBlock() {
switch c.Val() {
case "ttl":
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
t, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
if t < 0 || t > 3600 {
return nil, c.Errf("ttl must be in range [0, 3600]: %d", t)
}
e.ttl = uint32(t)
case "apex":
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
e.apex = args[0]
case "headless":
e.headless = true
case "fallthrough":
e.Fall.SetZonesFromArgs(c.RemainingArgs())
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
}
return e, nil
}
package external
import (
"context"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/kubernetes"
"github.com/coredns/coredns/plugin/transfer"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Transfer implements transfer.Transferer
func (e *External) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) {
z := plugin.Zones(e.Zones).Matches(zone)
if z != zone {
return nil, transfer.ErrNotAuthoritative
}
ctx := context.Background()
ch := make(chan []dns.RR, 2)
if zone == "." {
zone = ""
}
state := request.Request{Zone: zone}
// SOA
soa := e.soa(state)
ch <- []dns.RR{soa}
if serial != 0 && serial >= soa.Serial {
close(ch)
return ch, nil
}
go func() {
// Add NS
nsName := "ns1." + e.apex + "." + zone
nsHdr := dns.RR_Header{Name: zone, Rrtype: dns.TypeNS, Ttl: e.ttl, Class: dns.ClassINET}
ch <- []dns.RR{&dns.NS{Hdr: nsHdr, Ns: nsName}}
// Add Nameserver A/AAAA records
nsRecords := e.externalAddrFunc(state, e.headless)
for i := range nsRecords {
// externalAddrFunc returns incomplete header names, correct here
nsRecords[i].Header().Name = nsName
nsRecords[i].Header().Ttl = e.ttl
ch <- []dns.RR{nsRecords[i]}
}
svcs, headlessSvcs := e.externalServicesFunc(zone, e.headless)
srvSeen := make(map[string]struct{})
for i := range svcs {
name := msg.Domain(svcs[i].Key)
if svcs[i].TargetStrip == 0 {
// Add Service A/AAAA records
s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: name}}}}
as, _ := e.a(ctx, []msg.Service{svcs[i]}, s)
if len(as) > 0 {
ch <- as
}
aaaas, _ := e.aaaa(ctx, []msg.Service{svcs[i]}, s)
if len(aaaas) > 0 {
ch <- aaaas
}
// Add bare SRV record, ensuring uniqueness
recs, _ := e.srv(ctx, []msg.Service{svcs[i]}, s)
for _, srv := range recs {
if !nameSeen(srvSeen, srv) {
ch <- []dns.RR{srv}
}
}
continue
}
// Add full SRV record, ensuring uniqueness
s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: name}}}}
recs, _ := e.srv(ctx, []msg.Service{svcs[i]}, s)
for _, srv := range recs {
if !nameSeen(srvSeen, srv) {
ch <- []dns.RR{srv}
}
}
}
for key, svcs := range headlessSvcs {
// we have to strip the leading key because it's either port.protocol or endpoint
name := msg.Domain(key[:strings.LastIndex(key, "/")])
switchKey := key[strings.LastIndex(key, "/")+1:]
switch switchKey {
case kubernetes.Endpoint:
// headless.namespace.example.com records
s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: name}}}}
as, _ := e.a(ctx, svcs, s)
if len(as) > 0 {
ch <- as
}
aaaas, _ := e.aaaa(ctx, svcs, s)
if len(aaaas) > 0 {
ch <- aaaas
}
// Add bare SRV record, ensuring uniqueness
recs, _ := e.srv(ctx, svcs, s)
ch <- recs
for _, srv := range recs {
ch <- []dns.RR{srv}
}
for i := range svcs {
// endpoint.headless.namespace.example.com record
s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: msg.Domain(svcs[i].Key)}}}}
as, _ := e.a(ctx, []msg.Service{svcs[i]}, s)
if len(as) > 0 {
ch <- as
}
aaaas, _ := e.aaaa(ctx, []msg.Service{svcs[i]}, s)
if len(aaaas) > 0 {
ch <- aaaas
}
// Add bare SRV record, ensuring uniqueness
recs, _ := e.srv(ctx, []msg.Service{svcs[i]}, s)
ch <- recs
for _, srv := range recs {
ch <- []dns.RR{srv}
}
}
case kubernetes.PortProtocol:
s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: name}}}}
recs, _ := e.srv(ctx, svcs, s)
ch <- recs
}
}
ch <- []dns.RR{soa}
close(ch)
}()
return ch, nil
}
func nameSeen(namesSeen map[string]struct{}, rr dns.RR) bool {
if _, duplicate := namesSeen[rr.Header().Name]; duplicate {
return true
}
namesSeen[rr.Header().Name] = struct{}{}
return false
}
package kubernetes
import (
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/kubernetes/object"
"github.com/coredns/coredns/request"
)
// AutoPath implements the AutoPathFunc call from the autopath plugin.
// It returns a per-query search path or nil indicating no searchpathing should happen.
func (k *Kubernetes) AutoPath(state request.Request) []string {
// Check if the query falls in a zone we are actually authoritative for and thus if we want autopath.
zone := plugin.Zones(k.Zones).Matches(state.Name())
if zone == "" {
return nil
}
// cluster.local {
// autopath @kubernetes
// kubernetes {
// pods verified #
// }
// }
// if pods != verified will cause panic and return SERVFAIL, expect worked as normal without autopath function
if !k.opts.initPodCache {
return nil
}
ip := state.IP()
pod := k.podWithIP(ip)
if pod == nil {
return nil
}
totalSize := 3 + len(k.autoPathSearch) + 1 // +1 for sentinel
search := make([]string, 0, totalSize)
if zone == "." {
search = append(search, pod.Namespace+".svc.", "svc.", ".")
} else {
search = append(search, pod.Namespace+".svc."+zone, "svc."+zone, zone)
}
search = append(search, k.autoPathSearch...)
search = append(search, "") // sentinel
return search
}
// podWithIP returns the api.Pod for source IP. It returns nil if nothing can be found.
func (k *Kubernetes) podWithIP(ip string) *object.Pod {
if k.podMode != podModeVerified {
return nil
}
ps := k.APIConn.PodIndex(ip)
if len(ps) == 0 {
return nil
}
return ps[0]
}
package kubernetes
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/kubernetes/object"
api "k8s.io/api/core/v1"
discovery "k8s.io/api/discovery/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/selection"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/cache"
mcs "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1"
mcsClientset "sigs.k8s.io/mcs-api/pkg/client/clientset/versioned/typed/apis/v1alpha1"
)
const (
podIPIndex = "PodIP"
svcNameNamespaceIndex = "ServiceNameNamespace"
svcIPIndex = "ServiceIP"
svcExtIPIndex = "ServiceExternalIP"
epNameNamespaceIndex = "EndpointNameNamespace"
epIPIndex = "EndpointsIP"
svcImportNameNamespaceIndex = "ServiceImportNameNamespace"
mcEpNameNamespaceIndex = "MultiClusterEndpointsImportNameNamespace"
)
type ModifiedMode int
const (
ModifiedInternal ModifiedMode = iota
ModifiedExternal
ModifiedMultiCluster
)
type dnsController interface {
ServiceList() []*object.Service
EndpointsList() []*object.Endpoints
ServiceImportList() []*object.ServiceImport
SvcIndex(string) []*object.Service
SvcIndexReverse(string) []*object.Service
SvcExtIndexReverse(string) []*object.Service
SvcImportIndex(string) []*object.ServiceImport
PodIndex(string) []*object.Pod
EpIndex(string) []*object.Endpoints
EpIndexReverse(string) []*object.Endpoints
McEpIndex(string) []*object.MultiClusterEndpoints
GetNodeByName(context.Context, string) (*api.Node, error)
GetNamespaceByName(string) (*object.Namespace, error)
Run()
HasSynced() bool
Stop() error
// Modified returns the timestamp of the most recent changes to services.
Modified(ModifiedMode) int64
}
type dnsControl struct {
// modified tracks timestamp of the most recent changes
// It needs to be first because it is guaranteed to be 8-byte
// aligned ( we use sync.LoadAtomic with this )
modified int64
// multiClusterModified tracks timestamp of the most recent changes to
// multi cluster services
multiClusterModified int64
// extModified tracks timestamp of the most recent changes to
// services with external facing IP addresses
extModified int64
client kubernetes.Interface
mcsClient mcsClientset.MulticlusterV1alpha1Interface
selector labels.Selector
namespaceSelector labels.Selector
svcController cache.Controller
podController cache.Controller
epController cache.Controller
nsController cache.Controller
svcImportController cache.Controller
mcEpController cache.Controller
svcLister cache.Indexer
podLister cache.Indexer
epLister cache.Indexer
nsLister cache.Store
svcImportLister cache.Indexer
mcEpLister cache.Indexer
// stopLock is used to enforce only a single call to Stop is active.
// Needed because we allow stopping through an http endpoint and
// allowing concurrent stoppers leads to stack traces.
stopLock sync.Mutex
shutdown bool
stopCh chan struct{}
zones []string
endpointNameMode bool
multiclusterZones []string
}
type dnsControlOpts struct {
initPodCache bool
initEndpointsCache bool
ignoreEmptyService bool
// Label handling.
labelSelector *meta.LabelSelector
selector labels.Selector
namespaceLabelSelector *meta.LabelSelector
namespaceSelector labels.Selector
zones []string
endpointNameMode bool
multiclusterZones []string
}
// newdnsController creates a controller for CoreDNS.
func newdnsController(ctx context.Context, kubeClient kubernetes.Interface, mcsClient mcsClientset.MulticlusterV1alpha1Interface, opts dnsControlOpts) *dnsControl {
dns := dnsControl{
client: kubeClient,
mcsClient: mcsClient,
selector: opts.selector,
namespaceSelector: opts.namespaceSelector,
stopCh: make(chan struct{}),
zones: opts.zones,
endpointNameMode: opts.endpointNameMode,
multiclusterZones: opts.multiclusterZones,
}
dns.svcLister, dns.svcController = object.NewIndexerInformer(
&cache.ListWatch{
ListFunc: serviceListFunc(ctx, dns.client, api.NamespaceAll, dns.selector),
WatchFunc: serviceWatchFunc(ctx, dns.client, api.NamespaceAll, dns.selector),
},
&api.Service{},
cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
cache.Indexers{svcNameNamespaceIndex: svcNameNamespaceIndexFunc, svcIPIndex: svcIPIndexFunc, svcExtIPIndex: svcExtIPIndexFunc},
object.DefaultProcessor(object.ToService, nil),
)
podLister, podController := object.NewIndexerInformer(
&cache.ListWatch{
ListFunc: podListFunc(ctx, dns.client, api.NamespaceAll, dns.selector),
WatchFunc: podWatchFunc(ctx, dns.client, api.NamespaceAll, dns.selector),
},
&api.Pod{},
cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
cache.Indexers{podIPIndex: podIPIndexFunc},
object.DefaultProcessor(object.ToPod, nil),
)
dns.podLister = podLister
if opts.initPodCache {
dns.podController = podController
}
epLister, epController := object.NewIndexerInformer(
&cache.ListWatch{
ListFunc: endpointSliceListFunc(ctx, dns.client, api.NamespaceAll, dns.selector),
WatchFunc: endpointSliceWatchFunc(ctx, dns.client, api.NamespaceAll, dns.selector),
},
&discovery.EndpointSlice{},
cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
cache.Indexers{epNameNamespaceIndex: epNameNamespaceIndexFunc, epIPIndex: epIPIndexFunc},
object.DefaultProcessor(object.EndpointSliceToEndpoints, dns.EndpointSliceLatencyRecorder()),
)
dns.epLister = epLister
if opts.initEndpointsCache {
dns.epController = epController
}
dns.nsLister, dns.nsController = object.NewIndexerInformer(
&cache.ListWatch{
ListFunc: namespaceListFunc(ctx, dns.client, dns.namespaceSelector),
WatchFunc: namespaceWatchFunc(ctx, dns.client, dns.namespaceSelector),
},
&api.Namespace{},
cache.ResourceEventHandlerFuncs{},
cache.Indexers{},
object.DefaultProcessor(object.ToNamespace, nil),
)
if len(opts.multiclusterZones) > 0 {
mcsEpReq, _ := labels.NewRequirement(mcs.LabelServiceName, selection.Exists, []string{})
mcsEpSelector := dns.selector
if mcsEpSelector == nil {
mcsEpSelector = labels.NewSelector()
}
mcsEpSelector = mcsEpSelector.Add(*mcsEpReq)
dns.mcEpLister, dns.mcEpController = object.NewIndexerInformer(
&cache.ListWatch{
ListFunc: endpointSliceListFunc(ctx, dns.client, api.NamespaceAll, mcsEpSelector),
WatchFunc: endpointSliceWatchFunc(ctx, dns.client, api.NamespaceAll, mcsEpSelector),
},
&discovery.EndpointSlice{},
cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
cache.Indexers{mcEpNameNamespaceIndex: mcEpNameNamespaceIndexFunc},
object.DefaultProcessor(object.EndpointSliceToMultiClusterEndpoints, dns.EndpointSliceLatencyRecorder()),
)
dns.svcImportLister, dns.svcImportController = object.NewIndexerInformer(
&cache.ListWatch{
ListFunc: serviceImportListFunc(ctx, dns.mcsClient, api.NamespaceAll, dns.namespaceSelector),
WatchFunc: serviceImportWatchFunc(ctx, dns.mcsClient, api.NamespaceAll, dns.namespaceSelector),
},
&mcs.ServiceImport{},
cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete},
cache.Indexers{svcImportNameNamespaceIndex: svcImportNameNamespaceIndexFunc},
object.DefaultProcessor(object.ToServiceImport, nil),
)
}
return &dns
}
func (dns *dnsControl) EndpointsLatencyRecorder() *object.EndpointLatencyRecorder {
return &object.EndpointLatencyRecorder{
ServiceFunc: func(o meta.Object) []*object.Service {
return dns.SvcIndex(object.ServiceKey(o.GetName(), o.GetNamespace()))
},
}
}
func (dns *dnsControl) EndpointSliceLatencyRecorder() *object.EndpointLatencyRecorder {
return &object.EndpointLatencyRecorder{
ServiceFunc: func(o meta.Object) []*object.Service {
return dns.SvcIndex(object.ServiceKey(o.GetLabels()[discovery.LabelServiceName], o.GetNamespace()))
},
}
}
func podIPIndexFunc(obj any) ([]string, error) {
p, ok := obj.(*object.Pod)
if !ok {
return nil, errObj
}
return []string{p.PodIP}, nil
}
func svcIPIndexFunc(obj any) ([]string, error) {
svc, ok := obj.(*object.Service)
if !ok {
return nil, errObj
}
idx := make([]string, len(svc.ClusterIPs))
copy(idx, svc.ClusterIPs)
return idx, nil
}
func svcExtIPIndexFunc(obj any) ([]string, error) {
svc, ok := obj.(*object.Service)
if !ok {
return nil, errObj
}
idx := make([]string, len(svc.ExternalIPs))
copy(idx, svc.ExternalIPs)
return idx, nil
}
func svcNameNamespaceIndexFunc(obj any) ([]string, error) {
s, ok := obj.(*object.Service)
if !ok {
return nil, errObj
}
return []string{s.Index}, nil
}
func epNameNamespaceIndexFunc(obj any) ([]string, error) {
s, ok := obj.(*object.Endpoints)
if !ok {
return nil, errObj
}
return []string{s.Index}, nil
}
func epIPIndexFunc(obj any) ([]string, error) {
ep, ok := obj.(*object.Endpoints)
if !ok {
return nil, errObj
}
return ep.IndexIP, nil
}
func svcImportNameNamespaceIndexFunc(obj any) ([]string, error) {
s, ok := obj.(*object.ServiceImport)
if !ok {
return nil, errObj
}
return []string{s.Index}, nil
}
func mcEpNameNamespaceIndexFunc(obj any) ([]string, error) {
mcEp, ok := obj.(*object.MultiClusterEndpoints)
if !ok {
return nil, errObj
}
return []string{mcEp.Index}, nil
}
func serviceListFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) {
return func(opts meta.ListOptions) (runtime.Object, error) {
if s != nil {
opts.LabelSelector = s.String()
}
return c.CoreV1().Services(ns).List(ctx, opts)
}
}
func podListFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) {
return func(opts meta.ListOptions) (runtime.Object, error) {
if s != nil {
opts.LabelSelector = s.String()
}
if len(opts.FieldSelector) > 0 {
opts.FieldSelector = opts.FieldSelector + ","
}
opts.FieldSelector = opts.FieldSelector + "status.phase!=Succeeded,status.phase!=Failed,status.phase!=Unknown"
return c.CoreV1().Pods(ns).List(ctx, opts)
}
}
func endpointSliceListFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) {
return func(opts meta.ListOptions) (runtime.Object, error) {
if s != nil {
opts.LabelSelector = s.String()
}
return c.DiscoveryV1().EndpointSlices(ns).List(ctx, opts)
}
}
func namespaceListFunc(ctx context.Context, c kubernetes.Interface, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) {
return func(opts meta.ListOptions) (runtime.Object, error) {
if s != nil {
opts.LabelSelector = s.String()
}
return c.CoreV1().Namespaces().List(ctx, opts)
}
}
func serviceImportListFunc(ctx context.Context, c mcsClientset.MulticlusterV1alpha1Interface, ns string, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) {
return func(opts meta.ListOptions) (runtime.Object, error) {
if s != nil {
opts.LabelSelector = s.String()
}
return c.ServiceImports(ns).List(ctx, opts)
}
}
func serviceWatchFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) {
return func(options meta.ListOptions) (watch.Interface, error) {
if s != nil {
options.LabelSelector = s.String()
}
return c.CoreV1().Services(ns).Watch(ctx, options)
}
}
func podWatchFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) {
return func(options meta.ListOptions) (watch.Interface, error) {
if s != nil {
options.LabelSelector = s.String()
}
if len(options.FieldSelector) > 0 {
options.FieldSelector = options.FieldSelector + ","
}
options.FieldSelector = options.FieldSelector + "status.phase!=Succeeded,status.phase!=Failed,status.phase!=Unknown"
return c.CoreV1().Pods(ns).Watch(ctx, options)
}
}
func endpointSliceWatchFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) {
return func(options meta.ListOptions) (watch.Interface, error) {
if s != nil {
options.LabelSelector = s.String()
}
return c.DiscoveryV1().EndpointSlices(ns).Watch(ctx, options)
}
}
func namespaceWatchFunc(ctx context.Context, c kubernetes.Interface, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) {
return func(options meta.ListOptions) (watch.Interface, error) {
if s != nil {
options.LabelSelector = s.String()
}
return c.CoreV1().Namespaces().Watch(ctx, options)
}
}
func serviceImportWatchFunc(ctx context.Context, c mcsClientset.MulticlusterV1alpha1Interface, ns string, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) {
return func(options meta.ListOptions) (watch.Interface, error) {
if s != nil {
options.LabelSelector = s.String()
}
return c.ServiceImports(ns).Watch(ctx, options)
}
}
// Stop stops the controller.
func (dns *dnsControl) Stop() error {
dns.stopLock.Lock()
defer dns.stopLock.Unlock()
// Only try draining the workqueue if we haven't already.
if !dns.shutdown {
close(dns.stopCh)
dns.shutdown = true
return nil
}
return fmt.Errorf("shutdown already in progress")
}
// Run starts the controller.
func (dns *dnsControl) Run() {
go dns.svcController.Run(dns.stopCh)
if dns.epController != nil {
go func() {
dns.epController.Run(dns.stopCh)
}()
}
if dns.podController != nil {
go dns.podController.Run(dns.stopCh)
}
go dns.nsController.Run(dns.stopCh)
if dns.svcImportController != nil {
go dns.svcImportController.Run(dns.stopCh)
}
if dns.mcEpController != nil {
go dns.mcEpController.Run(dns.stopCh)
}
<-dns.stopCh
}
// HasSynced calls on all controllers.
func (dns *dnsControl) HasSynced() bool {
a := dns.svcController.HasSynced()
b := true
if dns.epController != nil {
b = dns.epController.HasSynced()
}
c := true
if dns.podController != nil {
c = dns.podController.HasSynced()
}
d := dns.nsController.HasSynced()
e := true
if dns.svcImportController != nil {
e = dns.svcImportController.HasSynced()
}
f := true
if dns.mcEpController != nil {
f = dns.mcEpController.HasSynced()
}
return a && b && c && d && e && f
}
func (dns *dnsControl) ServiceList() (svcs []*object.Service) {
os := dns.svcLister.List()
for _, o := range os {
s, ok := o.(*object.Service)
if !ok {
continue
}
svcs = append(svcs, s)
}
return svcs
}
func (dns *dnsControl) ServiceImportList() (svcs []*object.ServiceImport) {
os := dns.svcImportLister.List()
for _, o := range os {
s, ok := o.(*object.ServiceImport)
if !ok {
continue
}
svcs = append(svcs, s)
}
return svcs
}
func (dns *dnsControl) EndpointsList() (eps []*object.Endpoints) {
os := dns.epLister.List()
for _, o := range os {
ep, ok := o.(*object.Endpoints)
if !ok {
continue
}
eps = append(eps, ep)
}
return eps
}
func (dns *dnsControl) PodIndex(ip string) (pods []*object.Pod) {
os, err := dns.podLister.ByIndex(podIPIndex, ip)
if err != nil {
return nil
}
for _, o := range os {
p, ok := o.(*object.Pod)
if !ok {
continue
}
pods = append(pods, p)
}
return pods
}
func (dns *dnsControl) SvcIndex(idx string) (svcs []*object.Service) {
os, err := dns.svcLister.ByIndex(svcNameNamespaceIndex, idx)
if err != nil {
return nil
}
for _, o := range os {
s, ok := o.(*object.Service)
if !ok {
continue
}
svcs = append(svcs, s)
}
return svcs
}
func (dns *dnsControl) SvcIndexReverse(ip string) (svcs []*object.Service) {
os, err := dns.svcLister.ByIndex(svcIPIndex, ip)
if err != nil {
return nil
}
for _, o := range os {
s, ok := o.(*object.Service)
if !ok {
continue
}
svcs = append(svcs, s)
}
return svcs
}
func (dns *dnsControl) SvcExtIndexReverse(ip string) (svcs []*object.Service) {
os, err := dns.svcLister.ByIndex(svcExtIPIndex, ip)
if err != nil {
return nil
}
for _, o := range os {
s, ok := o.(*object.Service)
if !ok {
continue
}
svcs = append(svcs, s)
}
return svcs
}
func (dns *dnsControl) SvcImportIndex(idx string) (svcs []*object.ServiceImport) {
os, err := dns.svcImportLister.ByIndex(svcImportNameNamespaceIndex, idx)
if err != nil {
return nil
}
for _, o := range os {
s, ok := o.(*object.ServiceImport)
if !ok {
continue
}
svcs = append(svcs, s)
}
return svcs
}
func (dns *dnsControl) EpIndex(idx string) (ep []*object.Endpoints) {
os, err := dns.epLister.ByIndex(epNameNamespaceIndex, idx)
if err != nil {
return nil
}
for _, o := range os {
e, ok := o.(*object.Endpoints)
if !ok {
continue
}
ep = append(ep, e)
}
return ep
}
func (dns *dnsControl) EpIndexReverse(ip string) (ep []*object.Endpoints) {
os, err := dns.epLister.ByIndex(epIPIndex, ip)
if err != nil {
return nil
}
for _, o := range os {
e, ok := o.(*object.Endpoints)
if !ok {
continue
}
ep = append(ep, e)
}
return ep
}
func (dns *dnsControl) McEpIndex(idx string) (ep []*object.MultiClusterEndpoints) {
os, err := dns.mcEpLister.ByIndex(mcEpNameNamespaceIndex, idx)
if err != nil {
return nil
}
for _, o := range os {
e, ok := o.(*object.MultiClusterEndpoints)
if !ok {
continue
}
ep = append(ep, e)
}
return ep
}
// GetNodeByName return the node by name. If nothing is found an error is
// returned. This query causes a round trip to the k8s API server, so use
// sparingly. Currently, this is only used for Federation.
func (dns *dnsControl) GetNodeByName(ctx context.Context, name string) (*api.Node, error) {
v1node, err := dns.client.CoreV1().Nodes().Get(ctx, name, meta.GetOptions{})
return v1node, err
}
// GetNamespaceByName returns the namespace by name. If nothing is found an error is returned.
func (dns *dnsControl) GetNamespaceByName(name string) (*object.Namespace, error) {
o, exists, err := dns.nsLister.GetByKey(name)
if err != nil {
return nil, err
}
if !exists {
return nil, fmt.Errorf("namespace not found")
}
ns, ok := o.(*object.Namespace)
if !ok {
return nil, fmt.Errorf("found key but not namespace")
}
return ns, nil
}
func (dns *dnsControl) Add(obj any) { dns.updateModified() }
func (dns *dnsControl) Delete(obj any) { dns.updateModified() }
func (dns *dnsControl) Update(oldObj, newObj any) { dns.detectChanges(oldObj, newObj) }
// detectChanges detects changes in objects, and updates the modified timestamp
func (dns *dnsControl) detectChanges(oldObj, newObj any) {
// If both objects have the same resource version, they are identical.
if newObj != nil && oldObj != nil && (oldObj.(meta.Object).GetResourceVersion() == newObj.(meta.Object).GetResourceVersion()) {
return
}
obj := newObj
if obj == nil {
obj = oldObj
}
switch ob := obj.(type) {
case *object.Service:
imod, emod := serviceModified(oldObj, newObj)
if imod {
dns.updateModified()
}
if emod {
dns.updateExtModified()
}
case *object.ServiceImport:
if !serviceImportEquivalent(oldObj, newObj) {
dns.updateMultiClusterModified()
}
case *object.Pod:
dns.updateModified()
case *object.Endpoints:
if !endpointsEquivalent(oldObj.(*object.Endpoints), newObj.(*object.Endpoints)) {
dns.updateModified()
}
case *object.MultiClusterEndpoints:
if !multiclusterEndpointsEquivalent(oldObj.(*object.MultiClusterEndpoints), newObj.(*object.MultiClusterEndpoints)) {
dns.updateMultiClusterModified()
}
default:
log.Warningf("Updates for %T not supported.", ob)
}
}
// subsetsEquivalent checks if two endpoint subsets are significantly equivalent
// I.e. that they have the same ready addresses, host names, ports (including protocol
// and service names for SRV)
func subsetsEquivalent(sa, sb object.EndpointSubset) bool {
if len(sa.Addresses) != len(sb.Addresses) {
return false
}
if len(sa.Ports) != len(sb.Ports) {
return false
}
// in Addresses and Ports, we should be able to rely on
// these being sorted and able to be compared
// they are supposed to be in a canonical format
for addr, aaddr := range sa.Addresses {
baddr := sb.Addresses[addr]
if aaddr.IP != baddr.IP {
return false
}
if aaddr.Hostname != baddr.Hostname {
return false
}
}
for port, aport := range sa.Ports {
bport := sb.Ports[port]
if aport.Name != bport.Name {
return false
}
if aport.Port != bport.Port {
return false
}
if aport.Protocol != bport.Protocol {
return false
}
}
return true
}
// endpointsEquivalent checks if the update to an endpoint is something
// that matters to us or if they are effectively equivalent.
func endpointsEquivalent(a, b *object.Endpoints) bool {
if a == nil || b == nil {
return false
}
if len(a.Subsets) != len(b.Subsets) {
return false
}
// we should be able to rely on
// these being sorted and able to be compared
// they are supposed to be in a canonical format
for i, sa := range a.Subsets {
sb := b.Subsets[i]
if !subsetsEquivalent(sa, sb) {
return false
}
}
return true
}
// multiclusterEndpointsEquivalent checks if the update to an endpoint is something
// that matters to us or if they are effectively equivalent.
func multiclusterEndpointsEquivalent(a, b *object.MultiClusterEndpoints) bool {
if a == nil || b == nil {
return false
}
if !endpointsEquivalent(&a.Endpoints, &b.Endpoints) {
return false
}
if a.ClusterId != b.ClusterId {
return false
}
return true
}
// serviceModified checks the services passed for changes that result in changes
// to internal and or external records. It returns two booleans, one for internal
// record changes, and a second for external record changes
func serviceModified(oldObj, newObj any) (intSvc, extSvc bool) {
if oldObj != nil && newObj == nil {
// deleted service only modifies external zone records if it had external ips
return true, len(oldObj.(*object.Service).ExternalIPs) > 0
}
if oldObj == nil && newObj != nil {
// added service only modifies external zone records if it has external ips
return true, len(newObj.(*object.Service).ExternalIPs) > 0
}
newSvc := newObj.(*object.Service)
oldSvc := oldObj.(*object.Service)
// External IPs are mutable, affecting external zone records
if len(oldSvc.ExternalIPs) != len(newSvc.ExternalIPs) {
extSvc = true
} else {
for i := range oldSvc.ExternalIPs {
if oldSvc.ExternalIPs[i] != newSvc.ExternalIPs[i] {
extSvc = true
break
}
}
}
// ExternalName is mutable, affecting internal zone records
intSvc = oldSvc.ExternalName != newSvc.ExternalName
if intSvc && extSvc {
return intSvc, extSvc
}
// All Port fields are mutable, affecting both internal/external zone records
if len(oldSvc.Ports) != len(newSvc.Ports) {
return true, true
}
for i := range oldSvc.Ports {
if oldSvc.Ports[i].Name != newSvc.Ports[i].Name {
return true, true
}
if oldSvc.Ports[i].Port != newSvc.Ports[i].Port {
return true, true
}
if oldSvc.Ports[i].Protocol != newSvc.Ports[i].Protocol {
return true, true
}
}
return intSvc, extSvc
}
// serviceImportEquivalent checks if the update to a ServiceImport is something
// that matters to us or if they are effectively equivalent.
func serviceImportEquivalent(oldObj, newObj any) bool {
if oldObj != nil && newObj == nil {
return false
}
if oldObj == nil && newObj != nil {
return false
}
newSvc := newObj.(*object.ServiceImport)
oldSvc := oldObj.(*object.ServiceImport)
if oldSvc.Type != newSvc.Type {
return false
}
// All Port fields are mutable, affecting both internal/external zone records
if len(oldSvc.Ports) != len(newSvc.Ports) {
return false
}
for i := range oldSvc.Ports {
if oldSvc.Ports[i].Name != newSvc.Ports[i].Name {
return false
}
if oldSvc.Ports[i].Port != newSvc.Ports[i].Port {
return false
}
if oldSvc.Ports[i].Protocol != newSvc.Ports[i].Protocol {
return false
}
}
return true
}
func (dns *dnsControl) Modified(mode ModifiedMode) int64 {
switch mode {
case ModifiedInternal:
return atomic.LoadInt64(&dns.modified)
case ModifiedExternal:
return atomic.LoadInt64(&dns.extModified)
case ModifiedMultiCluster:
return atomic.LoadInt64(&dns.multiClusterModified)
}
return -1
}
// updateModified set dns.modified to the current time.
func (dns *dnsControl) updateModified() {
unix := time.Now().Unix()
atomic.StoreInt64(&dns.modified, unix)
}
// updateMultiClusterModified set dns.modified to the current time.
func (dns *dnsControl) updateMultiClusterModified() {
unix := time.Now().Unix()
atomic.StoreInt64(&dns.multiClusterModified, unix)
}
// updateExtModified set dns.extModified to the current time.
func (dns *dnsControl) updateExtModified() {
unix := time.Now().Unix()
atomic.StoreInt64(&dns.extModified, unix)
}
var errObj = errors.New("obj was not of the correct type")
package kubernetes
import (
"strings"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/kubernetes/object"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Those constants are used to distinguish between records in ExternalServices headless
// return values.
// They are always appendedn to key in a map which is
// either base service key eg. /com/example/namespace/service/endpoint or
// /com/example/namespace/service/_http/_tcp/port.protocol
// this will allow us to distinguish services in implementation of Transfer protocol
// see plugin/k8s_external/transfer.go
const (
Endpoint = "endpoint"
PortProtocol = "port.protocol"
)
// External implements the ExternalFunc call from the external plugin.
// It returns any services matching in the services' ExternalIPs and if enabled, headless endpoints..
func (k *Kubernetes) External(state request.Request, headless bool) ([]msg.Service, int) {
if state.QType() == dns.TypePTR {
ip := dnsutil.ExtractAddressFromReverse(state.Name())
if ip != "" {
svcs, err := k.ExternalReverse(ip)
if err != nil {
return nil, dns.RcodeNameError
}
return svcs, dns.RcodeSuccess
}
// for invalid reverse names, fall through to determine proper nxdomain/nodata response
}
base, _ := dnsutil.TrimZone(state.Name(), state.Zone)
segs := dns.SplitDomainName(base)
last := len(segs) - 1
if last < 0 {
return nil, dns.RcodeServerFailure
}
// We are dealing with a fairly normal domain name here, but we still need to have the service,
// namespace and if present, endpoint:
// service.namespace.<base> or
// endpoint.service.namespace.<base>
var port, protocol, endpoint string
namespace := segs[last]
if !k.namespaceExposed(namespace) {
return nil, dns.RcodeNameError
}
last--
if last < 0 {
return nil, dns.RcodeSuccess
}
service := segs[last]
last--
switch last {
case 0:
endpoint = stripUnderscore(segs[last])
last--
case 1:
protocol = stripUnderscore(segs[last])
port = stripUnderscore(segs[last-1])
last -= 2
}
if last != -1 {
// too long
return nil, dns.RcodeNameError
}
var (
endpointsList []*object.Endpoints
serviceList []*object.Service
)
idx := object.ServiceKey(service, namespace)
serviceList = k.APIConn.SvcIndex(idx)
services := []msg.Service{}
zonePath := msg.Path(state.Zone, coredns)
rcode := dns.RcodeNameError
for _, svc := range serviceList {
if namespace != svc.Namespace {
continue
}
if service != svc.Name {
continue
}
if headless && len(svc.ExternalIPs) == 0 && (svc.Headless() || endpoint != "") {
if endpointsList == nil {
endpointsList = k.APIConn.EpIndex(idx)
}
// Endpoint query or headless service
for _, ep := range endpointsList {
if object.EndpointsKey(svc.Name, svc.Namespace) != ep.Index {
continue
}
for _, eps := range ep.Subsets {
for _, addr := range eps.Addresses {
if endpoint != "" && !match(endpoint, endpointHostname(addr, k.endpointNameMode)) {
continue
}
for _, p := range eps.Ports {
if !(matchPortAndProtocol(port, p.Name, protocol, p.Protocol)) {
continue
}
rcode = dns.RcodeSuccess
s := msg.Service{Host: addr.IP, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join([]string{zonePath, svc.Namespace, svc.Name, endpointHostname(addr, k.endpointNameMode)}, "/")
services = append(services, s)
}
}
}
}
continue
}
for _, ip := range svc.ExternalIPs {
for _, p := range svc.Ports {
if !(matchPortAndProtocol(port, p.Name, protocol, string(p.Protocol))) {
continue
}
rcode = dns.RcodeSuccess
s := msg.Service{Host: ip, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join([]string{zonePath, svc.Namespace, svc.Name}, "/")
services = append(services, s)
}
}
}
if state.QType() == dns.TypePTR {
// if this was a PTR request, return empty service list, but retain rcode for proper nxdomain/nodata response
return nil, rcode
}
return services, rcode
}
// ExternalAddress returns the external service address(es) for the CoreDNS service.
func (k *Kubernetes) ExternalAddress(state request.Request, headless bool) []dns.RR {
// If CoreDNS is running inside the Kubernetes cluster: k.nsAddrs() will return the external IPs of the services
// targeting the CoreDNS Pod.
// If CoreDNS is running outside of the Kubernetes cluster: k.nsAddrs() will return the first non-loopback IP
// address seen on the local system it is running on. This could be the wrong answer if coredns is using the *bind*
// plugin to bind to a different IP address.
return k.nsAddrs(true, headless, state.Zone)
}
// ExternalServices returns all services with external IPs and if enabled headless services
func (k *Kubernetes) ExternalServices(zone string, headless bool) (services []msg.Service, headlessServices map[string][]msg.Service) {
zonePath := msg.Path(zone, coredns)
headlessServices = make(map[string][]msg.Service)
for _, svc := range k.APIConn.ServiceList() {
// Endpoints and headless services
if headless && len(svc.ExternalIPs) == 0 && svc.Headless() {
idx := object.ServiceKey(svc.Name, svc.Namespace)
endpointsList := k.APIConn.EpIndex(idx)
for _, ep := range endpointsList {
for _, eps := range ep.Subsets {
for _, addr := range eps.Addresses {
// we need to have some answers grouped together
// 1. for endpoint requests eg. endpoint-0.service.example.com - will always have one endpoint
// 2. for service requests eg. service.example.com - can have multiple endpoints
// 3. for port.protocol requests eg. _http._tcp.service.example.com - can have multiple endpoints
for _, p := range eps.Ports {
s := msg.Service{Host: addr.IP, Port: int(p.Port), TTL: k.ttl}
baseSvc := strings.Join([]string{zonePath, svc.Namespace, svc.Name}, "/")
s.Key = strings.Join([]string{baseSvc, endpointHostname(addr, k.endpointNameMode)}, "/")
headlessServices[strings.Join([]string{baseSvc, Endpoint}, "/")] = append(headlessServices[strings.Join([]string{baseSvc, Endpoint}, "/")], s)
// As per spec unnamed ports do not have a srv record
// https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records
if p.Name == "" {
continue
}
s.Host = msg.Domain(s.Key)
s.Key = strings.Join(append([]string{zonePath, svc.Namespace, svc.Name}, strings.ToLower("_"+p.Protocol), strings.ToLower("_"+p.Name)), "/")
headlessServices[strings.Join([]string{s.Key, PortProtocol}, "/")] = append(headlessServices[strings.Join([]string{s.Key, PortProtocol}, "/")], s)
}
}
}
}
continue
}
for _, ip := range svc.ExternalIPs {
for _, p := range svc.Ports {
s := msg.Service{Host: ip, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join([]string{zonePath, svc.Namespace, svc.Name}, "/")
services = append(services, s)
s.Key = strings.Join(append([]string{zonePath, svc.Namespace, svc.Name}, strings.ToLower("_"+string(p.Protocol)), strings.ToLower("_"+p.Name)), "/")
s.TargetStrip = 2
services = append(services, s)
}
}
}
return services, headlessServices
}
// ExternalSerial returns the serial of the external zone
func (k *Kubernetes) ExternalSerial(string) uint32 {
return uint32(k.APIConn.Modified(ModifiedExternal))
}
// ExternalReverse does a reverse lookup for the external IPs
func (k *Kubernetes) ExternalReverse(ip string) ([]msg.Service, error) {
records := k.serviceRecordForExternalIP(ip)
if len(records) == 0 {
return records, errNoItems
}
return records, nil
}
func (k *Kubernetes) serviceRecordForExternalIP(ip string) []msg.Service {
svcList := k.APIConn.SvcExtIndexReverse(ip)
svcLen := len(svcList)
svcs := make([]msg.Service, 0, svcLen)
for _, service := range svcList {
if len(k.Namespaces) > 0 && !k.namespaceExposed(service.Namespace) {
continue
}
domain := strings.Join([]string{service.Name, service.Namespace}, ".")
svcs = append(svcs, msg.Service{Host: domain, TTL: k.ttl})
}
return svcs
}
package kubernetes
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// ServeDNS implements the plugin.Handler interface.
func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.QName()
zone := plugin.Zones(k.Zones).Matches(qname)
if zone == "" {
return plugin.NextOrFailure(k.Name(), k.Next, ctx, w, r)
}
zone = qname[len(qname)-len(zone):] // maintain case of original query
state.Zone = zone
var (
records []dns.RR
extra []dns.RR
truncated bool
err error
)
switch state.QType() {
case dns.TypeA:
records, truncated, err = plugin.A(ctx, &k, zone, state, nil, plugin.Options{})
case dns.TypeAAAA:
records, truncated, err = plugin.AAAA(ctx, &k, zone, state, nil, plugin.Options{})
case dns.TypeTXT:
records, truncated, err = plugin.TXT(ctx, &k, zone, state, nil, plugin.Options{})
case dns.TypeCNAME:
records, err = plugin.CNAME(ctx, &k, zone, state, plugin.Options{})
case dns.TypePTR:
records, err = plugin.PTR(ctx, &k, zone, state, plugin.Options{})
case dns.TypeMX:
records, extra, err = plugin.MX(ctx, &k, zone, state, plugin.Options{})
case dns.TypeSRV:
records, extra, err = plugin.SRV(ctx, &k, zone, state, plugin.Options{})
case dns.TypeSOA:
if qname == zone {
records, err = plugin.SOA(ctx, &k, zone, state, plugin.Options{})
}
case dns.TypeAXFR, dns.TypeIXFR:
return dns.RcodeRefused, nil
case dns.TypeNS:
if state.Name() == zone {
records, extra, err = plugin.NS(ctx, &k, zone, state, plugin.Options{})
break
}
fallthrough
default:
// Do a fake A lookup, so we can distinguish between NODATA and NXDOMAIN
fake := state.NewWithQuestion(state.QName(), dns.TypeA)
fake.Zone = state.Zone
_, _, err = plugin.A(ctx, &k, zone, fake, nil, plugin.Options{})
}
if k.IsNameError(err) {
if k.Fall.Through(state.Name()) {
return plugin.NextOrFailure(k.Name(), k.Next, ctx, w, r)
}
if !k.APIConn.HasSynced() {
// If we haven't synchronized with the kubernetes cluster, return server failure
return plugin.BackendError(ctx, &k, zone, dns.RcodeServerFailure, state, nil /* err */, plugin.Options{})
}
return plugin.BackendError(ctx, &k, zone, dns.RcodeNameError, state, nil /* err */, plugin.Options{})
}
if err != nil {
return dns.RcodeServerFailure, err
}
if len(records) == 0 {
return plugin.BackendError(ctx, &k, zone, dns.RcodeSuccess, state, nil, plugin.Options{})
}
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = truncated
m.Authoritative = true
m.Answer = append(m.Answer, records...)
m.Extra = append(m.Extra, extra...)
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
// Name implements the Handler interface.
func (k Kubernetes) Name() string { return "kubernetes" }
// Package kubernetes provides the kubernetes backend.
package kubernetes
import (
"context"
"errors"
"fmt"
"net"
"runtime"
"strings"
"time"
"github.com/coredns/coredns/coremain"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/kubernetes/object"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
api "k8s.io/api/core/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
mcsClientset "sigs.k8s.io/mcs-api/pkg/client/clientset/versioned/typed/apis/v1alpha1"
)
// Kubernetes implements a plugin that connects to a Kubernetes cluster.
type Kubernetes struct {
Next plugin.Handler
Zones []string
Upstream Upstreamer
APIServerList []string
APICertAuth string
APIClientCert string
APIClientKey string
ClientConfig clientcmd.ClientConfig
APIConn dnsController
Namespaces map[string]struct{}
podMode string
endpointNameMode bool
Fall fall.F
ttl uint32
opts dnsControlOpts
primaryZoneIndex int
localIPs []net.IP
autoPathSearch []string // Local search path from /etc/resolv.conf. Needed for autopath.
startupTimeout time.Duration // startupTimeout set timeout of startup
}
// Upstreamer is used to resolve CNAME or other external targets
type Upstreamer interface {
Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error)
}
// New returns a initialized Kubernetes. It default interfaceAddrFunc to return 127.0.0.1. All other
// values default to their zero value, primaryZoneIndex will thus point to the first zone.
func New(zones []string) *Kubernetes {
k := new(Kubernetes)
k.Zones = zones
k.Namespaces = make(map[string]struct{})
k.podMode = podModeDisabled
k.ttl = defaultTTL
return k
}
const (
// podModeDisabled is the default value where pod requests are ignored
podModeDisabled = "disabled"
// podModeVerified is where Pod requests are answered only if they exist
podModeVerified = "verified"
// podModeInsecure is where pod requests are answered without verifying they exist
podModeInsecure = "insecure"
// DNSSchemaVersion is the schema version: https://github.com/kubernetes/dns/blob/master/docs/specification.md
DNSSchemaVersion = "1.1.0"
// Svc is the DNS schema for kubernetes services
Svc = "svc"
// Pod is the DNS schema for kubernetes pods
Pod = "pod"
// defaultTTL to apply to all answers.
defaultTTL = 5
)
var (
errNoItems = errors.New("no items found")
errNsNotExposed = errors.New("namespace is not exposed")
errInvalidRequest = errors.New("invalid query name")
)
// Services implements the ServiceBackend interface.
func (k *Kubernetes) Services(ctx context.Context, state request.Request, exact bool, opt plugin.Options) (svcs []msg.Service, err error) {
// We're looking again at types, which we've already done in ServeDNS, but there are some types k8s just can't answer.
switch state.QType() {
case dns.TypeTXT:
// 1 label + zone, label must be "dns-version".
t, _ := dnsutil.TrimZone(state.Name(), state.Zone)
// Hard code the only valid TXT - "dns-version.<zone>"
segs := dns.SplitDomainName(t)
if len(segs) == 1 && segs[0] == "dns-version" {
svc := msg.Service{Text: DNSSchemaVersion, TTL: 28800, Key: msg.Path(state.QName(), coredns)}
return []msg.Service{svc}, nil
}
// Check if we have an existing record for this query of another type
services, _ := k.Records(ctx, state, false)
if len(services) > 0 {
// If so we return an empty NOERROR
return nil, nil
}
// Return NXDOMAIN for no match
return nil, errNoItems
case dns.TypeNS:
// We can only get here if the qname equals the zone, see ServeDNS in handler.go.
nss := k.nsAddrs(false, false, state.Zone)
var svcs []msg.Service
for _, ns := range nss {
if ns.Header().Rrtype == dns.TypeA {
svcs = append(svcs, msg.Service{Host: ns.(*dns.A).A.String(), Key: msg.Path(ns.Header().Name, coredns), TTL: k.ttl})
continue
}
if ns.Header().Rrtype == dns.TypeAAAA {
svcs = append(svcs, msg.Service{Host: ns.(*dns.AAAA).AAAA.String(), Key: msg.Path(ns.Header().Name, coredns), TTL: k.ttl})
}
}
return svcs, nil
}
if isDefaultNS(state.Name(), state.Zone) {
nss := k.nsAddrs(false, false, state.Zone)
var svcs []msg.Service
for _, ns := range nss {
if ns.Header().Rrtype == dns.TypeA && state.QType() == dns.TypeA {
svcs = append(svcs, msg.Service{Host: ns.(*dns.A).A.String(), Key: msg.Path(state.QName(), coredns), TTL: k.ttl})
continue
}
if ns.Header().Rrtype == dns.TypeAAAA && state.QType() == dns.TypeAAAA {
svcs = append(svcs, msg.Service{Host: ns.(*dns.AAAA).AAAA.String(), Key: msg.Path(state.QName(), coredns), TTL: k.ttl})
}
}
return svcs, nil
}
s, e := k.Records(ctx, state, false)
// SRV for external services is not yet implemented, so remove those records.
if state.QType() != dns.TypeSRV {
return s, e
}
internal := []msg.Service{}
for _, svc := range s {
if t, _ := svc.HostType(); t != dns.TypeCNAME {
internal = append(internal, svc)
}
}
return internal, e
}
// primaryZone will return the first non-reverse zone being handled by this plugin
func (k *Kubernetes) primaryZone() string { return k.Zones[k.primaryZoneIndex] }
// Lookup implements the ServiceBackend interface.
func (k *Kubernetes) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) {
return k.Upstream.Lookup(ctx, state, name, typ)
}
// IsNameError implements the ServiceBackend interface.
func (k *Kubernetes) IsNameError(err error) bool {
return err == errNoItems || err == errNsNotExposed || err == errInvalidRequest
}
func (k *Kubernetes) getClientConfig() (*rest.Config, error) {
if k.ClientConfig != nil {
return k.ClientConfig.ClientConfig()
}
loadingRules := &clientcmd.ClientConfigLoadingRules{}
overrides := &clientcmd.ConfigOverrides{}
clusterinfo := clientcmdapi.Cluster{}
authinfo := clientcmdapi.AuthInfo{}
// Connect to API from in cluster
if len(k.APIServerList) == 0 {
cc, err := rest.InClusterConfig()
if err != nil {
return nil, err
}
cc.ContentType = "application/vnd.kubernetes.protobuf"
cc.UserAgent = fmt.Sprintf("%s/%s git_commit:%s (%s/%s/%s)", coremain.CoreName, coremain.CoreVersion, coremain.GitCommit, runtime.GOOS, runtime.GOARCH, runtime.Version())
return cc, err
}
// Connect to API from out of cluster
// Only the first one is used. We will deprecate multiple endpoints later.
clusterinfo.Server = k.APIServerList[0]
if len(k.APICertAuth) > 0 {
clusterinfo.CertificateAuthority = k.APICertAuth
}
if len(k.APIClientCert) > 0 {
authinfo.ClientCertificate = k.APIClientCert
}
if len(k.APIClientKey) > 0 {
authinfo.ClientKey = k.APIClientKey
}
overrides.ClusterInfo = clusterinfo
overrides.AuthInfo = authinfo
clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, overrides)
cc, err := clientConfig.ClientConfig()
if err != nil {
return nil, err
}
cc.ContentType = "application/vnd.kubernetes.protobuf"
cc.UserAgent = fmt.Sprintf("%s/%s git_commit:%s (%s/%s/%s)", coremain.CoreName, coremain.CoreVersion, coremain.GitCommit, runtime.GOOS, runtime.GOARCH, runtime.Version())
return cc, err
}
// InitKubeCache initializes a new Kubernetes cache.
func (k *Kubernetes) InitKubeCache(ctx context.Context) (onStart func() error, onShut func() error, err error) {
config, err := k.getClientConfig()
if err != nil {
return nil, nil, err
}
kubeClient, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, nil, fmt.Errorf("failed to create kubernetes notification controller: %q", err)
}
var mcsClient mcsClientset.MulticlusterV1alpha1Interface
if len(k.opts.multiclusterZones) > 0 {
mcsClient, err = mcsClientset.NewForConfig(config)
if err != nil {
return nil, nil, fmt.Errorf("failed to create kubernetes multicluster notification controller: %q", err)
}
}
if k.opts.labelSelector != nil {
var selector labels.Selector
selector, err = meta.LabelSelectorAsSelector(k.opts.labelSelector)
if err != nil {
return nil, nil, fmt.Errorf("unable to create Selector for LabelSelector '%s': %q", k.opts.labelSelector, err)
}
k.opts.selector = selector
}
if k.opts.namespaceLabelSelector != nil {
var selector labels.Selector
selector, err = meta.LabelSelectorAsSelector(k.opts.namespaceLabelSelector)
if err != nil {
return nil, nil, fmt.Errorf("unable to create Selector for LabelSelector '%s': %q", k.opts.namespaceLabelSelector, err)
}
k.opts.namespaceSelector = selector
}
k.opts.initPodCache = k.podMode == podModeVerified
k.opts.zones = k.Zones
k.opts.endpointNameMode = k.endpointNameMode
k.APIConn = newdnsController(ctx, kubeClient, mcsClient, k.opts)
onStart = func() error {
go func() {
k.APIConn.Run()
}()
timeoutTicker := time.NewTicker(k.startupTimeout)
defer timeoutTicker.Stop()
logDelay := 500 * time.Millisecond
logTicker := time.NewTicker(logDelay)
defer logTicker.Stop()
checkSyncTicker := time.NewTicker(100 * time.Millisecond)
defer checkSyncTicker.Stop()
for {
select {
case <-checkSyncTicker.C:
if k.APIConn.HasSynced() {
return nil
}
case <-logTicker.C:
log.Info("waiting for Kubernetes API before starting server")
case <-timeoutTicker.C:
log.Warning("starting server with unsynced Kubernetes API")
return nil
}
}
}
onShut = func() error {
return k.APIConn.Stop()
}
return onStart, onShut, err
}
// Records looks up services in kubernetes.
func (k *Kubernetes) Records(ctx context.Context, state request.Request, exact bool) ([]msg.Service, error) {
multicluster := k.isMultiClusterZone(state.Zone)
r, e := parseRequest(state.Name(), state.Zone, multicluster)
if e != nil {
return nil, e
}
if r.podOrSvc == "" {
return nil, nil
}
if dnsutil.IsReverse(state.Name()) > 0 {
return nil, errNoItems
}
if !k.namespaceExposed(r.namespace) {
return nil, errNsNotExposed
}
if r.podOrSvc == Pod {
pods, err := k.findPods(r, state.Zone)
return pods, err
}
var services []msg.Service
var err error
if !multicluster {
services, err = k.findServices(r, state.Zone)
} else {
services, err = k.findMultiClusterServices(r, state.Zone)
}
return services, err
}
func endpointHostname(addr object.EndpointAddress, endpointNameMode bool) string {
if addr.Hostname != "" {
return addr.Hostname
}
if endpointNameMode && addr.TargetRefName != "" {
return addr.TargetRefName
}
if strings.Contains(addr.IP, ".") {
return strings.ReplaceAll(addr.IP, ".", "-")
}
if strings.Contains(addr.IP, ":") {
ipv6Hostname := strings.ReplaceAll(addr.IP, ":", "-")
if strings.HasSuffix(ipv6Hostname, "-") {
return ipv6Hostname + "0"
}
return ipv6Hostname
}
return ""
}
func (k *Kubernetes) findPods(r recordRequest, zone string) (pods []msg.Service, err error) {
if k.podMode == podModeDisabled {
return nil, errNoItems
}
namespace := r.namespace
if !k.namespaceExposed(namespace) {
return nil, errNoItems
}
podname := r.service
// handle empty pod name
if podname == "" {
if k.namespaceExposed(namespace) {
// NODATA
return nil, nil
}
// NXDOMAIN
return nil, errNoItems
}
zonePath := msg.Path(zone, coredns)
var ip string
if strings.Count(podname, "-") == 3 && !strings.Contains(podname, "--") {
ip = strings.ReplaceAll(podname, "-", ".")
} else {
ip = strings.ReplaceAll(podname, "-", ":")
}
if k.podMode == podModeInsecure {
if !k.namespaceExposed(namespace) { // namespace does not exist
return nil, errNoItems
}
// If ip does not parse as an IP address, we return an error, otherwise we assume a CNAME and will try to resolve it in backend_lookup.go
if net.ParseIP(ip) == nil {
return nil, errNoItems
}
return []msg.Service{{Key: strings.Join([]string{zonePath, Pod, namespace, podname}, "/"), Host: ip, TTL: k.ttl}}, err
}
// PodModeVerified
err = errNoItems
for _, p := range k.APIConn.PodIndex(ip) {
// check for matching ip and namespace
if ip == p.PodIP && match(namespace, p.Namespace) {
s := msg.Service{Key: strings.Join([]string{zonePath, Pod, namespace, podname}, "/"), Host: ip, TTL: k.ttl}
pods = append(pods, s)
err = nil
}
}
return pods, err
}
// findServices returns the services matching r from the cache.
func (k *Kubernetes) findServices(r recordRequest, zone string) (services []msg.Service, err error) {
if !k.namespaceExposed(r.namespace) {
return nil, errNoItems
}
// handle empty service name
if r.service == "" {
if k.namespaceExposed(r.namespace) {
// NODATA
return nil, nil
}
// NXDOMAIN
return nil, errNoItems
}
err = errNoItems
var (
endpointsListFunc func() []*object.Endpoints
endpointsList []*object.Endpoints
serviceList []*object.Service
)
idx := object.ServiceKey(r.service, r.namespace)
serviceList = k.APIConn.SvcIndex(idx)
endpointsListFunc = func() []*object.Endpoints { return k.APIConn.EpIndex(idx) }
zonePath := msg.Path(zone, coredns)
for _, svc := range serviceList {
if !match(r.namespace, svc.Namespace) || !match(r.service, svc.Name) {
continue
}
// If "ignore empty_service" option is set and no endpoints exist, return NXDOMAIN unless
// it's a headless or externalName service (covered below).
if k.opts.ignoreEmptyService && svc.Type != api.ServiceTypeExternalName && !svc.Headless() { // serve NXDOMAIN if no endpoint is able to answer
podsCount := 0
for _, ep := range endpointsListFunc() {
for _, eps := range ep.Subsets {
podsCount += len(eps.Addresses)
}
}
if podsCount == 0 {
continue
}
}
// External service
if svc.Type == api.ServiceTypeExternalName {
// External services do not have endpoints, nor can we accept port/protocol pseudo subdomains in an SRV query, so skip this service if endpoint, port, or protocol is non-empty in the request
if r.endpoint != "" || r.port != "" || r.protocol != "" {
continue
}
s := msg.Service{Key: strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name}, "/"), Host: svc.ExternalName, TTL: k.ttl}
if t, _ := s.HostType(); t == dns.TypeCNAME {
s.Key = strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name}, "/")
services = append(services, s)
err = nil
}
continue
}
// Endpoint query or headless service
if svc.Headless() || r.endpoint != "" {
if endpointsList == nil {
endpointsList = endpointsListFunc()
}
for _, ep := range endpointsList {
if object.EndpointsKey(svc.Name, svc.Namespace) != ep.Index {
continue
}
for _, eps := range ep.Subsets {
for _, addr := range eps.Addresses {
// See comments in parse.go parseRequest about the endpoint handling.
if r.endpoint != "" {
if !match(r.endpoint, endpointHostname(addr, k.endpointNameMode)) {
continue
}
}
for _, p := range eps.Ports {
if !(matchPortAndProtocol(r.port, p.Name, r.protocol, p.Protocol)) {
continue
}
s := msg.Service{Host: addr.IP, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name, endpointHostname(addr, k.endpointNameMode)}, "/")
err = nil
services = append(services, s)
}
}
}
}
continue
}
// ClusterIP service
for _, p := range svc.Ports {
if !(matchPortAndProtocol(r.port, p.Name, r.protocol, string(p.Protocol))) {
continue
}
err = nil
for _, ip := range svc.ClusterIPs {
s := msg.Service{Host: ip, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name}, "/")
services = append(services, s)
}
}
}
return services, err
}
// findMultiClusterServices returns the multicluster services matching r from the cache.
func (k *Kubernetes) findMultiClusterServices(r recordRequest, zone string) (services []msg.Service, err error) {
if !k.namespaceExposed(r.namespace) {
return nil, errNoItems
}
// handle empty service name
if r.service == "" {
if k.namespaceExposed(r.namespace) {
// NODATA
return nil, nil
}
// NXDOMAIN
return nil, errNoItems
}
err = errNoItems
var (
endpointsListFunc func() []*object.MultiClusterEndpoints
endpointsList []*object.MultiClusterEndpoints
serviceList []*object.ServiceImport
)
idx := object.ServiceImportKey(r.service, r.namespace)
serviceList = k.APIConn.SvcImportIndex(idx)
endpointsListFunc = func() []*object.MultiClusterEndpoints { return k.APIConn.McEpIndex(idx) }
zonePath := msg.Path(zone, coredns)
for _, svc := range serviceList {
if !match(r.namespace, svc.Namespace) || !match(r.service, svc.Name) {
continue
}
// If "ignore empty_service" option is set and no endpoints exist, return NXDOMAIN unless
// it's a headless or externalName service (covered below).
if k.opts.ignoreEmptyService && !svc.Headless() { // serve NXDOMAIN if no endpoint is able to answer
podsCount := 0
for _, ep := range endpointsListFunc() {
for _, eps := range ep.Subsets {
podsCount += len(eps.Addresses)
}
}
if podsCount == 0 {
continue
}
}
// Endpoint query or headless service
if svc.Headless() || r.endpoint != "" {
if endpointsList == nil {
endpointsList = endpointsListFunc()
}
for _, ep := range endpointsList {
if object.MultiClusterEndpointsKey(svc.Name, svc.Namespace) != ep.Index {
continue
}
for _, eps := range ep.Subsets {
for _, addr := range eps.Addresses {
// See comments in parse.go parseRequest about the endpoint handling.
if r.endpoint != "" {
if !match(r.cluster, ep.ClusterId) || !match(r.endpoint, endpointHostname(addr, k.endpointNameMode)) {
continue
}
}
for _, p := range eps.Ports {
if !(matchPortAndProtocol(r.port, p.Name, r.protocol, p.Protocol)) {
continue
}
s := msg.Service{Host: addr.IP, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name, ep.ClusterId, endpointHostname(addr, k.endpointNameMode)}, "/")
err = nil
services = append(services, s)
}
}
}
}
continue
}
// ClusterIP service
for _, p := range svc.Ports {
if !(matchPortAndProtocol(r.port, p.Name, r.protocol, string(p.Protocol))) {
continue
}
err = nil
for _, ip := range svc.ClusterIPs {
s := msg.Service{Host: ip, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name}, "/")
services = append(services, s)
}
}
}
return services, err
}
// Serial return the SOA serial.
func (k *Kubernetes) Serial(state request.Request) uint32 {
if !k.isMultiClusterZone(state.Zone) {
return uint32(k.APIConn.Modified(ModifiedInternal))
} else {
return uint32(k.APIConn.Modified(ModifiedMultiCluster))
}
}
// MinTTL returns the minimal TTL.
func (k *Kubernetes) MinTTL(state request.Request) uint32 { return k.ttl }
func (k *Kubernetes) isMultiClusterZone(zone string) bool {
z := plugin.Zones(k.opts.multiclusterZones).Matches(zone)
return z != ""
}
// match checks if a and b are equal.
func match(a, b string) bool {
return strings.EqualFold(a, b)
}
// matchPortAndProtocol matches port and protocol, permitting the 'a' inputs to be wild
func matchPortAndProtocol(aPort, bPort, aProtocol, bProtocol string) bool {
return (match(aPort, bPort) || aPort == "") && (match(aProtocol, bProtocol) || aProtocol == "")
}
const coredns = "c" // used as a fake key prefix in msg.Service
package kubernetes
import (
"net"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
)
// boundIPs returns the list of non-loopback IPs that CoreDNS is bound to
func boundIPs(c *caddy.Controller) (ips []net.IP) {
conf := dnsserver.GetConfig(c)
hosts := conf.ListenHosts
if hosts == nil || hosts[0] == "" {
hosts = nil
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil
}
for _, addr := range addrs {
hosts = append(hosts, addr.String())
}
}
for _, host := range hosts {
ip, _, _ := net.ParseCIDR(host)
ip4 := ip.To4()
if ip4 != nil && !ip4.IsLoopback() {
ips = append(ips, ip4)
continue
}
ip6 := ip.To16()
if ip6 != nil && !ip6.IsLoopback() {
ips = append(ips, ip6)
}
}
return ips
}
package kubernetes
import (
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/go-logr/logr"
)
// loggerAdapter is a simple wrapper around CoreDNS plugin logger made to implement logr.LogSink interface, which is used
// as part of klog library for logging in Kubernetes client. By using this adapter CoreDNS is able to log messages/errors from
// kubernetes client in a CoreDNS logging format
type loggerAdapter struct {
clog.P
}
func (l *loggerAdapter) Init(_ logr.RuntimeInfo) {
}
func (l *loggerAdapter) Enabled(_ int) bool {
// verbosity is controlled inside klog library, we do not need to do anything here
return true
}
func (l *loggerAdapter) Info(_ int, msg string, _ ...any) {
l.P.Info(msg)
}
func (l *loggerAdapter) Error(_ error, msg string, _ ...any) {
l.P.Error(msg)
}
func (l *loggerAdapter) WithValues(_ ...any) logr.LogSink {
return l
}
func (l *loggerAdapter) WithName(_ string) logr.LogSink {
return l
}
package kubernetes
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/request"
)
// Metadata implements the metadata.Provider interface.
func (k *Kubernetes) Metadata(ctx context.Context, state request.Request) context.Context {
pod := k.podWithIP(state.IP())
if pod != nil {
metadata.SetValueFunc(ctx, "kubernetes/client-namespace", func() string {
return pod.Namespace
})
metadata.SetValueFunc(ctx, "kubernetes/client-pod-name", func() string {
return pod.Name
})
for k, v := range pod.Labels {
metadata.SetValueFunc(ctx, "kubernetes/client-label/"+k, func() string {
return v
})
}
}
zone := plugin.Zones(k.Zones).Matches(state.Name())
if zone == "" {
return ctx
}
multicluster := false
if z := plugin.Zones(k.opts.multiclusterZones).Matches(state.Zone); z != "" {
multicluster = true
}
// possible optimization: cache r so it doesn't need to be calculated again in ServeDNS
r, err := parseRequest(state.Name(), zone, multicluster)
if err != nil {
metadata.SetValueFunc(ctx, "kubernetes/parse-error", func() string {
return err.Error()
})
return ctx
}
metadata.SetValueFunc(ctx, "kubernetes/port-name", func() string {
return r.port
})
metadata.SetValueFunc(ctx, "kubernetes/protocol", func() string {
return r.protocol
})
metadata.SetValueFunc(ctx, "kubernetes/endpoint", func() string {
return r.endpoint
})
if multicluster {
metadata.SetValueFunc(ctx, "kubernetes/cluster", func() string {
return r.cluster
})
}
metadata.SetValueFunc(ctx, "kubernetes/service", func() string {
return r.service
})
metadata.SetValueFunc(ctx, "kubernetes/namespace", func() string {
return r.namespace
})
metadata.SetValueFunc(ctx, "kubernetes/kind", func() string {
return r.podOrSvc
})
return ctx
}
package kubernetes
import (
"context"
"net/url"
"time"
"github.com/coredns/coredns/plugin"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"k8s.io/client-go/tools/metrics"
)
var (
// requestLatency measures K8s rest client requests latency grouped by verb and host.
requestLatency = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: plugin.Namespace,
Subsystem: "kubernetes",
Name: "rest_client_request_duration_seconds",
Help: "Request latency in seconds. Broken down by verb and host.",
Buckets: prometheus.DefBuckets,
NativeHistogramBucketFactor: plugin.NativeHistogramBucketFactor,
},
[]string{"verb", "host"},
)
// rateLimiterLatency measures K8s rest client rate limiter latency grouped by verb and host.
rateLimiterLatency = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: plugin.Namespace,
Subsystem: "kubernetes",
Name: "rest_client_rate_limiter_duration_seconds",
Help: "Client side rate limiter latency in seconds. Broken down by verb and host.",
Buckets: prometheus.DefBuckets,
NativeHistogramBucketFactor: plugin.NativeHistogramBucketFactor,
},
[]string{"verb", "host"},
)
// requestResult measures K8s rest client request metrics grouped by status code, method & host.
requestResult = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "kubernetes",
Name: "rest_client_requests_total",
Help: "Number of HTTP requests, partitioned by status code, method, and host.",
},
[]string{"code", "method", "host"},
)
)
func init() {
metrics.Register(metrics.RegisterOpts{
RequestLatency: &latencyAdapter{m: requestLatency},
RateLimiterLatency: &latencyAdapter{m: rateLimiterLatency},
RequestResult: &resultAdapter{requestResult},
})
}
type latencyAdapter struct {
m *prometheus.HistogramVec
}
func (l *latencyAdapter) Observe(_ context.Context, verb string, u url.URL, latency time.Duration) {
l.m.WithLabelValues(verb, u.Host).Observe(latency.Seconds())
}
type resultAdapter struct {
m *prometheus.CounterVec
}
func (r *resultAdapter) Increment(_ context.Context, code, method, host string) {
r.m.WithLabelValues(code, method, host).Inc()
}
package kubernetes
// filteredNamespaceExists checks if namespace exists in this cluster
// according to any `namespace_labels` plugin configuration specified.
// Returns true even for namespaces not exposed by plugin configuration,
// see namespaceExposed.
func (k *Kubernetes) filteredNamespaceExists(namespace string) bool {
_, err := k.APIConn.GetNamespaceByName(namespace)
return err == nil
}
// configuredNamespace returns true when the namespace is exposed through the plugin
// `namespaces` configuration.
func (k *Kubernetes) configuredNamespace(namespace string) bool {
_, ok := k.Namespaces[namespace]
if len(k.Namespaces) > 0 && !ok {
return false
}
return true
}
func (k *Kubernetes) namespaceExposed(namespace string) bool {
return k.configuredNamespace(namespace) && k.filteredNamespaceExists(namespace)
}
package kubernetes
import (
"net"
"strings"
"github.com/miekg/dns"
)
func isDefaultNS(name, zone string) bool {
return strings.Index(name, defaultNSName) == 0 && strings.Index(name, zone) == len(defaultNSName)
}
// nsAddrs returns the A or AAAA records for the CoreDNS service in the cluster. If the service cannot be found,
// it returns a record for the local address of the machine we're running on.
func (k *Kubernetes) nsAddrs(external, headless bool, zone string) []dns.RR {
var (
svcNames []string
svcIPs []net.IP
foundEndpoint bool
)
// Find the CoreDNS Endpoints
for _, localIP := range k.localIPs {
endpoints := k.APIConn.EpIndexReverse(localIP.String())
// Collect IPs for all Services of the Endpoints
for _, endpoint := range endpoints {
foundEndpoint = true
svcs := k.APIConn.SvcIndex(endpoint.Index)
for _, svc := range svcs {
if external {
svcName := strings.Join([]string{svc.Name, svc.Namespace, zone}, ".")
if headless && svc.Headless() {
for _, s := range endpoint.Subsets {
for _, a := range s.Addresses {
svcNames = append(svcNames, endpointHostname(a, k.endpointNameMode)+"."+svcName)
svcIPs = append(svcIPs, net.ParseIP(a.IP))
}
}
} else {
for _, exIP := range svc.ExternalIPs {
svcNames = append(svcNames, svcName)
svcIPs = append(svcIPs, net.ParseIP(exIP))
}
}
continue
}
svcName := strings.Join([]string{svc.Name, svc.Namespace, Svc, zone}, ".")
if svc.Headless() {
// For a headless service, use the endpoints IPs
for _, s := range endpoint.Subsets {
for _, a := range s.Addresses {
svcNames = append(svcNames, endpointHostname(a, k.endpointNameMode)+"."+svcName)
svcIPs = append(svcIPs, net.ParseIP(a.IP))
}
}
} else {
for _, clusterIP := range svc.ClusterIPs {
svcNames = append(svcNames, svcName)
svcIPs = append(svcIPs, net.ParseIP(clusterIP))
}
}
}
}
}
// If no CoreDNS endpoints were found, use the localIPs directly
if !foundEndpoint {
svcIPs = make([]net.IP, len(k.localIPs))
svcNames = make([]string, len(k.localIPs))
for i, localIP := range k.localIPs {
svcNames[i] = defaultNSName + zone
svcIPs[i] = localIP
}
}
// Create an RR slice of collected IPs
rrs := make([]dns.RR, len(svcIPs))
for i, ip := range svcIPs {
if ip.To4() == nil {
rr := new(dns.AAAA)
rr.Hdr.Class = dns.ClassINET
rr.Hdr.Rrtype = dns.TypeAAAA
rr.Hdr.Name = svcNames[i]
rr.AAAA = ip
rrs[i] = rr
continue
}
rr := new(dns.A)
rr.Hdr.Class = dns.ClassINET
rr.Hdr.Rrtype = dns.TypeA
rr.Hdr.Name = svcNames[i]
rr.A = ip
rrs[i] = rr
}
return rrs
}
const defaultNSName = "ns.dns."
package object
import (
"fmt"
discovery "k8s.io/api/discovery/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
)
// Endpoints is a stripped down api.Endpoints with only the items we need for CoreDNS.
type Endpoints struct {
// Don't add new fields to this struct without talking to the CoreDNS maintainers.
Version string
Name string
Namespace string
Index string
IndexIP []string
Subsets []EndpointSubset
*Empty
}
// EndpointSubset is a group of addresses with a common set of ports. The
// expanded set of endpoints is the Cartesian product of Addresses x Ports.
type EndpointSubset struct {
Addresses []EndpointAddress
Ports []EndpointPort
}
// EndpointAddress is a tuple that describes single IP address.
type EndpointAddress struct {
IP string
Hostname string
NodeName string
TargetRefName string
}
// EndpointPort is a tuple that describes a single port.
type EndpointPort struct {
Port int32
Name string
Protocol string
}
// EndpointsKey returns a string using for the index.
func EndpointsKey(name, namespace string) string { return name + "." + namespace }
// EndpointSliceToEndpoints converts a *discovery.EndpointSlice to a *Endpoints.
func EndpointSliceToEndpoints(obj meta.Object) (meta.Object, error) {
ends, ok := obj.(*discovery.EndpointSlice)
if !ok {
return nil, fmt.Errorf("unexpected object %v", obj)
}
e := &Endpoints{
Version: ends.GetResourceVersion(),
Name: ends.GetName(),
Namespace: ends.GetNamespace(),
Index: EndpointsKey(ends.Labels[discovery.LabelServiceName], ends.GetNamespace()),
Subsets: make([]EndpointSubset, 1),
}
if len(ends.Ports) == 0 {
// Add sentinel if there are no ports.
e.Subsets[0].Ports = []EndpointPort{{Port: -1}}
} else {
e.Subsets[0].Ports = make([]EndpointPort, len(ends.Ports))
for k, p := range ends.Ports {
port := int32(-1)
name := ""
protocol := ""
if p.Port != nil {
port = *p.Port
}
if p.Name != nil {
name = *p.Name
}
if p.Protocol != nil {
protocol = string(*p.Protocol)
}
ep := EndpointPort{Port: port, Name: name, Protocol: protocol}
e.Subsets[0].Ports[k] = ep
}
}
for _, end := range ends.Endpoints {
if !endpointsliceReady(end.Conditions.Ready) {
continue
}
for _, a := range end.Addresses {
ea := EndpointAddress{IP: a}
if end.Hostname != nil {
ea.Hostname = *end.Hostname
}
// ignore pod names that are too long to be a valid label
if end.TargetRef != nil && len(end.TargetRef.Name) < 64 {
ea.TargetRefName = end.TargetRef.Name
}
if end.NodeName != nil {
ea.NodeName = *end.NodeName
}
e.Subsets[0].Addresses = append(e.Subsets[0].Addresses, ea)
e.IndexIP = append(e.IndexIP, a)
}
}
*ends = discovery.EndpointSlice{}
return e, nil
}
func endpointsliceReady(ready *bool) bool {
// Per API docs: a nil value indicates an unknown state. In most cases consumers
// should interpret this unknown state as ready.
if ready == nil {
return true
}
return *ready
}
// CopyWithoutSubsets copies e, without the subsets.
func (e *Endpoints) CopyWithoutSubsets() *Endpoints {
e1 := &Endpoints{
Version: e.Version,
Name: e.Name,
Namespace: e.Namespace,
Index: e.Index,
IndexIP: make([]string, len(e.IndexIP)),
}
copy(e1.IndexIP, e.IndexIP)
return e1
}
var _ runtime.Object = &Endpoints{}
// DeepCopyObject implements the ObjectKind interface.
func (e *Endpoints) DeepCopyObject() runtime.Object {
e1 := &Endpoints{
Version: e.Version,
Name: e.Name,
Namespace: e.Namespace,
Index: e.Index,
IndexIP: make([]string, len(e.IndexIP)),
Subsets: make([]EndpointSubset, len(e.Subsets)),
}
copy(e1.IndexIP, e.IndexIP)
for i, eps := range e.Subsets {
sub := EndpointSubset{
Addresses: make([]EndpointAddress, len(eps.Addresses)),
Ports: make([]EndpointPort, len(eps.Ports)),
}
for j, a := range eps.Addresses {
ea := EndpointAddress{IP: a.IP, Hostname: a.Hostname, NodeName: a.NodeName, TargetRefName: a.TargetRefName}
sub.Addresses[j] = ea
}
for k, p := range eps.Ports {
ep := EndpointPort{Port: p.Port, Name: p.Name, Protocol: p.Protocol}
sub.Ports[k] = ep
}
e1.Subsets[i] = sub
}
return e1
}
// GetNamespace implements the metav1.Object interface.
func (e *Endpoints) GetNamespace() string { return e.Namespace }
// SetNamespace implements the metav1.Object interface.
func (e *Endpoints) SetNamespace(namespace string) {}
// GetName implements the metav1.Object interface.
func (e *Endpoints) GetName() string { return e.Name }
// SetName implements the metav1.Object interface.
func (e *Endpoints) SetName(name string) {}
// GetResourceVersion implements the metav1.Object interface.
func (e *Endpoints) GetResourceVersion() string { return e.Version }
// SetResourceVersion implements the metav1.Object interface.
func (e *Endpoints) SetResourceVersion(version string) {}
package object
import (
"fmt"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/cache"
)
// NewIndexerInformer is a copy of the cache.NewIndexerInformer function, but allows custom process function
func NewIndexerInformer(lw cache.ListerWatcher, objType runtime.Object, h cache.ResourceEventHandler, indexers cache.Indexers, builder ProcessorBuilder) (cache.Indexer, cache.Controller) {
clientState := cache.NewIndexer(cache.DeletionHandlingMetaNamespaceKeyFunc, indexers)
cfg := &cache.Config{
Queue: cache.NewDeltaFIFOWithOptions(cache.DeltaFIFOOptions{KeyFunction: cache.MetaNamespaceKeyFunc, KnownObjects: clientState}),
ListerWatcher: lw,
ObjectType: objType,
FullResyncPeriod: defaultResyncPeriod,
Process: builder(clientState, h),
}
return clientState, cache.New(cfg)
}
// RecordLatencyFunc is a function for recording api object delta latency
type RecordLatencyFunc func(meta.Object)
// DefaultProcessor is based on the Process function from cache.NewIndexerInformer except it does a conversion.
func DefaultProcessor(convert ToFunc, recordLatency *EndpointLatencyRecorder) ProcessorBuilder {
return func(clientState cache.Indexer, h cache.ResourceEventHandler) cache.ProcessFunc {
return func(obj any, isInitialList bool) error {
for _, d := range obj.(cache.Deltas) {
if recordLatency != nil {
if o, ok := d.Object.(meta.Object); ok {
recordLatency.init(o)
}
}
switch d.Type {
case cache.Sync, cache.Added, cache.Updated:
obj, err := convert(d.Object.(meta.Object))
if err != nil {
if err == errPodTerminating {
continue
}
return err
}
if old, exists, err := clientState.Get(obj); err == nil && exists {
if err := clientState.Update(obj); err != nil {
return err
}
h.OnUpdate(old, obj)
} else {
if err := clientState.Add(obj); err != nil {
return err
}
h.OnAdd(obj, isInitialList)
}
if recordLatency != nil {
recordLatency.record()
}
case cache.Deleted:
var obj any
obj, ok := d.Object.(cache.DeletedFinalStateUnknown)
if !ok {
var err error
metaObj, ok := d.Object.(meta.Object)
if !ok {
return fmt.Errorf("unexpected object %v", d.Object)
}
obj, err = convert(metaObj)
if err != nil && err != errPodTerminating {
return err
}
}
if err := clientState.Delete(obj); err != nil {
return err
}
h.OnDelete(obj)
if !ok && recordLatency != nil {
recordLatency.record()
}
}
}
return nil
}
}
}
const defaultResyncPeriod = 0
package object
import (
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
api "k8s.io/api/core/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
)
var (
// DNSProgrammingLatency is defined as the time it took to program a DNS instance - from the time
// a service or pod has changed to the time the change was propagated and was available to be
// served by a DNS server.
// The definition of this SLI can be found at https://github.com/kubernetes/community/blob/master/sig-scalability/slos/dns_programming_latency.md
// Note that the metrics is partially based on the time exported by the endpoints controller on
// the master machine. The measurement may be inaccurate if there is a clock drift between the
// node and master machine.
// The service_kind label can be one of:
// * cluster_ip
// * headless_with_selector
// * headless_without_selector
DNSProgrammingLatency = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: plugin.Namespace,
Subsystem: "kubernetes",
Name: "dns_programming_duration_seconds",
// From 1 millisecond to ~17 minutes.
Buckets: prometheus.ExponentialBuckets(0.001, 2, 20),
NativeHistogramBucketFactor: plugin.NativeHistogramBucketFactor,
Help: "Histogram of the time (in seconds) it took to program a dns instance.",
}, []string{"service_kind"})
// DurationSinceFunc returns the duration elapsed since the given time.
// Added as a global variable to allow injection for testing.
DurationSinceFunc = time.Since
)
// EndpointLatencyRecorder records latency metric for endpoint objects
type EndpointLatencyRecorder struct {
TT time.Time
ServiceFunc func(meta.Object) []*Service
Services []*Service
}
func (l *EndpointLatencyRecorder) init(o meta.Object) {
l.Services = l.ServiceFunc(o)
l.TT = time.Time{}
stringVal, ok := o.GetAnnotations()[api.EndpointsLastChangeTriggerTime]
if ok {
tt, err := time.Parse(time.RFC3339Nano, stringVal)
if err != nil {
log.Warningf("DnsProgrammingLatency cannot be calculated for Endpoints '%s/%s'; invalid %q annotation RFC3339 value of %q",
o.GetNamespace(), o.GetName(), api.EndpointsLastChangeTriggerTime, stringVal)
// In case of error val = time.Zero, which is ignored downstream.
}
l.TT = tt
}
}
func (l *EndpointLatencyRecorder) record() {
// isHeadless indicates whether the endpoints object belongs to a headless
// service (i.e. clusterIp = None). Note that this can be a false negatives if the service
// informer is lagging, i.e. we may not see a recently created service. Given that the services
// don't change very often (comparing to much more frequent endpoints changes), cases when this method
// will return wrong answer should be relatively rare. Because of that we intentionally accept this
// flaw to keep the solution simple.
isHeadless := len(l.Services) == 1 && l.Services[0].Headless()
if !isHeadless || l.TT.IsZero() {
return
}
// If we're here it means that the Endpoints object is for a headless service and that
// the Endpoints object was created by the endpoints-controller (because the
// LastChangeTriggerTime annotation is set). It means that the corresponding service is a
// "headless service with selector".
DNSProgrammingLatency.WithLabelValues("headless_with_selector").
Observe(DurationSinceFunc(l.TT).Seconds())
}
package object
import (
"maps"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
mcs "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1"
)
// MultiClusterEndpoints is a stripped down api.Endpoints with only the items we need for CoreDNS.
type MultiClusterEndpoints struct {
Endpoints
ClusterId string
*Empty
}
// MultiClusterEndpointsKey returns a string using for the index.
func MultiClusterEndpointsKey(name, namespace string) string { return name + "." + namespace }
// EndpointSliceToMultiClusterEndpoints converts a *discovery.EndpointSlice to a *Endpoints.
func EndpointSliceToMultiClusterEndpoints(obj meta.Object) (meta.Object, error) {
labels := maps.Clone(obj.GetLabels())
ends, err := EndpointSliceToEndpoints(obj)
if err != nil {
return nil, err
}
e := &MultiClusterEndpoints{
Endpoints: *ends.(*Endpoints),
ClusterId: labels[mcs.LabelSourceCluster],
}
e.Index = MultiClusterEndpointsKey(labels[mcs.LabelServiceName], ends.GetNamespace())
return e, nil
}
var _ runtime.Object = &Endpoints{}
// DeepCopyObject implements the ObjectKind interface.
func (e *MultiClusterEndpoints) DeepCopyObject() runtime.Object {
e1 := &MultiClusterEndpoints{
ClusterId: e.ClusterId,
Endpoints: *e.Endpoints.DeepCopyObject().(*Endpoints),
}
return e1
}
// GetNamespace implements the metav1.Object interface.
func (e *MultiClusterEndpoints) GetNamespace() string { return e.Endpoints.GetNamespace() }
// SetNamespace implements the metav1.Object interface.
func (e *MultiClusterEndpoints) SetNamespace(namespace string) {}
// GetName implements the metav1.Object interface.
func (e *MultiClusterEndpoints) GetName() string { return e.Endpoints.GetName() }
// SetName implements the metav1.Object interface.
func (e *MultiClusterEndpoints) SetName(name string) {}
// GetResourceVersion implements the metav1.Object interface.
func (e *MultiClusterEndpoints) GetResourceVersion() string { return e.Endpoints.GetResourceVersion() }
// SetResourceVersion implements the metav1.Object interface.
func (e *MultiClusterEndpoints) SetResourceVersion(version string) {}
package object
import (
"fmt"
api "k8s.io/api/core/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
)
// Namespace is a stripped down api.Namespace with only the items we need for CoreDNS.
type Namespace struct {
// Don't add new fields to this struct without talking to the CoreDNS maintainers.
Version string
Name string
*Empty
}
// ToNamespace returns a function that converts an api.Namespace to a *Namespace.
func ToNamespace(obj meta.Object) (meta.Object, error) {
ns, ok := obj.(*api.Namespace)
if !ok {
return nil, fmt.Errorf("unexpected object %v", obj)
}
n := &Namespace{
Version: ns.GetResourceVersion(),
Name: ns.GetName(),
}
*ns = api.Namespace{}
return n, nil
}
var _ runtime.Object = &Namespace{}
// DeepCopyObject implements the ObjectKind interface.
func (n *Namespace) DeepCopyObject() runtime.Object {
n1 := &Namespace{
Version: n.Version,
Name: n.Name,
}
return n1
}
// GetNamespace implements the metav1.Object interface.
func (n *Namespace) GetNamespace() string { return "" }
// SetNamespace implements the metav1.Object interface.
func (n *Namespace) SetNamespace(namespace string) {}
// GetName implements the metav1.Object interface.
func (n *Namespace) GetName() string { return n.Name }
// SetName implements the metav1.Object interface.
func (n *Namespace) SetName(name string) {}
// GetResourceVersion implements the metav1.Object interface.
func (n *Namespace) GetResourceVersion() string { return n.Version }
// SetResourceVersion implements the metav1.Object interface.
func (n *Namespace) SetResourceVersion(version string) {}
// Package object holds functions that convert the objects from the k8s API in
// to a more memory efficient structures.
//
// Adding new fields to any of the structures defined in pod.go, endpoint.go
// and service.go should not be done lightly as this increases the memory use
// and will leads to OOMs in the k8s scale test.
//
// We can do some optimizations here as well. We store IP addresses as strings,
// this might be moved to uint32 (for v4) for instance, but then we need to
// convert those again.
//
// Also the msg.Service use in this plugin may be deprecated at some point, as
// we don't use most of those features anyway and would free us from the *etcd*
// dependency, where msg.Service is defined. And should save some mem/cpu as we
// convert to and from msg.Services.
package object
import (
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/cache"
)
// ToFunc converts one v1.Object to another v1.Object.
type ToFunc func(v1.Object) (v1.Object, error)
// ProcessorBuilder returns function to process cache events.
type ProcessorBuilder func(cache.Indexer, cache.ResourceEventHandler) cache.ProcessFunc
// Empty is an empty struct.
type Empty struct{}
// GetObjectKind implements the ObjectKind interface as a noop.
func (e *Empty) GetObjectKind() schema.ObjectKind { return schema.EmptyObjectKind }
// GetGenerateName implements the metav1.Object interface.
func (e *Empty) GetGenerateName() string { return "" }
// SetGenerateName implements the metav1.Object interface.
func (e *Empty) SetGenerateName(name string) {}
// GetUID implements the metav1.Object interface.
func (e *Empty) GetUID() types.UID { return "" }
// SetUID implements the metav1.Object interface.
func (e *Empty) SetUID(uid types.UID) {}
// GetGeneration implements the metav1.Object interface.
func (e *Empty) GetGeneration() int64 { return 0 }
// SetGeneration implements the metav1.Object interface.
func (e *Empty) SetGeneration(generation int64) {}
// GetSelfLink implements the metav1.Object interface.
func (e *Empty) GetSelfLink() string { return "" }
// SetSelfLink implements the metav1.Object interface.
func (e *Empty) SetSelfLink(selfLink string) {}
// GetCreationTimestamp implements the metav1.Object interface.
func (e *Empty) GetCreationTimestamp() v1.Time { return v1.Time{} }
// SetCreationTimestamp implements the metav1.Object interface.
func (e *Empty) SetCreationTimestamp(timestamp v1.Time) {}
// GetDeletionTimestamp implements the metav1.Object interface.
func (e *Empty) GetDeletionTimestamp() *v1.Time { return &v1.Time{} }
// SetDeletionTimestamp implements the metav1.Object interface.
func (e *Empty) SetDeletionTimestamp(timestamp *v1.Time) {}
// GetDeletionGracePeriodSeconds implements the metav1.Object interface.
func (e *Empty) GetDeletionGracePeriodSeconds() *int64 { return nil }
// SetDeletionGracePeriodSeconds implements the metav1.Object interface.
func (e *Empty) SetDeletionGracePeriodSeconds(*int64) {}
// GetLabels implements the metav1.Object interface.
func (e *Empty) GetLabels() map[string]string { return nil }
// SetLabels implements the metav1.Object interface.
func (e *Empty) SetLabels(labels map[string]string) {}
// GetAnnotations implements the metav1.Object interface.
func (e *Empty) GetAnnotations() map[string]string { return nil }
// SetAnnotations implements the metav1.Object interface.
func (e *Empty) SetAnnotations(annotations map[string]string) {}
// GetFinalizers implements the metav1.Object interface.
func (e *Empty) GetFinalizers() []string { return nil }
// SetFinalizers implements the metav1.Object interface.
func (e *Empty) SetFinalizers(finalizers []string) {}
// GetOwnerReferences implements the metav1.Object interface.
func (e *Empty) GetOwnerReferences() []v1.OwnerReference { return nil }
// SetOwnerReferences implements the metav1.Object interface.
func (e *Empty) SetOwnerReferences([]v1.OwnerReference) {}
// GetZZZ_DeprecatedClusterName implements the metav1.Object interface.
func (e *Empty) GetZZZ_DeprecatedClusterName() string { return "" }
// SetZZZ_DeprecatedClusterName implements the metav1.Object interface.
func (e *Empty) SetZZZ_DeprecatedClusterName(clusterName string) {}
// GetManagedFields implements the metav1.Object interface.
func (e *Empty) GetManagedFields() []v1.ManagedFieldsEntry { return nil }
// SetManagedFields implements the metav1.Object interface.
func (e *Empty) SetManagedFields(managedFields []v1.ManagedFieldsEntry) {}
package object
import (
"errors"
"fmt"
api "k8s.io/api/core/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
)
// Pod is a stripped down api.Pod with only the items we need for CoreDNS.
type Pod struct {
// Don't add new fields to this struct without talking to the CoreDNS maintainers.
Version string
PodIP string
Name string
Namespace string
Labels map[string]string
*Empty
}
var errPodTerminating = errors.New("pod terminating")
// ToPod converts an api.Pod to a *Pod.
func ToPod(obj meta.Object) (meta.Object, error) {
apiPod, ok := obj.(*api.Pod)
if !ok {
return nil, fmt.Errorf("unexpected object %v", obj)
}
pod := &Pod{
Version: apiPod.GetResourceVersion(),
PodIP: apiPod.Status.PodIP,
Namespace: apiPod.GetNamespace(),
Name: apiPod.GetName(),
Labels: apiPod.GetLabels(),
}
t := apiPod.DeletionTimestamp
if t != nil && !(*t).Time.IsZero() {
// if the pod is in the process of termination, return an error so it can be ignored
// during add/update event processing
return pod, errPodTerminating
}
*apiPod = api.Pod{}
return pod, nil
}
var _ runtime.Object = &Pod{}
// DeepCopyObject implements the ObjectKind interface.
func (p *Pod) DeepCopyObject() runtime.Object {
p1 := &Pod{
Version: p.Version,
PodIP: p.PodIP,
Namespace: p.Namespace,
Name: p.Name,
}
return p1
}
// GetNamespace implements the metav1.Object interface.
func (p *Pod) GetNamespace() string { return p.Namespace }
// SetNamespace implements the metav1.Object interface.
func (p *Pod) SetNamespace(namespace string) {}
// GetName implements the metav1.Object interface.
func (p *Pod) GetName() string { return p.Name }
// SetName implements the metav1.Object interface.
func (p *Pod) SetName(name string) {}
// GetResourceVersion implements the metav1.Object interface.
func (p *Pod) GetResourceVersion() string { return p.Version }
// SetResourceVersion implements the metav1.Object interface.
func (p *Pod) SetResourceVersion(version string) {}
package object
import (
"fmt"
api "k8s.io/api/core/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
)
// Service is a stripped down api.Service with only the items we need for CoreDNS.
type Service struct {
// Don't add new fields to this struct without talking to the CoreDNS maintainers.
Version string
Name string
Namespace string
Index string
ClusterIPs []string
Type api.ServiceType
ExternalName string
Ports []api.ServicePort
// ExternalIPs we may want to export.
ExternalIPs []string
*Empty
}
// ServiceKey returns a string using for the index.
func ServiceKey(name, namespace string) string { return name + "." + namespace }
// ToService converts an api.Service to a *Service.
func ToService(obj meta.Object) (meta.Object, error) {
svc, ok := obj.(*api.Service)
if !ok {
return nil, fmt.Errorf("unexpected object %v", obj)
}
s := &Service{
Version: svc.GetResourceVersion(),
Name: svc.GetName(),
Namespace: svc.GetNamespace(),
Index: ServiceKey(svc.GetName(), svc.GetNamespace()),
Type: svc.Spec.Type,
ExternalName: svc.Spec.ExternalName,
ExternalIPs: make([]string, len(svc.Status.LoadBalancer.Ingress)+len(svc.Spec.ExternalIPs)),
}
if len(svc.Spec.ClusterIPs) > 0 {
s.ClusterIPs = make([]string, len(svc.Spec.ClusterIPs))
copy(s.ClusterIPs, svc.Spec.ClusterIPs)
} else {
s.ClusterIPs = []string{svc.Spec.ClusterIP}
}
if len(svc.Spec.Ports) == 0 {
// Add sentinel if there are no ports.
s.Ports = []api.ServicePort{{Port: -1}}
} else {
s.Ports = make([]api.ServicePort, len(svc.Spec.Ports))
copy(s.Ports, svc.Spec.Ports)
}
li := copy(s.ExternalIPs, svc.Spec.ExternalIPs)
for i, lb := range svc.Status.LoadBalancer.Ingress {
if lb.IP != "" {
s.ExternalIPs[li+i] = lb.IP
continue
}
s.ExternalIPs[li+i] = lb.Hostname
}
*svc = api.Service{}
return s, nil
}
// Headless returns true if the service is headless
func (s *Service) Headless() bool {
return s.ClusterIPs[0] == api.ClusterIPNone
}
var _ runtime.Object = &Service{}
// DeepCopyObject implements the ObjectKind interface.
func (s *Service) DeepCopyObject() runtime.Object {
s1 := &Service{
Version: s.Version,
Name: s.Name,
Namespace: s.Namespace,
Index: s.Index,
Type: s.Type,
ExternalName: s.ExternalName,
ClusterIPs: make([]string, len(s.ClusterIPs)),
Ports: make([]api.ServicePort, len(s.Ports)),
ExternalIPs: make([]string, len(s.ExternalIPs)),
}
copy(s1.ClusterIPs, s.ClusterIPs)
copy(s1.Ports, s.Ports)
copy(s1.ExternalIPs, s.ExternalIPs)
return s1
}
// GetNamespace implements the metav1.Object interface.
func (s *Service) GetNamespace() string { return s.Namespace }
// SetNamespace implements the metav1.Object interface.
func (s *Service) SetNamespace(namespace string) {}
// GetName implements the metav1.Object interface.
func (s *Service) GetName() string { return s.Name }
// SetName implements the metav1.Object interface.
func (s *Service) SetName(name string) {}
// GetResourceVersion implements the metav1.Object interface.
func (s *Service) GetResourceVersion() string { return s.Version }
// SetResourceVersion implements the metav1.Object interface.
func (s *Service) SetResourceVersion(version string) {}
package object
import (
"fmt"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
mcs "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1"
)
// ServiceImport is a stripped down api.ServiceImport with only the items we need for CoreDNS.
type ServiceImport struct {
Version string
Name string
Namespace string
Index string
ClusterIPs []string
Type mcs.ServiceImportType
Ports []mcs.ServicePort
*Empty
}
// ServiceImportKey returns a string using for the index.
func ServiceImportKey(name, namespace string) string { return name + "." + namespace }
// ToServiceImport converts an v1alpha1.ServiceImport to a *ServiceImport.
func ToServiceImport(obj meta.Object) (meta.Object, error) {
svc, ok := obj.(*mcs.ServiceImport)
if !ok {
return nil, fmt.Errorf("unexpected object %v", obj)
}
s := &ServiceImport{
Version: svc.GetResourceVersion(),
Name: svc.GetName(),
Namespace: svc.GetNamespace(),
Index: ServiceImportKey(svc.GetName(), svc.GetNamespace()),
Type: svc.Spec.Type,
}
if len(svc.Spec.IPs) > 0 {
s.ClusterIPs = make([]string, len(svc.Spec.IPs))
copy(s.ClusterIPs, svc.Spec.IPs)
}
if len(svc.Spec.Ports) > 0 {
s.Ports = make([]mcs.ServicePort, len(svc.Spec.Ports))
copy(s.Ports, svc.Spec.Ports)
}
*svc = mcs.ServiceImport{}
return s, nil
}
var _ runtime.Object = &ServiceImport{}
// Headless returns true if the service is headless
func (s *ServiceImport) Headless() bool {
return s.Type == mcs.Headless
}
// DeepCopyObject implements the ObjectKind interface.
func (s *ServiceImport) DeepCopyObject() runtime.Object {
s1 := &ServiceImport{
Version: s.Version,
Name: s.Name,
Namespace: s.Namespace,
Index: s.Index,
Type: s.Type,
ClusterIPs: make([]string, len(s.ClusterIPs)),
Ports: make([]mcs.ServicePort, len(s.Ports)),
}
copy(s1.ClusterIPs, s.ClusterIPs)
copy(s1.Ports, s.Ports)
return s1
}
// GetNamespace implements the metav1.Object interface.
func (s *ServiceImport) GetNamespace() string { return s.Namespace }
// SetNamespace implements the metav1.Object interface.
func (s *ServiceImport) SetNamespace(namespace string) {}
// GetName implements the metav1.Object interface.
func (s *ServiceImport) GetName() string { return s.Name }
// SetName implements the metav1.Object interface.
func (s *ServiceImport) SetName(name string) {}
// GetResourceVersion implements the metav1.Object interface.
func (s *ServiceImport) GetResourceVersion() string { return s.Version }
// SetResourceVersion implements the metav1.Object interface.
func (s *ServiceImport) SetResourceVersion(version string) {}
package kubernetes
import (
"strings"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/miekg/dns"
)
type recordRequest struct {
// The named port from the kubernetes DNS spec, this is the service part (think _https) from a well formed
// SRV record.
port string
// The protocol is usually _udp or _tcp (if set), and comes from the protocol part of a well formed
// SRV record.
protocol string
endpoint string
cluster string
// The servicename used in Kubernetes.
service string
// The namespace used in Kubernetes.
namespace string
// A each name can be for a pod or a service, here we track what we've seen, either "pod" or "service".
podOrSvc string
}
// parseRequest parses the qname to find all the elements we need for querying k8s. Anything
// that is not parsed will have the wildcard "*" value (except r.endpoint).
// Potential underscores are stripped from _port and _protocol.
func parseRequest(name, zone string, multicluster bool) (r recordRequest, err error) {
// 4 Possible cases:
// 1. _port._protocol.service.namespace.pod|svc.zone
// 2. (endpoint): endpoint.service.namespace.pod|svc.zone
// 3. (service): service.namespace.pod|svc.zone
// 4. (endpoint multicluster): endpoint.cluster.service.namespace.pod|svc.zone
base, _ := dnsutil.TrimZone(name, zone)
// return NODATA for apex queries
if base == "" || base == Svc || base == Pod {
return r, nil
}
segs := dns.SplitDomainName(base)
last := len(segs) - 1
if last < 0 {
return r, nil
}
r.podOrSvc = segs[last]
if r.podOrSvc != Pod && r.podOrSvc != Svc {
return r, errInvalidRequest
}
last--
if last < 0 {
return r, nil
}
r.namespace = segs[last]
last--
if last < 0 {
return r, nil
}
r.service = segs[last]
last--
if last < 0 {
return r, nil
}
// Because of ambiguity we check the labels left: 1: an endpoint. 2: port and protocol or endpoint and clusterid.
// Anything else is a query that is too long to answer and can safely be delegated to return an nxdomain.
switch last {
case 0: // endpoint only
r.endpoint = segs[last]
case 1: // service and port or endpoint and clusterid
if !multicluster || strings.HasPrefix(segs[last], "_") || strings.HasPrefix(segs[last-1], "_") {
r.protocol = stripUnderscore(segs[last])
r.port = stripUnderscore(segs[last-1])
} else {
r.cluster = segs[last]
r.endpoint = segs[last-1]
}
default: // too long
return r, errInvalidRequest
}
return r, nil
}
// stripUnderscore removes a prefixed underscore from s.
func stripUnderscore(s string) string {
if len(s) == 0 {
return s
}
if s[0] != '_' {
return s
}
return s[1:]
}
// String returns a string representation of r, it just returns all fields concatenated with dots.
// This is mostly used in tests.
func (r recordRequest) String() string {
s := r.port
s += "." + r.protocol
s += "." + r.endpoint
s += "." + r.cluster
s += "." + r.service
s += "." + r.namespace
s += "." + r.podOrSvc
return s
}
package kubernetes
// Ready implements the ready.Readiness interface.
func (k *Kubernetes) Ready() bool { return k.APIConn.HasSynced() }
package kubernetes
import (
"context"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/request"
)
// Reverse implements the ServiceBackend interface.
func (k *Kubernetes) Reverse(ctx context.Context, state request.Request, exact bool, opt plugin.Options) ([]msg.Service, error) {
ip := dnsutil.ExtractAddressFromReverse(state.Name())
if ip == "" {
_, e := k.Records(ctx, state, exact)
return nil, e
}
records := k.serviceRecordForIP(ip, state.Name())
if len(records) == 0 {
return records, errNoItems
}
return records, nil
}
// serviceRecordForIP gets a service record with a cluster ip matching the ip argument
// If a service cluster ip does not match, it checks all endpoints
func (k *Kubernetes) serviceRecordForIP(ip, name string) []msg.Service {
// First check services with cluster ips
for _, service := range k.APIConn.SvcIndexReverse(ip) {
if len(k.Namespaces) > 0 && !k.namespaceExposed(service.Namespace) {
continue
}
domain := strings.Join([]string{service.Name, service.Namespace, Svc, k.primaryZone()}, ".")
return []msg.Service{{Host: domain, TTL: k.ttl}}
}
// If no cluster ips match, search endpoints
var svcs []msg.Service
for _, ep := range k.APIConn.EpIndexReverse(ip) {
if len(k.Namespaces) > 0 && !k.namespaceExposed(ep.Namespace) {
continue
}
for _, eps := range ep.Subsets {
for _, addr := range eps.Addresses {
if addr.IP == ip {
domain := strings.Join([]string{endpointHostname(addr, k.endpointNameMode), ep.Index, Svc, k.primaryZone()}, ".")
svcs = append(svcs, msg.Service{Host: domain, TTL: k.ttl})
}
}
}
}
return svcs
}
package kubernetes
import (
"context"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/go-logr/logr"
"github.com/miekg/dns"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
_ "k8s.io/client-go/plugin/pkg/client/auth/oidc" // pull this in here, because we want it excluded if plugin.cfg doesn't have k8s
"k8s.io/client-go/tools/clientcmd"
"k8s.io/klog/v2"
)
const pluginName = "kubernetes"
var log = clog.NewWithPlugin(pluginName)
func init() { plugin.Register(pluginName, setup) }
func setup(c *caddy.Controller) error {
// Do not call klog.InitFlags(nil) here. It will cause reload to panic.
klog.SetLogger(logr.New(&loggerAdapter{log}))
k, err := kubernetesParse(c)
if err != nil {
return plugin.Error(pluginName, err)
}
onStart, onShut, err := k.InitKubeCache(context.Background())
if err != nil {
return plugin.Error(pluginName, err)
}
if onStart != nil {
c.OnStartup(onStart)
}
if onShut != nil {
c.OnShutdown(onShut)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
k.Next = next
return k
})
// get locally bound addresses
c.OnStartup(func() error {
k.localIPs = boundIPs(c)
return nil
})
return nil
}
func kubernetesParse(c *caddy.Controller) (*Kubernetes, error) {
var (
k8s *Kubernetes
err error
)
i := 0
for c.Next() {
if i > 0 {
return nil, plugin.ErrOnce
}
i++
k8s, err = ParseStanza(c)
if err != nil {
return k8s, err
}
}
return k8s, nil
}
// ParseStanza parses a kubernetes stanza
func ParseStanza(c *caddy.Controller) (*Kubernetes, error) {
k8s := New([]string{""})
k8s.autoPathSearch = searchFromResolvConf()
opts := dnsControlOpts{
initEndpointsCache: true,
ignoreEmptyService: false,
}
k8s.opts = opts
k8s.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
k8s.primaryZoneIndex = -1
for i, z := range k8s.Zones {
if dnsutil.IsReverse(z) > 0 {
continue
}
k8s.primaryZoneIndex = i
break
}
if k8s.primaryZoneIndex == -1 {
return nil, errors.New("non-reverse zone name must be used")
}
k8s.Upstream = upstream.New()
k8s.startupTimeout = time.Second * 5
for c.NextBlock() {
switch c.Val() {
case "endpoint_pod_names":
args := c.RemainingArgs()
if len(args) > 0 {
return nil, c.ArgErr()
}
k8s.endpointNameMode = true
continue
case "pods":
args := c.RemainingArgs()
if len(args) == 1 {
switch args[0] {
case podModeDisabled, podModeInsecure, podModeVerified:
k8s.podMode = args[0]
default:
return nil, fmt.Errorf("wrong value for pods: %s, must be one of: disabled, verified, insecure", args[0])
}
continue
}
return nil, c.ArgErr()
case "namespaces":
args := c.RemainingArgs()
if len(args) > 0 {
for _, a := range args {
k8s.Namespaces[a] = struct{}{}
}
continue
}
return nil, c.ArgErr()
case "endpoint":
args := c.RemainingArgs()
if len(args) > 0 {
// Multiple endpoints are deprecated but still could be specified,
// only the first one be used, though
k8s.APIServerList = args
if len(args) > 1 {
log.Warningf("Multiple endpoints have been deprecated, only the first specified endpoint '%s' is used", args[0])
}
continue
}
return nil, c.ArgErr()
case "tls": // cert key cacertfile
args := c.RemainingArgs()
if len(args) == 3 {
k8s.APIClientCert, k8s.APIClientKey, k8s.APICertAuth = args[0], args[1], args[2]
continue
}
return nil, c.ArgErr()
case "labels":
args := c.RemainingArgs()
if len(args) > 0 {
labelSelectorString := strings.Join(args, " ")
ls, err := meta.ParseToLabelSelector(labelSelectorString)
if err != nil {
return nil, fmt.Errorf("unable to parse label selector value: '%v': %v", labelSelectorString, err)
}
k8s.opts.labelSelector = ls
continue
}
return nil, c.ArgErr()
case "namespace_labels":
args := c.RemainingArgs()
if len(args) > 0 {
namespaceLabelSelectorString := strings.Join(args, " ")
nls, err := meta.ParseToLabelSelector(namespaceLabelSelectorString)
if err != nil {
return nil, fmt.Errorf("unable to parse namespace_label selector value: '%v': %v", namespaceLabelSelectorString, err)
}
k8s.opts.namespaceLabelSelector = nls
continue
}
return nil, c.ArgErr()
case "fallthrough":
k8s.Fall.SetZonesFromArgs(c.RemainingArgs())
case "ttl":
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
t, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
if t < 0 || t > 3600 {
return nil, c.Errf("ttl must be in range [0, 3600]: %d", t)
}
k8s.ttl = uint32(t)
case "noendpoints":
if len(c.RemainingArgs()) != 0 {
return nil, c.ArgErr()
}
k8s.opts.initEndpointsCache = false
case "ignore":
args := c.RemainingArgs()
if len(args) > 0 {
ignore := args[0]
if ignore == "empty_service" {
k8s.opts.ignoreEmptyService = true
continue
}
return nil, fmt.Errorf("unable to parse ignore value: '%v'", ignore)
}
case "kubeconfig":
args := c.RemainingArgs()
if len(args) != 1 && len(args) != 2 {
return nil, c.ArgErr()
}
overrides := &clientcmd.ConfigOverrides{}
if len(args) == 2 {
overrides.CurrentContext = args[1]
}
config := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
&clientcmd.ClientConfigLoadingRules{ExplicitPath: args[0]},
overrides,
)
k8s.ClientConfig = config
case "multicluster":
k8s.opts.multiclusterZones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), []string{})
case "startup_timeout":
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
} else {
var err error
k8s.startupTimeout, err = time.ParseDuration(args[0])
if err != nil {
return nil, fmt.Errorf("failed to parse startup_timeout: %v, %s", args[0], err)
}
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
if len(k8s.Namespaces) != 0 && k8s.opts.namespaceLabelSelector != nil {
return nil, c.Errf("namespaces and namespace_labels cannot both be set")
}
for _, multiclusterZone := range k8s.opts.multiclusterZones {
if !slices.Contains(k8s.Zones, multiclusterZone) {
fmt.Println(k8s.Zones)
return nil, c.Errf("is not authoritative for the multicluster zone %s", multiclusterZone)
}
}
return k8s, nil
}
func searchFromResolvConf() []string {
rc, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil
}
plugin.Zones(rc.Search).Normalize()
return rc.Search
}
package kubernetes
import (
"context"
"math"
"net"
"sort"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/transfer"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
api "k8s.io/api/core/v1"
)
// Transfer implements the transfer.Transfer interface.
func (k *Kubernetes) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) {
match := plugin.Zones(k.Zones).Matches(zone)
if match == "" {
return nil, transfer.ErrNotAuthoritative
}
// state is not used here, hence the empty request.Request{]
soa, err := plugin.SOA(context.TODO(), k, zone, request.Request{}, plugin.Options{})
if err != nil {
return nil, transfer.ErrNotAuthoritative
}
ch := make(chan []dns.RR)
zonePath := msg.Path(zone, "coredns")
go func() {
// ixfr fallback
if serial != 0 && soa[0].(*dns.SOA).Serial == serial {
ch <- soa
close(ch)
return
}
ch <- soa
nsAddrs := k.nsAddrs(false, false, zone)
nsHosts := make(map[string]struct{})
for _, nsAddr := range nsAddrs {
nsHost := nsAddr.Header().Name
if _, ok := nsHosts[nsHost]; !ok {
nsHosts[nsHost] = struct{}{}
ch <- []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: zone, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: k.ttl}, Ns: nsHost}}
}
ch <- nsAddrs
if !k.isMultiClusterZone(zone) {
k.transferServices(ch, zonePath)
} else {
k.transferMultiClusterServices(ch, zonePath)
}
ch <- soa
close(ch)
}
}()
return ch, nil
}
func (k *Kubernetes) transferServices(ch chan []dns.RR, zonePath string) {
serviceList := k.APIConn.ServiceList()
sort.Slice(serviceList, func(i, j int) bool {
return serviceList[i].Name < serviceList[j].Name
})
for _, svc := range serviceList {
if !k.namespaceExposed(svc.Namespace) {
continue
}
svcBase := []string{zonePath, Svc, svc.Namespace, svc.Name}
switch svc.Type {
case api.ServiceTypeClusterIP, api.ServiceTypeNodePort, api.ServiceTypeLoadBalancer:
clusterIP := net.ParseIP(svc.ClusterIPs[0])
if clusterIP != nil {
var host string
for _, ip := range svc.ClusterIPs {
s := msg.Service{Host: ip, TTL: k.ttl}
s.Key = strings.Join(svcBase, "/")
// Change host from IP to Name for SRV records
host = emitAddressRecord(ch, s)
}
for _, p := range svc.Ports {
s := msg.Service{Host: host, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join(svcBase, "/")
// Need to generate this to handle use cases for peer-finder
// ref: https://github.com/coredns/coredns/pull/823
ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), 100)}
// As per spec unnamed ports do not have a srv record
// https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records
if p.Name == "" {
continue
}
s.Key = strings.Join(append(svcBase, strings.ToLower("_"+string(p.Protocol)), strings.ToLower("_"+p.Name)), "/")
ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), 100)}
}
// Skip endpoint discovery if clusterIP is defined
continue
}
endpointsList := k.APIConn.EpIndex(svc.Name + "." + svc.Namespace)
for _, ep := range endpointsList {
for _, eps := range ep.Subsets {
srvWeight := calcSRVWeight(len(eps.Addresses))
for _, addr := range eps.Addresses {
s := msg.Service{Host: addr.IP, TTL: k.ttl}
s.Key = strings.Join(svcBase, "/")
// We don't need to change the msg.Service host from IP to Name yet
// so disregard the return value here
emitAddressRecord(ch, s)
s.Key = strings.Join(append(svcBase, endpointHostname(addr, k.endpointNameMode)), "/")
// Change host from IP to Name for SRV records
host := emitAddressRecord(ch, s)
s.Host = host
for _, p := range eps.Ports {
// As per spec unnamed ports do not have a srv record
// https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records
if p.Name == "" {
continue
}
s.Port = int(p.Port)
s.Key = strings.Join(append(svcBase, strings.ToLower("_"+p.Protocol), strings.ToLower("_"+p.Name)), "/")
ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), srvWeight)}
}
}
}
}
case api.ServiceTypeExternalName:
s := msg.Service{Key: strings.Join(svcBase, "/"), Host: svc.ExternalName, TTL: k.ttl}
if t, _ := s.HostType(); t == dns.TypeCNAME {
ch <- []dns.RR{s.NewCNAME(msg.Domain(s.Key), s.Host)}
}
}
}
}
func (k *Kubernetes) transferMultiClusterServices(ch chan []dns.RR, zonePath string) {
serviceImportList := k.APIConn.ServiceImportList()
sort.Slice(serviceImportList, func(i, j int) bool {
return serviceImportList[i].Name < serviceImportList[j].Name
})
for _, svcImport := range serviceImportList {
if !k.namespaceExposed(svcImport.Namespace) {
continue
}
svcBase := []string{zonePath, Svc, svcImport.Namespace, svcImport.Name}
var clusterIP net.IP
if len(svcImport.ClusterIPs) > 0 {
clusterIP = net.ParseIP(svcImport.ClusterIPs[0])
}
if clusterIP != nil {
var host string
for _, ip := range svcImport.ClusterIPs {
s := msg.Service{Host: ip, TTL: k.ttl}
s.Key = strings.Join(svcBase, "/")
// Change host from IP to Name for SRV records
host = emitAddressRecord(ch, s)
}
for _, p := range svcImport.Ports {
s := msg.Service{Host: host, Port: int(p.Port), TTL: k.ttl}
s.Key = strings.Join(svcBase, "/")
// Need to generate this to handle use cases for peer-finder
// ref: https://github.com/coredns/coredns/pull/823
ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), 100)}
// As per spec unnamed ports do not have a srv record
// https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records
if p.Name == "" {
continue
}
s.Key = strings.Join(append(svcBase, strings.ToLower("_"+string(p.Protocol)), strings.ToLower("_"+p.Name)), "/")
ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), 100)}
}
// Skip endpoint discovery if clusterIP is defined
continue
}
endpointsList := k.APIConn.McEpIndex(svcImport.Name + "." + svcImport.Namespace)
for _, ep := range endpointsList {
for _, eps := range ep.Subsets {
srvWeight := calcSRVWeight(len(eps.Addresses))
for _, addr := range eps.Addresses {
s := msg.Service{Host: addr.IP, TTL: k.ttl}
s.Key = strings.Join(svcBase, "/")
// We don't need to change the msg.Service host from IP to Name yet
// so disregard the return value here
emitAddressRecord(ch, s)
s.Key = strings.Join(append(svcBase, endpointHostname(addr, k.endpointNameMode)), "/")
// Change host from IP to Name for SRV records
host := emitAddressRecord(ch, s)
s.Host = host
for _, p := range eps.Ports {
// As per spec unnamed ports do not have a srv record
// https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records
if p.Name == "" {
continue
}
s.Port = int(p.Port)
s.Key = strings.Join(append(svcBase, strings.ToLower("_"+p.Protocol), strings.ToLower("_"+p.Name)), "/")
ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), srvWeight)}
}
}
}
}
}
}
// emitAddressRecord generates a new A or AAAA record based on the msg.Service and writes it to a channel.
// emitAddressRecord returns the host name from the generated record.
func emitAddressRecord(c chan<- []dns.RR, s msg.Service) string {
ip := net.ParseIP(s.Host)
dnsType, _ := s.HostType()
switch dnsType {
case dns.TypeA:
r := s.NewA(msg.Domain(s.Key), ip)
c <- []dns.RR{r}
return r.Hdr.Name
case dns.TypeAAAA:
r := s.NewAAAA(msg.Domain(s.Key), ip)
c <- []dns.RR{r}
return r.Hdr.Name
}
return ""
}
// calcSRVWeight borrows the logic implemented in plugin.SRV for dynamically
// calculating the srv weight and priority
func calcSRVWeight(numservices int) uint16 {
services := make([]msg.Service, 0, numservices)
for range numservices {
services = append(services, msg.Service{})
}
w := make(map[int]int)
for _, serv := range services {
weight := 100
if serv.Weight != 0 {
weight = serv.Weight
}
if _, ok := w[serv.Priority]; !ok {
w[serv.Priority] = weight
continue
}
w[serv.Priority] += weight
}
weight := uint16(math.Floor((100.0 / float64(w[0])) * 100))
// weight should be at least 1
if weight == 0 {
weight = 1
}
return weight
}
// Package loadbalance is a plugin for rewriting responses to do "load balancing".
package loadbalance
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
// LoadBalance is a plugin to rewrite responses for "load balancing".
type LoadBalance struct {
Next plugin.Handler
shuffle func(*dns.Msg) *dns.Msg
}
// ServeDNS implements the plugin.Handler interface.
func (lb LoadBalance) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rw := &LoadBalanceResponseWriter{ResponseWriter: w, shuffle: lb.shuffle}
return plugin.NextOrFailure(lb.Name(), lb.Next, ctx, rw, r)
}
// Name implements the Handler interface.
func (lb LoadBalance) Name() string { return "loadbalance" }
package loadbalance
import (
"github.com/miekg/dns"
)
const (
ramdomShufflePolicy = "round_robin"
weightedRoundRobinPolicy = "weighted"
)
// LoadBalanceResponseWriter is a response writer that shuffles A, AAAA and MX records.
type LoadBalanceResponseWriter struct {
dns.ResponseWriter
shuffle func(*dns.Msg) *dns.Msg
}
// WriteMsg implements the dns.ResponseWriter interface.
func (r *LoadBalanceResponseWriter) WriteMsg(res *dns.Msg) error {
if res.Rcode != dns.RcodeSuccess {
return r.ResponseWriter.WriteMsg(res)
}
if res.Question[0].Qtype == dns.TypeAXFR || res.Question[0].Qtype == dns.TypeIXFR {
return r.ResponseWriter.WriteMsg(res)
}
return r.ResponseWriter.WriteMsg(r.shuffle(res))
}
func randomShuffle(res *dns.Msg) *dns.Msg {
res.Answer = roundRobin(res.Answer)
res.Ns = roundRobin(res.Ns)
res.Extra = roundRobin(res.Extra)
return res
}
func roundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{}
address := []dns.RR{}
mx := []dns.RR{}
rest := []dns.RR{}
for _, r := range in {
switch r.Header().Rrtype {
case dns.TypeCNAME:
cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA:
address = append(address, r)
case dns.TypeMX:
mx = append(mx, r)
default:
rest = append(rest, r)
}
}
roundRobinShuffle(address)
roundRobinShuffle(mx)
out := append(cname, rest...)
out = append(out, address...)
out = append(out, mx...)
return out
}
func roundRobinShuffle(records []dns.RR) {
switch l := len(records); l {
case 0, 1:
break
case 2:
if dns.Id()%2 == 0 {
records[0], records[1] = records[1], records[0]
}
default:
for j := range l {
p := j + (int(dns.Id()) % (l - j))
if j == p {
continue
}
records[j], records[p] = records[p], records[j]
}
}
}
// Write implements the dns.ResponseWriter interface.
func (r *LoadBalanceResponseWriter) Write(buf []byte) (int, error) {
// Should we pack and unpack here to fiddle with the packet... Not likely.
log.Warning("LoadBalance called with Write: not shuffling records")
n, err := r.ResponseWriter.Write(buf)
return n, err
}
package loadbalance
import (
"net"
"github.com/miekg/dns"
)
func reorderPreferredSubnets(msg *dns.Msg, subnets []*net.IPNet) *dns.Msg {
msg.Answer = reorderRecords(msg.Answer, subnets)
msg.Extra = reorderRecords(msg.Extra, subnets)
return msg
}
func reorderRecords(records []dns.RR, subnets []*net.IPNet) []dns.RR {
var cname, address, mx, rest []dns.RR
for _, r := range records {
switch r.Header().Rrtype {
case dns.TypeCNAME:
cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA:
address = append(address, r)
case dns.TypeMX:
mx = append(mx, r)
default:
rest = append(rest, r)
}
}
sorted := sortBySubnetPriority(address, subnets)
out := append([]dns.RR{}, cname...)
out = append(out, sorted...)
out = append(out, mx...)
out = append(out, rest...)
return out
}
func sortBySubnetPriority(records []dns.RR, subnets []*net.IPNet) []dns.RR {
matched := make([]dns.RR, 0, len(records))
seen := make(map[int]bool)
for _, subnet := range subnets {
for i, r := range records {
if seen[i] {
continue
}
ip := extractIP(r)
if ip != nil && subnet.Contains(ip) {
matched = append(matched, r)
seen[i] = true
}
}
}
unmatched := make([]dns.RR, 0, len(records)-len(matched))
for i, r := range records {
if !seen[i] {
unmatched = append(unmatched, r)
}
}
return append(matched, unmatched...)
}
func extractIP(rr dns.RR) net.IP {
switch r := rr.(type) {
case *dns.A:
return r.A
case *dns.AAAA:
return r.AAAA
default:
return nil
}
}
package loadbalance
import (
"errors"
"fmt"
"net"
"path/filepath"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("loadbalance")
var errOpen = errors.New("weight file open error")
func init() { plugin.Register("loadbalance", setup) }
type lbFuncs struct {
shuffleFunc func(*dns.Msg) *dns.Msg
onStartUpFunc func() error
onShutdownFunc func() error
weighted *weightedRR // used in unit tests only
preferSubnets []*net.IPNet
}
func setup(c *caddy.Controller) error {
//shuffleFunc, startUpFunc, shutdownFunc, err := parse(c)
lb, err := parse(c)
if err != nil {
return plugin.Error("loadbalance", err)
}
if lb.onStartUpFunc != nil {
c.OnStartup(lb.onStartUpFunc)
}
if lb.onShutdownFunc != nil {
c.OnShutdown(lb.onShutdownFunc)
}
shuffle := lb.shuffleFunc
if len(lb.preferSubnets) > 0 {
original := shuffle
shuffle = func(res *dns.Msg) *dns.Msg {
return reorderPreferredSubnets(original(res), lb.preferSubnets)
}
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return LoadBalance{Next: next, shuffle: shuffle}
})
return nil
}
func parse(c *caddy.Controller) (*lbFuncs, error) {
config := dnsserver.GetConfig(c)
lb := &lbFuncs{}
for c.Next() {
args := c.RemainingArgs()
if len(args) == 0 {
lb.shuffleFunc = randomShuffle
} else {
switch args[0] {
case ramdomShufflePolicy:
if len(args) > 1 {
return nil, c.Errf("unknown property for %s", args[0])
}
lb.shuffleFunc = randomShuffle
case weightedRoundRobinPolicy:
if len(args) < 2 {
return nil, c.Err("missing weight file argument")
}
if len(args) > 2 {
return nil, c.Err("unexpected argument(s)")
}
weightFileName := args[1]
if !filepath.IsAbs(weightFileName) && config.Root != "" {
weightFileName = filepath.Join(config.Root, weightFileName)
}
reload := 30 * time.Second
for c.NextBlock() {
switch c.Val() {
case "reload":
t := c.RemainingArgs()
if len(t) < 1 {
return nil, c.Err("reload duration value is missing")
}
if len(t) > 1 {
return nil, c.Err("unexpected argument")
}
var err error
reload, err = time.ParseDuration(t[0])
if err != nil {
return nil, c.Errf("invalid reload duration '%s'", t[0])
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
*lb = *createWeightedFuncs(weightFileName, reload)
default:
return nil, fmt.Errorf("unknown policy: %s", args[0])
}
}
for c.NextBlock() {
switch c.Val() {
case "prefer":
cidrs := c.RemainingArgs()
for _, cidr := range cidrs {
_, subnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, c.Errf("invalid CIDR %q: %v", cidr, err)
}
lb.preferSubnets = append(lb.preferSubnets, subnet)
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
}
return lb, nil
}
package loadbalance
import (
"bufio"
"bytes"
"crypto/md5"
"errors"
"fmt"
"io"
"math/rand"
"net"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
type (
// "weighted-round-robin" policy specific data
weightedRR struct {
fileName string
reload time.Duration
md5sum [md5.Size]byte
domains map[string]weights
randomGen
mutex sync.Mutex
}
// Per domain weights
weights []*weightItem
// Weight assigned to an address
weightItem struct {
address net.IP
value uint8
}
// Random uint generator
randomGen interface {
randInit()
randUint(limit uint) uint
}
)
// Random uint generator
type randomUint struct {
rn *rand.Rand
}
func (r *randomUint) randInit() {
r.rn = rand.New(rand.NewSource(time.Now().UnixNano()))
}
func (r *randomUint) randUint(limit uint) uint {
return uint(r.rn.Intn(int(limit)))
}
func weightedShuffle(res *dns.Msg, w *weightedRR) *dns.Msg {
switch res.Question[0].Qtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeSRV:
res.Answer = w.weightedRoundRobin(res.Answer)
res.Extra = w.weightedRoundRobin(res.Extra)
}
return res
}
func weightedOnStartUp(w *weightedRR, stopReloadChan chan bool) error {
err := w.updateWeights()
if errors.Is(err, errOpen) && w.reload != 0 {
log.Warningf("Failed to open weight file:%v. Will try again in %v",
err, w.reload)
} else if err != nil {
return plugin.Error("loadbalance", err)
}
// start periodic weight file reload go routine
w.periodicWeightUpdate(stopReloadChan)
return nil
}
func createWeightedFuncs(weightFileName string,
reload time.Duration) *lbFuncs {
lb := &lbFuncs{
weighted: &weightedRR{
fileName: weightFileName,
reload: reload,
randomGen: &randomUint{},
},
}
lb.weighted.randInit()
lb.shuffleFunc = func(res *dns.Msg) *dns.Msg {
return weightedShuffle(res, lb.weighted)
}
stopReloadChan := make(chan bool)
lb.onStartUpFunc = func() error {
return weightedOnStartUp(lb.weighted, stopReloadChan)
}
lb.onShutdownFunc = func() error {
// stop periodic weigh reload go routine
close(stopReloadChan)
return nil
}
return lb
}
// Apply weighted round robin policy to the answer
func (w *weightedRR) weightedRoundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{}
address := []dns.RR{}
mx := []dns.RR{}
rest := []dns.RR{}
for _, r := range in {
switch r.Header().Rrtype {
case dns.TypeCNAME:
cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA:
address = append(address, r)
case dns.TypeMX:
mx = append(mx, r)
default:
rest = append(rest, r)
}
}
if len(address) == 0 {
// no change
return in
}
w.setTopRecord(address)
out := append(cname, rest...)
out = append(out, address...)
out = append(out, mx...)
return out
}
// Move the next expected address to the first position in the result list
func (w *weightedRR) setTopRecord(address []dns.RR) {
itop := w.topAddressIndex(address)
if itop < 0 {
// internal error
return
}
if itop != 0 {
// swap the selected top entry with the actual one
address[0], address[itop] = address[itop], address[0]
}
}
// Compute the top (first) address index
func (w *weightedRR) topAddressIndex(address []dns.RR) int {
w.mutex.Lock()
defer w.mutex.Unlock()
// Determine the weight value for each address in the answer
var wsum uint
type waddress struct {
index int
weight uint8
}
weightedAddr := make([]waddress, len(address))
for i, ar := range address {
wa := &weightedAddr[i]
wa.index = i
wa.weight = 1 // default weight
var ip net.IP
switch ar.Header().Rrtype {
case dns.TypeA:
ip = ar.(*dns.A).A
case dns.TypeAAAA:
ip = ar.(*dns.AAAA).AAAA
}
ws := w.domains[ar.Header().Name]
for _, w := range ws {
if w.address.Equal(ip) {
wa.weight = w.value
break
}
}
wsum += uint(wa.weight)
}
// Select the first (top) IP
sort.Slice(weightedAddr, func(i, j int) bool {
return weightedAddr[i].weight > weightedAddr[j].weight
})
v := w.randUint(wsum)
var psum uint
for _, wa := range weightedAddr {
psum += uint(wa.weight)
if v < psum {
return wa.index
}
}
// we should never reach this
log.Errorf("Internal error: cannot find top address (randv:%v wsum:%v)", v, wsum)
return -1
}
// Start go routine to update weights from the weight file periodically
func (w *weightedRR) periodicWeightUpdate(stopReload <-chan bool) {
if w.reload == 0 {
return
}
go func() {
ticker := time.NewTicker(w.reload)
for {
select {
case <-stopReload:
return
case <-ticker.C:
err := w.updateWeights()
if err != nil {
log.Error(err)
}
}
}
}()
}
// Update weights from weight file
func (w *weightedRR) updateWeights() error {
reader, err := os.Open(filepath.Clean(w.fileName))
if err != nil {
return errOpen
}
defer reader.Close()
// check if the contents has changed
var buf bytes.Buffer
tee := io.TeeReader(reader, &buf)
bytes, err := io.ReadAll(tee)
if err != nil {
return err
}
md5sum := md5.Sum(bytes)
if md5sum == w.md5sum {
// file contents has not changed
return nil
}
w.md5sum = md5sum
scanner := bufio.NewScanner(&buf)
// Parse the weight file contents
domains, err := w.parseWeights(scanner)
if err != nil {
return err
}
// access to weights must be protected
w.mutex.Lock()
w.domains = domains
w.mutex.Unlock()
log.Infof("Successfully reloaded weight file %s", w.fileName)
return nil
}
// Parse the weight file contents
func (w *weightedRR) parseWeights(scanner *bufio.Scanner) (map[string]weights, error) {
var dname string
var ws weights
domains := make(map[string]weights)
for scanner.Scan() {
nextLine := strings.TrimSpace(scanner.Text())
if len(nextLine) == 0 || nextLine[0:1] == "#" {
// Empty and comment lines are ignored
continue
}
fields := strings.Fields(nextLine)
switch len(fields) {
case 1:
// (domain) name sanity check
if net.ParseIP(fields[0]) != nil {
return nil, fmt.Errorf("wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)",
fields[0], w.fileName)
}
dname = fields[0]
// add the root domain if it is missing
if dname[len(dname)-1] != '.' {
dname += "."
}
var ok bool
ws, ok = domains[dname]
if !ok {
ws = make(weights, 0)
domains[dname] = ws
}
case 2:
// IP address and weight value
ip := net.ParseIP(fields[0])
if ip == nil {
return nil, fmt.Errorf("wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName)
}
weight, err := strconv.ParseUint(fields[1], 10, 8)
if err != nil || weight == 0 {
return nil, fmt.Errorf("wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName)
}
witem := &weightItem{address: ip, value: uint8(weight)}
if dname == "" {
return nil, fmt.Errorf("missing domain name in weight file %s", w.fileName)
}
ws = append(ws, witem)
domains[dname] = ws
default:
return nil, fmt.Errorf("could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName)
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("weight file %s parsing error:%s", w.fileName, err)
}
return domains, nil
}
package local
import (
"context"
"net"
"strings"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("local")
// Local is a plugin that returns standard replies for local queries.
type Local struct {
Next plugin.Handler
}
var zones = []string{"localhost.", "0.in-addr.arpa.", "127.in-addr.arpa.", "255.in-addr.arpa."}
func soaFromOrigin(origin string) []dns.RR {
hdr := dns.RR_Header{Name: origin, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeSOA}
return []dns.RR{&dns.SOA{Hdr: hdr, Ns: "localhost.", Mbox: "root.localhost.", Serial: 1, Refresh: 0, Retry: 0, Expire: 0, Minttl: ttl}}
}
func nsFromOrigin(origin string) []dns.RR {
hdr := dns.RR_Header{Name: origin, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeNS}
return []dns.RR{&dns.NS{Hdr: hdr, Ns: "localhost."}}
}
// ServeDNS implements the plugin.Handler interface.
func (l Local) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.QName()
lc := len("localhost.")
if len(state.Name()) > lc && strings.HasPrefix(state.Name(), "localhost.") {
// we have multiple labels, but the first one is localhost, intercept this and return 127.0.0.1 or ::1
log.Debugf("Intercepting localhost query for %q %s, from %s", state.Name(), state.Type(), state.IP())
LocalhostCount.Inc()
reply := doLocalhost(state)
w.WriteMsg(reply)
return 0, nil
}
zone := plugin.Zones(zones).Matches(qname)
if zone == "" {
return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
}
m := new(dns.Msg)
m.SetReply(r)
zone = qname[len(qname)-len(zone):]
switch q := state.Name(); q {
case "localhost.", "0.in-addr.arpa.", "127.in-addr.arpa.", "255.in-addr.arpa.":
switch state.QType() {
case dns.TypeA:
if q != "localhost." {
// nodata
m.Ns = soaFromOrigin(qname)
break
}
hdr := dns.RR_Header{Name: qname, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeA}
m.Answer = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP("127.0.0.1").To4()}}
case dns.TypeAAAA:
if q != "localhost." {
// nodata
m.Ns = soaFromOrigin(qname)
break
}
hdr := dns.RR_Header{Name: qname, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeAAAA}
m.Answer = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP("::1")}}
case dns.TypeSOA:
m.Answer = soaFromOrigin(qname)
case dns.TypeNS:
m.Answer = nsFromOrigin(qname)
default:
// nodata
m.Ns = soaFromOrigin(qname)
}
case "1.0.0.127.in-addr.arpa.":
switch state.QType() {
case dns.TypePTR:
hdr := dns.RR_Header{Name: qname, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypePTR}
m.Answer = []dns.RR{&dns.PTR{Hdr: hdr, Ptr: "localhost."}}
default:
// nodata
m.Ns = soaFromOrigin(zone)
}
}
if len(m.Answer) == 0 && len(m.Ns) == 0 {
m.Ns = soaFromOrigin(zone)
m.Rcode = dns.RcodeNameError
}
w.WriteMsg(m)
return 0, nil
}
// Name implements the plugin.Handler interface.
func (l Local) Name() string { return "local" }
func doLocalhost(state request.Request) *dns.Msg {
m := new(dns.Msg)
m.SetReply(state.Req)
switch state.QType() {
case dns.TypeA:
hdr := dns.RR_Header{Name: state.QName(), Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeA}
m.Answer = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP("127.0.0.1").To4()}}
case dns.TypeAAAA:
hdr := dns.RR_Header{Name: state.QName(), Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeAAAA}
m.Answer = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP("::1")}}
default:
// nodata
m.Ns = soaFromOrigin(state.QName())
}
return m
}
const ttl = 604800
package local
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("local", setup) }
func setup(c *caddy.Controller) error {
l := Local{}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
l.Next = next
return l
})
return nil
}
// Package log implements basic but useful request (access) logging plugin.
package log
import (
"context"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnstest"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/replacer"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Logger is a basic request logging plugin.
type Logger struct {
Next plugin.Handler
Rules []Rule
repl replacer.Replacer
}
// ServeDNS implements the plugin.Handler interface.
func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
name := state.Name()
for _, rule := range l.Rules {
if !plugin.Name(rule.NameScope).Matches(name) {
continue
}
rrw := dnstest.NewRecorder(w)
rc, err := plugin.NextOrFailure(l.Name(), l.Next, ctx, rrw, r)
// If we don't set up a class in config, the default "all" will be added
// and we shouldn't have an empty rule.Class.
_, ok := rule.Class[response.All]
var ok1 bool
if !ok {
tpe, _ := response.Typify(rrw.Msg, time.Now().UTC())
class := response.Classify(tpe)
_, ok1 = rule.Class[class]
}
if ok || ok1 {
logstr := l.repl.Replace(ctx, state, rrw, rule.Format)
clog.Info(logstr)
}
return rc, err
}
return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
}
// Name implements the Handler interface.
func (l Logger) Name() string { return "log" }
// Rule configures the logging plugin.
type Rule struct {
NameScope string
Class map[response.Class]struct{}
Format string
}
const (
// CommonLogFormat is the common log format.
CommonLogFormat = `{remote}:{port} ` + replacer.EmptyValue + ` {>id} "{type} {class} {name} {proto} {size} {>do} {>bufsize}" {rcode} {>rflags} {rsize} {duration}`
// CombinedLogFormat is the combined log format.
CombinedLogFormat = CommonLogFormat + ` "{>opcode}"`
// DefaultLogFormat is the default log format.
DefaultLogFormat = CommonLogFormat
)
package log
import (
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/replacer"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/miekg/dns"
)
func init() { plugin.Register("log", setup) }
func setup(c *caddy.Controller) error {
rules, err := logParse(c)
if err != nil {
return plugin.Error("log", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Logger{Next: next, Rules: rules, repl: replacer.New()}
})
return nil
}
func logParse(c *caddy.Controller) ([]Rule, error) {
var rules []Rule
for c.Next() {
args := c.RemainingArgs()
length := len(rules)
switch len(args) {
case 0:
// Nothing specified; use defaults
rules = append(rules, Rule{
NameScope: ".",
Format: DefaultLogFormat,
Class: make(map[response.Class]struct{}),
})
case 1:
rules = append(rules, Rule{
NameScope: dns.Fqdn(args[0]),
Format: DefaultLogFormat,
Class: make(map[response.Class]struct{}),
})
default:
// Name scopes, and maybe a format specified
format := DefaultLogFormat
if strings.Contains(args[len(args)-1], "{") {
format = args[len(args)-1]
format = strings.ReplaceAll(format, "{common}", CommonLogFormat)
format = strings.ReplaceAll(format, "{combined}", CombinedLogFormat)
args = args[:len(args)-1]
}
for _, str := range args {
rules = append(rules, Rule{
NameScope: dns.Fqdn(str),
Format: format,
Class: make(map[response.Class]struct{}),
})
}
}
// Class refinements in an extra block.
classes := make(map[response.Class]struct{})
for c.NextBlock() {
switch c.Val() {
// class followed by combinations of all, denial, error and success.
case "class":
classesArgs := c.RemainingArgs()
if len(classesArgs) == 0 {
return nil, c.ArgErr()
}
for _, c := range classesArgs {
cls, err := response.ClassFromString(c)
if err != nil {
return nil, err
}
classes[cls] = struct{}{}
}
default:
return nil, c.ArgErr()
}
}
if len(classes) == 0 {
classes[response.All] = struct{}{}
}
for i := len(rules) - 1; i >= length; i-- {
rules[i].Class = classes
}
}
return rules, nil
}
package loop
import (
"context"
"sync"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("loop")
// Loop is a plugin that implements loop detection by sending a "random" query.
type Loop struct {
Next plugin.Handler
zone string
qname string
addr string
sync.RWMutex
i int
off bool
}
// New returns a new initialized Loop.
func New(zone string) *Loop { return &Loop{zone: zone, qname: qname(zone)} }
// ServeDNS implements the plugin.Handler interface.
func (l *Loop) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if r.Question[0].Qtype != dns.TypeHINFO {
return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
}
if l.disabled() {
return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
}
state := request.Request{W: w, Req: r}
zone := plugin.Zones([]string{l.zone}).Matches(state.Name())
if zone == "" {
return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
}
if state.Name() == l.qname {
l.inc()
}
if l.seen() > 2 {
log.Fatalf(`Loop (%s -> %s) detected for zone %q, see https://coredns.io/plugins/loop#troubleshooting. Query: "HINFO %s"`, state.RemoteAddr(), l.address(), l.zone, l.qname)
}
return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
}
// Name implements the plugin.Handler interface.
func (l *Loop) Name() string { return "loop" }
func (l *Loop) exchange(addr string) (*dns.Msg, error) {
m := new(dns.Msg)
m.SetQuestion(l.qname, dns.TypeHINFO)
return dns.Exchange(m, addr)
}
func (l *Loop) seen() int {
l.RLock()
defer l.RUnlock()
return l.i
}
func (l *Loop) inc() {
l.Lock()
defer l.Unlock()
l.i++
}
func (l *Loop) reset() {
l.Lock()
defer l.Unlock()
l.i = 0
}
func (l *Loop) setDisabled() {
l.Lock()
defer l.Unlock()
l.off = true
}
func (l *Loop) disabled() bool {
l.RLock()
defer l.RUnlock()
return l.off
}
func (l *Loop) setAddress(addr string) {
l.Lock()
defer l.Unlock()
l.addr = addr
}
func (l *Loop) address() string {
l.RLock()
defer l.RUnlock()
return l.addr
}
package loop
import (
"net"
"strconv"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/rand"
)
func init() { plugin.Register("loop", setup) }
func setup(c *caddy.Controller) error {
l, err := parse(c)
if err != nil {
return plugin.Error("loop", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
l.Next = next
return l
})
// Send query to ourselves and see if it end up with us again.
c.OnStartup(func() error {
// Another Go function, otherwise we block startup and can't send the packet.
go func() {
deadline := time.Now().Add(30 * time.Second)
conf := dnsserver.GetConfig(c)
lh := ""
if len(conf.ListenHosts) > 0 {
lh = conf.ListenHosts[0]
}
addr := net.JoinHostPort(lh, conf.Port)
for time.Now().Before(deadline) {
l.setAddress(addr)
if _, err := l.exchange(addr); err != nil {
l.reset()
time.Sleep(1 * time.Second)
continue
}
go func() {
time.Sleep(2 * time.Second)
l.setDisabled()
}()
break
}
l.setDisabled()
}()
return nil
})
return nil
}
func parse(c *caddy.Controller) (*Loop, error) {
i := 0
zones := []string{"."}
for c.Next() {
if i > 0 {
return nil, plugin.ErrOnce
}
i++
if c.NextArg() {
return nil, c.ArgErr()
}
if len(c.ServerBlockKeys) > 0 {
z := plugin.Host(c.ServerBlockKeys[0]).NormalizeExact()
if len(z) > 0 {
zones = z
}
}
}
return New(zones[0]), nil
}
// qname returns a random name. <rand.Int()>.<rand.Int().<zone>.
func qname(zone string) string {
l1 := strconv.Itoa(r.Int())
l2 := strconv.Itoa(r.Int())
return dnsutil.Join(l1, l2, zone)
}
var r = rand.New(time.Now().UnixNano())
package metadata
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Metadata implements collecting metadata information from all plugins that
// implement the Provider interface.
type Metadata struct {
Zones []string
Providers []Provider
Next plugin.Handler
}
// Name implements the Handler interface.
func (m *Metadata) Name() string { return "metadata" }
// ContextWithMetadata is exported for use by provider tests
func ContextWithMetadata(ctx context.Context) context.Context {
return context.WithValue(ctx, key{}, md{})
}
// ServeDNS implements the plugin.Handler interface.
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rcode, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, w, r)
return rcode, err
}
// Collect will retrieve metadata functions from each metadata provider and update the context
func (m *Metadata) Collect(ctx context.Context, state request.Request) context.Context {
ctx = ContextWithMetadata(ctx)
if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
// Go through all Providers and collect metadata.
for _, p := range m.Providers {
ctx = p.Metadata(ctx, state)
}
}
return ctx
}
// Package metadata provides an API that allows plugins to add metadata to the context.
// Each metadata is stored under a label that has the form <plugin>/<name>. Each metadata
// is returned as a Func. When Func is called the metadata is returned. If Func is expensive to
// execute it is its responsibility to provide some form of caching. During the handling of a
// query it is expected the metadata stays constant.
//
// Basic example:
//
// Implement the Provider interface for a plugin p:
//
// func (p P) Metadata(ctx context.Context, state request.Request) context.Context {
// metadata.SetValueFunc(ctx, "test/something", func() string { return "myvalue" })
// return ctx
// }
//
// Basic example with caching:
//
// func (p P) Metadata(ctx context.Context, state request.Request) context.Context {
// cached := ""
// f := func() string {
// if cached != "" {
// return cached
// }
// cached = expensiveFunc()
// return cached
// }
// metadata.SetValueFunc(ctx, "test/something", f)
// return ctx
// }
//
// If you need access to this metadata from another plugin:
//
// // ...
// valueFunc := metadata.ValueFunc(ctx, "test/something")
// value := valueFunc()
// // use 'value'
package metadata
import (
"context"
"strings"
"github.com/coredns/coredns/request"
)
// Provider interface needs to be implemented by each plugin willing to provide
// metadata information for other plugins.
type Provider interface {
// Metadata adds metadata to the context and returns a (potentially) new context.
// Note: this method should work quickly, because it is called for every request
// from the metadata plugin.
Metadata(ctx context.Context, state request.Request) context.Context
}
// Func is the type of function in the metadata, when called they return the value of the label.
type Func func() string
// IsLabel checks that the provided name is a valid label name, i.e. two or more words separated by a slash.
func IsLabel(label string) bool {
p := strings.Index(label, "/")
if p <= 0 || p >= len(label)-1 {
// cannot accept namespace empty nor label empty
return false
}
return true
}
// Labels returns all metadata keys stored in the context. These label names should be named
// as: plugin/NAME, where NAME is something descriptive.
func Labels(ctx context.Context) []string {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
return keys(m)
}
}
return nil
}
// ValueFuncs returns the map[string]Func from the context, or nil if it does not exist.
func ValueFuncs(ctx context.Context) map[string]Func {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
return m
}
}
return nil
}
// ValueFunc returns the value function of label. If none can be found nil is returned. Calling the
// function returns the value of the label.
func ValueFunc(ctx context.Context, label string) Func {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
return m[label]
}
}
return nil
}
// SetValueFunc set the metadata label to the value function. If no metadata can be found this is a noop and
// false is returned. Any existing value is overwritten.
func SetValueFunc(ctx context.Context, label string, f Func) bool {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
m[label] = f
return true
}
}
return false
}
// md is metadata information storage.
type md map[string]Func
// key defines the type of key that is used to save metadata into the context.
type key struct{}
func keys(m map[string]Func) []string {
s := make([]string, len(m))
i := 0
for k := range m {
s[i] = k
i++
}
return s
}
package metadata
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("metadata", setup) }
func setup(c *caddy.Controller) error {
m, err := metadataParse(c)
if err != nil {
return err
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
m.Next = next
return m
})
c.OnStartup(func() error {
plugins := dnsserver.GetConfig(c).Handlers()
for _, p := range plugins {
if met, ok := p.(Provider); ok {
m.Providers = append(m.Providers, met)
}
}
return nil
})
return nil
}
func metadataParse(c *caddy.Controller) (*Metadata, error) {
m := &Metadata{}
c.Next()
m.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
if c.NextBlock() || c.Next() {
return nil, plugin.Error("metadata", c.ArgErr())
}
return m, nil
}
package metrics
import (
"context"
"github.com/coredns/coredns/core/dnsserver"
)
// WithServer returns the current server handling the request. It returns the
// server listening address: <scheme>://[<bind>]:<port> Normally this is
// something like "dns://:53", but if the bind plugin is used, i.e. "bind
// 127.0.0.53", it will be "dns://127.0.0.53:53", etc. If not address is found
// the empty string is returned.
//
// Basic usage with a metric:
//
// <metric>.WithLabelValues(metrics.WithServer(ctx), labels..).Add(1)
func WithServer(ctx context.Context) string {
srv := ctx.Value(dnsserver.Key{})
if srv == nil {
return ""
}
return srv.(*dnsserver.Server).Addr
}
// WithView returns the name of the view currently handling the request, if a view is defined.
//
// Basic usage with a metric:
//
// <metric>.WithLabelValues(metrics.WithView(ctx), labels..).Add(1)
func WithView(ctx context.Context) string {
v := ctx.Value(dnsserver.ViewKey{})
if v == nil {
return ""
}
return v.(string)
}
package metrics
import (
"context"
"path/filepath"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// ServeDNS implements the Handler interface.
func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
// Capture the original request size before any plugins modify it
originalSize := r.Len()
qname := state.QName()
zone := plugin.Zones(m.ZoneNames()).Matches(qname)
if zone == "" {
zone = "."
}
// Record response to get status code and size of the reply.
rw := NewRecorder(w)
status, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, rw, r)
rc := rw.Rcode
if !plugin.ClientWrite(status) {
// when no response was written, fallback to status returned from next plugin as this status
// is actually used as rcode of DNS response
// see https://github.com/coredns/coredns/blob/master/core/dnsserver/server.go#L318
rc = status
}
plugin := m.authoritativePlugin(rw.Caller)
// Pass the original request size to vars.Report
vars.Report(WithServer(ctx), state, zone, WithView(ctx), rcode.ToString(rc), plugin,
rw.Len, rw.Start, vars.WithOriginalReqSize(originalSize))
return status, err
}
// Name implements the Handler interface.
func (m *Metrics) Name() string { return "prometheus" }
// authoritativePlugin returns which of made the write, if none is found the empty string is returned.
func (m *Metrics) authoritativePlugin(caller [3]string) string {
// a b and c contain the full path of the caller, the plugin name 2nd last elements
// .../coredns/plugin/whoami/whoami.go --> whoami
// this is likely FS specific, so use filepath.
for _, c := range caller {
plug := filepath.Base(filepath.Dir(c))
if _, ok := m.plugins[plug]; ok {
return plug
}
}
return ""
}
// Package metrics implement a handler and plugin that provides Prometheus metrics.
package metrics
import (
"context"
"net"
"net/http"
"sync"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics .
type Metrics struct {
Next plugin.Handler
Addr string
Reg *prometheus.Registry
ln net.Listener
lnSetup bool
mux *http.ServeMux
srv *http.Server
zoneNames []string
zoneMap map[string]struct{}
zoneMu sync.RWMutex
plugins map[string]struct{} // all available plugins, used to determine which plugin made the client write
}
// New returns a new instance of Metrics with the given address.
func New(addr string) *Metrics {
met := &Metrics{
Addr: addr,
Reg: prometheus.DefaultRegisterer.(*prometheus.Registry),
zoneMap: make(map[string]struct{}),
plugins: pluginList(caddy.ListPlugins()),
}
return met
}
// MustRegister wraps m.Reg.MustRegister.
func (m *Metrics) MustRegister(c prometheus.Collector) {
err := m.Reg.Register(c)
if err != nil {
// ignore any duplicate error, but fatal on any other kind of error
if _, ok := err.(prometheus.AlreadyRegisteredError); !ok {
log.Fatalf("Cannot register metrics collector: %s", err)
}
}
}
// AddZone adds zone z to m.
func (m *Metrics) AddZone(z string) {
m.zoneMu.Lock()
m.zoneMap[z] = struct{}{}
m.zoneNames = keys(m.zoneMap)
m.zoneMu.Unlock()
}
// RemoveZone remove zone z from m.
func (m *Metrics) RemoveZone(z string) {
m.zoneMu.Lock()
delete(m.zoneMap, z)
m.zoneNames = keys(m.zoneMap)
m.zoneMu.Unlock()
}
// ZoneNames returns the zones of m.
func (m *Metrics) ZoneNames() []string {
m.zoneMu.RLock()
s := m.zoneNames
m.zoneMu.RUnlock()
return s
}
// OnStartup sets up the metrics on startup.
func (m *Metrics) OnStartup() error {
ln, err := reuseport.Listen("tcp", m.Addr)
if err != nil {
log.Errorf("Failed to start metrics handler: %s", err)
return err
}
m.ln = ln
m.lnSetup = true
m.mux = http.NewServeMux()
m.mux.Handle("/metrics", promhttp.HandlerFor(m.Reg, promhttp.HandlerOpts{}))
// creating some helper variables to avoid data races on m.srv and m.ln
server := &http.Server{
Handler: m.mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 5 * time.Second,
}
m.srv = server
go func() {
server.Serve(ln)
}()
ListenAddr = ln.Addr().String() // For tests.
return nil
}
// OnRestart stops the listener on reload.
func (m *Metrics) OnRestart() error {
if !m.lnSetup {
return nil
}
u.Unset(m.Addr)
return m.stopServer()
}
func (m *Metrics) stopServer() error {
if !m.lnSetup {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
if err := m.srv.Shutdown(ctx); err != nil {
log.Infof("Failed to stop prometheus http server: %s", err)
return err
}
m.lnSetup = false
m.ln.Close()
return nil
}
// OnFinalShutdown tears down the metrics listener on shutdown and restart.
func (m *Metrics) OnFinalShutdown() error { return m.stopServer() }
func keys(m map[string]struct{}) []string {
sx := []string{}
for k := range m {
sx = append(sx, k)
}
return sx
}
// pluginList iterates over the returned plugin map from caddy and removes the "dns." prefix from them.
func pluginList(m map[string][]string) map[string]struct{} {
pm := map[string]struct{}{}
for _, p := range m["others"] {
// only add 'dns.' plugins
if len(p) > 3 {
pm[p[4:]] = struct{}{}
continue
}
}
return pm
}
// ListenAddr is assigned the address of the prometheus listener. Its use is mainly in tests where
// we listen on "localhost:0" and need to retrieve the actual address.
var ListenAddr string
// shutdownTimeout is the maximum amount of time the metrics plugin will wait
// before erroring when it tries to close the metrics server
const shutdownTimeout time.Duration = time.Second * 5
var buildInfo = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: plugin.Namespace,
Name: "build_info",
Help: "A metric with a constant '1' value labeled by version, revision, and goversion from which CoreDNS was built.",
}, []string{"version", "revision", "goversion"})
package metrics
import (
"runtime"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/miekg/dns"
)
// Recorder is a dnstest.Recorder specific to the metrics plugin.
type Recorder struct {
*dnstest.Recorder
// CallerN holds the string return value of the call to runtime.Caller(N+1)
Caller [3]string
}
// NewRecorder makes and returns a new Recorder.
func NewRecorder(w dns.ResponseWriter) *Recorder { return &Recorder{Recorder: dnstest.NewRecorder(w)} }
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *Recorder) WriteMsg(res *dns.Msg) error {
_, r.Caller[0], _, _ = runtime.Caller(1)
_, r.Caller[1], _, _ = runtime.Caller(2)
_, r.Caller[2], _, _ = runtime.Caller(3)
return r.Recorder.WriteMsg(res)
}
package metrics
import (
"sync"
"github.com/prometheus/client_golang/prometheus"
)
type reg struct {
sync.RWMutex
r map[string]*prometheus.Registry
}
func newReg() *reg { return ®{r: make(map[string]*prometheus.Registry)} }
// update sets the registry if not already there and returns the input. Or it returns
// a previous set value.
func (r *reg) getOrSet(addr string, pr *prometheus.Registry) *prometheus.Registry {
r.Lock()
defer r.Unlock()
if v, ok := r.r[addr]; ok {
return v
}
r.r[addr] = pr
return pr
}
package metrics
import (
"net"
"runtime"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/coremain"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/uniq"
)
var (
log = clog.NewWithPlugin("prometheus")
u = uniq.New()
registry = newReg()
)
func init() { plugin.Register("prometheus", setup) }
func setup(c *caddy.Controller) error {
m, err := parse(c)
if err != nil {
return plugin.Error("prometheus", err)
}
m.Reg = registry.getOrSet(m.Addr, m.Reg)
c.OnStartup(func() error { m.Reg = registry.getOrSet(m.Addr, m.Reg); u.Set(m.Addr, m.OnStartup); return nil })
c.OnRestartFailed(func() error { m.Reg = registry.getOrSet(m.Addr, m.Reg); u.Set(m.Addr, m.OnStartup); return nil })
c.OnStartup(func() error { return u.ForEach() })
c.OnRestartFailed(func() error { return u.ForEach() })
c.OnStartup(func() error {
conf := dnsserver.GetConfig(c)
for _, h := range conf.ListenHosts {
addrstr := conf.Transport + "://" + net.JoinHostPort(h, conf.Port)
for _, p := range conf.Handlers() {
vars.PluginEnabled.WithLabelValues(addrstr, conf.Zone, conf.ViewName, p.Name()).Set(1)
}
}
return nil
})
c.OnRestartFailed(func() error {
conf := dnsserver.GetConfig(c)
for _, h := range conf.ListenHosts {
addrstr := conf.Transport + "://" + net.JoinHostPort(h, conf.Port)
for _, p := range conf.Handlers() {
vars.PluginEnabled.WithLabelValues(addrstr, conf.Zone, conf.ViewName, p.Name()).Set(1)
}
}
return nil
})
c.OnRestart(m.OnRestart)
c.OnRestart(func() error { vars.PluginEnabled.Reset(); return nil })
c.OnFinalShutdown(m.OnFinalShutdown)
// Initialize metrics.
buildInfo.WithLabelValues(coremain.CoreVersion, coremain.GitCommit, runtime.Version()).Set(1)
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
m.Next = next
return m
})
return nil
}
func parse(c *caddy.Controller) (*Metrics, error) {
met := New(defaultAddr)
i := 0
for c.Next() {
if i > 0 {
return nil, plugin.ErrOnce
}
i++
zones := plugin.OriginsFromArgsOrServerBlock(nil /* args */, c.ServerBlockKeys)
for _, z := range zones {
met.AddZone(z)
}
args := c.RemainingArgs()
switch len(args) {
case 0:
case 1:
met.Addr = args[0]
_, _, e := net.SplitHostPort(met.Addr)
if e != nil {
return met, e
}
default:
return met, c.ArgErr()
}
}
return met, nil
}
// defaultAddr is the address the where the metrics are exported by default.
const defaultAddr = "localhost:9153"
package vars
import (
"github.com/miekg/dns"
)
var monitorType = map[uint16]struct{}{
dns.TypeAAAA: {},
dns.TypeA: {},
dns.TypeCNAME: {},
dns.TypeDNSKEY: {},
dns.TypeDS: {},
dns.TypeMX: {},
dns.TypeNSEC3: {},
dns.TypeNSEC: {},
dns.TypeNS: {},
dns.TypePTR: {},
dns.TypeRRSIG: {},
dns.TypeSOA: {},
dns.TypeSRV: {},
dns.TypeTXT: {},
dns.TypeHTTPS: {},
// Meta Qtypes
dns.TypeIXFR: {},
dns.TypeAXFR: {},
dns.TypeANY: {},
}
// qTypeString returns the RR type based on monitorType. It returns the text representation
// of those types. RR types not in that list will have "other" returned.
func qTypeString(qtype uint16) string {
if _, known := monitorType[qtype]; known {
return dns.Type(qtype).String()
}
return "other"
}
package vars
import (
"time"
"github.com/coredns/coredns/request"
)
// ReportOptions is a struct that contains available options for the Report function.
type ReportOptions struct {
OriginalReqSize int
}
// ReportOption defines a function that modifies ReportOptions
type ReportOption func(*ReportOptions)
// WithOriginalReqSize returns an option to set the original request size
func WithOriginalReqSize(size int) ReportOption {
return func(opts *ReportOptions) {
opts.OriginalReqSize = size
}
}
// Report reports the metrics data associated with request. This function is exported because it is also
// called from core/dnsserver to report requests hitting the server that should not be handled and are thus
// not sent down the plugin chain.
func Report(server string, req request.Request, zone, view, rcode, plugin string,
size int, start time.Time, opts ...ReportOption) {
options := ReportOptions{
OriginalReqSize: 0,
}
for _, opt := range opts {
opt(&options)
}
// Proto and Family.
net := req.Proto()
fam := "1"
if req.Family() == 2 {
fam = "2"
}
if req.Do() {
RequestDo.WithLabelValues(server, zone, view).Inc()
}
qType := qTypeString(req.QType())
RequestCount.WithLabelValues(server, zone, view, net, fam, qType).Inc()
RequestDuration.WithLabelValues(server, zone, view).Observe(time.Since(start).Seconds())
ResponseSize.WithLabelValues(server, zone, view, net).Observe(float64(size))
reqSize := req.Len()
if options.OriginalReqSize > 0 {
reqSize = options.OriginalReqSize
}
RequestSize.WithLabelValues(server, zone, view, net).Observe(float64(reqSize))
ResponseRcode.WithLabelValues(server, zone, view, rcode, plugin).Inc()
}
package minimal
import (
"context"
"fmt"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/nonwriter"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/miekg/dns"
)
// minimalHandler implements the plugin.Handler interface.
type minimalHandler struct {
Next plugin.Handler
}
func (m *minimalHandler) Name() string { return "minimal" }
func (m *minimalHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
nw := nonwriter.New(w)
rcode, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, nw, r)
if err != nil {
return rcode, err
}
ty, _ := response.Typify(nw.Msg, time.Now().UTC())
cl := response.Classify(ty)
// if response is Denial or Error pass through also if the type is Delegation pass through
if cl == response.Denial || cl == response.Error || ty == response.Delegation {
w.WriteMsg(nw.Msg)
return 0, nil
}
if ty != response.NoError {
w.WriteMsg(nw.Msg)
return 0, plugin.Error("minimal", fmt.Errorf("unhandled response type %q for %q", ty, nw.Msg.Question[0].Name))
}
// copy over the original Msg params, deep copy not required as RRs are not modified
d := &dns.Msg{
MsgHdr: nw.Msg.MsgHdr,
Compress: nw.Msg.Compress,
Question: nw.Msg.Question,
Answer: nw.Msg.Answer,
Ns: nil,
Extra: nil,
}
w.WriteMsg(d)
return 0, nil
}
package minimal
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() {
plugin.Register("minimal", setup)
}
func setup(c *caddy.Controller) error {
c.Next()
if c.NextArg() {
return plugin.Error("minimal", c.ArgErr())
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return &minimalHandler{Next: next}
})
return nil
}
package multisocket
import (
"fmt"
"runtime"
"strconv"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
const pluginName = "multisocket"
const maxNumSockets = 1024
func init() { plugin.Register(pluginName, setup) }
func setup(c *caddy.Controller) error {
err := parseNumSockets(c)
if err != nil {
return plugin.Error(pluginName, err)
}
return nil
}
func parseNumSockets(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
c.Next() // "multisocket"
args := c.RemainingArgs()
if len(args) > 1 || c.Next() {
return c.ArgErr()
}
if len(args) == 0 {
// Nothing specified; use default that is equal to GOMAXPROCS.
config.NumSockets = runtime.GOMAXPROCS(0)
return nil
}
numSockets, err := strconv.Atoi(args[0])
if err != nil {
return fmt.Errorf("invalid num sockets: %w", err)
}
if numSockets < 1 {
return fmt.Errorf("num sockets can not be zero or negative: %d", numSockets)
}
if numSockets > maxNumSockets {
return fmt.Errorf("num sockets exceeds maximum (%d): %d", maxNumSockets, numSockets)
}
config.NumSockets = numSockets
return nil
}
package nomad
import (
"net"
"github.com/hashicorp/nomad/api"
"github.com/miekg/dns"
)
func addSRVRecord(m *dns.Msg, s *api.ServiceRegistration, header dns.RR_Header, originalQName string, addr net.IP, ttl uint32) error {
srvRecord := &dns.SRV{
Hdr: header,
Target: originalQName,
Port: uint16(s.Port),
Priority: 10,
Weight: 10,
}
m.Answer = append(m.Answer, srvRecord)
if addr.To4() == nil {
addExtrasToAAAARecord(m, originalQName, ttl, addr)
} else {
addExtrasToARecord(m, originalQName, ttl, addr)
}
return nil
}
func addExtrasToARecord(m *dns.Msg, originalQName string, ttl uint32, addr net.IP) {
header := dns.RR_Header{
Name: originalQName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
}
m.Extra = append(m.Extra, &dns.A{Hdr: header, A: addr})
}
func addExtrasToAAAARecord(m *dns.Msg, originalQName string, ttl uint32, addr net.IP) {
header := dns.RR_Header{
Name: originalQName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
}
m.Extra = append(m.Extra, &dns.AAAA{Hdr: header, AAAA: addr})
}
func addARecord(m *dns.Msg, header dns.RR_Header, addr net.IP) {
m.Answer = append(m.Answer, &dns.A{Hdr: header, A: addr})
}
func addAAAARecord(m *dns.Msg, header dns.RR_Header, addr net.IP) {
m.Answer = append(m.Answer, &dns.AAAA{Hdr: header, AAAA: addr})
}
func createSOARecord(originalQName string, ttl uint32, zone string) *dns.SOA {
return &dns.SOA{
Hdr: dns.RR_Header{Name: dns.Fqdn(originalQName), Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: ttl},
Ns: dns.Fqdn("ns1." + originalQName),
Mbox: dns.Fqdn("hostmaster." + zone),
Serial: 0,
Refresh: 3600,
Retry: 600,
Expire: 86400,
Minttl: 30,
}
}
package nomad
import (
"context"
"fmt"
"net"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/hashicorp/nomad/api"
"github.com/miekg/dns"
)
const pluginName = "nomad"
var (
log = clog.NewWithPlugin(pluginName)
defaultTTL = 30
)
type Nomad struct {
Next plugin.Handler
ttl uint32
Zone string
clients []*api.Client
current int
}
func (n *Nomad) Name() string {
return pluginName
}
func (n Nomad) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname, originalQName, err := processQName(state.Name(), n.Zone)
if err != nil {
return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
}
namespace, serviceName, err := extractNamespaceAndService(qname)
if err != nil {
return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
}
m, header := initializeMessage(state, n.ttl)
svcRegistrations, _, err := fetchServiceRegistrations(n, serviceName, namespace)
if err != nil {
log.Warning(err)
return handleServiceLookupError(w, m, ctx, namespace)
}
if len(svcRegistrations) == 0 {
return handleResponseError(n, w, m, originalQName, n.ttl, ctx, namespace, err)
}
if err := addServiceResponses(m, svcRegistrations, header, state.QType(), originalQName, n.ttl); err != nil {
return handleResponseError(n, w, m, originalQName, n.ttl, ctx, namespace, err)
}
err = w.WriteMsg(m)
requestSuccessCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
return dns.RcodeSuccess, err
}
func processQName(qname, zone string) (string, string, error) {
original := dns.Fqdn(qname)
base, err := dnsutil.TrimZone(original, dns.Fqdn(zone))
return base, original, err
}
func extractNamespaceAndService(qname string) (string, string, error) {
qnameSplit := dns.SplitDomainName(qname)
if len(qnameSplit) < 2 {
return "", "", fmt.Errorf("invalid query name")
}
return qnameSplit[1], qnameSplit[0], nil
}
func initializeMessage(state request.Request, ttl uint32) (*dns.Msg, dns.RR_Header) {
m := new(dns.Msg)
m.SetReply(state.Req)
m.Authoritative, m.Compress, m.Rcode = true, true, dns.RcodeSuccess
header := dns.RR_Header{
Name: state.QName(),
Rrtype: state.QType(),
Class: dns.ClassINET,
Ttl: ttl,
}
return m, header
}
func fetchServiceRegistrations(n Nomad, serviceName, namespace string) ([]*api.ServiceRegistration, *api.QueryMeta, error) {
log.Debugf("Looking up record for svc: %s namespace: %s", serviceName, namespace)
nc, err := n.getClient()
if err != nil {
return nil, nil, err
}
return nc.Services().Get(serviceName, (&api.QueryOptions{Namespace: namespace}))
}
func handleServiceLookupError(w dns.ResponseWriter, m *dns.Msg, ctx context.Context, namespace string) (int, error) {
m.Rcode = dns.RcodeSuccess
err := w.WriteMsg(m)
requestFailedCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
return dns.RcodeServerFailure, err
}
func addServiceResponses(m *dns.Msg, svcRegistrations []*api.ServiceRegistration, header dns.RR_Header, qtype uint16, originalQName string, ttl uint32) error {
for _, s := range svcRegistrations {
addr := net.ParseIP(s.Address)
if addr == nil {
return fmt.Errorf("error parsing IP address")
}
switch qtype {
case dns.TypeA:
if addr.To4() == nil {
continue
}
addARecord(m, header, addr)
case dns.TypeAAAA:
if addr.To4() != nil {
continue
}
addAAAARecord(m, header, addr)
case dns.TypeSRV:
err := addSRVRecord(m, s, header, originalQName, addr, ttl)
if err != nil {
return err
}
default:
m.Rcode = dns.RcodeNotImplemented
return fmt.Errorf("query type not implemented")
}
}
return nil
}
func handleResponseError(n Nomad, w dns.ResponseWriter, m *dns.Msg, originalQName string, ttl uint32, ctx context.Context, namespace string, err error) (int, error) {
m.Rcode = dns.RcodeNameError
m.Answer = append(m.Answer, createSOARecord(originalQName, ttl, n.Zone))
if writeErr := w.WriteMsg(m); writeErr != nil {
return dns.RcodeServerFailure, fmt.Errorf("write message error: %w", writeErr)
}
requestFailedCount.WithLabelValues(metrics.WithServer(ctx), namespace).Inc()
return dns.RcodeSuccess, err
}
package nomad
// Ready signals when the plugin is ready for use.
// In case of Nomad, when the ping to the Nomad API is successful
// the plugin is ready.
func (n Nomad) Ready() bool {
client, _ := n.getClient()
return client != nil
}
package nomad
import (
"fmt"
"strconv"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
nomad "github.com/hashicorp/nomad/api"
)
// init registers this plugin.
func init() { plugin.Register(pluginName, setup) }
// setup is the function that gets called when the config parser sees the token "nomad". Setup is responsible
// for parsing any extra options the nomad plugin may have. The first token this function sees is "nomad".
func setup(c *caddy.Controller) error {
n := &Nomad{
ttl: uint32(defaultTTL),
clients: make([]*nomad.Client, 0),
current: -1,
}
// Parse the configuration, including the zone argument
if err := parse(c, n); err != nil {
return plugin.Error("nomad", err)
}
c.OnStartup(func() error {
var err error
for idx, client := range n.clients {
_, err := client.Agent().Self()
if err == nil {
n.current = idx
return nil
}
}
return err
})
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
n.Next = next
return n
})
return nil
}
func parse(c *caddy.Controller, n *Nomad) error {
var token string
addresses := []string{} // Multiple addresses are stored here
// Expect the first token to be "nomad"
if !c.Next() {
return c.Err("expected 'nomad' token")
}
// Check for the zone argument
args := c.RemainingArgs()
if len(args) == 0 {
n.Zone = "service.nomad"
} else {
n.Zone = args[0]
}
// Parse the configuration block
for c.NextBlock() {
selector := strings.ToLower(c.Val())
switch selector {
case "address":
args := c.RemainingArgs()
if len(args) == 0 {
return c.Err("at least one address is required")
}
addresses = append(addresses, args...)
case "token":
args := c.RemainingArgs()
if len(args) != 1 {
return c.Err("exactly one token is required")
}
token = args[0]
case "ttl":
args := c.RemainingArgs()
if len(args) != 1 {
return c.Err("exactly one ttl value is required")
}
t, err := strconv.Atoi(args[0])
if err != nil {
return c.Err("error parsing ttl: " + err.Error())
}
if t < 0 || t > 3600 {
return c.Errf("ttl must be in range [0, 3600]: %d", t)
}
n.ttl = uint32(t)
default:
return c.Errf("unknown property '%s'", selector)
}
}
// Push an empty address to create a client solely based on the defaults.
if len(addresses) == 0 {
addresses = append(addresses, "")
}
for _, addr := range addresses {
cfg := nomad.DefaultConfig()
if len(addr) > 0 {
cfg.Address = addr
}
if len(token) > 0 {
cfg.SecretID = token
}
client, err := nomad.NewClient(cfg)
if err != nil {
return plugin.Error("nomad", err)
}
n.clients = append(n.clients, client) // Store all clients
}
return nil
}
func (n *Nomad) getClient() (*nomad.Client, error) {
// Don't bother querying Agent().Self() if there is only one client.
if len(n.clients) == 1 {
return n.clients[0], nil
}
for i := range len(n.clients) {
idx := (n.current + i) % len(n.clients)
_, err := n.clients[idx].Agent().Self()
if err == nil {
n.current = idx
return n.clients[idx], nil
}
}
return nil, fmt.Errorf("no Nomad client available")
}
package plugin
import (
"fmt"
"net"
"runtime"
"strconv"
"strings"
"github.com/coredns/coredns/plugin/pkg/cidr"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/miekg/dns"
)
// See core/dnsserver/address.go - we should unify these two impls.
// Zones represents a lists of zone names.
type Zones []string
// Matches checks if qname is a subdomain of any of the zones in z. The match
// will return the most specific zones that matches. The empty string
// signals a not found condition.
func (z Zones) Matches(qname string) string {
zone := ""
for _, zname := range z {
if dns.IsSubDomain(zname, qname) {
// We want the *longest* matching zone, otherwise we may end up in a parent
if len(zname) > len(zone) {
zone = zname
}
}
}
return zone
}
// Normalize fully qualifies all zones in z. The zones in Z must be domain names, without
// a port or protocol prefix.
func (z Zones) Normalize() {
for i := range z {
z[i] = Name(z[i]).Normalize()
}
}
// Name represents a domain name.
type Name string
// Matches checks to see if other is a subdomain (or the same domain) of n.
// This method assures that names can be easily and consistently matched.
func (n Name) Matches(child string) bool {
if dns.Name(n) == dns.Name(child) {
return true
}
return dns.IsSubDomain(string(n), child)
}
// Normalize lowercases and makes n fully qualified.
func (n Name) Normalize() string { return strings.ToLower(dns.Fqdn(string(n))) }
type (
// Host represents a host from the Corefile, may contain port.
Host string
)
// Normalize will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased.
// An empty string is returned on failure
// Deprecated: use OriginsFromArgsOrServerBlock or NormalizeExact
func (h Host) Normalize() string {
var caller string
if _, file, line, ok := runtime.Caller(1); ok {
caller = fmt.Sprintf("(%v line %d) ", file, line)
}
log.Warning("An external plugin " + caller + "is using the deprecated function Normalize. " +
"This will be removed in a future versions of CoreDNS. The plugin should be updated to use " +
"OriginsFromArgsOrServerBlock or NormalizeExact instead.")
s := string(h)
_, s = parse.Transport(s)
// The error can be ignored here, because this function is called after the corefile has already been vetted.
hosts, _, err := SplitHostPort(s)
if err != nil {
return ""
}
return Name(hosts[0]).Normalize()
}
// MustNormalize will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased.
// An error is returned on error
// Deprecated: use OriginsFromArgsOrServerBlock or NormalizeExact
func (h Host) MustNormalize() (string, error) {
var caller string
if _, file, line, ok := runtime.Caller(1); ok {
caller = fmt.Sprintf("(%v line %d) ", file, line)
}
log.Warning("An external plugin " + caller + "is using the deprecated function MustNormalize. " +
"This will be removed in a future versions of CoreDNS. The plugin should be updated to use " +
"OriginsFromArgsOrServerBlock or NormalizeExact instead.")
s := string(h)
_, s = parse.Transport(s)
// The error can be ignored here, because this function is called after the corefile has already been vetted.
hosts, _, err := SplitHostPort(s)
if err != nil {
return "", err
}
return Name(hosts[0]).Normalize(), nil
}
// NormalizeExact will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased.
// An empty slice is returned on failure
func (h Host) NormalizeExact() []string {
// The error can be ignored here, because this function should only be called after the corefile has already been vetted.
s := string(h)
_, s = parse.Transport(s)
hosts, _, err := SplitHostPort(s)
if err != nil {
return nil
}
for i := range hosts {
hosts[i] = Name(hosts[i]).Normalize()
}
return hosts
}
// SplitHostPort splits s up in a host(s) and port portion, taking reverse address notation into account.
// String the string s should *not* be prefixed with any protocols, i.e. dns://. SplitHostPort can return
// multiple hosts when a reverse notation on a non-octet boundary is given.
func SplitHostPort(s string) (hosts []string, port string, err error) {
// If there is: :[0-9]+ on the end we assume this is the port. This works for (ascii) domain
// names and our reverse syntax, which always needs a /mask *before* the port.
// So from the back, find first colon, and then check if it's a number.
colon := strings.LastIndex(s, ":")
if colon == len(s)-1 {
return nil, "", fmt.Errorf("expecting data after last colon: %q", s)
}
if colon != -1 {
if p, err := strconv.Atoi(s[colon+1:]); err == nil {
port = strconv.Itoa(p)
s = s[:colon]
}
}
// TODO(miek): this should take escaping into account.
if len(s) > 255 {
return nil, "", fmt.Errorf("specified zone is too long: %d > 255", len(s))
}
if _, ok := dns.IsDomainName(s); !ok {
return nil, "", fmt.Errorf("zone is not a valid domain name: %s", s)
}
// Check if it parses as a reverse zone, if so we use that. Must be fully specified IP and mask.
_, n, err := net.ParseCIDR(s)
if err != nil {
return []string{s}, port, nil
}
if s[0] == ':' || (s[0] == '0' && strings.Contains(s, ":")) {
return nil, "", fmt.Errorf("invalid CIDR %s", s)
}
// now check if multiple hosts must be returned.
nets := cidr.Split(n)
hosts = cidr.Reverse(nets)
return hosts, port, nil
}
// OriginsFromArgsOrServerBlock returns the normalized args if that slice
// is not empty, otherwise the serverblock slice is returned (in a newly copied slice).
func OriginsFromArgsOrServerBlock(args, serverblock []string) []string {
if len(args) == 0 {
s := make([]string, len(serverblock))
copy(s, serverblock)
for i := range s {
sx := Host(s[i]).NormalizeExact() // expansion of these already happened in dnsserver/register.go
if len(sx) == 0 {
continue
}
s[i] = sx[0]
}
return s
}
s := []string{}
for i := range args {
sx := Host(args[i]).NormalizeExact()
if len(sx) == 0 {
continue // silently ignores errors.
}
s = append(s, sx...)
}
return s
}
// Package nsid implements NSID protocol
package nsid
import (
"context"
"encoding/hex"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
// Nsid plugin
type Nsid struct {
Next plugin.Handler
Data string
}
// ResponseWriter is a response writer that adds NSID response
type ResponseWriter struct {
dns.ResponseWriter
Data string
request *dns.Msg
}
// ServeDNS implements the plugin.Handler interface.
func (n Nsid) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if option := r.IsEdns0(); option != nil {
for _, o := range option.Option {
if _, ok := o.(*dns.EDNS0_NSID); ok {
nw := &ResponseWriter{ResponseWriter: w, Data: n.Data, request: r}
return plugin.NextOrFailure(n.Name(), n.Next, ctx, nw, r)
}
}
}
return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r)
}
// WriteMsg implements the dns.ResponseWriter interface.
func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
if w.request.IsEdns0() != nil && res.IsEdns0() == nil {
res.SetEdns0(w.request.IsEdns0().UDPSize(), true)
}
if option := res.IsEdns0(); option != nil {
var exists bool
for _, o := range option.Option {
if e, ok := o.(*dns.EDNS0_NSID); ok {
e.Code = dns.EDNS0NSID
e.Nsid = hex.EncodeToString([]byte(w.Data))
exists = true
}
}
// Append the NSID if it doesn't exist in EDNS0 options
if !exists {
option.Option = append(option.Option, &dns.EDNS0_NSID{
Code: dns.EDNS0NSID,
Nsid: hex.EncodeToString([]byte(w.Data)),
})
}
}
return w.ResponseWriter.WriteMsg(res)
}
// Name implements the Handler interface.
func (n Nsid) Name() string { return "nsid" }
package nsid
import (
"os"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("nsid", setup) }
func setup(c *caddy.Controller) error {
nsid, err := nsidParse(c)
if err != nil {
return plugin.Error("nsid", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Nsid{Next: next, Data: nsid}
})
return nil
}
func nsidParse(c *caddy.Controller) (string, error) {
// Use hostname as the default
nsid, err := os.Hostname()
if err != nil {
nsid = "localhost"
}
i := 0
for c.Next() {
if i > 0 {
return nsid, plugin.ErrOnce
}
i++
args := c.RemainingArgs()
if len(args) > 0 {
nsid = strings.Join(args, " ")
}
}
return nsid, nil
}
// Package cache implements a cache. The cache hold 256 shards, each shard
// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it
// just randomly evicts elements when it gets full.
package cache
import (
"hash/fnv"
"sync"
)
// Hash returns the FNV hash of what.
func Hash(what []byte) uint64 {
h := fnv.New64()
h.Write(what)
return h.Sum64()
}
// Cache is cache.
type Cache struct {
shards [shardSize]*shard
}
// shard is a cache with random eviction.
type shard struct {
items map[uint64]any
size int
sync.RWMutex
}
// New returns a new cache.
func New(size int) *Cache {
ssize := max(size/shardSize, 4)
c := &Cache{}
// Initialize all the shards
for i := range shardSize {
c.shards[i] = newShard(ssize)
}
return c
}
// Add adds a new element to the cache. If the element already exists it is overwritten.
// Returns true if an existing element was evicted to make room for this element.
func (c *Cache) Add(key uint64, el any) bool {
shard := key & (shardSize - 1)
return c.shards[shard].Add(key, el)
}
// Get looks up element index under key.
func (c *Cache) Get(key uint64) (any, bool) {
shard := key & (shardSize - 1)
return c.shards[shard].Get(key)
}
// Remove removes the element indexed with key.
func (c *Cache) Remove(key uint64) {
shard := key & (shardSize - 1)
c.shards[shard].Remove(key)
}
// Len returns the number of elements in the cache.
func (c *Cache) Len() int {
l := 0
for _, s := range &c.shards {
l += s.Len()
}
return l
}
// Walk walks each shard in the cache.
func (c *Cache) Walk(f func(map[uint64]any, uint64) bool) {
for _, s := range &c.shards {
s.Walk(f)
}
}
// newShard returns a new shard with size.
func newShard(size int) *shard { return &shard{items: make(map[uint64]any), size: size} }
// Add adds element indexed by key into the cache. Any existing element is overwritten
// Returns true if an existing element was evicted to make room for this element.
func (s *shard) Add(key uint64, el any) bool {
eviction := false
s.Lock()
if len(s.items) >= s.size {
if _, ok := s.items[key]; !ok {
for k := range s.items {
delete(s.items, k)
eviction = true
break
}
}
}
s.items[key] = el
s.Unlock()
return eviction
}
// Remove removes the element indexed by key from the cache.
func (s *shard) Remove(key uint64) {
s.Lock()
delete(s.items, key)
s.Unlock()
}
// Evict removes a random element from the cache.
func (s *shard) Evict() {
s.Lock()
for k := range s.items {
delete(s.items, k)
break
}
s.Unlock()
}
// Get looks up the element indexed under key.
func (s *shard) Get(key uint64) (any, bool) {
s.RLock()
el, found := s.items[key]
s.RUnlock()
return el, found
}
// Len returns the current length of the cache.
func (s *shard) Len() int {
s.RLock()
l := len(s.items)
s.RUnlock()
return l
}
// Walk walks the shard for each element the function f is executed while holding a write lock.
func (s *shard) Walk(f func(map[uint64]any, uint64) bool) {
s.RLock()
items := make([]uint64, len(s.items))
i := 0
for k := range s.items {
items[i] = k
i++
}
s.RUnlock()
for _, k := range items {
s.Lock()
ok := f(s.items, k)
s.Unlock()
if !ok {
return
}
}
}
const shardSize = 256
// Package cidr contains functions that deal with classless reverse zones in the DNS.
package cidr
import (
"math"
"net"
"strings"
"github.com/apparentlymart/go-cidr/cidr"
"github.com/miekg/dns"
)
// Split returns a slice of non-overlapping subnets that in union equal the subnet n,
// and where each subnet falls on a reverse name segment boundary.
// for ipv4 this is any multiple of 8 bits (/8, /16, /24 or /32)
// for ipv6 this is any multiple of 4 bits
func Split(n *net.IPNet) []string {
boundary := 8
nstr := n.String()
if strings.Contains(nstr, ":") {
boundary = 4
}
ones, _ := n.Mask.Size()
if ones%boundary == 0 {
return []string{n.String()}
}
mask := int(math.Ceil(float64(ones)/float64(boundary))) * boundary
networks := nets(n, mask)
cidrs := make([]string, len(networks))
for i := range networks {
cidrs[i] = networks[i].String()
}
return cidrs
}
// nets return a slice of prefixes with the desired mask subnetted from original network.
func nets(network *net.IPNet, newPrefixLen int) []*net.IPNet {
prefixLen, _ := network.Mask.Size()
maxSubnets := int(math.Exp2(float64(newPrefixLen)) / math.Exp2(float64(prefixLen)))
nets := []*net.IPNet{{IP: network.IP, Mask: net.CIDRMask(newPrefixLen, 8*len(network.IP))}}
for i := 1; i < maxSubnets; i++ {
next, exceeds := cidr.NextSubnet(nets[len(nets)-1], newPrefixLen)
nets = append(nets, next)
if exceeds {
break
}
}
return nets
}
// Reverse return the reverse zones that are authoritative for each net in ns.
func Reverse(nets []string) []string {
rev := make([]string, len(nets))
for i := range nets {
ip, n, _ := net.ParseCIDR(nets[i])
r, err1 := dns.ReverseAddr(ip.String())
if err1 != nil {
continue
}
ones, bits := n.Mask.Size()
// get the size, in bits, of each portion of hostname defined in the reverse address. (8 for IPv4, 4 for IPv6)
sizeDigit := 8
if len(n.IP) == net.IPv6len {
sizeDigit = 4
}
// Get the first lower octet boundary to see what encompassing zone we should be authoritative for.
mod := (bits - ones) % sizeDigit
nearest := (bits - ones) + mod
offset := 0
var end bool
for range nearest / sizeDigit {
offset, end = dns.NextLabel(r, offset)
if end {
break
}
}
rev[i] = r[offset:]
}
return rev
}
package dnstest
import (
"time"
"github.com/miekg/dns"
)
// MultiRecorder is a type of ResponseWriter that captures all messages written to it.
type MultiRecorder struct {
Len int
Msgs []*dns.Msg
Start time.Time
dns.ResponseWriter
}
// NewMultiRecorder makes and returns a new MultiRecorder.
func NewMultiRecorder(w dns.ResponseWriter) *MultiRecorder {
return &MultiRecorder{
ResponseWriter: w,
Msgs: make([]*dns.Msg, 0),
Start: time.Now(),
}
}
// WriteMsg records the message and its length written to it and call the
// underlying ResponseWriter's WriteMsg method.
func (r *MultiRecorder) WriteMsg(res *dns.Msg) error {
r.Len += res.Len()
r.Msgs = append(r.Msgs, res)
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the length of the messages that get written to it.
func (r *MultiRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.Len += n
}
return n, err
}
// Package dnstest allows for easy testing of DNS client against a test server.
package dnstest
import (
"time"
"github.com/miekg/dns"
)
// Recorder is a type of ResponseWriter that captures
// the rcode code written to it and also the size of the message
// written in the response. A rcode code does not have
// to be written, however, in which case 0 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
type Recorder struct {
dns.ResponseWriter
Rcode int
Len int
Msg *dns.Msg
Start time.Time
}
// NewRecorder makes and returns a new Recorder,
// which captures the DNS rcode from the ResponseWriter
// and also the length of the response message written through it.
func NewRecorder(w dns.ResponseWriter) *Recorder {
return &Recorder{
ResponseWriter: w,
Rcode: 0,
Msg: nil,
Start: time.Now(),
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *Recorder) WriteMsg(res *dns.Msg) error {
r.Rcode = res.Rcode
// We may get called multiple times (axfr for instance).
// Save the last message, but add the sizes.
r.Len += res.Len()
r.Msg = res
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the length of the message that gets written.
func (r *Recorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.Len += n
}
return n, err
}
package dnstest
import (
"net"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/miekg/dns"
)
// A Server is an DNS server listening on a system-chosen port on the local
// loopback interface, for use in end-to-end DNS tests.
type Server struct {
Addr string // Address where the server listening.
s1 *dns.Server // udp
s2 *dns.Server // tcp
}
// NewServer starts and returns a new Server. The caller should call Close when
// finished, to shut it down.
func NewServer(f dns.HandlerFunc) *Server {
dns.HandleFunc(".", f)
ch1 := make(chan bool)
ch2 := make(chan bool)
s1 := &dns.Server{} // udp
s2 := &dns.Server{} // tcp
for range 5 { // 5 attempts
s2.Listener, _ = reuseport.Listen("tcp", ":0")
if s2.Listener == nil {
continue
}
s1.PacketConn, _ = net.ListenPacket("udp", s2.Listener.Addr().String())
if s1.PacketConn != nil {
break
}
// perhaps UPD port is in use, try again
s2.Listener.Close()
s2.Listener = nil
}
if s2.Listener == nil {
panic("dnstest.NewServer(): failed to create new server")
}
s1.NotifyStartedFunc = func() { close(ch1) }
s2.NotifyStartedFunc = func() { close(ch2) }
go s1.ActivateAndServe()
go s2.ActivateAndServe()
<-ch1
<-ch2
return &Server{s1: s1, s2: s2, Addr: s2.Listener.Addr().String()}
}
// NewMultipleServer starts and returns a new Server(multiple). The caller should call Close when
// finished, to shut it down.
func NewMultipleServer(f dns.HandlerFunc) *Server {
ch1 := make(chan bool)
ch2 := make(chan bool)
s1 := &dns.Server{
Handler: f,
} // udp
s2 := &dns.Server{
Handler: f,
} // tcp
for range 5 { // 5 attempts
s2.Listener, _ = reuseport.Listen("tcp", ":0")
if s2.Listener == nil {
continue
}
s1.PacketConn, _ = net.ListenPacket("udp", s2.Listener.Addr().String())
if s1.PacketConn != nil {
break
}
// perhaps UPD port is in use, try again
s2.Listener.Close()
s2.Listener = nil
}
if s2.Listener == nil {
panic("dnstest.NewServer(): failed to create new server")
}
s1.NotifyStartedFunc = func() { close(ch1) }
s2.NotifyStartedFunc = func() { close(ch2) }
go s1.ActivateAndServe()
go s2.ActivateAndServe()
<-ch1
<-ch2
return &Server{s1: s1, s2: s2, Addr: s2.Listener.Addr().String()}
}
// Close shuts down the server.
func (s *Server) Close() {
s.s1.Shutdown()
s.s2.Shutdown()
}
package dnsutil
import "github.com/miekg/dns"
// DuplicateCNAME returns true if r already exists in records.
func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}
package dnsutil
import (
"strings"
"github.com/miekg/dns"
)
// Join joins labels to form a fully qualified domain name. If the last label is
// the root label it is ignored. Not other syntax checks are performed.
func Join(labels ...string) string {
ll := len(labels)
if labels[ll-1] == "." {
return strings.Join(labels[:ll-1], ".") + "."
}
return dns.Fqdn(strings.Join(labels, "."))
}
package dnsutil
import (
"net"
"strings"
)
// ExtractAddressFromReverse turns a standard PTR reverse record name
// into an IP address. This works for ipv4 or ipv6.
//
// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion
// fails the empty string is returned.
func ExtractAddressFromReverse(reverseName string) string {
f := reverse
var search string
switch {
case strings.HasSuffix(reverseName, IP4arpa):
search = strings.TrimSuffix(reverseName, IP4arpa)
case strings.HasSuffix(reverseName, IP6arpa):
search = strings.TrimSuffix(reverseName, IP6arpa)
f = reverse6
default:
return ""
}
// Reverse the segments and then combine them.
return f(strings.Split(search, "."))
}
// IsReverse returns 0 is name is not in a reverse zone. Anything > 0 indicates
// name is in a reverse zone. The returned integer will be 1 for in-addr.arpa. (IPv4)
// and 2 for ip6.arpa. (IPv6).
func IsReverse(name string) int {
if strings.HasSuffix(name, IP4arpa) {
return 1
}
if strings.HasSuffix(name, IP6arpa) {
return 2
}
return 0
}
func reverse(slice []string) string {
for i := range len(slice) / 2 {
j := len(slice) - i - 1
slice[i], slice[j] = slice[j], slice[i]
}
ip := net.ParseIP(strings.Join(slice, ".")).To4()
if ip == nil {
return ""
}
return ip.String()
}
// reverse6 reverse the segments and combine them according to RFC3596:
// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2
// is reversed to 2001:db8::567:89ab
func reverse6(slice []string) string {
for i := range len(slice) / 2 {
j := len(slice) - i - 1
slice[i], slice[j] = slice[j], slice[i]
}
slice6 := []string{}
for i := range len(slice) / 4 {
slice6 = append(slice6, strings.Join(slice[i*4:i*4+4], ""))
}
ip := net.ParseIP(strings.Join(slice6, ":")).To16()
if ip == nil {
return ""
}
return ip.String()
}
const (
// IP4arpa is the reverse tree suffix for v4 IP addresses.
IP4arpa = ".in-addr.arpa."
// IP6arpa is the reverse tree suffix for v6 IP addresses.
IP6arpa = ".ip6.arpa."
)
package dnsutil
import (
"time"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/miekg/dns"
)
// MinimalTTL scans the message returns the lowest TTL found taking into the response.Type of the message.
func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration {
if mt != response.NoError && mt != response.NameError && mt != response.NoData {
return MinimalDefaultTTL
}
// No records or OPT is the only record, return a short ttl as a fail safe.
if len(m.Answer)+len(m.Ns) == 0 &&
(len(m.Extra) == 0 || (len(m.Extra) == 1 && m.Extra[0].Header().Rrtype == dns.TypeOPT)) {
return MinimalDefaultTTL
}
minTTL := MaximumDefaulTTL
for _, r := range m.Answer {
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
for _, r := range m.Ns {
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
for _, r := range m.Extra {
if r.Header().Rrtype == dns.TypeOPT {
// OPT records use TTL field for extended rcode and flags
continue
}
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
return minTTL
}
const (
// MinimalDefaultTTL is the absolute lowest TTL we use in CoreDNS.
MinimalDefaultTTL = 5 * time.Second
// MaximumDefaulTTL is the maximum TTL was use on RRsets in CoreDNS.
// TODO: rename as MaximumDefaultTTL
MaximumDefaulTTL = 1 * time.Hour
)
package dnsutil
import (
"errors"
"github.com/miekg/dns"
)
// TrimZone removes the zone component from q. It returns the trimmed
// name or an error is zone is longer then qname. The trimmed name will be returned
// without a trailing dot.
func TrimZone(q string, z string) (string, error) {
zl := dns.CountLabel(z)
i, ok := dns.PrevLabel(q, zl)
if ok || i-1 < 0 {
return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z)
}
// This includes the '.', remove on return
return q[:i-1], nil
}
package doh
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"github.com/miekg/dns"
)
// MimeType is the DoH mimetype that should be used.
const MimeType = "application/dns-message"
// Path is the URL path that should be used.
const Path = "/dns-query"
// NewRequest returns a new DoH request given a HTTP method, URL and dns.Msg.
//
// The URL should not have a path, so please exclude /dns-query. The URL will
// be prefixed with https:// by default, unless it's already prefixed with
// either http:// or https://.
func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) {
buf, err := m.Pack()
if err != nil {
return nil, err
}
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
url = "https://" + url
}
switch method {
case http.MethodGet:
b64 := base64.RawURLEncoding.EncodeToString(buf)
req, err := http.NewRequest(
http.MethodGet,
fmt.Sprintf("%s%s?dns=%s", url, Path, b64),
nil,
)
if err != nil {
return req, err
}
req.Header.Set("Content-Type", MimeType)
req.Header.Set("Accept", MimeType)
return req, nil
case http.MethodPost:
req, err := http.NewRequest(
http.MethodPost,
fmt.Sprintf("%s%s", url, Path),
bytes.NewReader(buf),
)
if err != nil {
return req, err
}
req.Header.Set("Content-Type", MimeType)
req.Header.Set("Accept", MimeType)
return req, nil
default:
return nil, fmt.Errorf("method not allowed: %s", method)
}
}
// ResponseToMsg converts a http.Response to a dns message.
func ResponseToMsg(resp *http.Response) (*dns.Msg, error) {
defer resp.Body.Close()
return toMsg(resp.Body)
}
// RequestToMsg converts a http.Request to a dns message.
func RequestToMsg(req *http.Request) (*dns.Msg, error) {
switch req.Method {
case http.MethodGet:
return requestToMsgGet(req)
case http.MethodPost:
return requestToMsgPost(req)
default:
return nil, fmt.Errorf("method not allowed: %s", req.Method)
}
}
// requestToMsgPost extracts the dns message from the request body.
func requestToMsgPost(req *http.Request) (*dns.Msg, error) {
defer req.Body.Close()
return toMsg(req.Body)
}
// requestToMsgGet extract the dns message from the GET request.
func requestToMsgGet(req *http.Request) (*dns.Msg, error) {
values := req.URL.Query()
b64, ok := values["dns"]
if !ok {
return nil, fmt.Errorf("no 'dns' query parameter found")
}
if len(b64) != 1 {
return nil, fmt.Errorf("multiple 'dns' query values found")
}
return base64ToMsg(b64[0])
}
func toMsg(r io.ReadCloser) (*dns.Msg, error) {
buf, err := io.ReadAll(http.MaxBytesReader(nil, r, 65536))
if err != nil {
return nil, err
}
m := new(dns.Msg)
err = m.Unpack(buf)
return m, err
}
func base64ToMsg(b64 string) (*dns.Msg, error) {
buf, err := b64Enc.DecodeString(b64)
if err != nil {
return nil, err
}
m := new(dns.Msg)
err = m.Unpack(buf)
return m, err
}
var b64Enc = base64.RawURLEncoding
package durations
import (
"fmt"
"strconv"
"time"
)
// NewDurationFromArg returns a time.Duration from a configuration argument
// (string) which has come from the Corefile. The argument has some basic
// validation applied before returning a time.Duration. If the argument has no
// time unit specified and is numeric the argument will be treated as seconds
// rather than GO's default of nanoseconds.
func NewDurationFromArg(arg string) (time.Duration, error) {
_, err := strconv.Atoi(arg)
if err == nil {
arg = arg + "s"
}
d, err := time.ParseDuration(arg)
if err != nil {
return 0, fmt.Errorf("failed to parse duration '%s'", arg)
}
return d, nil
}
// Package edns provides function useful for adding/inspecting OPT records to/in messages.
package edns
import (
"errors"
"sync"
"github.com/miekg/dns"
)
var sup = &supported{m: make(map[uint16]struct{})}
type supported struct {
m map[uint16]struct{}
sync.RWMutex
}
// SetSupportedOption adds a new supported option the set of EDNS0 options that we support. Plugins typically call
// this in their setup code to signal support for a new option.
// By default we support:
// dns.EDNS0NSID, dns.EDNS0EXPIRE, dns.EDNS0COOKIE, dns.EDNS0TCPKEEPALIVE, dns.EDNS0PADDING. These
// values are not in this map and checked directly in the server.
func SetSupportedOption(option uint16) {
sup.Lock()
sup.m[option] = struct{}{}
sup.Unlock()
}
// SupportedOption returns true if the option code is supported as an extra EDNS0 option.
func SupportedOption(option uint16) bool {
sup.RLock()
_, ok := sup.m[option]
sup.RUnlock()
return ok
}
// Version checks the EDNS version in the request. If error
// is nil everything is OK and we can invoke the plugin. If non-nil, the
// returned Msg is valid to be returned to the client (and should).
func Version(req *dns.Msg) (*dns.Msg, error) {
opt := req.IsEdns0()
if opt == nil {
return nil, nil
}
if opt.Version() == 0 {
return nil, nil
}
m := new(dns.Msg)
m.SetReply(req)
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetVersion(0)
m.Rcode = dns.RcodeBadVers
o.SetExtendedRcode(dns.RcodeBadVers)
m.Extra = []dns.RR{o}
return m, errors.New("EDNS0 BADVERS")
}
// Size returns a normalized size based on proto.
func Size(proto string, size uint16) uint16 {
if proto == "tcp" {
return dns.MaxMsgSize
}
if size < dns.MinMsgSize {
return dns.MinMsgSize
}
return size
}
package expression
import (
"context"
"errors"
"net"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/request"
)
// DefaultEnv returns the default set of custom state variables and functions available to for use in expression evaluation.
func DefaultEnv(ctx context.Context, state *request.Request) map[string]any {
return map[string]any{
"incidr": func(ipStr, cidrStr string) (bool, error) {
ip := net.ParseIP(ipStr)
if ip == nil {
return false, errors.New("first argument is not an IP address")
}
_, cidr, err := net.ParseCIDR(cidrStr)
if err != nil {
return false, err
}
return cidr.Contains(ip), nil
},
"metadata": func(label string) string {
f := metadata.ValueFunc(ctx, label)
if f == nil {
return ""
}
return f()
},
"type": state.Type,
"name": state.Name,
"class": state.Class,
"proto": state.Proto,
"size": state.Len,
"client_ip": state.IP,
"port": state.Port,
"id": func() int { return int(state.Req.Id) },
"opcode": func() int { return state.Req.Opcode },
"do": state.Do,
"bufsize": state.Size,
"server_ip": state.LocalIP,
"server_port": state.LocalPort,
}
}
// Package fall handles the fallthrough logic used in plugins that support it. Be careful when including this
// functionality in your plugin. Why? In the DNS only 1 source is authoritative for a set of names. Fallthrough
// breaks this convention by allowing a plugin to query multiple sources, depending on the replies it got sofar.
//
// This may cause issues in downstream caches, where different answers for the same query can potentially confuse clients.
// On the other hand this is a powerful feature that can aid in migration or other edge cases.
//
// The take away: be mindful of this and don't blindly assume it's a good feature to have in your plugin.
//
// See https://github.com/coredns/coredns/issues/2723 for some discussion on this, which includes this quote:
//
// TL;DR: `fallthrough` is indeed risky and hackish, but still a good feature of CoreDNS as it allows to quickly answer boring edge cases.
package fall
import (
"slices"
"github.com/coredns/coredns/plugin"
)
// F can be nil to allow for no fallthrough, empty allow all zones to fallthrough or
// contain a zone list that is checked.
type F struct {
Zones []string
}
// Through will check if we should fallthrough for qname. Note that we've named the
// variable in each plugin "Fall", so this then reads Fall.Through().
func (f F) Through(qname string) bool {
return plugin.Zones(f.Zones).Matches(qname) != ""
}
// setZones will set zones in f.
func (f *F) setZones(zones []string) {
z := []string{}
for i := range zones {
z = append(z, plugin.Host(zones[i]).NormalizeExact()...)
}
f.Zones = z
}
// SetZonesFromArgs sets zones in f to the passed value or to "." if the slice is empty.
func (f *F) SetZonesFromArgs(zones []string) {
if len(zones) == 0 {
f.setZones(Root.Zones)
return
}
f.setZones(zones)
}
// Equal returns true if f and g are equal.
func (f *F) Equal(g F) bool {
return slices.Equal(f.Zones, g.Zones)
}
// Zero returns a zero valued F.
var Zero = func() F {
return F{[]string{}}
}()
// Root returns F set to only ".".
var Root = func() F {
return F{[]string{"."}}
}()
// Package fuzz contains functions that enable fuzzing of plugins.
package fuzz
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
// Do will fuzz p - used by gofuzz. See Makefile.fuzz for comments and context.
func Do(p plugin.Handler, data []byte) int {
ctx := context.TODO()
r := new(dns.Msg)
if err := r.Unpack(data); err != nil {
return 0 // plugin will never be called when this happens.
}
// If the data unpack into a dns msg, but does not have a proper question section discard it.
// The server parts make sure this is true before calling the plugins; mimic this behavior.
if len(r.Question) == 0 {
return 0
}
if _, err := p.ServeDNS(ctx, &test.ResponseWriter{}, r); err != nil {
return 1
}
return 0
}
package log
import (
"sync"
)
// Listener listens for all log prints of plugin loggers aka loggers with plugin name.
// When a plugin logger gets called, it should first call the same method in the Listener object.
// A usage example is, the external plugin k8s_event will replicate log prints to Kubernetes events.
type Listener interface {
Name() string
Debug(plugin string, v ...any)
Debugf(plugin string, format string, v ...any)
Info(plugin string, v ...any)
Infof(plugin string, format string, v ...any)
Warning(plugin string, v ...any)
Warningf(plugin string, format string, v ...any)
Error(plugin string, v ...any)
Errorf(plugin string, format string, v ...any)
Fatal(plugin string, v ...any)
Fatalf(plugin string, format string, v ...any)
}
type listeners struct {
listeners []Listener
sync.RWMutex
}
var ls *listeners
func init() {
ls = &listeners{}
ls.listeners = make([]Listener, 0)
}
// RegisterListener register a listener object.
func RegisterListener(new Listener) error {
ls.Lock()
defer ls.Unlock()
for k, l := range ls.listeners {
if l.Name() == new.Name() {
ls.listeners[k] = new
return nil
}
}
ls.listeners = append(ls.listeners, new)
return nil
}
// DeregisterListener deregister a listener object.
func DeregisterListener(old Listener) error {
ls.Lock()
defer ls.Unlock()
for k, l := range ls.listeners {
if l.Name() == old.Name() {
ls.listeners = append(ls.listeners[:k], ls.listeners[k+1:]...)
return nil
}
}
return nil
}
func (ls *listeners) debug(plugin string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Debug(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) debugf(plugin string, format string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Debugf(plugin, format, v...)
}
ls.RUnlock()
}
func (ls *listeners) info(plugin string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Info(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) infof(plugin string, format string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Infof(plugin, format, v...)
}
ls.RUnlock()
}
func (ls *listeners) warning(plugin string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Warning(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) warningf(plugin string, format string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Warningf(plugin, format, v...)
}
ls.RUnlock()
}
func (ls *listeners) error(plugin string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Error(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) errorf(plugin string, format string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Errorf(plugin, format, v...)
}
ls.RUnlock()
}
func (ls *listeners) fatal(plugin string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Fatal(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) fatalf(plugin string, format string, v ...any) {
ls.RLock()
for _, l := range ls.listeners {
l.Fatalf(plugin, format, v...)
}
ls.RUnlock()
}
// Package log implements a small wrapper around the std lib log package. It
// implements log levels by prefixing the logs with [INFO], [DEBUG], [WARNING]
// or [ERROR]. Debug logging is available and enabled if the *debug* plugin is
// used.
//
// log.Info("this is some logging"), will log on the Info level.
//
// log.Debug("this is debug output"), will log in the Debug level, etc.
package log
import (
"fmt"
"io"
golog "log"
"os"
"sync/atomic"
)
// D controls whether we should output debug logs. If true, we do, once set
// it can not be unset.
var D = &d{}
type d struct {
on atomic.Bool
}
// Set enables debug logging.
func (d *d) Set() {
d.on.Store(true)
}
// Clear disables debug logging.
func (d *d) Clear() {
d.on.Store(false)
}
// Value returns if debug logging is enabled.
func (d *d) Value() bool {
return d.on.Load()
}
// logf calls log.Printf prefixed with level.
func logf(level, format string, v ...any) {
golog.Print(level, fmt.Sprintf(format, v...))
}
// log calls log.Print prefixed with level.
func log(level string, v ...any) {
golog.Print(level, fmt.Sprint(v...))
}
// Debug is equivalent to log.Print(), but prefixed with "[DEBUG] ". It only outputs something
// if D is true.
func Debug(v ...any) {
if !D.Value() {
return
}
log(debug, v...)
}
// Debugf is equivalent to log.Printf(), but prefixed with "[DEBUG] ". It only outputs something
// if D is true.
func Debugf(format string, v ...any) {
if !D.Value() {
return
}
logf(debug, format, v...)
}
// Info is equivalent to log.Print, but prefixed with "[INFO] ".
func Info(v ...any) { log(info, v...) }
// Infof is equivalent to log.Printf, but prefixed with "[INFO] ".
func Infof(format string, v ...any) { logf(info, format, v...) }
// Warning is equivalent to log.Print, but prefixed with "[WARNING] ".
func Warning(v ...any) { log(warning, v...) }
// Warningf is equivalent to log.Printf, but prefixed with "[WARNING] ".
func Warningf(format string, v ...any) { logf(warning, format, v...) }
// Error is equivalent to log.Print, but prefixed with "[ERROR] ".
func Error(v ...any) { log(err, v...) }
// Errorf is equivalent to log.Printf, but prefixed with "[ERROR] ".
func Errorf(format string, v ...any) { logf(err, format, v...) }
// Fatal is equivalent to log.Print, but prefixed with "[FATAL] ", and calling
// os.Exit(1).
func Fatal(v ...any) { log(fatal, v...); os.Exit(1) }
// Fatalf is equivalent to log.Printf, but prefixed with "[FATAL] ", and calling
// os.Exit(1)
func Fatalf(format string, v ...any) { logf(fatal, format, v...); os.Exit(1) }
// Discard sets the log output to /dev/null.
func Discard() { golog.SetOutput(io.Discard) }
const (
debug = "[DEBUG] "
err = "[ERROR] "
fatal = "[FATAL] "
info = "[INFO] "
warning = "[WARNING] "
)
package log
import (
"fmt"
"os"
)
// P is a logger that includes the plugin doing the logging.
type P struct {
plugin string
}
// NewWithPlugin returns a logger that includes "plugin/name: " in the log message.
// I.e [INFO] plugin/<name>: message.
func NewWithPlugin(name string) P { return P{"plugin/" + name + ": "} }
func (p P) logf(level, format string, v ...any) {
log(level, p.plugin, fmt.Sprintf(format, v...))
}
func (p P) log(level string, v ...any) {
log(level+p.plugin, v...)
}
// Debug logs as log.Debug.
func (p P) Debug(v ...any) {
if !D.Value() {
return
}
ls.debug(p.plugin, v...)
p.log(debug, v...)
}
// Debugf logs as log.Debugf.
func (p P) Debugf(format string, v ...any) {
if !D.Value() {
return
}
ls.debugf(p.plugin, format, v...)
p.logf(debug, format, v...)
}
// Info logs as log.Info.
func (p P) Info(v ...any) {
ls.info(p.plugin, v...)
p.log(info, v...)
}
// Infof logs as log.Infof.
func (p P) Infof(format string, v ...any) {
ls.infof(p.plugin, format, v...)
p.logf(info, format, v...)
}
// Warning logs as log.Warning.
func (p P) Warning(v ...any) {
ls.warning(p.plugin, v...)
p.log(warning, v...)
}
// Warningf logs as log.Warningf.
func (p P) Warningf(format string, v ...any) {
ls.warningf(p.plugin, format, v...)
p.logf(warning, format, v...)
}
// Error logs as log.Error.
func (p P) Error(v ...any) {
ls.error(p.plugin, v...)
p.log(err, v...)
}
// Errorf logs as log.Errorf.
func (p P) Errorf(format string, v ...any) {
ls.errorf(p.plugin, format, v...)
p.logf(err, format, v...)
}
// Fatal logs as log.Fatal and calls os.Exit(1).
func (p P) Fatal(v ...any) {
ls.fatal(p.plugin, v...)
p.log(fatal, v...)
os.Exit(1)
}
// Fatalf logs as log.Fatalf and calls os.Exit(1).
func (p P) Fatalf(format string, v ...any) {
ls.fatalf(p.plugin, format, v...)
p.logf(fatal, format, v...)
os.Exit(1)
}
// Package nonwriter implements a dns.ResponseWriter that never writes, but captures the dns.Msg being written.
package nonwriter
import (
"github.com/miekg/dns"
)
// Writer is a type of ResponseWriter that captures the message, but never writes to the client.
type Writer struct {
dns.ResponseWriter
Msg *dns.Msg
}
// New makes and returns a new NonWriter.
func New(w dns.ResponseWriter) *Writer { return &Writer{ResponseWriter: w} }
// WriteMsg records the message, but doesn't write it itself.
func (w *Writer) WriteMsg(res *dns.Msg) error {
w.Msg = res
return nil
}
package parse
import (
"errors"
"fmt"
"net"
"os"
"strings"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
// ErrNoNameservers is returned by HostPortOrFile if no servers can be parsed.
var ErrNoNameservers = errors.New("no nameservers found")
// Strips the zone, but preserves any port that comes after the zone
func stripZone(host string) string {
if strings.Contains(host, "%") {
lastPercent := strings.LastIndex(host, "%")
newHost := host[:lastPercent]
return newHost
}
return host
}
// HostPortOrFile parses the strings in s, each string can either be a
// address, [scheme://]address:port or a filename. The address part is checked
// and in case of filename a resolv.conf like file is (assumed) and parsed and
// the nameservers found are returned.
func HostPortOrFile(s ...string) ([]string, error) {
var servers []string //nolint:prealloc // impossible to know the final length upfront
for _, h := range s {
trans, host := Transport(h)
if len(host) == 0 {
return servers, fmt.Errorf("invalid address: %q", h)
}
if trans == transport.UNIX {
servers = append(servers, trans+"://"+host)
continue
}
addr, _, err := net.SplitHostPort(host)
if err != nil {
// Parse didn't work, it is not a addr:port combo
hostNoZone := stripZone(host)
if net.ParseIP(hostNoZone) == nil {
ss, err := tryFile(host)
if err == nil {
servers = append(servers, ss...)
continue
}
return servers, fmt.Errorf("not an IP address or file: %q", host)
}
var ss string
switch trans {
case transport.DNS:
ss = net.JoinHostPort(host, transport.Port)
case transport.TLS:
ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort)
case transport.QUIC:
ss = transport.QUIC + "://" + net.JoinHostPort(host, transport.QUICPort)
case transport.GRPC:
ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort)
case transport.HTTPS:
ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort)
}
servers = append(servers, ss)
continue
}
if net.ParseIP(stripZone(addr)) == nil {
ss, err := tryFile(host)
if err == nil {
servers = append(servers, ss...)
continue
}
return servers, fmt.Errorf("not an IP address or file: %q", host)
}
servers = append(servers, h)
}
if len(servers) == 0 {
return servers, ErrNoNameservers
}
return servers, nil
}
// Try to open this is a file first.
func tryFile(s string) ([]string, error) {
c, err := dns.ClientConfigFromFile(s)
if err == os.ErrNotExist {
return nil, fmt.Errorf("failed to open file %q: %q", s, err)
} else if err != nil {
return nil, err
}
servers := []string{}
for _, s := range c.Servers {
servers = append(servers, net.JoinHostPort(stripZone(s), c.Port))
}
return servers, nil
}
// HostPort will check if the host part is a valid IP address, if the
// IP address is valid, but no port is found, defaultPort is added.
func HostPort(s, defaultPort string) (string, error) {
addr, port, err := net.SplitHostPort(s)
if port == "" {
port = defaultPort
}
if err != nil {
if net.ParseIP(s) == nil {
return "", fmt.Errorf("must specify an IP address: `%s'", s)
}
return net.JoinHostPort(s, port), nil
}
if net.ParseIP(addr) == nil {
return "", fmt.Errorf("must specify an IP address: `%s'", addr)
}
return net.JoinHostPort(addr, port), nil
}
// Package parse contains functions that can be used in the setup code for plugins.
package parse
import (
"fmt"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/pkg/transport"
)
// TransferIn parses transfer statements: 'transfer from [address...]'.
func TransferIn(c *caddy.Controller) (froms []string, err error) {
if !c.NextArg() {
return nil, c.ArgErr()
}
value := c.Val()
switch value {
default:
return nil, c.Errf("unknown property %s", value)
case "from":
froms = c.RemainingArgs()
if len(froms) == 0 {
return nil, c.ArgErr()
}
for i := range froms {
if froms[i] == "*" {
return nil, fmt.Errorf("can't use '*' in transfer from")
}
normalized, err := HostPort(froms[i], transport.Port)
if err != nil {
return nil, err
}
froms[i] = normalized
}
}
return froms, nil
}
package parse
import (
"strings"
"github.com/coredns/coredns/plugin/pkg/transport"
)
// Transport returns the transport defined in s and a string where the
// transport prefix is removed (if there was any). If no transport is defined
// we default to TransportDNS
func Transport(s string) (trans string, addr string) {
switch {
case strings.HasPrefix(s, transport.TLS+"://"):
s = s[len(transport.TLS+"://"):]
return transport.TLS, s
case strings.HasPrefix(s, transport.DNS+"://"):
s = s[len(transport.DNS+"://"):]
return transport.DNS, s
case strings.HasPrefix(s, transport.QUIC+"://"):
s = s[len(transport.QUIC+"://"):]
return transport.QUIC, s
case strings.HasPrefix(s, transport.GRPC+"://"):
s = s[len(transport.GRPC+"://"):]
return transport.GRPC, s
case strings.HasPrefix(s, transport.HTTPS+"://"):
s = s[len(transport.HTTPS+"://"):]
return transport.HTTPS, s
case strings.HasPrefix(s, transport.UNIX+"://"):
s = s[len(transport.UNIX+"://"):]
return transport.UNIX, s
}
return transport.DNS, s
}
// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses
// inband healthchecking.
package proxy
import (
"context"
"errors"
"io"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
const (
ErrTransportStopped = "proxy: transport stopped"
ErrTransportStoppedDuringDial = "proxy: transport stopped during dial"
ErrTransportStoppedRetClosed = "proxy: transport stopped, ret channel closed"
ErrTransportStoppedDuringRetWait = "proxy: transport stopped during ret wait"
)
// limitTimeout is a utility function to auto-tune timeout values
// average observed time is moved towards the last observed delay moderated by a weight
// next timeout to use will be the double of the computed average, limited by min and max frame.
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
rt := time.Duration(atomic.LoadInt64(currentAvg))
if rt < minValue {
return minValue
}
if rt < maxValue/2 {
return 2 * rt
}
return maxValue
}
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
dt := time.Duration(atomic.LoadInt64(currentAvg))
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
}
func (t *Transport) dialTimeout() time.Duration {
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
}
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
}
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
// If tls has been configured; use it.
if t.tlsConfig != nil {
proto = "tcp-tls"
}
// Check if transport is stopped before attempting to dial
select {
case <-t.stop:
return nil, false, errors.New(ErrTransportStopped)
default:
}
// Use select to avoid blocking if connManager has stopped
select {
case t.dial <- proto:
// Successfully sent dial request
case <-t.stop:
return nil, false, errors.New(ErrTransportStoppedDuringDial)
}
// Receive response with stop awareness
select {
case pc, ok := <-t.ret:
if !ok {
// ret channel was closed by connManager during stop
return nil, false, errors.New(ErrTransportStoppedRetClosed)
}
if pc != nil {
connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1)
return pc, true, nil
}
connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1)
reqTime := time.Now()
timeout := t.dialTimeout()
if proto == "tcp-tls" {
conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
t.updateDialTimeout(time.Since(reqTime))
return &persistConn{c: conn}, false, err
}
conn, err := dns.DialTimeout(proto, t.addr, timeout)
t.updateDialTimeout(time.Since(reqTime))
return &persistConn{c: conn}, false, err
case <-t.stop:
return nil, false, errors.New(ErrTransportStoppedDuringRetWait)
}
}
// Connect selects an upstream, sends the request and waits for a response.
func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
start := time.Now()
var proto string
switch {
case opts.ForceTCP: // TCP flag has precedence over UDP flag
proto = "tcp"
case opts.PreferUDP:
proto = "udp"
default:
proto = state.Proto()
}
pc, cached, err := p.transport.Dial(proto)
if err != nil {
return nil, err
}
// Set buffer size correctly for this client.
pc.c.UDPSize = max(uint16(state.Size()), 512)
pc.c.SetWriteDeadline(time.Now().Add(maxTimeout))
// records the origin Id before upstream.
originId := state.Req.Id
state.Req.Id = dns.Id()
defer func() {
state.Req.Id = originId
}()
if err := pc.c.WriteMsg(state.Req); err != nil {
pc.c.Close() // not giving it back
if err == io.EOF && cached {
return nil, ErrCachedClosed
}
return nil, err
}
var ret *dns.Msg
pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
for {
ret, err = pc.c.ReadMsg()
if err != nil {
if ret != nil && (state.Req.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) {
// For UDP, if the error is an overflow, we probably have an upstream misbehaving in some way.
// (e.g. sending >512 byte responses without an eDNS0 OPT RR).
// Instead of returning an error, return an empty response with TC bit set. This will make the
// client retry over TCP (if that's supported) or at least receive a clean
// error. The connection is still good so we break before the close.
// Truncate the response.
ret = truncateResponse(ret)
break
}
pc.c.Close() // not giving it back
if err == io.EOF && cached {
return nil, ErrCachedClosed
}
// recovery the origin Id after upstream.
if ret != nil {
ret.Id = originId
}
return ret, err
}
// drop out-of-order responses
if state.Req.Id == ret.Id {
break
}
}
// recovery the origin Id after upstream.
ret.Id = originId
p.transport.Yield(pc)
rc, ok := dns.RcodeToString[ret.Rcode]
if !ok {
rc = strconv.Itoa(ret.Rcode)
}
requestDuration.WithLabelValues(p.proxyName, p.addr, rc).Observe(time.Since(start).Seconds())
return ret, nil
}
const cumulativeAvgWeight = 4
// Function to determine if a response should be truncated.
func shouldTruncateResponse(err error) bool {
// This is to handle a scenario in which upstream sets the TC bit, but doesn't truncate the response
// and we get ErrBuf instead of overflow.
if _, isDNSErr := err.(*dns.Error); isDNSErr && errors.Is(err, dns.ErrBuf) {
return true
} else if strings.Contains(err.Error(), "overflow") {
return true
}
return false
}
// Function to return an empty response with TC (truncated) bit set.
func truncateResponse(response *dns.Msg) *dns.Msg {
// Clear out Answer, Extra, and Ns sections
response.Answer = nil
response.Extra = nil
response.Ns = nil
// Set TC bit to indicate truncation.
response.Truncated = true
return response
}
package proxy
import (
"crypto/tls"
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
// HealthChecker checks the upstream health.
type HealthChecker interface {
Check(*Proxy) error
SetTLSConfig(*tls.Config)
GetTLSConfig() *tls.Config
SetRecursionDesired(bool)
GetRecursionDesired() bool
SetDomain(domain string)
GetDomain() string
SetTCPTransport()
GetReadTimeout() time.Duration
SetReadTimeout(time.Duration)
GetWriteTimeout() time.Duration
SetWriteTimeout(time.Duration)
}
// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
type dnsHc struct {
c *dns.Client
recursionDesired bool
domain string
proxyName string
}
// NewHealthChecker returns a new HealthChecker based on transport.
func NewHealthChecker(proxyName, trans string, recursionDesired bool, domain string) HealthChecker {
switch trans {
case transport.DNS, transport.TLS:
c := new(dns.Client)
c.Net = "udp"
c.ReadTimeout = 1 * time.Second
c.WriteTimeout = 1 * time.Second
return &dnsHc{
c: c,
recursionDesired: recursionDesired,
domain: domain,
proxyName: proxyName,
}
}
log.Warningf("No healthchecker for transport %q", trans)
return nil
}
func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
h.c.Net = "tcp-tls"
h.c.TLSConfig = cfg
}
func (h *dnsHc) GetTLSConfig() *tls.Config {
return h.c.TLSConfig
}
func (h *dnsHc) SetRecursionDesired(recursionDesired bool) {
h.recursionDesired = recursionDesired
}
func (h *dnsHc) GetRecursionDesired() bool {
return h.recursionDesired
}
func (h *dnsHc) SetDomain(domain string) {
h.domain = domain
}
func (h *dnsHc) GetDomain() string {
return h.domain
}
func (h *dnsHc) SetTCPTransport() {
h.c.Net = "tcp"
}
func (h *dnsHc) GetReadTimeout() time.Duration {
return h.c.ReadTimeout
}
func (h *dnsHc) SetReadTimeout(t time.Duration) {
h.c.ReadTimeout = t
}
func (h *dnsHc) GetWriteTimeout() time.Duration {
return h.c.WriteTimeout
}
func (h *dnsHc) SetWriteTimeout(t time.Duration) {
h.c.WriteTimeout = t
}
// For HC, we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty
// replies are considered fails, basically anything else constitutes a healthy upstream.
// Check is used as the up.Func in the up.Probe.
func (h *dnsHc) Check(p *Proxy) error {
err := h.send(p.addr)
if err != nil {
healthcheckFailureCount.WithLabelValues(p.proxyName, p.addr).Add(1)
p.incrementFails()
return err
}
atomic.StoreUint32(&p.fails, 0)
return nil
}
func (h *dnsHc) send(addr string) error {
ping := new(dns.Msg)
ping.SetQuestion(h.domain, dns.TypeNS)
ping.RecursionDesired = h.recursionDesired
m, _, err := h.c.Exchange(ping, addr)
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
if err != nil && m != nil {
// Silly check, something sane came back.
if m.Response || m.Opcode == dns.OpcodeQuery {
err = nil
}
}
return err
}
package proxy
import (
"crypto/tls"
"sort"
"time"
"github.com/miekg/dns"
)
// a persistConn hold the dns.Conn and the last used time.
type persistConn struct {
c *dns.Conn
used time.Time
}
// Transport hold the persistent cache.
type Transport struct {
avgDialTime int64 // kind of average time of dial time
conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls.
expire time.Duration // After this duration a connection is expired.
addr string
tlsConfig *tls.Config
proxyName string
dial chan string
yield chan *persistConn
ret chan *persistConn
stop chan bool
}
func newTransport(proxyName, addr string) *Transport {
t := &Transport{
avgDialTime: int64(maxDialTimeout / 2),
conns: [typeTotalCount][]*persistConn{},
expire: defaultExpire,
addr: addr,
dial: make(chan string),
yield: make(chan *persistConn),
ret: make(chan *persistConn),
stop: make(chan bool),
proxyName: proxyName,
}
return t
}
// connManager manages the persistent connection cache for UDP and TCP.
func (t *Transport) connManager() {
ticker := time.NewTicker(defaultExpire)
defer ticker.Stop()
Wait:
for {
select {
case proto := <-t.dial:
transtype := stringToTransportType(proto)
// take the last used conn - complexity O(1)
if stack := t.conns[transtype]; len(stack) > 0 {
pc := stack[len(stack)-1]
if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn.
t.conns[transtype] = stack[:len(stack)-1]
t.ret <- pc
continue Wait
}
// clear entire cache if the last conn is expired
t.conns[transtype] = nil
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack)
}
t.ret <- nil
case pc := <-t.yield:
transtype := t.transportTypeFromConn(pc)
t.conns[transtype] = append(t.conns[transtype], pc)
case <-ticker.C:
t.cleanup(false)
case <-t.stop:
t.cleanup(true)
close(t.ret)
return
}
}
}
// closeConns closes connections.
func closeConns(conns []*persistConn) {
for _, pc := range conns {
pc.c.Close()
}
}
// cleanup removes connections from cache.
func (t *Transport) cleanup(all bool) {
staleTime := time.Now().Add(-t.expire)
for transtype, stack := range t.conns {
if len(stack) == 0 {
continue
}
if all {
t.conns[transtype] = nil
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack)
continue
}
if stack[0].used.After(staleTime) {
continue
}
// connections in stack are sorted by "used"
good := sort.Search(len(stack), func(i int) bool {
return stack[i].used.After(staleTime)
})
t.conns[transtype] = stack[good:]
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack[:good])
}
}
// It is hard to pin a value to this, the import thing is to no block forever, losing at cached connection is not terrible.
const yieldTimeout = 25 * time.Millisecond
// Yield returns the connection to transport for reuse.
func (t *Transport) Yield(pc *persistConn) {
pc.used = time.Now() // update used time
// Make this non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This
// blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning
// these connection is an optimization anyway.
select {
case t.yield <- pc:
return
case <-time.After(yieldTimeout):
return
}
}
// Start starts the transport's connection manager.
func (t *Transport) Start() { go t.connManager() }
// Stop stops the transport's connection manager.
func (t *Transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport.
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport.
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
// GetTLSConfig returns the TLS config in transport.
func (t *Transport) GetTLSConfig() *tls.Config { return t.tlsConfig }
const (
defaultExpire = 10 * time.Second
minDialTimeout = 1 * time.Second
maxDialTimeout = 30 * time.Second
)
package proxy
import (
"crypto/tls"
"runtime"
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/up"
)
// Proxy defines an upstream host.
type Proxy struct {
fails uint32
addr string
proxyName string
transport *Transport
readTimeout time.Duration
// health checking
probe *up.Probe
health HealthChecker
}
// NewProxy returns a new proxy.
func NewProxy(proxyName, addr, trans string) *Proxy {
p := &Proxy{
addr: addr,
fails: 0,
probe: up.New(),
readTimeout: 2 * time.Second,
transport: newTransport(proxyName, addr),
health: NewHealthChecker(proxyName, trans, true, "."),
proxyName: proxyName,
}
runtime.SetFinalizer(p, (*Proxy).finalizer)
return p
}
func (p *Proxy) Addr() string { return p.addr }
// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client.
func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
p.transport.SetTLSConfig(cfg)
p.health.SetTLSConfig(cfg)
}
// SetExpire sets the expire duration in the lower p.transport.
func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
func (p *Proxy) GetHealthchecker() HealthChecker {
return p.health
}
func (p *Proxy) GetTransport() *Transport {
return p.transport
}
func (p *Proxy) Fails() uint32 {
return atomic.LoadUint32(&p.fails)
}
// Healthcheck kicks of a round of health checks for this proxy.
func (p *Proxy) Healthcheck() {
if p.health == nil {
log.Warning("No healthchecker")
return
}
p.probe.Do(func() error {
return p.health.Check(p)
})
}
// Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
func (p *Proxy) Down(maxfails uint32) bool {
if maxfails == 0 {
return false
}
fails := atomic.LoadUint32(&p.fails)
return fails > maxfails
}
// Stop close stops the health checking goroutine.
func (p *Proxy) Stop() { p.probe.Stop() }
func (p *Proxy) finalizer() { p.transport.Stop() }
// Start starts the proxy's healthchecking.
func (p *Proxy) Start(duration time.Duration) {
p.probe.Start(duration)
p.transport.Start()
}
func (p *Proxy) SetReadTimeout(duration time.Duration) {
p.readTimeout = duration
}
// incrementFails increments the number of fails safely.
func (p *Proxy) incrementFails() {
curVal := atomic.LoadUint32(&p.fails)
if curVal > curVal+1 {
// overflow occurred, do not update the counter again
return
}
atomic.AddUint32(&p.fails, 1)
}
const (
maxTimeout = 2 * time.Second
)
package proxy
import (
"net"
)
type transportType int
const (
typeUDP transportType = iota
typeTCP
typeTLS
typeTotalCount // keep this last
)
func stringToTransportType(s string) transportType {
switch s {
case "udp":
return typeUDP
case "tcp":
return typeTCP
case "tcp-tls":
return typeTLS
}
return typeUDP
}
func (t *Transport) transportTypeFromConn(pc *persistConn) transportType {
if _, ok := pc.c.Conn.(*net.UDPConn); ok {
return typeUDP
}
if t.tlsConfig == nil {
return typeTCP
}
return typeTLS
}
// Package rand is used for concurrency safe random number generator.
// This package provides a thread-safe wrapper around math/rand for use in
// load balancing and server selection. It is NOT suitable for cryptographic
// purposes and should not be used for security-sensitive operations.
package rand
import (
"math/rand"
"sync"
)
// Rand is used for concurrency safe random number generator.
type Rand struct {
m sync.Mutex
r *rand.Rand
}
// New returns a new Rand from seed.
func New(seed int64) *Rand {
return &Rand{r: rand.New(rand.NewSource(seed))}
}
// Int returns a non-negative pseudo-random int from the Source in Rand.r.
func (r *Rand) Int() int {
r.m.Lock()
v := r.r.Int()
r.m.Unlock()
return v
}
// Perm returns, as a slice of n ints, a pseudo-random permutation of the
// integers in the half-open interval [0,n) from the Source in Rand.r.
func (r *Rand) Perm(n int) []int {
r.m.Lock()
v := r.r.Perm(n)
r.m.Unlock()
return v
}
package rcode
import (
"strconv"
"github.com/miekg/dns"
)
// ToString convert the rcode to the official DNS string, or to "RCODE"+value if the RCODE value is unknown.
func ToString(rcode int) string {
if str, ok := dns.RcodeToString[rcode]; ok {
return str
}
return "RCODE" + strconv.Itoa(rcode)
}
package replacer
import (
"context"
"strconv"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Replacer replaces labels for values in strings.
type Replacer struct{}
// New makes a new replacer. This only needs to be called once in the setup and
// then call Replace for each incoming message. A replacer is safe for concurrent use.
func New() Replacer {
return Replacer{}
}
// Replace performs a replacement of values on s and returns the string with the replaced values.
func (r Replacer) Replace(ctx context.Context, state request.Request, rr *dnstest.Recorder, s string) string {
return loadFormat(s).Replace(ctx, state, rr)
}
const (
headerReplacer = "{>"
// EmptyValue is the default empty value.
EmptyValue = "-"
)
// labels are all supported labels that can be used in the default Replacer.
var labels = map[string]struct{}{
"{type}": {},
"{name}": {},
"{class}": {},
"{proto}": {},
"{size}": {},
"{remote}": {},
"{port}": {},
"{local}": {},
// Header values.
headerReplacer + "id}": {},
headerReplacer + "opcode}": {},
headerReplacer + "do}": {},
headerReplacer + "bufsize}": {},
// Recorded replacements.
"{rcode}": {},
"{rsize}": {},
"{duration}": {},
headerReplacer + "rflags}": {},
}
// appendValue appends the current value of label.
func appendValue(b []byte, state request.Request, rr *dnstest.Recorder, label string) []byte {
switch label {
// Recorded replacements.
case "{rcode}":
if rr == nil || rr.Msg == nil {
return append(b, EmptyValue...)
}
if rcode := dns.RcodeToString[rr.Rcode]; rcode != "" {
return append(b, rcode...)
}
return strconv.AppendInt(b, int64(rr.Rcode), 10)
case "{rsize}":
if rr == nil {
return append(b, EmptyValue...)
}
return strconv.AppendInt(b, int64(rr.Len), 10)
case "{duration}":
if rr == nil {
return append(b, EmptyValue...)
}
secs := time.Since(rr.Start).Seconds()
return append(strconv.AppendFloat(b, secs, 'f', -1, 64), 's')
case headerReplacer + "rflags}":
if rr != nil && rr.Msg != nil {
return appendFlags(b, rr.Msg.MsgHdr)
}
return append(b, EmptyValue...)
}
if (request.Request{}) == state {
return append(b, EmptyValue...)
}
switch label {
case "{type}":
return append(b, state.Type()...)
case "{name}":
return append(b, state.Name()...)
case "{class}":
return append(b, state.Class()...)
case "{proto}":
return append(b, state.Proto()...)
case "{size}":
return strconv.AppendInt(b, int64(state.Req.Len()), 10)
case "{remote}":
return appendAddrToRFC3986(b, state.IP())
case "{port}":
return append(b, state.Port()...)
case "{local}":
return appendAddrToRFC3986(b, state.LocalIP())
// Header placeholders (case-insensitive).
case headerReplacer + "id}":
return strconv.AppendInt(b, int64(state.Req.Id), 10)
case headerReplacer + "opcode}":
return strconv.AppendInt(b, int64(state.Req.Opcode), 10)
case headerReplacer + "do}":
return strconv.AppendBool(b, state.Do())
case headerReplacer + "bufsize}":
return strconv.AppendInt(b, int64(state.Size()), 10)
default:
return append(b, EmptyValue...)
}
}
// appendFlags checks all header flags and appends those
// that are set as a string separated with commas
func appendFlags(b []byte, h dns.MsgHdr) []byte {
origLen := len(b)
if h.Response {
b = append(b, "qr,"...)
}
if h.Authoritative {
b = append(b, "aa,"...)
}
if h.Truncated {
b = append(b, "tc,"...)
}
if h.RecursionDesired {
b = append(b, "rd,"...)
}
if h.RecursionAvailable {
b = append(b, "ra,"...)
}
if h.Zero {
b = append(b, "z,"...)
}
if h.AuthenticatedData {
b = append(b, "ad,"...)
}
if h.CheckingDisabled {
b = append(b, "cd,"...)
}
if n := len(b); n > origLen {
return b[:n-1] // trim trailing ','
}
return b
}
// appendAddrToRFC3986 will add brackets to the address if it is an IPv6 address.
func appendAddrToRFC3986(b []byte, addr string) []byte {
if strings.IndexByte(addr, ':') != -1 {
b = append(b, '[')
b = append(b, addr...)
b = append(b, ']')
} else {
b = append(b, addr...)
}
return b
}
type nodeType int
const (
typeLabel nodeType = iota // "{type}"
typeLiteral // "foo"
typeMetadata // "{/metadata}"
)
// A node represents a segment of a parsed format. For example: "A {type}"
// contains two nodes: "A " (literal); and "{type}" (label).
type node struct {
value string // Literal value, label or metadata label
typ nodeType
}
// A replacer is an ordered list of all the nodes in a format.
type replacer []node
func parseFormat(s string) replacer {
// Assume there is a literal between each label - its cheaper to over
// allocate once than allocate twice.
rep := make(replacer, 0, strings.Count(s, "{")*2)
for {
// We find the right bracket then backtrack to find the left bracket.
// This allows us to handle formats like: "{ {foo} }".
j := strings.IndexByte(s, '}')
if j < 0 {
break
}
i := strings.LastIndexByte(s[:j], '{')
if i < 0 {
// Handle: "A } {foo}" by treating "A }" as a literal
rep = append(rep, node{
value: s[:j+1],
typ: typeLiteral,
})
s = s[j+1:]
continue
}
val := s[i : j+1]
var typ nodeType
switch _, ok := labels[val]; {
case ok:
typ = typeLabel
case strings.HasPrefix(val, "{/"):
// Strip "{/}" from metadata labels
val = val[2 : len(val)-1]
typ = typeMetadata
default:
// Given: "A {X}" val is "{X}" expand it to the whole literal.
val = s[:j+1]
typ = typeLiteral
}
// Append any leading literal. Given "A {type}" the literal is "A "
if i != 0 && typ != typeLiteral {
rep = append(rep, node{
value: s[:i],
typ: typeLiteral,
})
}
rep = append(rep, node{
value: val,
typ: typ,
})
s = s[j+1:]
}
if len(s) != 0 {
rep = append(rep, node{
value: s,
typ: typeLiteral,
})
}
return rep
}
var replacerCache sync.Map // map[string]replacer
func loadFormat(s string) replacer {
if v, ok := replacerCache.Load(s); ok {
return v.(replacer)
}
v, _ := replacerCache.LoadOrStore(s, parseFormat(s))
return v.(replacer)
}
// bufPool stores pointers to scratch buffers.
var bufPool = sync.Pool{
New: func() any {
return make([]byte, 0, 256)
},
}
func (r replacer) Replace(ctx context.Context, state request.Request, rr *dnstest.Recorder) string {
b := bufPool.Get().([]byte)
for _, s := range r {
switch s.typ {
case typeLabel:
b = appendValue(b, state, rr, s.value)
case typeLiteral:
b = append(b, s.value...)
case typeMetadata:
if fm := metadata.ValueFunc(ctx, s.value); fm != nil {
b = append(b, fm()...)
} else {
b = append(b, EmptyValue...)
}
}
}
s := string(b)
//nolint:staticcheck
bufPool.Put(b[:0])
return s
}
package response
import "fmt"
// Class holds sets of Types
type Class int
const (
// All is a meta class encompassing all the classes.
All Class = iota
// Success is a class for a successful response.
Success
// Denial is a class for denying existence (NXDOMAIN, or a nodata: type does not exist)
Denial
// Error is a class for errors, right now defined as not Success and not Denial
Error
)
func (c Class) String() string {
switch c {
case All:
return "all"
case Success:
return "success"
case Denial:
return "denial"
case Error:
return "error"
}
return ""
}
// ClassFromString returns the class from the string s. If not class matches
// the All class and an error are returned
func ClassFromString(s string) (Class, error) {
switch s {
case "all":
return All, nil
case "success":
return Success, nil
case "denial":
return Denial, nil
case "error":
return Error, nil
}
return All, fmt.Errorf("invalid Class: %s", s)
}
// Classify classifies the Type t, it returns its Class.
func Classify(t Type) Class {
switch t {
case NoError, Delegation:
return Success
case NameError, NoData:
return Denial
case OtherError:
fallthrough
default:
return Error
}
}
package response
import (
"fmt"
"time"
"github.com/miekg/dns"
)
// Type is the type of the message.
type Type int
const (
// NoError indicates a positive reply
NoError Type = iota
// NameError is a NXDOMAIN in header, SOA in auth.
NameError
// ServerError is a set of errors we want to cache, for now it contains SERVFAIL and NOTIMPL.
ServerError
// NoData indicates name found, but not the type: NOERROR in header, SOA in auth.
NoData
// Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked).
Delegation
// Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR.
Meta
// Update is an dynamic update message.
Update
// OtherError indicates any other error: don't cache these.
OtherError
)
var toString = map[Type]string{
NoError: "NOERROR",
NameError: "NXDOMAIN",
ServerError: "SERVERERROR",
NoData: "NODATA",
Delegation: "DELEGATION",
Meta: "META",
Update: "UPDATE",
OtherError: "OTHERERROR",
}
func (t Type) String() string { return toString[t] }
// TypeFromString returns the type from the string s. If not type matches
// the OtherError type and an error are returned.
func TypeFromString(s string) (Type, error) {
for t, str := range toString {
if s == str {
return t, nil
}
}
return NoError, fmt.Errorf("invalid Type: %s", s)
}
// Typify classifies a message, it returns the Type.
func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) {
if m == nil {
return OtherError, nil
}
opt := m.IsEdns0()
do := false
if opt != nil {
do = opt.Do()
}
if m.Opcode == dns.OpcodeUpdate {
return Update, opt
}
// Check transfer and update first
if m.Opcode == dns.OpcodeNotify {
return Meta, opt
}
if len(m.Question) > 0 {
if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR {
return Meta, opt
}
}
// If our message contains any expired sigs and we care about that, we should return expired
if do {
if expired := typifyExpired(m, t); expired {
return OtherError, opt
}
}
if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {
return NoError, opt
}
soa := false
ns := 0
for _, r := range m.Ns {
if r.Header().Rrtype == dns.TypeSOA {
soa = true
continue
}
if r.Header().Rrtype == dns.TypeNS {
ns++
}
}
if soa && m.Rcode == dns.RcodeSuccess {
return NoData, opt
}
if soa && m.Rcode == dns.RcodeNameError {
return NameError, opt
}
if m.Rcode == dns.RcodeServerFailure || m.Rcode == dns.RcodeNotImplemented {
return ServerError, opt
}
if ns > 0 && m.Rcode == dns.RcodeSuccess {
return Delegation, opt
}
if m.Rcode == dns.RcodeSuccess {
return NoError, opt
}
return OtherError, opt
}
func typifyExpired(m *dns.Msg, t time.Time) bool {
if expired := typifyExpiredRRSIG(m.Answer, t); expired {
return true
}
if expired := typifyExpiredRRSIG(m.Ns, t); expired {
return true
}
if expired := typifyExpiredRRSIG(m.Extra, t); expired {
return true
}
return false
}
func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool {
for _, r := range rrs {
if r.Header().Rrtype != dns.TypeRRSIG {
continue
}
ok := r.(*dns.RRSIG).ValidityPeriod(t)
if !ok {
return true
}
}
return false
}
//go:build go1.11 && (aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd)
package reuseport
import (
"context"
"net"
"syscall"
"github.com/coredns/coredns/plugin/pkg/log"
"golang.org/x/sys/unix"
)
func control(network, address string, c syscall.RawConn) error {
c.Control(func(fd uintptr) {
if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
log.Warningf("Failed to set SO_REUSEPORT on socket: %s", err)
}
})
return nil
}
// Listen announces on the local network address. See net.Listen for more information.
// If SO_REUSEPORT is available it will be set on the socket.
func Listen(network, addr string) (net.Listener, error) {
lc := net.ListenConfig{Control: control}
return lc.Listen(context.Background(), network, addr)
}
// ListenPacket announces on the local network address. See net.ListenPacket for more information.
// If SO_REUSEPORT is available it will be set on the socket.
func ListenPacket(network, addr string) (net.PacketConn, error) {
lc := net.ListenConfig{Control: control}
return lc.ListenPacket(context.Background(), network, addr)
}
/*
Copyright 2012 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package singleflight provides a duplicate function call suppression
// mechanism.
package singleflight
import "sync"
// call is an in-flight or completed Do call
type call struct {
wg sync.WaitGroup
val any
err error
}
// Group represents a class of work and forms a namespace in which
// units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[uint64]*call // lazily initialized
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
func (g *Group) Do(key uint64, fn func() (any, error)) (any, error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[uint64]*call)
}
if c, ok := g.m[key]; ok {
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
return c.val, c.err
}
package tls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"time"
)
func setTLSDefaults(ctls *tls.Config) {
ctls.MinVersion = tls.VersionTLS12
ctls.MaxVersion = tls.VersionTLS13
ctls.CipherSuites = []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}
}
// NewTLSConfigFromArgs returns a TLS config based upon the passed
// in list of arguments. Typically these come straight from the
// Corefile.
// no args
// - creates a Config with no cert and using system CAs
// - use for a client that talks to a server with a public signed cert (CA installed in system)
// - the client will not be authenticated by the server since there is no cert
//
// one arg: the path to CA PEM file
// - creates a Config with no cert using a specific CA
// - use for a client that talks to a server with a private signed cert (CA not installed in system)
// - the client will not be authenticated by the server since there is no cert
//
// two args: path to cert PEM file, the path to private key PEM file
// - creates a Config with a cert, using system CAs to validate the other end
// - use for:
// - a server; or,
// - a client that talks to a server with a public cert and needs certificate-based authentication
// - the other end will authenticate this end via the provided cert
// - the cert of the other end will be verified via system CAs
//
// three args: path to cert PEM file, path to client private key PEM file, path to CA PEM file
// - creates a Config with the cert, using specified CA to validate the other end
// - use for:
// - a server; or,
// - a client that talks to a server with a privately signed cert and needs certificate-based
// authentication
// - the other end will authenticate this end via the provided cert
// - this end will verify the other end's cert using the specified CA
func NewTLSConfigFromArgs(args ...string) (*tls.Config, error) {
var err error
var c *tls.Config
switch len(args) {
case 0:
// No client cert, use system CA
c, err = NewTLSClientConfig("")
case 1:
// No client cert, use specified CA
c, err = NewTLSClientConfig(args[0])
case 2:
// Client cert, use system CA
c, err = NewTLSConfig(args[0], args[1], "")
case 3:
// Client cert, use specified CA
c, err = NewTLSConfig(args[0], args[1], args[2])
default:
err = fmt.Errorf("maximum of three arguments allowed for TLS config, found %d", len(args))
}
if err != nil {
return nil, err
}
return c, nil
}
// NewTLSConfig returns a TLS config that includes a certificate
// Use for server TLS config or when using a client certificate
// If caPath is empty, system CAs will be used
func NewTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("could not load TLS cert: %s", err)
}
roots, err := loadRoots(caPath)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: roots}
setTLSDefaults(tlsConfig)
return tlsConfig, nil
}
// NewTLSClientConfig returns a TLS config for a client connection
// If caPath is empty, system CAs will be used
func NewTLSClientConfig(caPath string) (*tls.Config, error) {
roots, err := loadRoots(caPath)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{RootCAs: roots}
setTLSDefaults(tlsConfig)
return tlsConfig, nil
}
func loadRoots(caPath string) (*x509.CertPool, error) {
if caPath == "" {
return nil, nil
}
roots := x509.NewCertPool()
pem, err := os.ReadFile(filepath.Clean(caPath))
if err != nil {
return nil, fmt.Errorf("error reading %s: %s", caPath, err)
}
ok := roots.AppendCertsFromPEM(pem)
if !ok {
return nil, fmt.Errorf("could not read root certs: %s", err)
}
return roots, nil
}
// NewHTTPSTransport returns an HTTP transport configured using tls.Config
func NewHTTPSTransport(cc *tls.Config) *http.Transport {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: cc,
MaxIdleConnsPerHost: 25,
}
return tr
}
// Package uniq keeps track of "thing" that are either "todo" or "done". Multiple
// identical events will only be processed once.
package uniq
// U keeps track of item to be done.
type U struct {
u map[string]item
}
type item struct {
state int // either todo or done
f func() error // function to be executed.
}
// New returns a new initialized U.
func New() U { return U{u: make(map[string]item)} }
// Set sets function f in U under key. If the key already exists it is not overwritten.
func (u U) Set(key string, f func() error) {
if _, ok := u.u[key]; ok {
return
}
u.u[key] = item{todo, f}
}
// Unset removes the key.
func (u U) Unset(key string) {
delete(u.u, key)
}
// ForEach iterates over u and executes f for each element that is 'todo' and sets it to 'done'.
func (u U) ForEach() error {
for k, v := range u.u {
if v.state == todo {
v.f()
}
v.state = done
u.u[k] = v
}
return nil
}
const (
todo = 1
done = 2
)
// Package up is used to run a function for some duration. If a new function is added while a previous run is
// still ongoing, nothing new will be executed.
package up
import (
"sync"
"time"
)
// Probe is used to run a single Func until it returns true (indicating a target is healthy). If an Func
// is already in progress no new one will be added, i.e. there is always a maximum of 1 checks in flight.
//
// There is a tradeoff to be made in figuring out quickly that an upstream is healthy and not doing much work
// (sending queries) to find that out. Having some kind of exp. backoff here won't help much, because you don't want
// to backoff too much. You then also need random queries to be performed every so often to quickly detect a working
// upstream. In the end we just send a query every 0.5 second to check the upstream. This hopefully strikes a balance
// between getting information about the upstream state quickly and not doing too much work. Note that 0.5s is still an
// eternity in DNS, so we may actually want to shorten it.
type Probe struct {
sync.Mutex
inprogress int
interval time.Duration
}
// Func is used to determine if a target is alive. If so this function must return nil.
type Func func() error
// New returns a pointer to an initialized Probe.
func New() *Probe { return &Probe{} }
// Do will probe target, if a probe is already in progress this is a noop.
func (p *Probe) Do(f Func) {
p.Lock()
if p.inprogress != idle {
p.Unlock()
return
}
p.inprogress = active
interval := p.interval
p.Unlock()
// Passed the lock. Now run f for as long it returns false. If a true is returned
// we return from the goroutine and we can accept another Func to run.
go func() {
i := 1
for {
if err := f(); err == nil {
break
}
time.Sleep(interval)
p.Lock()
if p.inprogress == stop {
p.Unlock()
return
}
p.Unlock()
i++
}
p.Lock()
p.inprogress = idle
p.Unlock()
}()
}
// Stop stops the probing.
func (p *Probe) Stop() {
p.Lock()
p.inprogress = stop
p.Unlock()
}
// Start will initialize the probe manager, after which probes can be initiated with Do.
func (p *Probe) Start(interval time.Duration) {
p.Lock()
p.interval = interval
p.Unlock()
}
const (
idle = iota
active
stop
)
// Package upstream abstracts a upstream lookups so that plugins can handle them in an unified way.
package upstream
import (
"context"
"fmt"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin/pkg/nonwriter"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Upstream is used to resolve CNAME or other external targets via CoreDNS itself.
type Upstream struct{}
// New creates a new Upstream to resolve names using the coredns process.
func New() *Upstream { return &Upstream{} }
// Lookup routes lookups to our selves to make it follow the plugin chain *again*, but with a (possibly) new query. As
// we are doing the query against ourselves again, there is no actual new hop, as such RFC 6891 does not apply and we
// need the EDNS0 option present in the *original* query to be present here too.
func (u *Upstream) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) {
server, ok := ctx.Value(dnsserver.Key{}).(*dnsserver.Server)
if !ok {
return nil, fmt.Errorf("no full server is running")
}
req := state.NewWithQuestion(name, typ)
nw := nonwriter.New(state.W)
server.ServeDNS(ctx, nw, req.Req)
return nw.Msg, nil
}
// Package plugin provides some types and functions common among plugin.
package plugin
import (
"context"
"errors"
"fmt"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
)
type (
// Plugin is a middle layer which represents the traditional
// idea of plugin: it chains one Handler to the next by being
// passed the next Handler in the chain.
Plugin func(Handler) Handler
// Handler is like dns.Handler except ServeDNS may return an rcode
// and/or error.
//
// If ServeDNS writes to the response body, it should return a status
// code. CoreDNS assumes *no* reply has yet been written if the status
// code is one of the following:
//
// * SERVFAIL (dns.RcodeServerFailure)
//
// * REFUSED (dns.RecodeRefused)
//
// * FORMERR (dns.RcodeFormatError)
//
// * NOTIMP (dns.RcodeNotImplemented)
//
// All other response codes signal other handlers above it that the
// response message is already written, and that they should not write
// to it also.
//
// If ServeDNS encounters an error, it should return the error value
// so it can be logged by designated error-handling plugin.
//
// If writing a response after calling another ServeDNS method, the
// returned rcode SHOULD be used when writing the response.
//
// If handling errors after calling another ServeDNS method, the
// returned error value SHOULD be logged or handled accordingly.
//
// Otherwise, return values should be propagated down the plugin
// chain by returning them unchanged.
Handler interface {
ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
Name() string
}
// HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error. See Handler
// documentation for more information.
HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
)
// ServeDNS implements the Handler interface.
func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(ctx, w, r)
}
// Name implements the Handler interface.
func (f HandlerFunc) Name() string { return "handlerfunc" }
// Error returns err with 'plugin/name: ' prefixed to it.
func Error(name string, err error) error { return fmt.Errorf("%s/%s: %w", "plugin", name, err) }
// NextOrFailure calls next.ServeDNS when next is not nil, otherwise it will return, a ServerFailure and a `no next plugin found` error.
func NextOrFailure(name string, next Handler, ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if next != nil {
if span := ot.SpanFromContext(ctx); span != nil {
child := span.Tracer().StartSpan(next.Name(), ot.ChildOf(span.Context()))
defer child.Finish()
ctx = ot.ContextWithSpan(ctx, child)
}
return next.ServeDNS(ctx, w, r)
}
return dns.RcodeServerFailure, Error(name, errors.New("no next plugin found"))
}
// ClientWrite returns true if the response has been written to the client.
// Each plugin to adhere to this protocol.
func ClientWrite(rcode int) bool {
switch rcode {
case dns.RcodeServerFailure:
fallthrough
case dns.RcodeRefused:
fallthrough
case dns.RcodeFormatError:
fallthrough
case dns.RcodeNotImplemented:
return false
}
return true
}
// Namespace is the namespace used for the metrics.
const Namespace = "coredns"
// TimeBuckets is based on Prometheus client_golang prometheus.DefBuckets
var TimeBuckets = prometheus.ExponentialBuckets(0.00025, 2, 16) // from 0.25ms to 8 seconds
// SlimTimeBuckets is low cardinality set of duration buckets.
var SlimTimeBuckets = prometheus.ExponentialBuckets(0.00025, 10, 5) // from 0.25ms to 2.5 seconds
// NativeHistogramBucketFactor controls the resolution of Prometheus native histogram buckets.
// See: https://pkg.go.dev/github.com/prometheus/client_golang@v1.19.0/prometheus#section-readme
var NativeHistogramBucketFactor = 1.05
// ErrOnce is returned when a plugin doesn't support multiple setups per server.
var ErrOnce = errors.New("this plugin can only be used once per Server Block")
// Package pprof implements a debug endpoint for getting profiles using the
// go pprof tooling.
package pprof
import (
"net"
"net/http"
pp "net/http/pprof"
"runtime"
"github.com/coredns/coredns/plugin/pkg/reuseport"
)
type handler struct {
addr string
rateBloc int
ln net.Listener
mux *http.ServeMux
}
func (h *handler) Startup() error {
// Reloading the plugin without changing the listening address results
// in an error unless we reuse the port because Startup is called for
// new handlers before Shutdown is called for the old ones.
ln, err := reuseport.Listen("tcp", h.addr)
if err != nil {
log.Errorf("Failed to start pprof handler: %s", err)
return err
}
h.ln = ln
h.mux = http.NewServeMux()
h.mux.HandleFunc(path, func(rw http.ResponseWriter, req *http.Request) {
http.Redirect(rw, req, path+"/", http.StatusFound)
})
h.mux.HandleFunc(path+"/", pp.Index)
h.mux.HandleFunc(path+"/cmdline", pp.Cmdline)
h.mux.HandleFunc(path+"/profile", pp.Profile)
h.mux.HandleFunc(path+"/symbol", pp.Symbol)
h.mux.HandleFunc(path+"/trace", pp.Trace)
runtime.SetBlockProfileRate(h.rateBloc)
go func() {
http.Serve(h.ln, h.mux)
}()
return nil
}
func (h *handler) Shutdown() error {
if h.ln != nil {
return h.ln.Close()
}
return nil
}
const (
path = "/debug/pprof"
)
package pprof
import (
"net"
"strconv"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
)
var log = clog.NewWithPlugin("pprof")
const defaultAddr = "localhost:6053"
func init() { plugin.Register("pprof", setup) }
func setup(c *caddy.Controller) error {
h := &handler{addr: defaultAddr}
i := 0
for c.Next() {
if i > 0 {
return plugin.Error("pprof", plugin.ErrOnce)
}
i++
args := c.RemainingArgs()
if len(args) == 1 {
h.addr = args[0]
_, _, e := net.SplitHostPort(h.addr)
if e != nil {
return plugin.Error("pprof", c.Errf("%v", e))
}
}
if len(args) > 1 {
return plugin.Error("pprof", c.ArgErr())
}
for c.NextBlock() {
switch c.Val() {
case "block":
args := c.RemainingArgs()
if len(args) > 1 {
return plugin.Error("pprof", c.ArgErr())
}
h.rateBloc = 1
if len(args) > 0 {
t, err := strconv.Atoi(args[0])
if err != nil {
return plugin.Error("pprof", c.Errf("property '%s' invalid integer value '%v'", "block", args[0]))
}
h.rateBloc = t
}
default:
return plugin.Error("pprof", c.Errf("unknown property '%s'", c.Val()))
}
}
}
c.OnStartup(h.Startup)
c.OnShutdown(h.Shutdown)
return nil
}
package quic
import (
"strconv"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() {
caddy.RegisterPlugin("quic", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
err := parseQuic(c)
if err != nil {
return plugin.Error("quic", err)
}
return nil
}
func parseQuic(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
// Skip the "quic" directive itself
c.Next()
// Get any arguments on the "quic" line
args := c.RemainingArgs()
if len(args) > 0 {
return c.ArgErr()
}
// Process all nested directives in the block
for c.NextBlock() {
switch c.Val() {
case "max_streams":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
}
val, err := strconv.Atoi(args[0])
if err != nil {
return c.Errf("invalid max_streams value '%s': %v", args[0], err)
}
if val <= 0 {
return c.Errf("max_streams must be a positive integer: %d", val)
}
if config.MaxQUICStreams != nil {
return c.Err("max_streams already defined for this server block")
}
config.MaxQUICStreams = &val
case "worker_pool_size":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
}
val, err := strconv.Atoi(args[0])
if err != nil {
return c.Errf("invalid worker_pool_size value '%s': %v", args[0], err)
}
if val <= 0 {
return c.Errf("worker_pool_size must be a positive integer: %d", val)
}
if config.MaxQUICWorkerPoolSize != nil {
return c.Err("worker_pool_size already defined for this server block")
}
config.MaxQUICWorkerPoolSize = &val
default:
return c.Errf("unknown property '%s'", c.Val())
}
}
return nil
}
package ready
import (
"sort"
"strings"
"sync"
)
// list is a structure that holds the plugins that signals readiness for this server block.
type list struct {
sync.RWMutex
rs []Readiness
names []string
// keepReadiness indicates whether the readiness status of plugins should be retained
// after they have been confirmed as ready. When set to false, the plugin readiness
// status will be reset to nil to conserve resources, assuming ready plugins don't
// need continuous monitoring.
keepReadiness bool
}
// Reset resets l
func (l *list) Reset() {
l.Lock()
defer l.Unlock()
l.rs = nil
l.names = nil
}
// Append adds a new readiness to l.
func (l *list) Append(r Readiness, name string) {
l.Lock()
defer l.Unlock()
l.rs = append(l.rs, r)
l.names = append(l.names, name)
}
// Ready return true when all plugins ready, if the returned value is false the string
// contains a comma separated list of plugins that are not ready.
func (l *list) Ready() (bool, string) {
l.RLock()
defer l.RUnlock()
ok := true
s := []string{}
for i, r := range l.rs {
if r == nil {
continue
}
if r.Ready() {
if !l.keepReadiness {
l.rs[i] = nil
}
continue
}
ok = false
s = append(s, l.names[i])
}
if ok {
return true, ""
}
sort.Strings(s)
return false, strings.Join(s, ",")
}
// Package ready is used to signal readiness of the CoreDNS process. Once all
// plugins have called in the plugin will signal readiness by returning a 200
// OK on the HTTP handler (on port 8181). If not ready yet, the handler will
// return a 503.
package ready
import (
"io"
"net"
"net/http"
"sync"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/uniq"
)
var (
log = clog.NewWithPlugin("ready")
plugins = &list{}
uniqAddr = uniq.New()
)
type ready struct {
Addr string
sync.RWMutex
ln net.Listener
done bool
mux *http.ServeMux
}
func (rd *ready) onStartup() error {
ln, err := reuseport.Listen("tcp", rd.Addr)
if err != nil {
return err
}
rd.Lock()
rd.ln = ln
rd.mux = http.NewServeMux()
rd.done = true
rd.Unlock()
rd.mux.HandleFunc("/ready", func(w http.ResponseWriter, _ *http.Request) {
rd.Lock()
defer rd.Unlock()
if !rd.done {
w.WriteHeader(http.StatusServiceUnavailable)
io.WriteString(w, "Shutting down")
return
}
ready, notReadyPlugins := plugins.Ready()
if ready {
w.WriteHeader(http.StatusOK)
io.WriteString(w, http.StatusText(http.StatusOK))
return
}
log.Infof("Plugins not ready: %q", notReadyPlugins)
w.WriteHeader(http.StatusServiceUnavailable)
io.WriteString(w, notReadyPlugins)
})
go func() { http.Serve(rd.ln, rd.mux) }()
return nil
}
func (rd *ready) onFinalShutdown() error {
rd.Lock()
defer rd.Unlock()
if !rd.done {
return nil
}
uniqAddr.Unset(rd.Addr)
rd.ln.Close()
rd.done = false
return nil
}
package ready
import (
"net"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("ready", setup) }
func setup(c *caddy.Controller) error {
addr, monType, err := parse(c)
if err != nil {
return plugin.Error("ready", err)
}
if monType == monitorTypeContinuously {
plugins.keepReadiness = true
} else {
plugins.keepReadiness = false
}
rd := &ready{Addr: addr}
uniqAddr.Set(addr, rd.onStartup)
c.OnStartup(func() error { uniqAddr.Set(addr, rd.onStartup); return nil })
c.OnRestartFailed(func() error { uniqAddr.Set(addr, rd.onStartup); return nil })
c.OnStartup(func() error { return uniqAddr.ForEach() })
c.OnRestartFailed(func() error { return uniqAddr.ForEach() })
c.OnStartup(func() error {
plugins.Reset()
for _, p := range dnsserver.GetConfig(c).Handlers() {
if r, ok := p.(Readiness); ok {
plugins.Append(r, p.Name())
}
}
return nil
})
c.OnRestartFailed(func() error {
for _, p := range dnsserver.GetConfig(c).Handlers() {
if r, ok := p.(Readiness); ok {
plugins.Append(r, p.Name())
}
}
return nil
})
c.OnRestart(rd.onFinalShutdown)
c.OnFinalShutdown(rd.onFinalShutdown)
return nil
}
// monitorType represents the type of monitoring behavior for the readiness plugin.
type monitorType string
const (
// monitorTypeUntilReady indicates the monitoring should continue until the system is ready.
monitorTypeUntilReady monitorType = "until-ready"
// monitorTypeContinuously indicates the monitoring should continue indefinitely.
monitorTypeContinuously monitorType = "continuously"
)
func parse(c *caddy.Controller) (string, monitorType, error) {
addr := ":8181"
monType := monitorTypeUntilReady
i := 0
for c.Next() {
if i > 0 {
return "", "", plugin.ErrOnce
}
i++
args := c.RemainingArgs()
switch len(args) {
case 0:
case 1:
addr = args[0]
if _, _, e := net.SplitHostPort(addr); e != nil {
return "", "", e
}
default:
return "", "", c.ArgErr()
}
for c.NextBlock() {
switch c.Val() {
case "monitor":
args := c.RemainingArgs()
if len(args) != 1 {
return "", "", c.ArgErr()
}
var err error
monType, err = parseMonitorType(c, args[0])
if err != nil {
return "", "", err
}
}
}
}
return addr, monType, nil
}
func parseMonitorType(c *caddy.Controller, arg string) (monitorType, error) {
switch arg {
case "until-ready":
return monitorTypeUntilReady, nil
case "continuously":
return monitorTypeContinuously, nil
default:
return "", c.Errf("monitor type '%s' not supported", arg)
}
}
package plugin
import "github.com/coredns/caddy"
// Register registers your plugin with CoreDNS and allows it to be called when the server is running.
func Register(name string, action caddy.SetupFunc) {
caddy.RegisterPlugin(name, caddy.Plugin{
ServerType: "dns",
Action: action,
})
}
// Package reload periodically checks if the Corefile has changed, and reloads if so.
package reload
import (
"bytes"
"crypto/sha512"
"encoding/hex"
"encoding/json"
"sync"
"time"
"github.com/coredns/caddy"
"github.com/coredns/caddy/caddyfile"
"github.com/prometheus/client_golang/prometheus"
)
const (
unused = 0
maybeUsed = 1
used = 2
)
type reload struct {
dur time.Duration
u int
mtx sync.RWMutex
quit chan bool
}
func (r *reload) setUsage(u int) {
r.mtx.Lock()
defer r.mtx.Unlock()
r.u = u
}
func (r *reload) usage() int {
r.mtx.RLock()
defer r.mtx.RUnlock()
return r.u
}
func (r *reload) setInterval(i time.Duration) {
r.mtx.Lock()
defer r.mtx.Unlock()
r.dur = i
}
func (r *reload) interval() time.Duration {
r.mtx.RLock()
defer r.mtx.RUnlock()
return r.dur
}
func parse(corefile caddy.Input) ([]byte, error) {
serverBlocks, err := caddyfile.Parse(corefile.Path(), bytes.NewReader(corefile.Body()), nil)
if err != nil {
return nil, err
}
return json.Marshal(serverBlocks)
}
func hook(event caddy.EventName, info any) error {
if event != caddy.InstanceStartupEvent {
return nil
}
// if reload is removed from the Corefile, then the hook
// is still registered but setup is never called again
// so we need a flag to tell us not to reload
if r.usage() == unused {
return nil
}
// this should be an instance. ok to panic if not
instance := info.(*caddy.Instance)
parsedCorefile, err := parse(instance.Caddyfile())
if err != nil {
return err
}
sha512sum := sha512.Sum512(parsedCorefile)
log.Infof("Running configuration SHA512 = %x\n", sha512sum)
go func() {
tick := time.NewTicker(r.interval())
defer tick.Stop()
for {
select {
case <-tick.C:
corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType())
if err != nil {
continue
}
parsedCorefile, err := parse(corefile)
if err != nil {
log.Warningf("Corefile parse failed: %s", err)
continue
}
s := sha512.Sum512(parsedCorefile)
if s != sha512sum {
reloadInfo.Delete(prometheus.Labels{"hash": "sha512", "value": hex.EncodeToString(sha512sum[:])})
// Let not try to restart with the same file, even though it is wrong.
sha512sum = s
// now lets consider that plugin will not be reload, unless appear in next config file
// change status of usage will be reset in setup if the plugin appears in config file
r.setUsage(maybeUsed)
// If shutdown is in progress, avoid attempting a restart.
if shutdownRequested(r.quit) {
return
}
_, err := instance.Restart(corefile)
reloadInfo.WithLabelValues("sha512", hex.EncodeToString(sha512sum[:])).Set(1)
if err != nil {
log.Errorf("Corefile changed but reload failed: %s", err)
failedCount.Add(1)
continue
}
// we are done, if the plugin was not set used, then it is not.
if r.usage() == maybeUsed {
r.setUsage(unused)
}
return
}
case <-r.quit:
return
}
}
}()
return nil
}
// shutdownRequested reports whether a shutdown has been requested via quit channel.
// helps with unit testing of the shutdown gate logic.
func shutdownRequested(quit <-chan bool) bool {
select {
case <-quit:
return true
default:
return false
}
}
package reload
import (
"fmt"
"math/rand"
"sync"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
)
var log = clog.NewWithPlugin("reload")
func init() { plugin.Register("reload", setup) }
// the info reload is global to all application, whatever number of reloads.
// it is used to transmit data between Setup and start of the hook called 'onInstanceStartup'
// channel for QUIT is never changed in purpose.
// WARNING: this data may be unsync after an invalid attempt of reload Corefile.
var (
r = reload{dur: defaultInterval, u: unused, quit: make(chan bool, 1)}
once, shutOnce sync.Once
)
func setup(c *caddy.Controller) error {
c.Next() // 'reload'
args := c.RemainingArgs()
if len(args) > 2 {
return plugin.Error("reload", c.ArgErr())
}
i := defaultInterval
if len(args) > 0 {
d, err := time.ParseDuration(args[0])
if err != nil {
return plugin.Error("reload", err)
}
i = d
}
if i < minInterval {
return plugin.Error("reload", fmt.Errorf("interval value must be greater or equal to %v", minInterval))
}
j := defaultJitter
if len(args) > 1 {
d, err := time.ParseDuration(args[1])
if err != nil {
return plugin.Error("reload", err)
}
j = d
}
if j < minJitter {
return plugin.Error("reload", fmt.Errorf("jitter value must be greater or equal to %v", minJitter))
}
if j > i/2 {
j = i / 2
}
jitter := time.Duration(rand.Int63n(j.Nanoseconds()) - (j.Nanoseconds() / 2))
i = i + jitter
// prepare info for next onInstanceStartup event
r.setInterval(i)
r.setUsage(used)
once.Do(func() {
caddy.RegisterEventHook("reload", hook)
})
// re-register on finalShutDown as the instance most-likely will be changed
shutOnce.Do(func() {
c.OnFinalShutdown(func() error {
r.quit <- true
return nil
})
})
return nil
}
const (
minJitter = 1 * time.Second
minInterval = 2 * time.Second
defaultInterval = 30 * time.Second
defaultJitter = 15 * time.Second
)
package rewrite
import (
"context"
"fmt"
"strings"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type classRule struct {
fromClass uint16
toClass uint16
NextAction string
}
// newClassRule creates a class matching rule
func newClassRule(nextAction string, args ...string) (Rule, error) {
var from, to uint16
var ok bool
if from, ok = dns.StringToClass[strings.ToUpper(args[0])]; !ok {
return nil, fmt.Errorf("invalid class %q", strings.ToUpper(args[0]))
}
if to, ok = dns.StringToClass[strings.ToUpper(args[1])]; !ok {
return nil, fmt.Errorf("invalid class %q", strings.ToUpper(args[1]))
}
return &classRule{from, to, nextAction}, nil
}
// Rewrite rewrites the current request.
func (rule *classRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if rule.fromClass > 0 && rule.toClass > 0 {
if state.Req.Question[0].Qclass == rule.fromClass {
state.Req.Question[0].Qclass = rule.toClass
return nil, RewriteDone
}
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *classRule) Mode() string { return rule.NextAction }
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// UpstreamInt wraps the Upstream API for dependency injection during testing
type UpstreamInt interface {
Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error)
}
// cnameTargetRule is cname target rewrite rule.
type cnameTargetRule struct {
rewriteType string
paramFromTarget string
paramToTarget string
nextAction string
Upstream UpstreamInt // Upstream for looking up external names during the resolution process.
}
// cnameTargetRuleWithReqState is cname target rewrite rule state
type cnameTargetRuleWithReqState struct {
rule cnameTargetRule
state request.Request
ctx context.Context
}
func (r *cnameTargetRule) getFromAndToTarget(inputCName string) (from string, to string) {
switch r.rewriteType {
case ExactMatch:
return r.paramFromTarget, r.paramToTarget
case PrefixMatch:
if after, ok := strings.CutPrefix(inputCName, r.paramFromTarget); ok {
return inputCName, r.paramToTarget + after
}
case SuffixMatch:
if before, ok := strings.CutSuffix(inputCName, r.paramFromTarget); ok {
return inputCName, before + r.paramToTarget
}
case SubstringMatch:
if strings.Contains(inputCName, r.paramFromTarget) {
return inputCName, strings.ReplaceAll(inputCName, r.paramFromTarget, r.paramToTarget)
}
case RegexMatch:
pattern := regexp.MustCompile(r.paramFromTarget)
regexGroups := pattern.FindStringSubmatch(inputCName)
if len(regexGroups) == 0 {
return "", ""
}
substitution := r.paramToTarget
for groupIndex, groupValue := range regexGroups {
groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}"
substitution = strings.ReplaceAll(substitution, groupIndexStr, groupValue)
}
return inputCName, substitution
}
return "", ""
}
func (r *cnameTargetRuleWithReqState) RewriteResponse(res *dns.Msg, rr dns.RR) {
// logic to rewrite the cname target of dns response
switch rr.Header().Rrtype {
case dns.TypeCNAME:
// rename the target of the cname response
if cname, ok := rr.(*dns.CNAME); ok {
fromTarget, toTarget := r.rule.getFromAndToTarget(cname.Target)
if cname.Target == fromTarget {
// create upstream request with the new target with the same qtype
r.state.Req.Question[0].Name = toTarget
// upRes can be nil if the internal query path didn't write a response
// (e.g. a plugin returned a success rcode without writing, dropped the query,
// or the context was canceled). Guard upRes before dereferencing.
upRes, err := r.rule.Upstream.Lookup(r.ctx, r.state, toTarget, r.state.Req.Question[0].Qtype)
if err != nil {
log.Errorf("upstream lookup failed: %v", err)
return
}
if upRes == nil {
log.Errorf("upstream lookup returned nil")
return
}
var newAnswer []dns.RR
// iterate over first upstram response
// add the cname record to the new answer
for _, rr := range res.Answer {
if cname, ok := rr.(*dns.CNAME); ok {
// change the target name in the response
cname.Target = toTarget
newAnswer = append(newAnswer, rr)
}
}
// iterate over upstream response received
for _, rr := range upRes.Answer {
if rr.Header().Name == toTarget {
newAnswer = append(newAnswer, rr)
}
}
res.Answer = newAnswer
// if not propagated, the truncated response might get cached,
// and it will be impossible to resolve the full response
res.Truncated = upRes.Truncated
}
}
}
}
func newCNAMERule(nextAction string, args ...string) (Rule, error) {
var rewriteType string
var paramFromTarget, paramToTarget string
if len(args) == 3 {
rewriteType = (strings.ToLower(args[0]))
switch rewriteType {
case ExactMatch:
case PrefixMatch:
case SuffixMatch:
case SubstringMatch:
case RegexMatch:
default:
return nil, fmt.Errorf("unknown cname rewrite type: %s", rewriteType)
}
paramFromTarget, paramToTarget = strings.ToLower(args[1]), strings.ToLower(args[2])
} else if len(args) == 2 {
rewriteType = ExactMatch
paramFromTarget, paramToTarget = strings.ToLower(args[0]), strings.ToLower(args[1])
} else {
return nil, fmt.Errorf("too few (%d) arguments for a cname rule", len(args))
}
rule := cnameTargetRule{
rewriteType: rewriteType,
paramFromTarget: paramFromTarget,
paramToTarget: paramToTarget,
nextAction: nextAction,
Upstream: upstream.New(),
}
return &rule, nil
}
// Rewrite rewrites the current request.
func (r *cnameTargetRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if r != nil && len(r.rewriteType) > 0 && len(r.paramFromTarget) > 0 && len(r.paramToTarget) > 0 {
return ResponseRules{&cnameTargetRuleWithReqState{
rule: *r,
state: state,
ctx: ctx,
}}, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (r *cnameTargetRule) Mode() string { return r.nextAction }
package rewrite
import (
"context"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/pkg/edns"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// edns0LocalRule is a rewrite rule for EDNS0_LOCAL options.
type edns0LocalRule struct {
mode string
action string
code uint16
data []byte
revert bool
}
// edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable.
type edns0VariableRule struct {
mode string
action string
code uint16
variable string
revert bool
}
// ends0NsidRule is a rewrite rule for EDNS0_NSID options.
type edns0NsidRule struct {
mode string
action string
revert bool
}
type edns0SetResponseRule struct {
code uint16
}
func (r *edns0SetResponseRule) RewriteResponse(res *dns.Msg, _ dns.RR) {
ednsOpt := res.IsEdns0()
for idx, opt := range ednsOpt.Option {
if opt.Option() == r.code {
ednsOpt.Option = append(ednsOpt.Option[:idx], ednsOpt.Option[idx+1:]...)
return
}
}
}
type edns0ReplaceResponseRule[T dns.EDNS0] struct {
code uint16
source T
}
func (r *edns0ReplaceResponseRule[T]) RewriteResponse(res *dns.Msg, _ dns.RR) {
ednsOpt := res.IsEdns0()
for idx, opt := range ednsOpt.Option {
if opt.Option() == r.code {
ednsOpt.Option[idx] = r.source
return
}
}
}
// setupEdns0Opt will retrieve the EDNS0 OPT or create it if it does not exist.
func setupEdns0Opt(r *dns.Msg) *dns.OPT {
o := r.IsEdns0()
if o == nil {
r.SetEdns0(4096, false)
o = r.IsEdns0()
}
return o
}
func unsetEdns0Option(opt *dns.OPT, code uint16) {
var newOpts []dns.EDNS0
for _, o := range opt.Option {
if o.Option() != code {
newOpts = append(newOpts, o)
}
}
opt.Option = newOpts
}
// Rewrite will alter the request EDNS0 NSID option
func (rule *edns0NsidRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
o := setupEdns0Opt(state.Req)
if rule.action == Unset {
unsetEdns0Option(o, dns.EDNS0NSID)
return nil, RewriteDone
}
var resp ResponseRules
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_NSID); ok {
if rule.action == Replace || rule.action == Set {
if rule.revert {
old := *e
resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_NSID]{code: e.Code, source: &old})
}
e.Nsid = "" // make sure it is empty for request
return resp, RewriteDone
}
}
}
// add option if not found
if rule.action == Append || rule.action == Set {
o.Option = append(o.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""})
if rule.revert {
resp = append(resp, &edns0SetResponseRule{code: dns.EDNS0NSID})
}
return resp, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *edns0NsidRule) Mode() string { return rule.mode }
// Rewrite will alter the request EDNS0 local options.
func (rule *edns0LocalRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
o := setupEdns0Opt(state.Req)
if rule.action == Unset {
unsetEdns0Option(o, rule.code)
return nil, RewriteDone
}
var resp ResponseRules
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_LOCAL); ok {
if rule.code == e.Code {
if rule.action == Replace || rule.action == Set {
if rule.revert {
old := *e
resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_LOCAL]{code: rule.code, source: &old})
}
e.Data = rule.data
return resp, RewriteDone
}
}
}
}
// add option if not found
if rule.action == Append || rule.action == Set {
o.Option = append(o.Option, &dns.EDNS0_LOCAL{Code: rule.code, Data: rule.data})
if rule.revert {
resp = append(resp, &edns0SetResponseRule{code: rule.code})
}
return resp, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *edns0LocalRule) Mode() string { return rule.mode }
// newEdns0Rule creates an EDNS0 rule of the appropriate type based on the args
func newEdns0Rule(mode string, args ...string) (Rule, error) {
if len(args) < 2 {
return nil, fmt.Errorf("too few arguments for an EDNS0 rule")
}
ruleType := strings.ToLower(args[0])
action := strings.ToLower(args[1])
switch action {
case Append:
case Replace:
case Set:
case Unset:
return newEdns0UnsetRule(mode, action, ruleType, args...)
default:
return nil, fmt.Errorf("invalid action: %q", action)
}
// Extract "revert" parameter.
var revert bool
if args[len(args)-1] == "revert" {
revert = true
args = args[:len(args)-1]
}
switch ruleType {
case "local":
if len(args) != 4 {
return nil, fmt.Errorf("EDNS0 local rules require three or four args")
}
// Check for variable option.
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
return newEdns0VariableRule(mode, action, args[2], args[3], revert)
}
return newEdns0LocalRule(mode, action, args[2], args[3], revert)
case "nsid":
if len(args) != 2 {
return nil, fmt.Errorf("EDNS0 NSID rules can accept no more than one arg")
}
return &edns0NsidRule{mode: mode, action: action, revert: revert}, nil
case "subnet":
if len(args) != 4 {
return nil, fmt.Errorf("EDNS0 subnet rules require three or four args")
}
return newEdns0SubnetRule(mode, action, args[2], args[3], revert)
default:
return nil, fmt.Errorf("invalid rule type %q", ruleType)
}
}
func newEdns0UnsetRule(mode string, action string, ruleType string, args ...string) (Rule, error) {
switch ruleType {
case "local":
if len(args) != 3 {
return nil, fmt.Errorf("local unset action requires exactly two arguments")
}
return newEdns0LocalRule(mode, action, args[2], "", false)
case "nsid":
if len(args) != 2 {
return nil, fmt.Errorf("nsid unset action requires exactly one argument")
}
return &edns0NsidRule{mode, action, false}, nil
case "subnet":
if len(args) != 2 {
return nil, fmt.Errorf("subnet unset action requires exactly one argument")
}
return &edns0SubnetRule{mode, 0, 0, action, false}, nil
default:
return nil, fmt.Errorf("invalid rule type %q", ruleType)
}
}
func newEdns0LocalRule(mode, action, code, data string, revert bool) (*edns0LocalRule, error) {
c, err := strconv.ParseUint(code, 0, 16)
if err != nil {
return nil, err
}
decoded := []byte(data)
if strings.HasPrefix(data, "0x") {
decoded, err = hex.DecodeString(data[2:])
if err != nil {
return nil, err
}
}
// Add this code to the ones the server supports.
edns.SetSupportedOption(uint16(c))
return &edns0LocalRule{mode: mode, action: action, code: uint16(c), data: decoded, revert: revert}, nil
}
// newEdns0VariableRule creates an EDNS0 rule that handles variable substitution
func newEdns0VariableRule(mode, action, code, variable string, revert bool) (*edns0VariableRule, error) {
c, err := strconv.ParseUint(code, 0, 16)
if err != nil {
return nil, err
}
//Validate
if !isValidVariable(variable) {
return nil, fmt.Errorf("unsupported variable name %q", variable)
}
// Add this code to the ones the server supports.
edns.SetSupportedOption(uint16(c))
return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable, revert: revert}, nil
}
// ruleData returns the data specified by the variable.
func (rule *edns0VariableRule) ruleData(ctx context.Context, state request.Request) ([]byte, error) {
switch rule.variable {
case queryName:
return []byte(state.QName()), nil
case queryType:
return uint16ToWire(state.QType()), nil
case clientIP:
return ipToWire(state.Family(), state.IP())
case serverIP:
return ipToWire(state.Family(), state.LocalIP())
case clientPort:
return portToWire(state.Port())
case serverPort:
return portToWire(state.LocalPort())
case protocol:
return []byte(state.Proto()), nil
}
fetcher := metadata.ValueFunc(ctx, rule.variable[1:len(rule.variable)-1])
if fetcher != nil {
value := fetcher()
if len(value) > 0 {
return []byte(value), nil
}
}
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
}
// Rewrite will alter the request EDNS0 local options with specified variables.
func (rule *edns0VariableRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
data, err := rule.ruleData(ctx, state)
if err != nil || data == nil {
return nil, RewriteIgnored
}
var resp ResponseRules
o := setupEdns0Opt(state.Req)
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_LOCAL); ok {
if rule.code == e.Code {
if rule.action == Replace || rule.action == Set {
if rule.revert {
old := *e
resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_LOCAL]{code: rule.code, source: &old})
}
e.Data = data
return resp, RewriteDone
}
return nil, RewriteIgnored
}
}
}
// add option if not found
if rule.action == Append || rule.action == Set {
o.Option = append(o.Option, &dns.EDNS0_LOCAL{Code: rule.code, Data: data})
if rule.revert {
resp = append(resp, &edns0SetResponseRule{code: rule.code})
}
return resp, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *edns0VariableRule) Mode() string { return rule.mode }
func isValidVariable(variable string) bool {
switch variable {
case
queryName,
queryType,
clientIP,
clientPort,
protocol,
serverIP,
serverPort:
return true
}
// we cannot validate the labels of metadata - but we can verify it has the syntax of a label
if strings.HasPrefix(variable, "{") && strings.HasSuffix(variable, "}") && metadata.IsLabel(variable[1:len(variable)-1]) {
return true
}
return false
}
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
type edns0SubnetRule struct {
mode string
v4BitMaskLen uint8
v6BitMaskLen uint8
action string
revert bool
}
func newEdns0SubnetRule(mode, action, v4BitMaskLen, v6BitMaskLen string, revert bool) (*edns0SubnetRule, error) {
v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16)
if err != nil {
return nil, err
}
// validate V4 length
if v4Len > net.IPv4len*8 {
return nil, fmt.Errorf("invalid IPv4 bit mask length %d", v4Len)
}
v6Len, err := strconv.ParseUint(v6BitMaskLen, 0, 16)
if err != nil {
return nil, err
}
// validate V6 length
if v6Len > net.IPv6len*8 {
return nil, fmt.Errorf("invalid IPv6 bit mask length %d", v6Len)
}
return &edns0SubnetRule{mode: mode, action: action,
v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len), revert: revert}, nil
}
// fillEcsData sets the subnet data into the ecs option
func (rule *edns0SubnetRule) fillEcsData(state request.Request, ecs *dns.EDNS0_SUBNET) error {
family := state.Family()
if (family != 1) && (family != 2) {
return fmt.Errorf("unable to fill data for EDNS0 subnet due to invalid IP family")
}
ecs.Family = uint16(family)
ecs.SourceScope = 0
ipAddr := state.IP()
switch family {
case 1:
ipv4Mask := net.CIDRMask(int(rule.v4BitMaskLen), 32)
ipv4Addr := net.ParseIP(ipAddr)
ecs.SourceNetmask = rule.v4BitMaskLen
ecs.Address = ipv4Addr.Mask(ipv4Mask).To4()
case 2:
ipv6Mask := net.CIDRMask(int(rule.v6BitMaskLen), 128)
ipv6Addr := net.ParseIP(ipAddr)
ecs.SourceNetmask = rule.v6BitMaskLen
ecs.Address = ipv6Addr.Mask(ipv6Mask).To16()
}
return nil
}
// Rewrite will alter the request EDNS0 subnet option.
func (rule *edns0SubnetRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
o := setupEdns0Opt(state.Req)
if rule.action == Unset {
unsetEdns0Option(o, dns.EDNS0SUBNET)
return nil, RewriteDone
}
var resp ResponseRules
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_SUBNET); ok {
if rule.action == Replace || rule.action == Set {
if rule.revert {
old := *e
resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_SUBNET]{code: e.Code, source: &old})
}
if rule.fillEcsData(state, e) == nil {
return resp, RewriteDone
}
}
return nil, RewriteIgnored
}
}
// add option if not found
if rule.action == Append || rule.action == Set {
opt := &dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET}
if rule.fillEcsData(state, opt) == nil {
o.Option = append(o.Option, opt)
if rule.revert {
resp = append(resp, &edns0SetResponseRule{code: dns.EDNS0SUBNET})
}
return resp, RewriteDone
}
}
return nil, RewriteIgnored
}
// Mode returns the processing mode
func (rule *edns0SubnetRule) Mode() string { return rule.mode }
// These are all defined actions.
const (
Replace = "replace"
Set = "set"
Append = "append"
Unset = "unset"
)
// Supported local EDNS0 variables
const (
queryName = "{qname}"
queryType = "{qtype}"
clientIP = "{client_ip}"
clientPort = "{client_port}"
protocol = "{protocol}"
serverIP = "{server_ip}"
serverPort = "{server_port}"
)
//go:build gofuzz
package rewrite
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/pkg/fuzz"
)
// Fuzz fuzzes rewrite.
func Fuzz(data []byte) int {
c := caddy.NewTestController("dns", "rewrite edns0 subnet set 24 56")
rules, err := rewriteParse(c)
if err != nil {
return 0
}
r := Rewrite{Rules: rules}
return fuzz.Do(r, data)
}
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// stringRewriter rewrites a string
type stringRewriter interface {
rewriteString(src string) string
}
// regexStringRewriter can be used to rewrite strings by regex pattern.
// it contains all the information required to detect and execute a rewrite
// on a string.
type regexStringRewriter struct {
pattern *regexp.Regexp
replacement string
}
var _ stringRewriter = ®exStringRewriter{}
func newStringRewriter(pattern *regexp.Regexp, replacement string) stringRewriter {
return ®exStringRewriter{pattern, replacement}
}
func (r *regexStringRewriter) rewriteString(src string) string {
regexGroups := r.pattern.FindStringSubmatch(src)
if len(regexGroups) == 0 {
return src
}
s := r.replacement
for groupIndex, groupValue := range regexGroups {
groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}"
s = strings.ReplaceAll(s, groupIndexStr, groupValue)
}
return s
}
// remapStringRewriter maps a dedicated string to another string
// it also maps a the domain of a sub domain.
type remapStringRewriter struct {
orig string
replacement string
}
var _ stringRewriter = &remapStringRewriter{}
func newRemapStringRewriter(orig, replacement string) stringRewriter {
return &remapStringRewriter{orig, replacement}
}
func (r *remapStringRewriter) rewriteString(src string) string {
if src == r.orig {
return r.replacement
}
if strings.HasSuffix(src, "."+r.orig) {
return src[0:len(src)-len(r.orig)] + r.replacement
}
return src
}
// suffixStringRewriter maps a dedicated suffix string to another string
type suffixStringRewriter struct {
suffix string
replacement string
}
var _ stringRewriter = &suffixStringRewriter{}
func newSuffixStringRewriter(orig, replacement string) stringRewriter {
return &suffixStringRewriter{orig, replacement}
}
func (r *suffixStringRewriter) rewriteString(src string) string {
if before, ok := strings.CutSuffix(src, r.suffix); ok {
return before + r.replacement
}
return src
}
// nameRewriterResponseRule maps a record name according to a stringRewriter.
type nameRewriterResponseRule struct {
stringRewriter
}
func (r *nameRewriterResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
rr.Header().Name = r.rewriteString(rr.Header().Name)
}
// valueRewriterResponseRule maps a record value according to a stringRewriter.
type valueRewriterResponseRule struct {
stringRewriter
}
func (r *valueRewriterResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
value := getRecordValueForRewrite(rr)
if value != "" {
new := r.rewriteString(value)
if new != value {
setRewrittenRecordValue(rr, new)
}
}
}
const (
// ExactMatch matches only on exact match of the name in the question section of a request
ExactMatch = "exact"
// PrefixMatch matches when the name begins with the matching string
PrefixMatch = "prefix"
// SuffixMatch matches when the name ends with the matching string
SuffixMatch = "suffix"
// SubstringMatch matches on partial match of the name in the question section of a request
SubstringMatch = "substring"
// RegexMatch matches when the name in the question section of a request matches a regular expression
RegexMatch = "regex"
// AnswerMatch matches an answer rewrite
AnswerMatch = "answer"
// AutoMatch matches the auto name answer rewrite
AutoMatch = "auto"
// NameMatch matches the name answer rewrite
NameMatch = "name"
// ValueMatch matches the value answer rewrite
ValueMatch = "value"
)
type nameRuleBase struct {
nextAction string
auto bool
replacement string
static ResponseRules
}
func newNameRuleBase(nextAction string, auto bool, replacement string, staticResponses ResponseRules) nameRuleBase {
return nameRuleBase{
nextAction: nextAction,
auto: auto,
replacement: replacement,
static: staticResponses,
}
}
// responseRuleFor create for auto mode dynamically response rewriters for name and value
// reverting the mapping done by the name rewrite rule, which can be found in the state.
func (rule *nameRuleBase) responseRuleFor(state request.Request) (ResponseRules, Result) {
if !rule.auto {
return rule.static, RewriteDone
}
rewriter := newRemapStringRewriter(state.Req.Question[0].Name, state.Name())
rules := ResponseRules{
&nameRewriterResponseRule{rewriter},
&valueRewriterResponseRule{rewriter},
}
return append(rules, rule.static...), RewriteDone
}
// Mode returns the processing nextAction
func (rule *nameRuleBase) Mode() string { return rule.nextAction }
// exactNameRule rewrites the current request based upon exact match of the name
// in the question section of the request.
type exactNameRule struct {
nameRuleBase
from string
}
func newExactNameRule(nextAction string, orig, replacement string, answers ResponseRules) Rule {
return &exactNameRule{
newNameRuleBase(nextAction, true, replacement, answers),
orig,
}
}
func (rule *exactNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if rule.from == state.Name() {
state.Req.Question[0].Name = rule.replacement
return rule.responseRuleFor(state)
}
return nil, RewriteIgnored
}
// prefixNameRule rewrites the current request when the name begins with the matching string.
type prefixNameRule struct {
nameRuleBase
prefix string
}
func newPrefixNameRule(nextAction string, auto bool, prefix, replacement string, answers ResponseRules) Rule {
return &prefixNameRule{
newNameRuleBase(nextAction, auto, replacement, answers),
prefix,
}
}
func (rule *prefixNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if after, ok := strings.CutPrefix(state.Name(), rule.prefix); ok {
state.Req.Question[0].Name = rule.replacement + after
return rule.responseRuleFor(state)
}
return nil, RewriteIgnored
}
// suffixNameRule rewrites the current request when the name ends with the matching string.
type suffixNameRule struct {
nameRuleBase
suffix string
}
func newSuffixNameRule(nextAction string, auto bool, suffix, replacement string, answers ResponseRules) Rule {
var rules ResponseRules
if auto {
// for a suffix rewriter better standard response rewrites can be done
// just by using the original suffix/replacement in the opposite order
rewriter := newSuffixStringRewriter(replacement, suffix)
rules = ResponseRules{
&nameRewriterResponseRule{rewriter},
&valueRewriterResponseRule{rewriter},
}
}
return &suffixNameRule{
newNameRuleBase(nextAction, false, replacement, append(rules, answers...)),
suffix,
}
}
func (rule *suffixNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if before, ok := strings.CutSuffix(state.Name(), rule.suffix); ok {
state.Req.Question[0].Name = before + rule.replacement
return rule.responseRuleFor(state)
}
return nil, RewriteIgnored
}
// substringNameRule rewrites the current request based upon partial match of the
// name in the question section of the request.
type substringNameRule struct {
nameRuleBase
substring string
}
func newSubstringNameRule(nextAction string, auto bool, substring, replacement string, answers ResponseRules) Rule {
return &substringNameRule{
newNameRuleBase(nextAction, auto, replacement, answers),
substring,
}
}
func (rule *substringNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if strings.Contains(state.Name(), rule.substring) {
state.Req.Question[0].Name = strings.ReplaceAll(state.Name(), rule.substring, rule.replacement)
return rule.responseRuleFor(state)
}
return nil, RewriteIgnored
}
// regexNameRule rewrites the current request when the name in the question
// section of the request matches a regular expression.
type regexNameRule struct {
nameRuleBase
pattern *regexp.Regexp
}
func newRegexNameRule(nextAction string, auto bool, pattern *regexp.Regexp, replacement string, answers ResponseRules) Rule {
return ®exNameRule{
newNameRuleBase(nextAction, auto, replacement, answers),
pattern,
}
}
func (rule *regexNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
regexGroups := rule.pattern.FindStringSubmatch(state.Name())
if len(regexGroups) == 0 {
return nil, RewriteIgnored
}
s := rule.replacement
for groupIndex, groupValue := range regexGroups {
groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}"
s = strings.ReplaceAll(s, groupIndexStr, groupValue)
}
state.Req.Question[0].Name = s
return rule.responseRuleFor(state)
}
// newNameRule creates a name matching rule based on exact, partial, or regex match
func newNameRule(nextAction string, args ...string) (Rule, error) {
var matchType, rewriteQuestionFrom, rewriteQuestionTo string
if len(args) < 2 {
return nil, fmt.Errorf("too few arguments for a name rule")
}
if len(args) == 2 {
matchType = ExactMatch
rewriteQuestionFrom = plugin.Name(args[0]).Normalize()
rewriteQuestionTo = plugin.Name(args[1]).Normalize()
}
if len(args) >= 3 {
matchType = strings.ToLower(args[0])
if matchType == RegexMatch {
rewriteQuestionFrom = args[1]
rewriteQuestionTo = args[2]
} else {
rewriteQuestionFrom = plugin.Name(args[1]).Normalize()
rewriteQuestionTo = plugin.Name(args[2]).Normalize()
}
}
if matchType == ExactMatch || matchType == SuffixMatch {
if !hasClosingDot(rewriteQuestionFrom) {
rewriteQuestionFrom = rewriteQuestionFrom + "."
}
if !hasClosingDot(rewriteQuestionTo) {
rewriteQuestionTo = rewriteQuestionTo + "."
}
}
var err error
var answers ResponseRules
auto := false
if len(args) > 3 {
auto, answers, err = parseAnswerRules(matchType, args[3:])
if err != nil {
return nil, err
}
}
switch matchType {
case ExactMatch:
if _, err := isValidRegexPattern(rewriteQuestionTo, rewriteQuestionFrom); err != nil {
return nil, err
}
return newExactNameRule(nextAction, rewriteQuestionFrom, rewriteQuestionTo, answers), nil
case PrefixMatch:
return newPrefixNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil
case SuffixMatch:
return newSuffixNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil
case SubstringMatch:
return newSubstringNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil
case RegexMatch:
rewriteQuestionFromPattern, err := isValidRegexPattern(rewriteQuestionFrom, rewriteQuestionTo)
if err != nil {
return nil, err
}
rewriteQuestionTo := plugin.Name(args[2]).Normalize()
return newRegexNameRule(nextAction, auto, rewriteQuestionFromPattern, rewriteQuestionTo, answers), nil
default:
return nil, fmt.Errorf("name rule supports only exact, prefix, suffix, substring, and regex name matching, received: %s", matchType)
}
}
func parseAnswerRules(name string, args []string) (auto bool, rules ResponseRules, err error) {
auto = false
arg := 0
nameRules := 0
last := ""
if len(args) < 2 {
return false, nil, fmt.Errorf("invalid arguments for %s rule", name)
}
for arg < len(args) {
if last == "" && args[arg] != AnswerMatch {
if last == "" {
return false, nil, fmt.Errorf("exceeded the number of arguments for a non-answer rule argument for %s rule", name)
}
return false, nil, fmt.Errorf("exceeded the number of arguments for %s answer rule for %s rule", last, name)
}
if args[arg] == AnswerMatch {
arg++
}
if len(args)-arg == 0 {
return false, nil, fmt.Errorf("type missing for answer rule for %s rule", name)
}
last = args[arg]
arg++
switch last {
case AutoMatch:
auto = true
continue
case NameMatch:
if len(args)-arg < 2 {
return false, nil, fmt.Errorf("%s answer rule for %s rule: 2 arguments required", last, name)
}
rewriteAnswerFrom := args[arg]
rewriteAnswerTo := args[arg+1]
rewriteAnswerFromPattern, err := isValidRegexPattern(rewriteAnswerFrom, rewriteAnswerTo)
rewriteAnswerTo = plugin.Name(rewriteAnswerTo).Normalize()
if err != nil {
return false, nil, fmt.Errorf("%s answer rule for %s rule: %s", last, name, err)
}
rules = append(rules, &nameRewriterResponseRule{newStringRewriter(rewriteAnswerFromPattern, rewriteAnswerTo)})
arg += 2
nameRules++
case ValueMatch:
if len(args)-arg < 2 {
return false, nil, fmt.Errorf("%s answer rule for %s rule: 2 arguments required", last, name)
}
rewriteAnswerFrom := args[arg]
rewriteAnswerTo := args[arg+1]
rewriteAnswerFromPattern, err := isValidRegexPattern(rewriteAnswerFrom, rewriteAnswerTo)
rewriteAnswerTo = plugin.Name(rewriteAnswerTo).Normalize()
if err != nil {
return false, nil, fmt.Errorf("%s answer rule for %s rule: %s", last, name, err)
}
rules = append(rules, &valueRewriterResponseRule{newStringRewriter(rewriteAnswerFromPattern, rewriteAnswerTo)})
arg += 2
default:
return false, nil, fmt.Errorf("invalid type %q for answer rule for %s rule", last, name)
}
}
if auto && nameRules > 0 {
return false, nil, fmt.Errorf("auto name answer rule cannot be combined with explicit name anwer rules")
}
return auto, rules, nil
}
// hasClosingDot returns true if s has a closing dot at the end.
func hasClosingDot(s string) bool {
return strings.HasSuffix(s, ".")
}
// getSubExprUsage returns the number of subexpressions used in s.
func getSubExprUsage(s string) int {
subExprUsage := 0
for i := range 101 {
if strings.Contains(s, "{"+strconv.Itoa(i)+"}") {
subExprUsage++
}
}
return subExprUsage
}
// isValidRegexPattern returns a regular expression for pattern matching or errors, if any.
func isValidRegexPattern(rewriteFrom, rewriteTo string) (*regexp.Regexp, error) {
rewriteFromPattern, err := regexp.Compile(rewriteFrom)
if err != nil {
return nil, fmt.Errorf("invalid regex matching pattern: %s", rewriteFrom)
}
if getSubExprUsage(rewriteTo) > rewriteFromPattern.NumSubexp() {
return nil, fmt.Errorf("the rewrite regex pattern (%s) uses more subexpressions than its corresponding matching regex pattern (%s)", rewriteTo, rewriteFrom)
}
return rewriteFromPattern, nil
}
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type rcodeResponseRule struct {
old int
new int
}
func (r *rcodeResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
if r.old == res.Rcode {
res.Rcode = r.new
}
}
type rcodeRuleBase struct {
nextAction string
response rcodeResponseRule
}
func newRCodeRuleBase(nextAction string, old, new int) rcodeRuleBase {
return rcodeRuleBase{
nextAction: nextAction,
response: rcodeResponseRule{old: old, new: new},
}
}
func (rule *rcodeRuleBase) responseRule(match bool) (ResponseRules, Result) {
if match {
return ResponseRules{&rule.response}, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing nextAction
func (rule *rcodeRuleBase) Mode() string { return rule.nextAction }
type exactRCodeRule struct {
rcodeRuleBase
From string
}
type prefixRCodeRule struct {
rcodeRuleBase
Prefix string
}
type suffixRCodeRule struct {
rcodeRuleBase
Suffix string
}
type substringRCodeRule struct {
rcodeRuleBase
Substring string
}
type regexRCodeRule struct {
rcodeRuleBase
Pattern *regexp.Regexp
}
// Rewrite rewrites the current request based upon exact match of the name
// in the question section of the request.
func (rule *exactRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(rule.From == state.Name())
}
// Rewrite rewrites the current request when the name begins with the matching string.
func (rule *prefixRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.HasPrefix(state.Name(), rule.Prefix))
}
// Rewrite rewrites the current request when the name ends with the matching string.
func (rule *suffixRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.HasSuffix(state.Name(), rule.Suffix))
}
// Rewrite rewrites the current request based upon partial match of the
// name in the question section of the request.
func (rule *substringRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.Contains(state.Name(), rule.Substring))
}
// Rewrite rewrites the current request when the name in the question
// section of the request matches a regular expression.
func (rule *regexRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(len(rule.Pattern.FindStringSubmatch(state.Name())) != 0)
}
// newRCodeRule creates a name matching rule based on exact, partial, or regex match
func newRCodeRule(nextAction string, args ...string) (Rule, error) {
if len(args) < 3 {
return nil, fmt.Errorf("too few (%d) arguments for a rcode rule", len(args))
}
var oldStr, newStr string
if len(args) == 3 {
oldStr, newStr = args[1], args[2]
}
if len(args) == 4 {
oldStr, newStr = args[2], args[3]
}
old, valid := isValidRCode(oldStr)
if !valid {
return nil, fmt.Errorf("invalid matching RCODE '%s' for a rcode rule", oldStr)
}
new, valid := isValidRCode(newStr)
if !valid {
return nil, fmt.Errorf("invalid replacement RCODE '%s' for a rcode rule", newStr)
}
if len(args) == 4 {
switch strings.ToLower(args[0]) {
case ExactMatch:
return &exactRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[1]).Normalize(),
}, nil
case PrefixMatch:
return &prefixRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[1]).Normalize(),
}, nil
case SuffixMatch:
return &suffixRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[1]).Normalize(),
}, nil
case SubstringMatch:
return &substringRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[1]).Normalize(),
}, nil
case RegexMatch:
regexPattern, err := regexp.Compile(args[1])
if err != nil {
return nil, fmt.Errorf("invalid regex pattern in a rcode rule: %s", args[1])
}
return ®exRCodeRule{
newRCodeRuleBase(nextAction, old, new),
regexPattern,
}, nil
default:
return nil, fmt.Errorf("rcode rule supports only exact, prefix, suffix, substring, and regex name matching")
}
}
if len(args) > 4 {
return nil, fmt.Errorf("many few arguments for a rcode rule")
}
return &exactRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[0]).Normalize(),
}, nil
}
// validRCode returns true if v is valid RCode value.
func isValidRCode(v string) (int, bool) {
i, err := strconv.ParseUint(v, 10, 32)
// try parsing integer based rcode
if err == nil && i <= 23 {
return int(i), true
}
if RCodeInt, ok := dns.StringToRcode[strings.ToUpper(v)]; ok {
return RCodeInt, true
}
return 0, false
}
package rewrite
import (
"github.com/miekg/dns"
)
// RevertPolicy controls the overall reverting process
type RevertPolicy interface {
DoRevert() bool
DoQuestionRestore() bool
}
type revertPolicy struct {
noRevert bool
noRestore bool
}
func (p revertPolicy) DoRevert() bool {
return !p.noRevert
}
func (p revertPolicy) DoQuestionRestore() bool {
return !p.noRestore
}
// NoRevertPolicy disables all response rewrite rules
func NoRevertPolicy() RevertPolicy {
return revertPolicy{true, false}
}
// NoRestorePolicy disables the question restoration during the response rewrite
func NoRestorePolicy() RevertPolicy {
return revertPolicy{false, true}
}
// NewRevertPolicy creates a new reverter policy by dynamically specifying all
// options.
func NewRevertPolicy(noRevert, noRestore bool) RevertPolicy {
return revertPolicy{noRestore: noRestore, noRevert: noRevert}
}
// ResponseRule contains a rule to rewrite a response with.
type ResponseRule interface {
RewriteResponse(res *dns.Msg, rr dns.RR)
}
// ResponseRules describes an ordered list of response rules to apply
// after a name rewrite
type ResponseRules = []ResponseRule
// ResponseReverter reverses the operations done on the question section of a packet.
// This is need because the client will otherwise disregards the response, i.e.
// dig will complain with ';; Question section mismatch: got example.org/HINFO/IN'
type ResponseReverter struct {
dns.ResponseWriter
originalQuestion dns.Question
ResponseRules ResponseRules
revertPolicy RevertPolicy
}
// NewResponseReverter returns a pointer to a new ResponseReverter.
func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg, policy RevertPolicy) *ResponseReverter {
return &ResponseReverter{
ResponseWriter: w,
originalQuestion: r.Question[0],
revertPolicy: policy,
}
}
// WriteMsg records the status code and calls the underlying ResponseWriter's WriteMsg method.
func (r *ResponseReverter) WriteMsg(res1 *dns.Msg) error {
// Deep copy 'res' as to not (e.g). rewrite a message that's also stored in the cache.
res := res1.Copy()
if r.revertPolicy.DoQuestionRestore() {
res.Question[0] = r.originalQuestion
}
if len(r.ResponseRules) > 0 {
for _, rr := range res.Ns {
r.rewriteResourceRecord(res, rr)
}
for _, rr := range res.Answer {
r.rewriteResourceRecord(res, rr)
}
for _, rr := range res.Extra {
r.rewriteResourceRecord(res, rr)
}
}
return r.ResponseWriter.WriteMsg(res)
}
func (r *ResponseReverter) rewriteResourceRecord(res *dns.Msg, rr dns.RR) {
// The reverting rules need to be done in reversed order.
for i := len(r.ResponseRules) - 1; i >= 0; i-- {
r.ResponseRules[i].RewriteResponse(res, rr)
}
}
// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseReverter) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
return n, err
}
func getRecordValueForRewrite(rr dns.RR) (name string) {
switch rr.Header().Rrtype {
case dns.TypeSRV:
return rr.(*dns.SRV).Target
case dns.TypeMX:
return rr.(*dns.MX).Mx
case dns.TypeCNAME:
return rr.(*dns.CNAME).Target
case dns.TypeNS:
return rr.(*dns.NS).Ns
case dns.TypeDNAME:
return rr.(*dns.DNAME).Target
case dns.TypeNAPTR:
return rr.(*dns.NAPTR).Replacement
case dns.TypeSOA:
return rr.(*dns.SOA).Ns
case dns.TypePTR:
return rr.(*dns.PTR).Ptr
default:
return ""
}
}
func setRewrittenRecordValue(rr dns.RR, value string) {
switch rr.Header().Rrtype {
case dns.TypeSRV:
rr.(*dns.SRV).Target = value
case dns.TypeMX:
rr.(*dns.MX).Mx = value
case dns.TypeCNAME:
rr.(*dns.CNAME).Target = value
case dns.TypeNS:
rr.(*dns.NS).Ns = value
case dns.TypeDNAME:
rr.(*dns.DNAME).Target = value
case dns.TypeNAPTR:
rr.(*dns.NAPTR).Replacement = value
case dns.TypeSOA:
rr.(*dns.SOA).Ns = value
case dns.TypePTR:
rr.(*dns.PTR).Ptr = value
}
}
package rewrite
import (
"context"
"fmt"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Result is the result of a rewrite
type Result int
const (
// RewriteIgnored is returned when rewrite is not done on request.
RewriteIgnored Result = iota
// RewriteDone is returned when rewrite is done on request.
RewriteDone
)
// These are defined processing mode.
const (
// Stop processing should stop after completing this rule
Stop = "stop"
// Continue processing should continue to next rule
Continue = "continue"
)
// Rewrite is a plugin to rewrite requests internally before being handled.
type Rewrite struct {
Next plugin.Handler
Rules []Rule
RevertPolicy
}
// ServeDNS implements the plugin.Handler interface.
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if rw.RevertPolicy == nil {
rw.RevertPolicy = NewRevertPolicy(false, false)
}
wr := NewResponseReverter(w, r, rw.RevertPolicy)
state := request.Request{W: w, Req: r}
for _, rule := range rw.Rules {
respRules, result := rule.Rewrite(ctx, state)
if result == RewriteDone {
if _, ok := dns.IsDomainName(state.Req.Question[0].Name); !ok {
err := fmt.Errorf("invalid name after rewrite: %s", state.Req.Question[0].Name)
state.Req.Question[0] = wr.originalQuestion
return dns.RcodeServerFailure, err
}
wr.ResponseRules = append(wr.ResponseRules, respRules...)
if rule.Mode() == Stop {
if !rw.DoRevert() {
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
}
rcode, err := plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
if plugin.ClientWrite(rcode) {
return rcode, err
}
// The next plugins didn't write a response, so write one now with the ResponseReverter.
// If server.ServeDNS does this then it will create an answer mismatch.
res := new(dns.Msg).SetRcode(r, rcode)
state.SizeAndDo(res)
wr.WriteMsg(res)
// return success, so server does not write a second error response to client
return dns.RcodeSuccess, err
}
}
}
if !rw.DoRevert() || len(wr.ResponseRules) == 0 {
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
}
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
}
// Name implements the Handler interface.
func (rw Rewrite) Name() string { return "rewrite" }
// Rule describes a rewrite rule.
type Rule interface {
// Rewrite rewrites the current request.
Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result)
// Mode returns the processing mode stop or continue.
Mode() string
}
func newRule(args ...string) (Rule, error) {
if len(args) == 0 {
return nil, fmt.Errorf("no rule type specified for rewrite")
}
arg0 := strings.ToLower(args[0])
var ruleType string
var expectNumArgs, startArg int
mode := Stop
switch arg0 {
case Continue:
mode = Continue
if len(args) < 2 {
return nil, fmt.Errorf("continue rule must begin with a rule type")
}
ruleType = strings.ToLower(args[1])
expectNumArgs = len(args) - 1
startArg = 2
case Stop:
if len(args) < 2 {
return nil, fmt.Errorf("stop rule must begin with a rule type")
}
ruleType = strings.ToLower(args[1])
expectNumArgs = len(args) - 1
startArg = 2
default:
// for backward compatibility
ruleType = arg0
expectNumArgs = len(args)
startArg = 1
}
switch ruleType {
case "answer":
return nil, fmt.Errorf("response rewrites must begin with a name rule")
case "name":
return newNameRule(mode, args[startArg:]...)
case "class":
if expectNumArgs != 3 {
return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
}
return newClassRule(mode, args[startArg:]...)
case "type":
if expectNumArgs != 3 {
return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
}
return newTypeRule(mode, args[startArg:]...)
case "edns0":
return newEdns0Rule(mode, args[startArg:]...)
case "ttl":
return newTTLRule(mode, args[startArg:]...)
case "cname":
return newCNAMERule(mode, args[startArg:]...)
case "rcode":
return newRCodeRule(mode, args[startArg:]...)
default:
return nil, fmt.Errorf("invalid rule type %q", args[0])
}
}
package rewrite
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("rewrite", setup) }
func setup(c *caddy.Controller) error {
rewrites, err := rewriteParse(c)
if err != nil {
return plugin.Error("rewrite", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Rewrite{Next: next, Rules: rewrites}
})
return nil
}
func rewriteParse(c *caddy.Controller) ([]Rule, error) {
var rules []Rule
for c.Next() {
args := c.RemainingArgs()
if len(args) < 2 {
// Handles rules out of nested instructions, i.e. the ones enclosed in curly brackets
for c.NextBlock() {
args = append(args, c.Val())
}
}
rule, err := newRule(args...)
if err != nil {
return nil, err
}
rules = append(rules, rule)
}
return rules, nil
}
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type ttlResponseRule struct {
minTTL uint32
maxTTL uint32
}
func (r *ttlResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
if rr.Header().Ttl < r.minTTL {
rr.Header().Ttl = r.minTTL
} else if rr.Header().Ttl > r.maxTTL {
rr.Header().Ttl = r.maxTTL
}
}
type ttlRuleBase struct {
nextAction string
response ttlResponseRule
}
func newTTLRuleBase(nextAction string, minTtl, maxTtl uint32) ttlRuleBase {
return ttlRuleBase{
nextAction: nextAction,
response: ttlResponseRule{minTTL: minTtl, maxTTL: maxTtl},
}
}
func (rule *ttlRuleBase) responseRule(match bool) (ResponseRules, Result) {
if match {
return ResponseRules{&rule.response}, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing nextAction
func (rule *ttlRuleBase) Mode() string { return rule.nextAction }
type exactTTLRule struct {
ttlRuleBase
From string
}
type prefixTTLRule struct {
ttlRuleBase
Prefix string
}
type suffixTTLRule struct {
ttlRuleBase
Suffix string
}
type substringTTLRule struct {
ttlRuleBase
Substring string
}
type regexTTLRule struct {
ttlRuleBase
Pattern *regexp.Regexp
}
// Rewrite rewrites the current request based upon exact match of the name
// in the question section of the request.
func (rule *exactTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(rule.From == state.Name())
}
// Rewrite rewrites the current request when the name begins with the matching string.
func (rule *prefixTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.HasPrefix(state.Name(), rule.Prefix))
}
// Rewrite rewrites the current request when the name ends with the matching string.
func (rule *suffixTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.HasSuffix(state.Name(), rule.Suffix))
}
// Rewrite rewrites the current request based upon partial match of the
// name in the question section of the request.
func (rule *substringTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.Contains(state.Name(), rule.Substring))
}
// Rewrite rewrites the current request when the name in the question
// section of the request matches a regular expression.
func (rule *regexTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(len(rule.Pattern.FindStringSubmatch(state.Name())) != 0)
}
// newTTLRule creates a name matching rule based on exact, partial, or regex match
func newTTLRule(nextAction string, args ...string) (Rule, error) {
if len(args) < 2 {
return nil, fmt.Errorf("too few (%d) arguments for a ttl rule", len(args))
}
var s string
if len(args) == 2 {
s = args[1]
}
if len(args) == 3 {
s = args[2]
}
minTtl, maxTtl, valid := isValidTTL(s)
if !valid {
return nil, fmt.Errorf("invalid TTL '%s' for a ttl rule", s)
}
if len(args) == 3 {
switch strings.ToLower(args[0]) {
case ExactMatch:
return &exactTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case PrefixMatch:
return &prefixTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case SuffixMatch:
return &suffixTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case SubstringMatch:
return &substringTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case RegexMatch:
regexPattern, err := regexp.Compile(args[1])
if err != nil {
return nil, fmt.Errorf("invalid regex pattern in a ttl rule: %s", args[1])
}
return ®exTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
regexPattern,
}, nil
default:
return nil, fmt.Errorf("ttl rule supports only exact, prefix, suffix, substring, and regex name matching")
}
}
if len(args) > 3 {
return nil, fmt.Errorf("many few arguments for a ttl rule")
}
return &exactTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[0]).Normalize(),
}, nil
}
// validTTL returns true if v is valid TTL value.
func isValidTTL(v string) (uint32, uint32, bool) {
s := strings.Split(v, "-")
if len(s) == 1 {
i, err := strconv.ParseUint(s[0], 10, 32)
if err != nil {
return 0, 0, false
}
return uint32(i), uint32(i), true
}
if len(s) == 2 {
var min, max uint64
var err error
if s[0] == "" {
min = 0
} else {
min, err = strconv.ParseUint(s[0], 10, 32)
if err != nil {
return 0, 0, false
}
}
if s[1] == "" {
if s[0] == "" {
// explicitly reject ttl directive "-" that would otherwise be interpreted
// as 0-2147483647 which is pretty useless
return 0, 0, false
}
max = 2147483647
} else {
max, err = strconv.ParseUint(s[1], 10, 32)
if err != nil {
return 0, 0, false
}
}
if min > max {
// reject invalid range
return 0, 0, false
}
return uint32(min), uint32(max), true
}
return 0, 0, false
}
package rewrite
import (
"context"
"fmt"
"strings"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// typeRule is a type rewrite rule.
type typeRule struct {
fromType uint16
toType uint16
nextAction string
}
func newTypeRule(nextAction string, args ...string) (Rule, error) {
var from, to uint16
var ok bool
if from, ok = dns.StringToType[strings.ToUpper(args[0])]; !ok {
return nil, fmt.Errorf("invalid type %q", strings.ToUpper(args[0]))
}
if to, ok = dns.StringToType[strings.ToUpper(args[1])]; !ok {
return nil, fmt.Errorf("invalid type %q", strings.ToUpper(args[1]))
}
return &typeRule{from, to, nextAction}, nil
}
// Rewrite rewrites the current request.
func (rule *typeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if rule.fromType > 0 && rule.toType > 0 {
if state.QType() == rule.fromType {
state.Req.Question[0].Qtype = rule.toType
return nil, RewriteDone
}
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *typeRule) Mode() string { return rule.nextAction }
package rewrite
import (
"encoding/binary"
"fmt"
"net"
"strconv"
)
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
func ipToWire(family int, ipAddr string) ([]byte, error) {
switch family {
case 1:
return net.ParseIP(ipAddr).To4(), nil
case 2:
return net.ParseIP(ipAddr).To16(), nil
}
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
}
// uint16ToWire writes unit16 to wire/binary format
func uint16ToWire(data uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, data)
return buf
}
// portToWire writes port to wire/binary format, 2 bytes
func portToWire(portStr string) ([]byte, error) {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
return uint16ToWire(uint16(port)), nil
}
package root
import (
"os"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
)
var log = clog.NewWithPlugin("root")
func init() { plugin.Register("root", setup) }
func setup(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
for c.Next() {
if !c.NextArg() {
return plugin.Error("root", c.ArgErr())
}
config.Root = c.Val()
}
// Check if root path exists
_, err := os.Stat(config.Root)
if err != nil {
if !os.IsNotExist(err) {
return plugin.Error("root", c.Errf("unable to access root path '%s': %v", config.Root, err))
}
// Allow this, because the folder might appear later.
// But make sure the user knows!
log.Warningf("Root path does not exist: %s", config.Root)
}
return nil
}
// Package route53 implements a plugin that returns resource records
// from AWS route53.
package route53
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/file"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/route53"
"github.com/aws/aws-sdk-go-v2/service/route53/types"
"github.com/miekg/dns"
)
// Route53 is a plugin that returns RR from AWS route53.
type Route53 struct {
Next plugin.Handler
Fall fall.F
zoneNames []string
client route53Client
upstream *upstream.Upstream
refresh time.Duration
zMu sync.RWMutex
zones zones
}
type zone struct {
id string
z *file.Zone
dns string
}
type zones map[string][]*zone
// New reads from the keys map which uses domain names as its key and hosted
// zone id lists as its values, validates that each domain name/zone id pair
// does exist, and returns a new *Route53. In addition to this, upstream is use
// for doing recursive queries against CNAMEs. Returns error if it cannot
// verify any given domain name/zone id pair.
func New(ctx context.Context, c route53Client, keys map[string][]string, refresh time.Duration) (*Route53, error) {
zones := make(map[string][]*zone, len(keys))
zoneNames := make([]string, 0, len(keys))
for dns, hostedZoneIDs := range keys {
for _, hostedZoneID := range hostedZoneIDs {
_, err := c.ListHostedZonesByName(ctx, &route53.ListHostedZonesByNameInput{
DNSName: aws.String(dns),
HostedZoneId: aws.String(hostedZoneID),
})
if err != nil {
return nil, err
}
if _, ok := zones[dns]; !ok {
zoneNames = append(zoneNames, dns)
}
zones[dns] = append(zones[dns], &zone{id: hostedZoneID, dns: dns, z: file.NewZone(dns, "")})
}
}
return &Route53{
client: c,
zoneNames: zoneNames,
zones: zones,
upstream: upstream.New(),
refresh: refresh,
}, nil
}
// Run executes first update, spins up an update forever-loop.
// Returns error if first update fails.
func (h *Route53) Run(ctx context.Context) error {
if err := h.updateZones(ctx); err != nil {
return err
}
go func() {
timer := time.NewTimer(h.refresh)
defer timer.Stop()
for {
timer.Reset(h.refresh)
select {
case <-ctx.Done():
log.Debugf("Breaking out of Route53 update loop for %v: %v", h.zoneNames, ctx.Err())
return
case <-timer.C:
if err := h.updateZones(ctx); err != nil && ctx.Err() == nil /* Don't log error if ctx expired. */ {
log.Errorf("Failed to update zones %v: %v", h.zoneNames, err)
}
}
}
}()
return nil
}
// ServeDNS implements the plugin.Handler.ServeDNS.
func (h *Route53) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.Name()
zName := plugin.Zones(h.zoneNames).Matches(qname)
if zName == "" {
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
z, ok := h.zones[zName]
if !ok || z == nil {
return dns.RcodeServerFailure, nil
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
var result file.Result
for _, hostedZone := range z {
h.zMu.RLock()
m.Answer, m.Ns, m.Extra, result = hostedZone.z.Lookup(ctx, state, qname)
h.zMu.RUnlock()
// Take the answer if it's non-empty OR if there is another
// record type exists for this name (NODATA).
if len(m.Answer) != 0 || result == file.NoData {
break
}
}
if len(m.Answer) == 0 && result != file.NoData && h.Fall.Through(qname) {
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
switch result {
case file.Success:
case file.NoData:
case file.NameError:
m.Rcode = dns.RcodeNameError
case file.Delegation:
m.Authoritative = false
case file.ServerFailure:
return dns.RcodeServerFailure, nil
}
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
const escapeSeq = "\\"
// maybeUnescape parses s and converts escaped ASCII codepoints (in octal) back
// to its ASCII representation.
//
// From AWS docs:
//
// "If the domain name includes any characters other than a to z, 0 to 9, -
// (hyphen), or _ (underscore), Route 53 API actions return the characters as
// escape codes."
//
// For our purposes (and with respect to RFC 1035), we'll fish for a-z, 0-9,
// '-', '.' and '*' as the leftmost character (for wildcards) and throw error
// for everything else.
//
// Example:
//
// `\\052.example.com.` -> `*.example.com`
// `\\137.example.com.` -> error ('_' is not valid)
func maybeUnescape(s string) (string, error) {
var outSb strings.Builder
for {
i := strings.Index(s, escapeSeq)
if i < 0 {
return outSb.String() + s, nil
}
outSb.WriteString(s[:i])
li, ri := i+len(escapeSeq), i+len(escapeSeq)+3
if ri > len(s) {
return "", fmt.Errorf("invalid escape sequence: '%s%s'", escapeSeq, s[li:])
}
// Parse `\\xxx` in base 8 (2nd arg) and attempt to fit into
// 8-bit result (3rd arg).
n, err := strconv.ParseInt(s[li:ri], 8, 8)
if err != nil {
return "", fmt.Errorf("invalid escape sequence: '%s%s'", escapeSeq, s[li:ri])
}
r := rune(n)
switch {
case r >= rune('a') && r <= rune('z'): // Route53 converts everything to lowercase.
case r >= rune('0') && r <= rune('9'):
case r == rune('*'):
if outSb.Len() != 0 {
return "", errors.New("`*' only supported as wildcard (leftmost label)")
}
case r == rune('-'):
case r == rune('.'):
default:
return "", fmt.Errorf("invalid character: %s%#03o", escapeSeq, r)
}
outSb.WriteString(string(r))
s = s[i+len(escapeSeq)+3:]
}
}
func updateZoneFromRRS(rrs *types.ResourceRecordSet, z *file.Zone) error {
for _, rr := range rrs.ResourceRecords {
n, err := maybeUnescape(aws.ToString(rrs.Name))
if err != nil {
return fmt.Errorf("failed to unescape `%s' name: %v", aws.ToString(rrs.Name), err)
}
v, err := maybeUnescape(aws.ToString(rr.Value))
if err != nil {
return fmt.Errorf("failed to unescape `%s' value: %v", aws.ToString(rr.Value), err)
}
// Assemble RFC 1035 conforming record to pass into dns scanner.
rfc1035 := fmt.Sprintf("%s %d IN %s %s", n, aws.ToInt64(rrs.TTL), rrs.Type, v)
r, err := dns.NewRR(rfc1035)
if err != nil {
return fmt.Errorf("failed to parse resource record: %v", err)
}
z.Insert(r)
}
return nil
}
// updateZones re-queries resource record sets for each zone and updates the
// zone object.
// Returns error if any zones error'ed out, but waits for other zones to
// complete first.
func (h *Route53) updateZones(ctx context.Context) error {
errc := make(chan error)
defer close(errc)
for zName, z := range h.zones {
go func(zName string, z []*zone) {
var err error
defer func() {
errc <- err
}()
for i, hostedZone := range z {
newZ := file.NewZone(zName, "")
newZ.Upstream = h.upstream
in := &route53.ListResourceRecordSetsInput{
HostedZoneId: aws.String(hostedZone.id),
MaxItems: aws.Int32(1000),
}
complete := false
var out *route53.ListResourceRecordSetsOutput
for out, err = h.client.ListResourceRecordSets(ctx, in); !complete; out, err = h.client.ListResourceRecordSets(ctx, in) {
if err != nil {
err = fmt.Errorf("failed to list resource records for %v:%v from route53: %v", zName, hostedZone.id, err)
return
}
for _, rrs := range out.ResourceRecordSets {
if err := updateZoneFromRRS(&rrs, newZ); err != nil {
// Maybe unsupported record type. Log and carry on.
log.Warningf("Failed to process resource record set: %v", err)
}
}
if out.IsTruncated {
in.StartRecordName = out.NextRecordName
in.StartRecordType = out.NextRecordType
in.StartRecordIdentifier = out.NextRecordIdentifier
} else {
complete = true
}
}
h.zMu.Lock()
(*z[i]).z = newZ
h.zMu.Unlock()
}
}(zName, z)
}
// Collect errors (if any). This will also sync on all zones updates
// completion.
var errs []string
for range len(h.zones) {
err := <-errc
if err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) != 0 {
return fmt.Errorf("errors updating zones: %v", errs)
}
return nil
}
// Name implements plugin.Handler.Name.
func (h *Route53) Name() string { return "route53" }
package route53
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/fall"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/route53"
)
var log = clog.NewWithPlugin("route53")
func init() { plugin.Register("route53", setup) }
// exposed for testing
type route53Client interface {
ActivateKeySigningKey(ctx context.Context, params *route53.ActivateKeySigningKeyInput, optFns ...func(*route53.Options)) (*route53.ActivateKeySigningKeyOutput, error)
ListHostedZonesByName(ctx context.Context, params *route53.ListHostedZonesByNameInput, optFns ...func(*route53.Options)) (*route53.ListHostedZonesByNameOutput, error)
ListResourceRecordSets(ctx context.Context, params *route53.ListResourceRecordSetsInput, optFns ...func(*route53.Options)) (*route53.ListResourceRecordSetsOutput, error)
}
var f = func(ctx context.Context, cfgOpts []func(*config.LoadOptions) error, clientOpts []func(*route53.Options)) (route53Client, error) {
cfg, err := config.LoadDefaultConfig(ctx, cfgOpts...)
if err != nil {
return nil, err
}
// If no region is specified, retrieve one from IMDS (SDK v1 used the AWS global partition as a fallback, v2 doesn't)
if cfg.Region == "" {
imdsClient := imds.NewFromConfig(cfg)
region, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{})
if err != nil {
return nil, fmt.Errorf("failed to get region from IMDS: %w", err)
}
cfg.Region = region.Region
}
return route53.NewFromConfig(cfg, clientOpts...), nil
}
func setup(c *caddy.Controller) error {
for c.Next() {
keyPairs := map[string]struct{}{}
keys := map[string][]string{}
// Route53 plugin attempts to load AWS credentials following default SDK chaining.
// The order configuration is loaded in is:
// * Static AWS keys set in Corefile (deprecated)
// * Environment Variables
// * Shared Credentials file
// * Shared Configuration file (if AWS_SDK_LOAD_CONFIG is set to truthy value)
// * EC2 Instance Metadata (credentials only)
cfgOpts := []func(*config.LoadOptions) error{}
clientOpts := []func(*route53.Options){}
var fall fall.F
refresh := time.Duration(1) * time.Minute // default update frequency to 1 minute
args := c.RemainingArgs()
for i := range args {
parts := strings.SplitN(args[i], ":", 2)
if len(parts) != 2 {
return plugin.Error("route53", c.Errf("invalid zone %q", args[i]))
}
dns, hostedZoneID := parts[0], parts[1]
if dns == "" || hostedZoneID == "" {
return plugin.Error("route53", c.Errf("invalid zone %q", args[i]))
}
if _, ok := keyPairs[args[i]]; ok {
return plugin.Error("route53", c.Errf("conflict zone %q", args[i]))
}
keyPairs[args[i]] = struct{}{}
keys[dns] = append(keys[dns], hostedZoneID)
}
for c.NextBlock() {
switch c.Val() {
case "aws_access_key":
v := c.RemainingArgs()
if len(v) < 2 {
return plugin.Error("route53", c.Errf("invalid access key: '%v'", v))
}
cfgOpts = append(cfgOpts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(v[0], v[1], "")))
log.Warningf("Save aws_access_key in Corefile has been deprecated, please use other authentication methods instead")
case "aws_endpoint":
if c.NextArg() {
clientOpts = append(clientOpts, func(o *route53.Options) {
o.BaseEndpoint = aws.String(c.Val())
})
} else {
return plugin.Error("route53", c.ArgErr())
}
case "upstream":
c.RemainingArgs() // eats args
case "credentials":
if c.NextArg() {
cfgOpts = append(cfgOpts, config.WithSharedConfigProfile(c.Val()))
} else {
return c.ArgErr()
}
if c.NextArg() {
sharedConfigFiles := []string{c.Val()}
// If AWS_SDK_LOAD_CONFIG is set also load ~/.aws/config to stay consistent
// with default SDK behavior.
if ok, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG")); ok {
sharedConfigFiles = append(sharedConfigFiles, config.DefaultSharedConfigFilename())
}
cfgOpts = append(cfgOpts, config.WithSharedConfigFiles(sharedConfigFiles))
}
case "fallthrough":
fall.SetZonesFromArgs(c.RemainingArgs())
case "refresh":
if c.NextArg() {
refreshStr := c.Val()
_, err := strconv.Atoi(refreshStr)
if err == nil {
refreshStr = c.Val() + "s"
}
refresh, err = time.ParseDuration(refreshStr)
if err != nil {
return plugin.Error("route53", c.Errf("Unable to parse duration: %v", err))
}
if refresh <= 0 {
return plugin.Error("route53", c.Errf("refresh interval must be greater than 0: %q", refreshStr))
}
} else {
return plugin.Error("route53", c.ArgErr())
}
default:
return plugin.Error("route53", c.Errf("unknown property %q", c.Val()))
}
}
ctx, cancel := context.WithCancel(context.Background())
client, err := f(ctx, cfgOpts, clientOpts)
if err != nil {
cancel()
return plugin.Error("route53", c.Errf("failed to create route53 client: %v", err))
}
h, err := New(ctx, client, keys, refresh)
if err != nil {
cancel()
return plugin.Error("route53", c.Errf("failed to create route53 plugin: %v", err))
}
h.Fall = fall
if err := h.Run(ctx); err != nil {
cancel()
return plugin.Error("route53", c.Errf("failed to initialize route53 plugin: %v", err))
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
h.Next = next
return h
})
c.OnShutdown(func() error { cancel(); return nil })
}
return nil
}
// Package secondary implements a secondary plugin.
package secondary
import "github.com/coredns/coredns/plugin/file"
// Secondary implements a secondary plugin that allows CoreDNS to retrieve (via AXFR)
// zone information from a primary server.
type Secondary struct {
file.File
}
// Name implements the Handler interface.
func (s Secondary) Name() string { return "secondary" }
package secondary
import (
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/file"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/coredns/coredns/plugin/pkg/upstream"
)
var log = clog.NewWithPlugin("secondary")
func init() { plugin.Register("secondary", setup) }
func setup(c *caddy.Controller) error {
zones, err := secondaryParse(c)
if err != nil {
return plugin.Error("secondary", err)
}
// Add startup functions to retrieve the zone and keep it up to date.
for i := range zones.Names {
n := zones.Names[i]
z := zones.Z[n]
if len(z.TransferFrom) > 0 {
c.OnStartup(func() error {
z.StartupOnce.Do(func() {
go func() {
dur := time.Millisecond * 250
max := time.Second * 10
for {
err := z.TransferIn()
if err == nil {
break
}
log.Warningf("All '%s' masters failed to transfer, retrying in %s: %s", n, dur.String(), err)
time.Sleep(dur)
dur <<= 1 // double the duration
if dur > max {
dur = max
}
}
z.Update()
}()
})
return nil
})
}
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Secondary{file.File{Next: next, Zones: zones}}
})
return nil
}
func secondaryParse(c *caddy.Controller) (file.Zones, error) {
z := make(map[string]*file.Zone)
names := []string{}
for c.Next() {
if c.Val() == "secondary" {
// secondary [origin]
origins := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
for i := range origins {
z[origins[i]] = file.NewZone(origins[i], "stdin")
names = append(names, origins[i])
}
hasTransfer := false
for c.NextBlock() {
var f []string
switch c.Val() {
case "transfer":
var err error
f, err = parse.TransferIn(c)
if err != nil {
return file.Zones{}, err
}
hasTransfer = true
default:
return file.Zones{}, c.Errf("unknown property '%s'", c.Val())
}
for _, origin := range origins {
if f != nil {
z[origin].TransferFrom = append(z[origin].TransferFrom, f...)
}
z[origin].Upstream = upstream.New()
}
}
if !hasTransfer {
return file.Zones{}, c.Err("secondary zones require a transfer from property")
}
}
}
return file.Zones{Z: z, Names: names}, nil
}
package sign
import (
"github.com/miekg/dns"
)
func (p Pair) signRRs(rrs []dns.RR, signerName string, ttl, incep, expir uint32) (*dns.RRSIG, error) {
rrsig := &dns.RRSIG{
Hdr: dns.RR_Header{Rrtype: dns.TypeRRSIG, Ttl: ttl},
Algorithm: p.Public.Algorithm,
SignerName: signerName,
KeyTag: p.KeyTag,
OrigTtl: ttl,
Inception: incep,
Expiration: expir,
}
e := rrsig.Sign(p.Private, rrs)
return rrsig, e
}
package sign
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/coredns/coredns/plugin/file"
"github.com/coredns/coredns/plugin/file/tree"
"github.com/miekg/dns"
)
// write writes out the zone file to a temporary file which is then moved into the correct place.
func (s *Signer) write(z *file.Zone) error {
f, err := os.CreateTemp(s.directory, "signed-")
if err != nil {
return err
}
if err := write(f, z); err != nil {
f.Close()
return err
}
f.Close()
return os.Rename(f.Name(), filepath.Join(s.directory, s.signedfile))
}
func write(w io.Writer, z *file.Zone) error {
if _, err := io.WriteString(w, z.SOA.String()); err != nil {
return err
}
w.Write([]byte("\n")) // RR Stringer() method doesn't include newline, which ends the RR in a zone file, write that here.
for _, rr := range z.SIGSOA {
io.WriteString(w, rr.String())
w.Write([]byte("\n"))
}
for _, rr := range z.NS {
io.WriteString(w, rr.String())
w.Write([]byte("\n"))
}
for _, rr := range z.SIGNS {
io.WriteString(w, rr.String())
w.Write([]byte("\n"))
}
err := z.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error {
for _, r := range e.All() {
io.WriteString(w, r.String())
w.Write([]byte("\n"))
}
return nil
})
return err
}
// Parse parses the zone in filename and returns a new Zone or an error. This
// is similar to the Parse function in the *file* plugin. However when parsing
// the record types DNSKEY, RRSIG, CDNSKEY and CDS are *not* included in the returned
// zone (if encountered).
func Parse(f io.Reader, origin, fileName string) (*file.Zone, error) {
zp := dns.NewZoneParser(f, dns.Fqdn(origin), fileName)
zp.SetIncludeAllowed(true)
z := file.NewZone(origin, fileName)
seenSOA := false
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
switch rr.(type) {
case *dns.DNSKEY, *dns.RRSIG, *dns.CDNSKEY, *dns.CDS:
continue
case *dns.SOA:
seenSOA = true
if err := z.Insert(rr); err != nil {
return nil, err
}
default:
if err := z.Insert(rr); err != nil {
return nil, err
}
}
}
if !seenSOA {
return nil, fmt.Errorf("file %q has no SOA record", fileName)
}
if err := zp.Err(); err != nil {
return nil, err
}
return z, nil
}
package sign
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/miekg/dns"
"golang.org/x/crypto/ed25519"
)
// Pair holds DNSSEC key information, both the public and private components are stored here.
type Pair struct {
Public *dns.DNSKEY
KeyTag uint16
Private crypto.Signer
}
// keyParse reads the public and private key from disk.
func keyParse(c *caddy.Controller) ([]Pair, error) {
if !c.NextArg() {
return nil, c.ArgErr()
}
pairs := []Pair{}
config := dnsserver.GetConfig(c)
switch c.Val() {
case "file":
ks := c.RemainingArgs()
if len(ks) == 0 {
return nil, c.ArgErr()
}
for _, k := range ks {
base := k
// Kmiek.nl.+013+26205.key, handle .private or without extension: Kmiek.nl.+013+26205
if strings.HasSuffix(k, ".key") {
base = k[:len(k)-4]
}
if strings.HasSuffix(k, ".private") {
base = k[:len(k)-8]
}
if !filepath.IsAbs(base) && config.Root != "" {
base = filepath.Join(config.Root, base)
}
pair, err := readKeyPair(base+".key", base+".private")
if err != nil {
return nil, err
}
pairs = append(pairs, pair)
}
case "directory":
return nil, fmt.Errorf("directory: not implemented")
}
return pairs, nil
}
func readKeyPair(public, private string) (Pair, error) {
rk, err := os.Open(filepath.Clean(public))
if err != nil {
return Pair{}, err
}
b, err := io.ReadAll(rk)
if err != nil {
return Pair{}, err
}
dnskey, err := dns.NewRR(string(b))
if err != nil {
return Pair{}, err
}
if _, ok := dnskey.(*dns.DNSKEY); !ok {
return Pair{}, fmt.Errorf("RR in %q is not a DNSKEY: %d", public, dnskey.Header().Rrtype)
}
ksk := dnskey.(*dns.DNSKEY).Flags&(1<<8) == (1<<8) && dnskey.(*dns.DNSKEY).Flags&1 == 1
if !ksk {
return Pair{}, fmt.Errorf("DNSKEY in %q is not a CSK/KSK", public)
}
rp, err := os.Open(filepath.Clean(private))
if err != nil {
return Pair{}, err
}
privkey, err := dnskey.(*dns.DNSKEY).ReadPrivateKey(rp, private)
if err != nil {
return Pair{}, err
}
switch signer := privkey.(type) {
case *ecdsa.PrivateKey:
return Pair{Public: dnskey.(*dns.DNSKEY), KeyTag: dnskey.(*dns.DNSKEY).KeyTag(), Private: signer}, nil
case ed25519.PrivateKey:
return Pair{Public: dnskey.(*dns.DNSKEY), KeyTag: dnskey.(*dns.DNSKEY).KeyTag(), Private: signer}, nil
case *rsa.PrivateKey:
return Pair{Public: dnskey.(*dns.DNSKEY), KeyTag: dnskey.(*dns.DNSKEY).KeyTag(), Private: signer}, nil
default:
return Pair{}, fmt.Errorf("unsupported algorithm %s", signer)
}
}
// keyTag returns the key tags of the keys in ps as a formatted string.
func keyTag(ps []Pair) string {
if len(ps) == 0 {
return ""
}
var sb strings.Builder
for _, p := range ps {
sb.WriteString(strconv.Itoa(int(p.KeyTag)) + ",")
}
s := sb.String()
return s[:len(s)-1]
}
package sign
import (
"slices"
"github.com/coredns/coredns/plugin/file"
"github.com/coredns/coredns/plugin/file/tree"
"github.com/miekg/dns"
)
// names returns the elements of the zone in nsec order.
func names(origin string, z *file.Zone) []string {
// There will also be apex records other than NS and SOA (who are kept separate), as we
// are adding DNSKEY and CDS/CDNSKEY records in the apex *before* we sign.
n := []string{}
z.AuthWalk(func(e *tree.Elem, _ map[uint16][]dns.RR, auth bool) error {
if !auth {
return nil
}
n = append(n, e.Name())
return nil
})
return n
}
// NSEC returns an NSEC record according to name, next, ttl and bitmap. Note that the bitmap is sorted before use.
func NSEC(name, next string, ttl uint32, bitmap []uint16) *dns.NSEC {
slices.Sort(bitmap)
return &dns.NSEC{
Hdr: dns.RR_Header{Name: name, Ttl: ttl, Rrtype: dns.TypeNSEC, Class: dns.ClassINET},
NextDomain: next,
TypeBitMap: bitmap,
}
}
package sign
import (
"fmt"
"math/rand"
"path/filepath"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("sign", setup) }
func setup(c *caddy.Controller) error {
sign, err := parse(c)
if err != nil {
return plugin.Error("sign", err)
}
c.OnStartup(sign.OnStartup)
c.OnStartup(func() error {
for _, signer := range sign.signers {
go signer.refresh(durationRefreshHours)
}
return nil
})
c.OnShutdown(func() error {
for _, signer := range sign.signers {
close(signer.stop)
}
return nil
})
// Don't call AddPlugin, *sign* is not a plugin.
return nil
}
func parse(c *caddy.Controller) (*Sign, error) {
sign := &Sign{}
config := dnsserver.GetConfig(c)
for c.Next() {
if !c.NextArg() {
return nil, c.ArgErr()
}
dbfile := c.Val()
if !filepath.IsAbs(dbfile) && config.Root != "" {
dbfile = filepath.Join(config.Root, dbfile)
}
// Validate dbfile token to avoid infinite signing loops caused by invalid paths
if strings.ContainsRune(dbfile, '\uFFFD') {
return nil, fmt.Errorf("dbfile %q contains invalid characters", dbfile)
}
origins := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
signers := make([]*Signer, len(origins))
for i := range origins {
signers[i] = &Signer{
dbfile: dbfile,
origin: origins[i],
jitterIncep: time.Duration(float32(durationInceptionJitter) * rand.Float32()),
jitterExpir: time.Duration(float32(durationExpirationDayJitter) * rand.Float32()),
directory: "/var/lib/coredns",
stop: make(chan struct{}),
signedfile: fmt.Sprintf("db.%ssigned", origins[i]), // origins[i] is a fqdn, so it ends with a dot, hence %ssigned.
}
}
for c.NextBlock() {
switch c.Val() {
case "key":
pairs, err := keyParse(c)
if err != nil {
return sign, err
}
for i := range signers {
for _, p := range pairs {
p.Public.Header().Name = signers[i].origin
}
signers[i].keys = append(signers[i].keys, pairs...)
}
case "directory":
dir := c.RemainingArgs()
if len(dir) == 0 || len(dir) > 1 {
return sign, fmt.Errorf("can only be one argument after %q", "directory")
}
if !filepath.IsAbs(dir[0]) && config.Root != "" {
dir[0] = filepath.Join(config.Root, dir[0])
}
for i := range signers {
signers[i].directory = dir[0]
signers[i].signedfile = fmt.Sprintf("db.%ssigned", signers[i].origin)
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
sign.signers = append(sign.signers, signers...)
}
return sign, nil
}
// Package sign implements a zone signer as a plugin.
package sign
import (
"path/filepath"
"time"
)
// Sign contains signers that sign the zones files.
type Sign struct {
signers []*Signer
}
// OnStartup scans all signers and signs or resigns zones if needed.
func (s *Sign) OnStartup() error {
for _, signer := range s.signers {
why := signer.resign()
if why == nil {
log.Infof("Skipping signing zone %q in %q: signatures are valid", signer.origin, filepath.Join(signer.directory, signer.signedfile))
continue
}
go signAndLog(signer, why)
}
return nil
}
// Various duration constants for signing of the zones.
const (
durationExpireDays = 7 * 24 * time.Hour // max time allowed before expiration
durationResignDays = 6 * 24 * time.Hour // if the last sign happened this long ago, sign again
durationSignatureExpireDays = 32 * 24 * time.Hour // sign for 32 days
durationRefreshHours = 5 * time.Hour // check zones every 5 hours
durationInceptionJitter = -18 * time.Hour // default max jitter for the inception
durationExpirationDayJitter = 5 * 24 * time.Hour // default max jitter for the expiration
durationSignatureInceptionHours = -3 * time.Hour // -(2+1) hours, be sure to catch daylight saving time and such, jitter is subtracted
)
const timeFmt = "2006-01-02T15:04:05.000Z07:00"
package sign
import (
"fmt"
"io"
"os"
"path/filepath"
"time"
"github.com/coredns/coredns/plugin/file"
"github.com/coredns/coredns/plugin/file/tree"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("sign")
// Signer holds the data needed to sign a zone file.
type Signer struct {
keys []Pair
origin string
dbfile string
directory string
jitterIncep time.Duration
jitterExpir time.Duration
signedfile string
stop chan struct{}
}
// Sign signs a zone file according to the parameters in s.
func (s *Signer) Sign(now time.Time) (*file.Zone, error) {
rd, err := os.Open(s.dbfile)
if err != nil {
return nil, err
}
z, err := Parse(rd, s.origin, s.dbfile)
if err != nil {
return nil, err
}
mttl := z.SOA.Minttl
ttl := z.SOA.Header().Ttl
inception, expiration := lifetime(now, s.jitterIncep, s.jitterExpir)
z.SOA.Serial = uint32(now.Unix())
for _, pair := range s.keys {
pair.Public.Header().Ttl = ttl // set TTL on key so it matches the RRSIG.
z.Insert(pair.Public)
z.Insert(pair.Public.ToDS(dns.SHA1).ToCDS())
z.Insert(pair.Public.ToDS(dns.SHA256).ToCDS())
z.Insert(pair.Public.ToCDNSKEY())
}
names := names(s.origin, z)
ln := len(names)
for _, pair := range s.keys {
rrsig, err := pair.signRRs([]dns.RR{z.SOA}, s.origin, ttl, inception, expiration)
if err != nil {
return nil, err
}
z.Insert(rrsig)
// NS apex may not be set if RR's have been discarded because the origin doesn't match.
if len(z.NS) > 0 {
rrsig, err = pair.signRRs(z.NS, s.origin, ttl, inception, expiration)
if err != nil {
return nil, err
}
z.Insert(rrsig)
}
}
// We are walking the tree in the same direction, so names[] can be used here to indicated the next element.
i := 1
err = z.AuthWalk(func(e *tree.Elem, zrrs map[uint16][]dns.RR, auth bool) error {
if !auth {
return nil
}
if e.Name() == s.origin {
nsec := NSEC(e.Name(), names[(ln+i)%ln], mttl, append(e.Types(), dns.TypeNS, dns.TypeSOA, dns.TypeRRSIG, dns.TypeNSEC))
z.Insert(nsec)
} else {
nsec := NSEC(e.Name(), names[(ln+i)%ln], mttl, append(e.Types(), dns.TypeRRSIG, dns.TypeNSEC))
z.Insert(nsec)
}
for t, rrs := range zrrs {
// RRSIGs are not signed and NS records are not signed because we are never authoratiative for them.
// The zone's apex nameservers records are not kept in this tree and are signed separately.
if t == dns.TypeRRSIG || t == dns.TypeNS {
continue
}
for _, pair := range s.keys {
rrsig, err := pair.signRRs(rrs, s.origin, rrs[0].Header().Ttl, inception, expiration)
if err != nil {
return err
}
e.Insert(rrsig)
}
}
i++
return nil
})
return z, err
}
// resign checks if the signed zone exists, or needs resigning.
func (s *Signer) resign() error {
signedfile := filepath.Join(s.directory, s.signedfile)
rd, err := os.Open(filepath.Clean(signedfile))
if err != nil && os.IsNotExist(err) {
return err
}
now := time.Now().UTC()
return resign(rd, now)
}
// resign will scan rd and check the signature on the SOA record. We will resign on the basis
// of 2 conditions:
// * either the inception is more than 6 days ago, or
// * we only have 1 week left on the signature
//
// All SOA signatures will be checked. If the SOA isn't found in the first 100
// records, we will resign the zone.
func resign(rd io.Reader, now time.Time) (why error) {
zp := dns.NewZoneParser(rd, ".", "resign")
zp.SetIncludeAllowed(true)
i := 0
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
switch x := rr.(type) {
case *dns.RRSIG:
if x.TypeCovered != dns.TypeSOA {
continue
}
incep, _ := time.Parse("20060102150405", dns.TimeToString(x.Inception))
// If too long ago, resign.
if now.Sub(incep) >= 0 && now.Sub(incep) > durationResignDays {
return fmt.Errorf("inception %q was more than: %s ago from %s: %s", incep.Format(timeFmt), durationResignDays, now.Format(timeFmt), now.Sub(incep))
}
// Inception hasn't even start yet.
if now.Sub(incep) < 0 {
return fmt.Errorf("inception %q date is in the future: %s", incep.Format(timeFmt), now.Sub(incep))
}
expire, _ := time.Parse("20060102150405", dns.TimeToString(x.Expiration))
if expire.Sub(now) < durationExpireDays {
return fmt.Errorf("expiration %q is less than: %s away from %s: %s", expire.Format(timeFmt), durationExpireDays, now.Format(timeFmt), expire.Sub(now))
}
}
i++
if i > 100 {
// 100 is a random number. A SOA record should be the first in the zonefile, but RFC 1035 doesn't actually mandate this. So it could
// be 3rd or even later. The number 100 looks crazy high enough that it will catch all weird zones, but not high enough to keep the CPU
// busy with parsing all the time.
return fmt.Errorf("no SOA RRSIG found in first 100 records")
}
}
return zp.Err()
}
func signAndLog(s *Signer, why error) {
now := time.Now().UTC()
z, err := s.Sign(now)
log.Infof("Signing %q because %s", s.origin, why)
if err != nil {
log.Warningf("Error signing %q with key tags %q in %s: %s, next: %s", s.origin, keyTag(s.keys), time.Since(now), err, now.Add(durationRefreshHours).Format(timeFmt))
return
}
if err := s.write(z); err != nil {
log.Warningf("Error signing %q: failed to move zone file into place: %s", s.origin, err)
return
}
log.Infof("Successfully signed zone %q in %q with key tags %q and %d SOA serial, elapsed %f, next: %s", s.origin, filepath.Join(s.directory, s.signedfile), keyTag(s.keys), z.SOA.Serial, time.Since(now).Seconds(), now.Add(durationRefreshHours).Format(timeFmt))
}
// refresh checks every val if some zones need to be resigned.
func (s *Signer) refresh(val time.Duration) {
tick := time.NewTicker(val)
defer tick.Stop()
for {
select {
case <-s.stop:
return
case <-tick.C:
why := s.resign()
if why == nil {
continue
}
signAndLog(s, why)
}
}
}
func lifetime(now time.Time, jitterInception, jitterExpiration time.Duration) (uint32, uint32) {
incep := uint32(now.Add(durationSignatureInceptionHours).Add(jitterInception).Unix())
expir := uint32(now.Add(durationSignatureExpireDays).Add(jitterExpiration).Unix())
return incep, expir
}
package template
import (
"regexp"
"strconv"
gotmpl "text/template"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/miekg/dns"
)
func init() { plugin.Register("template", setupTemplate) }
func setupTemplate(c *caddy.Controller) error {
handler, err := templateParse(c)
if err != nil {
return plugin.Error("template", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
handler.Next = next
return handler
})
return nil
}
func templateParse(c *caddy.Controller) (handler Handler, err error) {
handler.Templates = make([]template, 0)
for c.Next() {
if !c.NextArg() {
return handler, c.ArgErr()
}
class, ok := dns.StringToClass[c.Val()]
if !ok {
return handler, c.Errf("invalid query class %s", c.Val())
}
if !c.NextArg() {
return handler, c.ArgErr()
}
qtype, ok := dns.StringToType[c.Val()]
if !ok {
return handler, c.Errf("invalid RR class %s", c.Val())
}
zones := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
handler.Zones = append(handler.Zones, zones...)
t := template{qclass: class, qtype: qtype, zones: zones}
t.regex = make([]*regexp.Regexp, 0)
templatePrefix := ""
t.answer = make([]*gotmpl.Template, 0)
t.upstream = upstream.New()
for c.NextBlock() {
switch c.Val() {
case "match":
args := c.RemainingArgs()
if len(args) == 0 {
return handler, c.ArgErr()
}
for _, regex := range args {
r, err := regexp.Compile(regex)
if err != nil {
return handler, c.Errf("could not parse regex: %s, %v", regex, err)
}
templatePrefix = templatePrefix + regex + " "
t.regex = append(t.regex, r)
}
case "answer":
args := c.RemainingArgs()
if len(args) == 0 {
return handler, c.ArgErr()
}
for _, answer := range args {
tmpl, err := newTemplate("answer", answer)
if err != nil {
return handler, c.Errf("could not compile template: %s, %v", c.Val(), err)
}
t.answer = append(t.answer, tmpl)
}
case "additional":
args := c.RemainingArgs()
if len(args) == 0 {
return handler, c.ArgErr()
}
for _, additional := range args {
tmpl, err := newTemplate("additional", additional)
if err != nil {
return handler, c.Errf("could not compile template: %s, %v\n", c.Val(), err)
}
t.additional = append(t.additional, tmpl)
}
case "authority":
args := c.RemainingArgs()
if len(args) == 0 {
return handler, c.ArgErr()
}
for _, authority := range args {
tmpl, err := newTemplate("authority", authority)
if err != nil {
return handler, c.Errf("could not compile template: %s, %v\n", c.Val(), err)
}
t.authority = append(t.authority, tmpl)
}
case "rcode":
if !c.NextArg() {
return handler, c.ArgErr()
}
rcode, ok := dns.StringToRcode[c.Val()]
if !ok {
return handler, c.Errf("unknown rcode %s", c.Val())
}
t.rcode = rcode
case "ederror":
args := c.RemainingArgs()
if len(args) != 1 && len(args) != 2 {
return handler, c.ArgErr()
}
code, err := strconv.ParseUint(args[0], 10, 16)
if err != nil {
return handler, c.Errf("error parsing extended DNS error code %s, %v\n", c.Val(), err)
}
if len(args) == 2 {
t.ederror = &ederror{code: uint16(code), reason: args[1]}
} else {
t.ederror = &ederror{code: uint16(code)}
}
case "fallthrough":
t.fall.SetZonesFromArgs(c.RemainingArgs())
case "upstream":
// remove soon
c.RemainingArgs()
default:
return handler, c.ArgErr()
}
}
if len(t.regex) == 0 {
t.regex = append(t.regex, regexp.MustCompile(".*"))
}
handler.Templates = append(handler.Templates, t)
}
return handler, nil
}
package template
import (
"bytes"
"context"
"regexp"
"strconv"
gotmpl "text/template"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Handler is a plugin handler that takes a query and templates a response.
type Handler struct {
Zones []string
Next plugin.Handler
Templates []template
}
type template struct {
zones []string
rcode int
regex []*regexp.Regexp
answer []*gotmpl.Template
additional []*gotmpl.Template
authority []*gotmpl.Template
qclass uint16
qtype uint16
ederror *ederror
fall fall.F
upstream Upstreamer
}
type ederror struct {
code uint16
reason string
}
// Upstreamer looks up targets of CNAME templates
type Upstreamer interface {
Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error)
}
type templateData struct {
Zone string
Name string
Regex string
Match []string
Group map[string]string
Class string
Type string
Message *dns.Msg
Question *dns.Question
Remote string
md map[string]metadata.Func
}
func (data *templateData) Meta(metaName string) string {
if data.md == nil {
return ""
}
if f, ok := data.md[metaName]; ok {
return f()
}
return ""
}
// ServeDNS implements the plugin.Handler interface.
func (h Handler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
zone := plugin.Zones(h.Zones).Matches(state.Name())
if zone == "" {
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
for _, template := range h.Templates {
data, match, fthrough := template.match(ctx, state)
if !match {
if !fthrough {
return dns.RcodeServerFailure, nil
}
continue
}
templateMatchesCount.WithLabelValues(metrics.WithServer(ctx), data.Zone, metrics.WithView(ctx), data.Class, data.Type).Inc()
if template.rcode == dns.RcodeServerFailure {
return template.rcode, nil
}
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
msg.Rcode = template.rcode
for _, answer := range template.answer {
rr, err := executeRRTemplate(metrics.WithServer(ctx), metrics.WithView(ctx), "answer", answer, data)
if err != nil {
return dns.RcodeServerFailure, err
}
msg.Answer = append(msg.Answer, rr)
if template.upstream != nil && (state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA) && rr.Header().Rrtype == dns.TypeCNAME {
if up, err := template.upstream.Lookup(ctx, state, rr.(*dns.CNAME).Target, state.QType()); err == nil && up != nil {
msg.Truncated = up.Truncated
msg.Answer = append(msg.Answer, up.Answer...)
}
}
}
for _, additional := range template.additional {
rr, err := executeRRTemplate(metrics.WithServer(ctx), metrics.WithView(ctx), "additional", additional, data)
if err != nil {
return dns.RcodeServerFailure, err
}
msg.Extra = append(msg.Extra, rr)
}
for _, authority := range template.authority {
rr, err := executeRRTemplate(metrics.WithServer(ctx), metrics.WithView(ctx), "authority", authority, data)
if err != nil {
return dns.RcodeServerFailure, err
}
msg.Ns = append(msg.Ns, rr)
}
if template.ederror != nil {
msg = msg.SetEdns0(4096, true)
ede := dns.EDNS0_EDE{InfoCode: template.ederror.code, ExtraText: template.ederror.reason}
msg.IsEdns0().Option = append(msg.IsEdns0().Option, &ede)
}
w.WriteMsg(msg)
return template.rcode, nil
}
return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r)
}
// Name implements the plugin.Handler interface.
func (h Handler) Name() string { return "template" }
func executeRRTemplate(server, view, section string, template *gotmpl.Template, data *templateData) (dns.RR, error) {
buffer := &bytes.Buffer{}
err := template.Execute(buffer, data)
if err != nil {
templateFailureCount.WithLabelValues(server, data.Zone, view, data.Class, data.Type, section, template.Tree.Root.String()).Inc()
return nil, err
}
rr, err := dns.NewRR(buffer.String())
if err != nil {
templateRRFailureCount.WithLabelValues(server, data.Zone, view, data.Class, data.Type, section, template.Tree.Root.String()).Inc()
return rr, err
}
return rr, nil
}
func newTemplate(name, text string) (*gotmpl.Template, error) {
funcMap := gotmpl.FuncMap{
"parseInt": strconv.ParseUint,
}
return gotmpl.New(name).Funcs(funcMap).Parse(text)
}
func (t template) match(ctx context.Context, state request.Request) (*templateData, bool, bool) {
q := state.Req.Question[0]
data := &templateData{md: metadata.ValueFuncs(ctx), Remote: state.IP()}
zone := plugin.Zones(t.zones).Matches(state.Name())
if zone == "" {
return data, false, true
}
if t.qclass != dns.ClassANY && q.Qclass != dns.ClassANY && q.Qclass != t.qclass {
return data, false, true
}
if t.qtype != dns.TypeANY && q.Qtype != dns.TypeANY && q.Qtype != t.qtype {
return data, false, true
}
for _, regex := range t.regex {
if !regex.MatchString(state.Name()) {
continue
}
data.Zone = zone
data.Regex = regex.String()
data.Name = state.Name()
data.Question = &q
data.Message = state.Req
if q.Qclass != dns.ClassANY {
data.Class = dns.ClassToString[q.Qclass]
} else {
data.Class = dns.ClassToString[t.qclass]
}
if q.Qtype != dns.TypeANY {
data.Type = dns.TypeToString[q.Qtype]
} else {
data.Type = dns.TypeToString[t.qtype]
}
matches := regex.FindStringSubmatch(state.Name())
data.Match = make([]string, len(matches))
data.Group = make(map[string]string)
groupNames := regex.SubexpNames()
for i, m := range matches {
data.Match[i] = m
data.Group[strconv.Itoa(i)] = m
}
for i, m := range matches {
if len(groupNames[i]) > 0 {
data.Group[groupNames[i]] = m
}
}
return data, true, false
}
return data, false, t.fall.Through(state.Name())
}
package test
import (
"os"
"path/filepath"
"testing"
)
// TempFile will create a temporary file on disk and returns the name and a cleanup function to remove it later.
func TempFile(dir, content string) (string, func(), error) {
f, err := os.CreateTemp(dir, "go-test-tmpfile")
if err != nil {
return "", nil, err
}
if err := os.WriteFile(f.Name(), []byte(content), 0644); err != nil {
return "", nil, err
}
rmFunc := func() { os.Remove(f.Name()) }
return f.Name(), rmFunc, nil
}
// WritePEMFiles creates a tmp dir with ca.pem, cert.pem, and key.pem
func WritePEMFiles(t *testing.T) (string, error) {
t.Helper()
tempDir := t.TempDir()
data := `-----BEGIN CERTIFICATE-----
MIIC9zCCAd+gAwIBAgIJALGtqdMzpDemMA0GCSqGSIb3DQEBCwUAMBIxEDAOBgNV
BAMMB2t1YmUtY2EwHhcNMTYxMDE5MTU1NDI0WhcNNDQwMzA2MTU1NDI0WjASMRAw
DgYDVQQDDAdrdWJlLWNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA
pa4Wu/WkpJNRr8pMVE6jjwzNUOx5mIyoDr8WILSxVQcEeyVPPmAqbmYXtVZO11p9
jTzoEqF7Kgts3HVYGCk5abqbE14a8Ru/DmV5avU2hJ/NvSjtNi/O+V6SzCbg5yR9
lBR53uADDlzuJEQT9RHq7A5KitFkx4vUcXnjOQCbDogWFoYuOgNEwJPy0Raz3NJc
ViVfDqSJ0QHg02kCOMxcGFNRQ9F5aoW7QXZXZXD0tn3wLRlu4+GYyqt8fw5iNdLJ
t79yKp8I+vMTmMPz4YKUO+eCl5EY10Qs7wvoG/8QNbjH01BRN3L8iDT2WfxdvjTu
1RjPxFL92i+B7HZO7jGLfQIDAQABo1AwTjAdBgNVHQ4EFgQUZTrg+Xt87tkxDhlB
gKk9FdTOW3IwHwYDVR0jBBgwFoAUZTrg+Xt87tkxDhlBgKk9FdTOW3IwDAYDVR0T
BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEApB7JFVrZpGSOXNO3W7SlN6OCPXv9
C7rIBc8rwOrzi2mZWcBmWheQrqBo8xHif2rlFNVQxtq3JcQ8kfg/m1fHeQ/Ygzel
Z+U1OqozynDySBZdNn9i+kXXgAUCqDPp3hEQWe0os/RRpIwo9yOloBxdiX6S0NIf
VB8n8kAynFPkH7pYrGrL1HQgDFCSfa4tUJ3+9sppnCu0pNtq5AdhYx9xFb2sn+8G
xGbtCkhVk2VQ+BiCWnjYXJ6ZMzabP7wiOFDP9Pvr2ik22PRItsW/TLfHFXM1jDmc
I1rs/VUGKzcJGVIWbHrgjP68CTStGAvKgbsTqw7aLXTSqtPw88N9XVSyRg==
-----END CERTIFICATE-----`
path := filepath.Join(tempDir, "ca.pem")
if err := os.WriteFile(path, []byte(data), 0644); err != nil {
return "", err
}
data = `-----BEGIN CERTIFICATE-----
MIICozCCAYsCCQCRlf5BrvPuqjANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDDAdr
dWJlLWNhMB4XDTE2MTAxOTE2MDUxOFoXDTE3MTAxOTE2MDUxOFowFTETMBEGA1UE
AwwKa3ViZS1hZG1pbjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMTw
a7wCFoiCad/N53aURfjrme+KR7FS0yf5Ur9OR/oM3BoS9stYu5Flzr35oL5T6t5G
c2ey78mUs/Cs07psnjUdKH55bDpJSdG7zW9mXNyeLwIefFcj/38SS5NBSotmLo8u
scJMGXeQpCQtfVuVJSP2bfU5u5d0KTLSg/Cor6UYonqrRB82HbOuuk8Wjaww4VHo
nCq7X8o948V6HN5ZibQOgMMo+nf0wORREHBjvwc4W7ewbaTcfoe1VNAo/QnkqxTF
ueMb2HxgghArqQSK8b44O05V0zrde25dVnmnte6sPjcV0plqMJ37jViISxsOPUFh
/ZW7zbIM/7CMcDekCiECAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAYZE8OxwRR7GR
kdd5aIriDwWfcl56cq5ICyx87U8hAZhBxk46a6a901LZPzt3xKyWIFQSRj/NYiQ+
/thjGLZI2lhkVgYtyAD4BNxDiuppQSCbkjY9tLVDdExGttEVN7+UYDWJBHy6X16Y
xSG9FE3Dvp9LI89Nq8E3dRh+Q8wu52q9HaQXjS5YtzQOtDFKPBkihXu/c6gEHj4Y
bZVk8rFiH8/CvcQxAuvNI3VVCFUKd2LeQtqwYQQ//qoiuA15krTq5Ut9eXJ8zxAw
zhDEPP4FhY+Sz+y1yWirphl7A1aZwhXVPcfWIGqpQ3jzNwUeocbH27kuLh+U4hQo
qeg10RdFnw==
-----END CERTIFICATE-----`
path = filepath.Join(tempDir, "cert.pem")
if err := os.WriteFile(path, []byte(data), 0644); err != nil {
return "", err
}
data = `-----BEGIN RSA PRIVATE KEY-----
MIIEpgIBAAKCAQEAxPBrvAIWiIJp383ndpRF+OuZ74pHsVLTJ/lSv05H+gzcGhL2
y1i7kWXOvfmgvlPq3kZzZ7LvyZSz8KzTumyeNR0ofnlsOklJ0bvNb2Zc3J4vAh58
VyP/fxJLk0FKi2Yujy6xwkwZd5CkJC19W5UlI/Zt9Tm7l3QpMtKD8KivpRiieqtE
HzYds666TxaNrDDhUeicKrtfyj3jxXoc3lmJtA6Awyj6d/TA5FEQcGO/Bzhbt7Bt
pNx+h7VU0Cj9CeSrFMW54xvYfGCCECupBIrxvjg7TlXTOt17bl1Weae17qw+NxXS
mWownfuNWIhLGw49QWH9lbvNsgz/sIxwN6QKIQIDAQABAoIBAQDCXq9V7ZGjxWMN
OkFaLVkqJg3V91puztoMt+xNV8t+JTcOnOzrIXZuOFbl9PwLHPPP0SSRkm9LOvKl
dU26zv0OWureeKSymia7U2mcqyC3tX+bzc7WinbeSYZBnc0e7AjD1EgpBcaU1TLL
agIxY3A2oD9CKmrVPhZzTIZf/XztqTYjhvs5I2kBeT0imdYGpXkdndRyGX4I5/JQ
fnp3Czj+AW3zX7RvVnXOh4OtIAcfoG9xoNyD5LOSlJkkX0MwTS8pEBeZA+A4nb+C
ivjnOSgXWD+liisI+LpBgBbwYZ/E49x5ghZYrJt8QXSk7Bl/+UOyv6XZAm2mev6j
RLAZtoABAoGBAP2P+1PoKOwsk+d/AmHqyTCUQm0UG18LOLB/5PyWfXs/6caDmdIe
DZWeZWng1jUQLEadmoEw/CBY5+tPfHlzwzMNhT7KwUfIDQCIBoS7dzHYnwrJ3VZh
qYA05cuGHAAHqwb6UWz3y6Pa4AEVSHX6CM83CAi9jdWZ1rdZybWG+qYBAoGBAMbV
FsR/Ft+tK5ALgXGoG83TlmxzZYuZ1SnNje1OSdCQdMFCJB10gwoaRrw1ICzi40Xk
ydJwV1upGz1om9ReDAD1zQM9artmQx6+TVLiVPALuARdZE70+NrA6w3ZvxUgJjdN
ngvXUr+8SdvaYUAwFu7BulfJlwXjUS711hHW/KQhAoGBALY41QuV2mLwHlLNie7I
hlGtGpe9TXZeYB0nrG6B0CfU5LJPPSotguG1dXhDpm138/nDpZeWlnrAqdsHwpKd
yPhVjR51I7XsZLuvBdA50Q03egSM0c4UXXXPjh1XgaPb3uMi3YWMBwL4ducQXoS6
bb5M9C8j2lxZNF+L3VPhbxwBAoGBAIEWDvX7XKpTDxkxnxRfA84ZNGusb5y2fsHp
Bd+vGBUj8+kUO8Yzwm9op8vA4ebCVrMl2jGZZd3IaDryE1lIxZpJ+pPD5+tKdQEc
o67P6jz+HrYWu+zW9klvPit71qasfKMi7Rza6oo4f+sQWFsH3ZucgpJD+pyD/Ez0
pcpnPRaBAoGBANT/xgHBfIWt4U2rtmRLIIiZxKr+3mGnQdpA1J2BCh+/6AvrEx//
E/WObVJXDnBdViu0L9abE9iaTToBVri4cmlDlZagLuKVR+TFTCN/DSlVZTDkqkLI
8chzqtkH6b2b2R73hyRysWjsomys34ma3mEEPTX/aXeAF2MSZ/EWT9yL
-----END RSA PRIVATE KEY-----`
path = filepath.Join(tempDir, "key.pem")
if err := os.WriteFile(path, []byte(data), 0644); err != nil {
return "", err
}
return tempDir, nil
}
package test
import (
"context"
"fmt"
"sort"
"strings"
"github.com/miekg/dns"
)
type sect int
const (
// Answer is the answer section in an Msg.
Answer sect = iota
// Ns is the authoritative section in an Msg.
Ns
// Extra is the additional section in an Msg.
Extra
)
// RRSet represents a list of RRs.
type RRSet []dns.RR
func (p RRSet) Len() int { return len(p) }
func (p RRSet) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p RRSet) Less(i, j int) bool { return p[i].String() < p[j].String() }
// Case represents a test case that encapsulates various data from a query and response.
// Note that is the TTL of a record is 303 we don't compare it with the TTL.
type Case struct {
Qname string
Qtype uint16
Rcode int
Do bool
CheckingDisabled bool
RecursionAvailable bool
AuthenticatedData bool
Authoritative bool
Truncated bool
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
Error error
}
// Msg returns a *dns.Msg embedded in c.
func (c Case) Msg() *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(c.Qname), c.Qtype)
if c.Do {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetDo()
o.SetUDPSize(4096)
m.Extra = []dns.RR{o}
}
return m
}
// A returns an A record from rr. It panics on errors.
func A(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
// AAAA returns an AAAA record from rr. It panics on errors.
func AAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
// CNAME returns a CNAME record from rr. It panics on errors.
func CNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
// DNAME returns a DNAME record from rr. It panics on errors.
func DNAME(rr string) *dns.DNAME { r, _ := dns.NewRR(rr); return r.(*dns.DNAME) }
// SRV returns a SRV record from rr. It panics on errors.
func SRV(rr string) *dns.SRV { r, _ := dns.NewRR(rr); return r.(*dns.SRV) }
// SOA returns a SOA record from rr. It panics on errors.
func SOA(rr string) *dns.SOA { r, _ := dns.NewRR(rr); return r.(*dns.SOA) }
// NS returns an NS record from rr. It panics on errors.
func NS(rr string) *dns.NS { r, _ := dns.NewRR(rr); return r.(*dns.NS) }
// PTR returns a PTR record from rr. It panics on errors.
func PTR(rr string) *dns.PTR { r, _ := dns.NewRR(rr); return r.(*dns.PTR) }
// TXT returns a TXT record from rr. It panics on errors.
func TXT(rr string) *dns.TXT { r, _ := dns.NewRR(rr); return r.(*dns.TXT) }
// CAA returns a CAA record from rr. It panics on errors.
func CAA(rr string) *dns.CAA { r, _ := dns.NewRR(rr); return r.(*dns.CAA) }
// HINFO returns a HINFO record from rr. It panics on errors.
func HINFO(rr string) *dns.HINFO { r, _ := dns.NewRR(rr); return r.(*dns.HINFO) }
// MX returns an MX record from rr. It panics on errors.
func MX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) }
// RRSIG returns an RRSIG record from rr. It panics on errors.
func RRSIG(rr string) *dns.RRSIG { r, _ := dns.NewRR(rr); return r.(*dns.RRSIG) }
// NSEC returns an NSEC record from rr. It panics on errors.
func NSEC(rr string) *dns.NSEC { r, _ := dns.NewRR(rr); return r.(*dns.NSEC) }
// DNSKEY returns a DNSKEY record from rr. It panics on errors.
func DNSKEY(rr string) *dns.DNSKEY { r, _ := dns.NewRR(rr); return r.(*dns.DNSKEY) }
// DS returns a DS record from rr. It panics on errors.
func DS(rr string) *dns.DS { r, _ := dns.NewRR(rr); return r.(*dns.DS) }
// NAPTR returns a NAPTR record from rr. It panics on errors.
func NAPTR(rr string) *dns.NAPTR { r, _ := dns.NewRR(rr); return r.(*dns.NAPTR) }
// OPT returns an OPT record with UDP buffer size set to bufsize and the DO bit set to do.
func OPT(bufsize int, do bool) *dns.OPT {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetVersion(0)
o.SetUDPSize(uint16(bufsize))
if do {
o.SetDo()
}
return o
}
// Header tests if the header in resp matches the header as defined in tc.
func Header(tc Case, resp *dns.Msg) error {
if resp.Rcode != tc.Rcode {
return fmt.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode])
}
if len(resp.Answer) != len(tc.Answer) {
return fmt.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer))
}
if len(resp.Ns) != len(tc.Ns) {
return fmt.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns))
}
if len(resp.Extra) != len(tc.Extra) {
return fmt.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra))
}
return nil
}
// Section tests if the section in tc matches rr.
func Section(tc Case, sec sect, rr []dns.RR) error {
section := []dns.RR{}
switch sec {
case 0:
section = tc.Answer
case 1:
section = tc.Ns
case 2:
section = tc.Extra
}
for i, a := range rr {
if a.Header().Name != section[i].Header().Name {
return fmt.Errorf("RR %d should have a Header Name of %q, but has %q", i, section[i].Header().Name, a.Header().Name)
}
// 303 signals: don't care what the ttl is.
if section[i].Header().Ttl != 303 && a.Header().Ttl != section[i].Header().Ttl {
if _, ok := section[i].(*dns.OPT); !ok {
// we check edns0 bufize on this one
return fmt.Errorf("RR %d should have a Header TTL of %d, but has %d", i, section[i].Header().Ttl, a.Header().Ttl)
}
}
if a.Header().Rrtype != section[i].Header().Rrtype {
return fmt.Errorf("RR %d should have a header rr type of %d, but has %d", i, section[i].Header().Rrtype, a.Header().Rrtype)
}
switch x := a.(type) {
case *dns.SRV:
if x.Priority != section[i].(*dns.SRV).Priority {
return fmt.Errorf("RR %d should have a Priority of %d, but has %d", i, section[i].(*dns.SRV).Priority, x.Priority)
}
if x.Weight != section[i].(*dns.SRV).Weight {
return fmt.Errorf("RR %d should have a Weight of %d, but has %d", i, section[i].(*dns.SRV).Weight, x.Weight)
}
if x.Port != section[i].(*dns.SRV).Port {
return fmt.Errorf("RR %d should have a Port of %d, but has %d", i, section[i].(*dns.SRV).Port, x.Port)
}
if x.Target != section[i].(*dns.SRV).Target {
return fmt.Errorf("RR %d should have a Target of %q, but has %q", i, section[i].(*dns.SRV).Target, x.Target)
}
case *dns.RRSIG:
if x.TypeCovered != section[i].(*dns.RRSIG).TypeCovered {
return fmt.Errorf("RR %d should have a TypeCovered of %d, but has %d", i, section[i].(*dns.RRSIG).TypeCovered, x.TypeCovered)
}
if x.Labels != section[i].(*dns.RRSIG).Labels {
return fmt.Errorf("RR %d should have a Labels of %d, but has %d", i, section[i].(*dns.RRSIG).Labels, x.Labels)
}
if x.SignerName != section[i].(*dns.RRSIG).SignerName {
return fmt.Errorf("RR %d should have a SignerName of %s, but has %s", i, section[i].(*dns.RRSIG).SignerName, x.SignerName)
}
case *dns.NSEC:
if x.NextDomain != section[i].(*dns.NSEC).NextDomain {
return fmt.Errorf("RR %d should have a NextDomain of %s, but has %s", i, section[i].(*dns.NSEC).NextDomain, x.NextDomain)
}
// TypeBitMap
case *dns.A:
if x.A.String() != section[i].(*dns.A).A.String() {
return fmt.Errorf("RR %d should have a Address of %q, but has %q", i, section[i].(*dns.A).A.String(), x.A.String())
}
case *dns.AAAA:
if x.AAAA.String() != section[i].(*dns.AAAA).AAAA.String() {
return fmt.Errorf("RR %d should have a Address of %q, but has %q", i, section[i].(*dns.AAAA).AAAA.String(), x.AAAA.String())
}
case *dns.TXT:
actualTxt := strings.Join(x.Txt, "")
expectedTxt := strings.Join(section[i].(*dns.TXT).Txt, "")
if actualTxt != expectedTxt {
return fmt.Errorf("RR %d should have a TXT value of %q, but has %q", i, expectedTxt, actualTxt)
}
case *dns.HINFO:
if x.Cpu != section[i].(*dns.HINFO).Cpu {
return fmt.Errorf("RR %d should have a Cpu of %s, but has %s", i, section[i].(*dns.HINFO).Cpu, x.Cpu)
}
if x.Os != section[i].(*dns.HINFO).Os {
return fmt.Errorf("RR %d should have a Os of %s, but has %s", i, section[i].(*dns.HINFO).Os, x.Os)
}
case *dns.SOA:
tt := section[i].(*dns.SOA)
if x.Ns != tt.Ns {
return fmt.Errorf("SOA nameserver should be %q, but is %q", tt.Ns, x.Ns)
}
case *dns.PTR:
tt := section[i].(*dns.PTR)
if x.Ptr != tt.Ptr {
return fmt.Errorf("PTR ptr should be %q, but is %q", tt.Ptr, x.Ptr)
}
case *dns.CNAME:
tt := section[i].(*dns.CNAME)
if x.Target != tt.Target {
return fmt.Errorf("CNAME target should be %q, but is %q", tt.Target, x.Target)
}
case *dns.MX:
tt := section[i].(*dns.MX)
if x.Mx != tt.Mx {
return fmt.Errorf("MX Mx should be %q, but is %q", tt.Mx, x.Mx)
}
if x.Preference != tt.Preference {
return fmt.Errorf("MX Preference should be %q, but is %q", tt.Preference, x.Preference)
}
case *dns.NS:
tt := section[i].(*dns.NS)
if x.Ns != tt.Ns {
return fmt.Errorf("NS nameserver should be %q, but is %q", tt.Ns, x.Ns)
}
case *dns.OPT:
tt := section[i].(*dns.OPT)
if x.UDPSize() != tt.UDPSize() {
return fmt.Errorf("OPT UDPSize should be %d, but is %d", tt.UDPSize(), x.UDPSize())
}
if x.Do() != tt.Do() {
return fmt.Errorf("OPT DO should be %t, but is %t", tt.Do(), x.Do())
}
}
}
return nil
}
// CNAMEOrder makes sure that CNAMES do not appear after their target records.
func CNAMEOrder(res *dns.Msg) error {
for i, c := range res.Answer {
if c.Header().Rrtype != dns.TypeCNAME {
continue
}
for _, a := range res.Answer[:i] {
if a.Header().Name != c.(*dns.CNAME).Target {
continue
}
return fmt.Errorf("CNAME found after target record")
}
}
return nil
}
// SortAndCheck sorts resp and the checks the header and three sections against the testcase in tc.
func SortAndCheck(resp *dns.Msg, tc Case) error {
sort.Sort(RRSet(resp.Answer))
sort.Sort(RRSet(resp.Ns))
sort.Sort(RRSet(resp.Extra))
if err := Header(tc, resp); err != nil {
return err
}
if err := Section(tc, Answer, resp.Answer); err != nil {
return err
}
if err := Section(tc, Ns, resp.Ns); err != nil {
return err
}
return Section(tc, Extra, resp.Extra)
}
// ErrorHandler returns a Handler that returns ServerFailure error when called.
func ErrorHandler() Handler {
return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(m)
return dns.RcodeServerFailure, nil
})
}
// NextHandler returns a Handler that returns rcode and err.
func NextHandler(rcode int, err error) Handler {
return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return rcode, err
})
}
// Copied here to prevent an import cycle, so that we can define to above handlers.
type (
// HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error.
HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
// Handler interface defines a plugin.
Handler interface {
ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
Name() string
}
)
// ServeDNS implements the Handler interface.
func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(ctx, w, r)
}
// Name implements the Handler interface.
func (f HandlerFunc) Name() string { return "handlerfunc" }
package test
import (
"net"
"github.com/miekg/dns"
)
// ResponseWriter is useful for writing tests. It uses some fixed values for the client. The
// remote will always be 10.240.0.1 and port 40212. The local address is always 127.0.0.1 and
// port 53.
type ResponseWriter struct {
TCP bool // if TCP is true we return an TCP connection instead of an UDP one.
RemoteIP string
Zone string
}
// LocalAddr returns the local address, 127.0.0.1:53 (UDP, TCP if t.TCP is true).
func (t *ResponseWriter) LocalAddr() net.Addr {
ip := net.ParseIP("127.0.0.1")
port := 53
if t.TCP {
return &net.TCPAddr{IP: ip, Port: port, Zone: ""}
}
return &net.UDPAddr{IP: ip, Port: port, Zone: ""}
}
// RemoteAddr returns the remote address, defaults to 10.240.0.1:40212 (UDP, TCP is t.TCP is true).
func (t *ResponseWriter) RemoteAddr() net.Addr {
remoteIP := "10.240.0.1"
if t.RemoteIP != "" {
remoteIP = t.RemoteIP
}
ip := net.ParseIP(remoteIP)
port := 40212
if t.TCP {
return &net.TCPAddr{IP: ip, Port: port, Zone: t.Zone}
}
return &net.UDPAddr{IP: ip, Port: port, Zone: t.Zone}
}
// Network implements dns.ResponseWriter interface.
func (t *ResponseWriter) Network() string { return "" }
// WriteMsg implements dns.ResponseWriter interface.
func (t *ResponseWriter) WriteMsg(m *dns.Msg) error { return nil }
// Write implements dns.ResponseWriter interface.
func (t *ResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil }
// Close implements dns.ResponseWriter interface.
func (t *ResponseWriter) Close() error { return nil }
// TsigStatus implements dns.ResponseWriter interface.
func (t *ResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly implements dns.ResponseWriter interface.
func (t *ResponseWriter) TsigTimersOnly(bool) {}
// Hijack implements dns.ResponseWriter interface.
func (t *ResponseWriter) Hijack() {}
// ResponseWriter6 returns fixed client and remote address in IPv6. The remote
// address is always fe80::42:ff:feca:4c65 and port 40212. The local address is always ::1 and port 53.
type ResponseWriter6 struct {
ResponseWriter
}
// LocalAddr returns the local address, always ::1, port 53 (UDP, TCP is t.TCP is true).
func (t *ResponseWriter6) LocalAddr() net.Addr {
if t.TCP {
return &net.TCPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""}
}
return &net.UDPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""}
}
// RemoteAddr returns the remote address, always fe80::42:ff:feca:4c65 port 40212 (UDP, TCP is t.TCP is true).
func (t *ResponseWriter6) RemoteAddr() net.Addr {
if t.TCP {
return &net.TCPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""}
}
return &net.UDPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""}
}
// Adapted by Miek Gieben for CoreDNS testing.
//
// License from prom2json
// Copyright 2014 Prometheus Team
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package test contains helper functions for writing plugin tests.
// For example to scrape a target and inspect the variables.
// Basic usage:
//
// result := Scrape("http://localhost:9153/metrics")
// v := MetricValue("coredns_cache_capacity", result)
package test
import (
"fmt"
"io"
"mime"
"net/http"
"strconv"
"github.com/matttproud/golang_protobuf_extensions/pbutil"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
)
type (
// MetricFamily holds a prometheus metric.
MetricFamily struct {
Name string `json:"name"`
Help string `json:"help"`
Type string `json:"type"`
Metrics []any `json:"metrics,omitempty"` // Either metric or summary.
}
// metric is for all "single value" metrics.
metric struct {
Labels map[string]string `json:"labels,omitempty"`
Value string `json:"value"`
}
summary struct {
Labels map[string]string `json:"labels,omitempty"`
Quantiles map[string]string `json:"quantiles,omitempty"`
Count string `json:"count"`
Sum string `json:"sum"`
}
histogram struct {
Labels map[string]string `json:"labels,omitempty"`
Buckets map[string]string `json:"buckets,omitempty"`
Count string `json:"count"`
Sum string `json:"sum"`
}
)
// Scrape returns the all the vars a []*metricFamily.
func Scrape(url string) []*MetricFamily {
mfChan := make(chan *dto.MetricFamily, 1024)
go fetchMetricFamilies(url, mfChan)
result := []*MetricFamily{}
for mf := range mfChan {
result = append(result, newMetricFamily(mf))
}
return result
}
// ScrapeMetricAsInt provides a sum of all metrics collected for the name and label provided.
// if the metric is not a numeric value, it will be counted a 0.
func ScrapeMetricAsInt(addr string, name string, label string, nometricvalue int) int {
valueToInt := func(m metric) int {
v := m.Value
r, err := strconv.Atoi(v)
if err != nil {
return 0
}
return r
}
met := Scrape(fmt.Sprintf("http://%s/metrics", addr))
found := false
tot := 0
for _, mf := range met {
if mf.Name == name {
// Sum all metrics available
for _, m := range mf.Metrics {
if label == "" {
tot += valueToInt(m.(metric))
found = true
continue
}
for _, v := range m.(metric).Labels {
if v == label {
tot += valueToInt(m.(metric))
found = true
}
}
}
}
}
if !found {
return nometricvalue
}
return tot
}
// MetricValue returns the value associated with name as a string as well as the labels.
// It only returns the first metrics of the slice.
func MetricValue(name string, mfs []*MetricFamily) (string, map[string]string) {
for _, mf := range mfs {
if mf.Name == name {
// Only works with Gauge and Counter...
return mf.Metrics[0].(metric).Value, mf.Metrics[0].(metric).Labels
}
}
return "", nil
}
// MetricValueLabel returns the value for name *and* label *value*.
func MetricValueLabel(name, label string, mfs []*MetricFamily) (string, map[string]string) {
// bit hacky is this really handy...?
for _, mf := range mfs {
if mf.Name == name {
for _, m := range mf.Metrics {
for _, v := range m.(metric).Labels {
if v == label {
return m.(metric).Value, m.(metric).Labels
}
}
}
}
}
return "", nil
}
func newMetricFamily(dtoMF *dto.MetricFamily) *MetricFamily {
mf := &MetricFamily{
Name: dtoMF.GetName(),
Help: dtoMF.GetHelp(),
Type: dtoMF.GetType().String(),
Metrics: make([]any, len(dtoMF.GetMetric())),
}
for i, m := range dtoMF.GetMetric() {
if dtoMF.GetType() == dto.MetricType_SUMMARY {
mf.Metrics[i] = summary{
Labels: makeLabels(m),
Quantiles: makeQuantiles(m),
Count: strconv.FormatUint(m.GetSummary().GetSampleCount(), 10),
Sum: fmt.Sprint(m.GetSummary().GetSampleSum()),
}
} else if dtoMF.GetType() == dto.MetricType_HISTOGRAM {
mf.Metrics[i] = histogram{
Labels: makeLabels(m),
Buckets: makeBuckets(m),
Count: strconv.FormatUint(m.GetHistogram().GetSampleCount(), 10),
Sum: fmt.Sprint(m.GetSummary().GetSampleSum()),
}
} else {
mf.Metrics[i] = metric{
Labels: makeLabels(m),
Value: fmt.Sprint(value(m)),
}
}
}
return mf
}
func value(m *dto.Metric) float64 {
if m.GetGauge() != nil {
return m.GetGauge().GetValue()
}
if m.GetCounter() != nil {
return m.GetCounter().GetValue()
}
if m.GetUntyped() != nil {
return m.GetUntyped().GetValue()
}
return 0.
}
func makeLabels(m *dto.Metric) map[string]string {
result := map[string]string{}
for _, lp := range m.GetLabel() {
result[lp.GetName()] = lp.GetValue()
}
return result
}
func makeQuantiles(m *dto.Metric) map[string]string {
result := map[string]string{}
for _, q := range m.GetSummary().GetQuantile() {
result[fmt.Sprint(q.GetQuantile())] = fmt.Sprint(q.GetValue())
}
return result
}
func makeBuckets(m *dto.Metric) map[string]string {
result := map[string]string{}
for _, b := range m.GetHistogram().GetBucket() {
result[fmt.Sprint(b.GetUpperBound())] = strconv.FormatUint(b.GetCumulativeCount(), 10)
}
return result
}
func fetchMetricFamilies(url string, ch chan<- *dto.MetricFamily) {
defer close(ch)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return
}
req.Header.Add("Accept", acceptHeader)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return
}
mediatype, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err == nil && mediatype == "application/vnd.google.protobuf" &&
params["encoding"] == "delimited" &&
params["proto"] == "io.prometheus.client.MetricFamily" {
for {
mf := &dto.MetricFamily{}
if _, err = pbutil.ReadDelimited(resp.Body, mf); err != nil {
if err == io.EOF {
break
}
return
}
ch <- mf
}
} else {
// We could do further content-type checks here, but the
// fallback for now will anyway be the text format
// version 0.0.4, so just go for it and see if it works.
var parser expfmt.TextParser
metricFamilies, err := parser.TextToMetricFamilies(resp.Body)
if err != nil {
return
}
for _, mf := range metricFamilies {
ch <- mf
}
}
}
const acceptHeader = `application/vnd.google.protobuf;proto=io.prometheus.client.MetricFamily;encoding=delimited;q=0.7,text/plain;version=0.0.4;q=0.3`
package timeouts
import (
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/durations"
)
func init() { plugin.Register("timeouts", setup) }
func setup(c *caddy.Controller) error {
err := parseTimeouts(c)
if err != nil {
return plugin.Error("timeouts", err)
}
return nil
}
func parseTimeouts(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
for c.Next() {
args := c.RemainingArgs()
if len(args) > 0 {
return plugin.Error("timeouts", c.ArgErr())
}
b := 0
for c.NextBlock() {
block := c.Val()
timeoutArgs := c.RemainingArgs()
if len(timeoutArgs) != 1 {
return c.ArgErr()
}
timeout, err := durations.NewDurationFromArg(timeoutArgs[0])
if err != nil {
return c.Err(err.Error())
}
if timeout < (1*time.Second) || timeout > (24*time.Hour) {
return c.Errf("timeout provided '%s' needs to be between 1 second and 24 hours", timeout)
}
switch block {
case "read":
config.ReadTimeout = timeout
case "write":
config.WriteTimeout = timeout
case "idle":
config.IdleTimeout = timeout
default:
return c.Errf("unknown option: '%s'", block)
}
b++
}
if b == 0 {
return plugin.Error("timeouts", c.Err("timeouts block with no timeouts specified"))
}
}
return nil
}
package tls
import (
ctls "crypto/tls"
"path/filepath"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/tls"
)
func init() { plugin.Register("tls", setup) }
func setup(c *caddy.Controller) error {
err := parseTLS(c)
if err != nil {
return plugin.Error("tls", err)
}
return nil
}
func parseTLS(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
if config.TLSConfig != nil {
return plugin.Error("tls", c.Errf("TLS already configured for this server instance"))
}
for c.Next() {
args := c.RemainingArgs()
if len(args) < 2 || len(args) > 3 {
return plugin.Error("tls", c.ArgErr())
}
clientAuth := ctls.NoClientCert
for c.NextBlock() {
switch c.Val() {
case "client_auth":
authTypeArgs := c.RemainingArgs()
if len(authTypeArgs) != 1 {
return c.ArgErr()
}
switch authTypeArgs[0] {
case "nocert":
clientAuth = ctls.NoClientCert
case "request":
clientAuth = ctls.RequestClientCert
case "require":
clientAuth = ctls.RequireAnyClientCert
case "verify_if_given":
clientAuth = ctls.VerifyClientCertIfGiven
case "require_and_verify":
clientAuth = ctls.RequireAndVerifyClientCert
default:
return c.Errf("unknown authentication type '%s'", authTypeArgs[0])
}
default:
return c.Errf("unknown option '%s'", c.Val())
}
}
for i := range args {
if !filepath.IsAbs(args[i]) && config.Root != "" {
args[i] = filepath.Join(config.Root, args[i])
}
}
tls, err := tls.NewTLSConfigFromArgs(args...)
if err != nil {
return err
}
tls.ClientAuth = clientAuth
// NewTLSConfigFromArgs only sets RootCAs, so we need to let ClientCAs refer to it.
tls.ClientCAs = tls.RootCAs
config.TLSConfig = tls
}
return nil
}
package trace
import (
clog "github.com/coredns/coredns/plugin/pkg/log"
)
// loggerAdapter is a simple adapter around plugin logger made to implement io.Writer and ddtrace.Logger interface
// in order to log errors from span reporters as warnings
type loggerAdapter struct {
clog.P
}
func (l *loggerAdapter) Write(p []byte) (n int, err error) {
l.Warning(string(p))
return len(p), nil
}
func (l *loggerAdapter) Log(msg string) {
l.Warning(msg)
}
package trace
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("trace", setup) }
func setup(c *caddy.Controller) error {
t, err := traceParse(c)
if err != nil {
return plugin.Error("trace", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
t.Next = next
return t
})
c.OnStartup(t.OnStartup)
c.OnShutdown(t.OnShutdown)
return nil
}
func traceParse(c *caddy.Controller) (*trace, error) {
var (
tr = &trace{every: 1, serviceName: defServiceName}
err error
)
cfg := dnsserver.GetConfig(c)
if len(cfg.ListenHosts) > 0 && cfg.ListenHosts[0] != "" {
tr.serviceEndpoint = cfg.ListenHosts[0] + ":" + cfg.Port
}
for c.Next() { // trace
var err error
args := c.RemainingArgs()
switch len(args) {
case 0:
tr.EndpointType, tr.Endpoint, err = normalizeEndpoint(defEpType, "")
case 1:
tr.EndpointType, tr.Endpoint, err = normalizeEndpoint(defEpType, args[0])
case 2:
epType := strings.ToLower(args[0])
tr.EndpointType, tr.Endpoint, err = normalizeEndpoint(epType, args[1])
default:
err = c.ArgErr()
}
if err != nil {
return tr, err
}
for c.NextBlock() {
switch c.Val() {
case "every":
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
tr.every, err = strconv.ParseUint(args[0], 10, 64)
if err != nil {
return nil, err
}
case "service":
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
tr.serviceName = args[0]
case "client_server":
args := c.RemainingArgs()
if len(args) > 1 {
return nil, c.ArgErr()
}
tr.clientServer = true
if len(args) == 1 {
tr.clientServer, err = strconv.ParseBool(args[0])
}
if err != nil {
return nil, err
}
case "datadog_analytics_rate":
args := c.RemainingArgs()
if len(args) > 1 {
return nil, c.ArgErr()
}
tr.datadogAnalyticsRate = 0
if len(args) == 1 {
tr.datadogAnalyticsRate, err = strconv.ParseFloat(args[0], 64)
}
if err != nil {
return nil, err
}
if tr.datadogAnalyticsRate > 1 || tr.datadogAnalyticsRate < 0 {
return nil, fmt.Errorf("datadog analytics rate must be between 0 and 1, '%f' is not supported", tr.datadogAnalyticsRate)
}
case "zipkin_max_backlog_size":
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
tr.zipkinMaxBacklogSize, err = strconv.Atoi(args[0])
if err != nil {
return nil, err
}
case "zipkin_max_batch_size":
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
tr.zipkinMaxBatchSize, err = strconv.Atoi(args[0])
if err != nil {
return nil, err
}
case "zipkin_max_batch_interval":
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
tr.zipkinMaxBatchInterval, err = time.ParseDuration(args[0])
if err != nil {
return nil, err
}
}
}
}
return tr, err
}
func normalizeEndpoint(epType, ep string) (string, string, error) {
if _, ok := supportedProviders[epType]; !ok {
return "", "", fmt.Errorf("tracing endpoint type '%s' is not supported", epType)
}
if ep == "" {
ep = supportedProviders[epType]
}
if epType == "zipkin" {
if !strings.Contains(ep, "http") {
ep = "http://" + ep + "/api/v2/spans"
}
}
return epType, ep, nil
}
var supportedProviders = map[string]string{
"zipkin": "localhost:9411",
"datadog": "localhost:8126",
}
const (
defEpType = "zipkin"
defServiceName = "coredns"
)
// Package trace implements OpenTracing-based tracing
package trace
import (
"context"
"fmt"
stdlog "log"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/pkg/dnstest"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/rcode"
_ "github.com/coredns/coredns/plugin/pkg/trace" // Plugin the trace package.
"github.com/coredns/coredns/request"
"github.com/DataDog/dd-trace-go/v2/ddtrace/ext"
"github.com/DataDog/dd-trace-go/v2/ddtrace/tracer"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
otext "github.com/opentracing/opentracing-go/ext"
otlog "github.com/opentracing/opentracing-go/log"
zipkinot "github.com/openzipkin-contrib/zipkin-go-opentracing"
"github.com/openzipkin/zipkin-go"
zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http"
)
const (
defaultTopLevelSpanName = "servedns"
metaTraceIdKey = "trace/traceid"
)
var log = clog.NewWithPlugin("trace")
type traceTags struct {
Name string
Type string
Rcode string
Proto string
Remote string
}
var tagByProvider = map[string]traceTags{
"default": {
Name: "coredns.io/name",
Type: "coredns.io/type",
Rcode: "coredns.io/rcode",
Proto: "coredns.io/proto",
Remote: "coredns.io/remote",
},
"datadog": {
Name: "coredns.io@name",
Type: "coredns.io@type",
Rcode: "coredns.io@rcode",
Proto: "coredns.io@proto",
Remote: "coredns.io@remote",
},
}
type trace struct {
count uint64 // as per Go spec, needs to be first element in a struct
Next plugin.Handler
Endpoint string
EndpointType string
zipkinTracer ot.Tracer
serviceEndpoint string
serviceName string
clientServer bool
every uint64
datadogAnalyticsRate float64
zipkinMaxBacklogSize int
zipkinMaxBatchSize int
zipkinMaxBatchInterval time.Duration
Once sync.Once
tagSet traceTags
}
func (t *trace) Tracer() ot.Tracer {
return t.zipkinTracer
}
// OnStartup sets up the tracer
func (t *trace) OnStartup() error {
var err error
t.Once.Do(func() {
switch t.EndpointType {
case "zipkin":
err = t.setupZipkin()
case "datadog":
tracer.Start(
tracer.WithAgentAddr(t.Endpoint),
tracer.WithDebugMode(clog.D.Value()),
tracer.WithGlobalTag(ext.SpanTypeDNS, true),
tracer.WithService(t.serviceName),
tracer.WithAnalyticsRate(t.datadogAnalyticsRate),
tracer.WithLogger(&loggerAdapter{log}),
)
t.tagSet = tagByProvider["datadog"]
default:
err = fmt.Errorf("unknown endpoint type: %s", t.EndpointType)
}
})
return err
}
// OnShutdown cleans up the tracer
func (t *trace) OnShutdown() error {
if t.EndpointType == "datadog" {
tracer.Stop()
}
return nil
}
func (t *trace) setupZipkin() error {
var opts []zipkinhttp.ReporterOption
opts = append(opts, zipkinhttp.Logger(stdlog.New(&loggerAdapter{log}, "", 0)))
if t.zipkinMaxBacklogSize != 0 {
opts = append(opts, zipkinhttp.MaxBacklog(t.zipkinMaxBacklogSize))
}
if t.zipkinMaxBatchSize != 0 {
opts = append(opts, zipkinhttp.BatchSize(t.zipkinMaxBatchSize))
}
if t.zipkinMaxBatchInterval != 0 {
opts = append(opts, zipkinhttp.BatchInterval(t.zipkinMaxBatchInterval))
}
reporter := zipkinhttp.NewReporter(t.Endpoint, opts...)
recorder, err := zipkin.NewEndpoint(t.serviceName, t.serviceEndpoint)
if err != nil {
log.Warningf("build Zipkin endpoint found err: %v", err)
}
tracer, err := zipkin.NewTracer(
reporter,
zipkin.WithLocalEndpoint(recorder),
zipkin.WithSharedSpans(t.clientServer),
)
if err != nil {
return err
}
t.zipkinTracer = zipkinot.Wrap(tracer)
t.tagSet = tagByProvider["default"]
return err
}
// Name implements the Handler interface.
func (t *trace) Name() string { return "trace" }
// ServeDNS implements the plugin.Handle interface.
func (t *trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
shouldTrace := false
if t.every > 0 {
queryNr := atomic.AddUint64(&t.count, 1)
if queryNr%t.every == 0 {
shouldTrace = true
}
}
if t.EndpointType == "datadog" {
return t.serveDNSDatadog(ctx, w, r, shouldTrace)
}
return t.serveDNSZipkin(ctx, w, r, shouldTrace)
}
func (t *trace) serveDNSDatadog(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, shouldTrace bool) (int, error) {
if !shouldTrace {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
span, spanCtx := tracer.StartSpanFromContext(ctx, defaultTopLevelSpanName)
defer span.Finish()
metadata.SetValueFunc(ctx, metaTraceIdKey, func() string { return span.Context().TraceID() })
req := request.Request{W: w, Req: r}
rw := dnstest.NewRecorder(w)
status, err := plugin.NextOrFailure(t.Name(), t.Next, spanCtx, rw, r)
t.setDatadogSpanTags(span, req, rw, status, err)
return status, err
}
func (t *trace) serveDNSZipkin(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, shouldTrace bool) (int, error) {
span := ot.SpanFromContext(ctx)
if !shouldTrace || span != nil {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
var spanCtx ot.SpanContext
if val := ctx.Value(dnsserver.HTTPRequestKey{}); val != nil {
if httpReq, ok := val.(*http.Request); ok {
spanCtx, _ = t.Tracer().Extract(ot.HTTPHeaders, ot.HTTPHeadersCarrier(httpReq.Header))
}
}
req := request.Request{W: w, Req: r}
span = t.Tracer().StartSpan(defaultTopLevelSpanName, otext.RPCServerOption(spanCtx))
defer span.Finish()
if spanCtx, ok := span.Context().(zipkinot.SpanContext); ok {
metadata.SetValueFunc(ctx, metaTraceIdKey, func() string { return spanCtx.TraceID.String() })
}
rw := dnstest.NewRecorder(w)
ctx = ot.ContextWithSpan(ctx, span)
status, err := plugin.NextOrFailure(t.Name(), t.Next, ctx, rw, r)
t.setZipkinSpanTags(span, req, rw, status, err)
return status, err
}
// setDatadogSpanTags sets span tags using DataDog v2 API
func (t *trace) setDatadogSpanTags(span *tracer.Span, req request.Request, rw *dnstest.Recorder, status int, err error) {
span.SetTag(t.tagSet.Name, req.Name())
span.SetTag(t.tagSet.Type, req.Type())
span.SetTag(t.tagSet.Proto, req.Proto())
span.SetTag(t.tagSet.Remote, req.IP())
rc := rw.Rcode
if !plugin.ClientWrite(status) {
rc = status
}
span.SetTag(t.tagSet.Rcode, rcode.ToString(rc))
if err != nil {
span.SetTag("error.message", err.Error())
span.SetTag("error", true)
span.SetTag("error.type", "dns_error")
}
}
// setZipkinSpanTags sets span tags for Zipkin/OpenTracing spans
func (t *trace) setZipkinSpanTags(span ot.Span, req request.Request, rw *dnstest.Recorder, status int, err error) {
span.SetTag(t.tagSet.Name, req.Name())
span.SetTag(t.tagSet.Type, req.Type())
span.SetTag(t.tagSet.Proto, req.Proto())
span.SetTag(t.tagSet.Remote, req.IP())
rc := rw.Rcode
if !plugin.ClientWrite(status) {
// when no response was written, fallback to status returned from next plugin as this status
// is actually used as rcode of DNS response
// see https://github.com/coredns/coredns/blob/master/core/dnsserver/server.go#L318
rc = status
}
span.SetTag(t.tagSet.Rcode, rcode.ToString(rc))
if err != nil {
// Use OpenTracing error handling
otext.Error.Set(span, true)
span.LogFields(otlog.Event("error"), otlog.Error(err))
}
}
package transfer
import (
"fmt"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/miekg/dns"
)
// Notify will send notifies to all configured to hosts IP addresses. The string zone must be lowercased.
func (t *Transfer) Notify(zone string) error {
if t == nil { // t might be nil, mostly expected in tests, so intercept and to a noop in that case
return nil
}
m := new(dns.Msg)
m.SetNotify(zone)
c := new(dns.Client)
x := longestMatch(t.xfrs, zone)
if x == nil {
// return without error if there is no matching zone
return nil
}
var err1 error
for _, t := range x.to {
if t == "*" {
continue
}
if err := sendNotify(c, m, t); err != nil {
err1 = err
}
}
log.Debugf("Sent notifies for zone %q to %v", zone, x.to)
return err1 // this only captures the last error
}
func sendNotify(c *dns.Client, m *dns.Msg, s string) error {
var err error
var ret *dns.Msg
code := dns.RcodeServerFailure
for range 3 {
ret, _, err = c.Exchange(m, s)
if err != nil {
continue
}
code = ret.Rcode
if code == dns.RcodeSuccess {
return nil
}
}
if err != nil {
return fmt.Errorf("notify for zone %q was not accepted by %q: %q", m.Question[0].Name, s, err)
}
return fmt.Errorf("notify for zone %q was not accepted by %q: rcode was %q", m.Question[0].Name, s, rcode.ToString(code))
}
package transfer
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/coredns/coredns/plugin/pkg/transport"
)
func init() {
caddy.RegisterPlugin("transfer", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
t, err := parseTransfer(c)
if err != nil {
return plugin.Error("transfer", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
t.Next = next
return t
})
c.OnStartup(func() error {
config := dnsserver.GetConfig(c)
t.tsigSecret = config.TsigSecret
// find all plugins that implement Transferer and add them to Transferers
plugins := config.Handlers()
for _, pl := range plugins {
tr, ok := pl.(Transferer)
if !ok {
continue
}
t.Transferers = append(t.Transferers, tr)
}
return nil
})
return nil
}
func parseTransfer(c *caddy.Controller) (*Transfer, error) {
t := &Transfer{}
for c.Next() {
x := &xfr{}
x.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
for c.NextBlock() {
switch c.Val() {
case "to":
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
for _, host := range args {
if host == "*" {
x.to = append(x.to, host)
continue
}
normalized, err := parse.HostPort(host, transport.Port)
if err != nil {
return nil, err
}
x.to = append(x.to, normalized)
}
default:
return nil, plugin.Error("transfer", c.Errf("unknown property %q", c.Val()))
}
}
if len(x.to) == 0 {
return nil, plugin.Error("transfer", c.Err("'to' is required"))
}
t.xfrs = append(t.xfrs, x)
}
return t, nil
}
package transfer
import (
"context"
"errors"
"net"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("transfer")
// Transfer is a plugin that handles zone transfers.
type Transfer struct {
Transferers []Transferer // List of plugins that implement Transferer
xfrs []*xfr
tsigSecret map[string]string
Next plugin.Handler
}
type xfr struct {
Zones []string
to []string
}
// Transferer may be implemented by plugins to enable zone transfers
type Transferer interface {
// Transfer returns a channel to which it writes responses to the transfer request.
// If the plugin is not authoritative for the zone, it should immediately return the
// transfer.ErrNotAuthoritative error. This is important otherwise the transfer plugin can
// use plugin X while it should transfer the data from plugin Y.
//
// If serial is 0, handle as an AXFR request. Transfer should send all records
// in the zone to the channel. The SOA should be written to the channel first, followed
// by all other records, including all NS + glue records. The implementation is also responsible
// for sending the last SOA record (to signal end of the transfer). This plugin will just grab
// these records and send them back to the requester, there is little validation done.
//
// If serial is not 0, it will be handled as an IXFR request. If the serial is equal to or greater (newer) than
// the current serial for the zone, send a single SOA record to the channel and then close it.
// If the serial is less (older) than the current serial for the zone, perform an AXFR fallback
// by proceeding as if an AXFR was requested (as above).
Transfer(zone string, serial uint32) (<-chan []dns.RR, error)
}
var (
// ErrNotAuthoritative is returned by Transfer() when the plugin is not authoritative for the zone.
ErrNotAuthoritative = errors.New("not authoritative for zone")
)
// ServeDNS implements the plugin.Handler interface.
func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if state.QType() != dns.TypeAXFR && state.QType() != dns.TypeIXFR {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
if state.Proto() != "tcp" {
return dns.RcodeRefused, nil
}
x := longestMatch(t.xfrs, state.QName())
if x == nil {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
if !x.allowed(state) {
// write msg here, so logging will pick it up
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(m)
return 0, nil
}
// Get serial from request if this is an IXFR.
var serial uint32
if state.QType() == dns.TypeIXFR {
if len(r.Ns) != 1 {
return dns.RcodeServerFailure, nil
}
soa, ok := r.Ns[0].(*dns.SOA)
if !ok {
return dns.RcodeServerFailure, nil
}
serial = soa.Serial
}
// Get a receiving channel from the first Transferer plugin that returns one.
var pchan <-chan []dns.RR
var err error
for _, p := range t.Transferers {
pchan, err = p.Transfer(state.QName(), serial)
if err == ErrNotAuthoritative {
// plugin was not authoritative for the zone, try next plugin
continue
}
if err != nil {
return dns.RcodeServerFailure, err
}
break
}
if pchan == nil {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
// Send response to client
ch := make(chan *dns.Envelope)
tr := new(dns.Transfer)
if r.IsTsig() != nil {
tr.TsigSecret = t.tsigSecret
}
errCh := make(chan error)
go func() {
if err := tr.Out(w, r, ch); err != nil {
errCh <- err
}
close(errCh)
}()
rrs := []dns.RR{}
l := 0
var soa *dns.SOA
for records := range pchan {
if x, ok := records[0].(*dns.SOA); ok && soa == nil {
soa = x
}
rrs = append(rrs, records...)
if len(rrs) > 500 {
select {
case ch <- &dns.Envelope{RR: rrs}:
case err := <-errCh:
// Client errored; drain pchan to avoid blocking the producer goroutine.
go func() {
for range pchan {
}
}()
return dns.RcodeServerFailure, err
}
l += len(rrs)
rrs = []dns.RR{}
}
}
// if we are here and we only hold 1 soa (len(rrs) == 1) and soa != nil, and IXFR fallback should
// be performed. We haven't send anything on ch yet, so that can be closed (and waited for), and we only
// need to return the SOA back to the client and return.
if len(rrs) == 1 && soa != nil { // soa should never be nil...
close(ch)
err := <-errCh
if err != nil {
return dns.RcodeServerFailure, err
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = []dns.RR{soa}
w.WriteMsg(m)
log.Infof("Outgoing noop, incremental transfer for up to date zone %q to %s for %d SOA serial", state.QName(), state.IP(), soa.Serial)
return 0, nil
}
if len(rrs) > 0 {
ch <- &dns.Envelope{RR: rrs}
l += len(rrs)
}
close(ch) // Even though we close the channel here, we still have
err = <-errCh // to wait before we can return and close the connection.
if err != nil {
return dns.RcodeServerFailure, err
}
logserial := uint32(0)
if soa != nil {
logserial = soa.Serial
}
log.Infof("Outgoing transfer of %d records of zone %q to %s for %d SOA serial", l, state.QName(), state.IP(), logserial)
return 0, nil
}
func (x xfr) allowed(state request.Request) bool {
for _, h := range x.to {
if h == "*" {
return true
}
to, _, err := net.SplitHostPort(h)
if err != nil {
return false
}
// If remote IP matches we accept. TODO(): make this works with ranges
if to == state.IP() {
return true
}
}
return false
}
// Find the first transfer instance for which the queried zone is the longest match. When nothing
// is found nil is returned.
func longestMatch(xfrs []*xfr, name string) *xfr {
// TODO(xxx): optimize and make it a map (or maps)
var x *xfr
zone := "" // longest zone match wins
for _, xfr := range xfrs {
if z := plugin.Zones(xfr.Zones).Matches(name); z != "" {
if z > zone {
zone = z
x = xfr
}
}
}
return x
}
// Name implements the Handler interface.
func (Transfer) Name() string { return "transfer" }
package tsig
import (
"bufio"
"fmt"
"io"
"os"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
func init() {
caddy.RegisterPlugin(pluginName, caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
t, err := parse(c)
if err != nil {
return plugin.Error(pluginName, c.ArgErr())
}
config := dnsserver.GetConfig(c)
config.TsigSecret = t.secrets
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
t.Next = next
return t
})
return nil
}
func parse(c *caddy.Controller) (*TSIGServer, error) {
t := &TSIGServer{
secrets: make(map[string]string),
types: defaultQTypes,
}
for i := 0; c.Next(); i++ {
if i > 0 {
return nil, plugin.ErrOnce
}
t.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
for c.NextBlock() {
switch c.Val() {
case "secret":
args := c.RemainingArgs()
if len(args) != 2 {
return nil, c.ArgErr()
}
k := plugin.Name(args[0]).Normalize()
if _, exists := t.secrets[k]; exists {
return nil, fmt.Errorf("key %q redefined", k)
}
t.secrets[k] = args[1]
case "secrets":
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
f, err := os.Open(args[0])
if err != nil {
return nil, err
}
secrets, err := parseKeyFile(f)
if err != nil {
return nil, err
}
for k, s := range secrets {
if _, exists := t.secrets[k]; exists {
return nil, fmt.Errorf("key %q redefined", k)
}
t.secrets[k] = s
}
case "require":
t.types = qTypes{}
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
if args[0] == "all" {
t.all = true
continue
}
if args[0] == "none" {
continue
}
for _, str := range args {
qt, ok := dns.StringToType[str]
if !ok {
return nil, c.Errf("unknown query type '%s'", str)
}
t.types[qt] = struct{}{}
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
}
return t, nil
}
func parseKeyFile(f io.Reader) (map[string]string, error) {
secrets := make(map[string]string)
s := bufio.NewScanner(f)
for s.Scan() {
fields := strings.Fields(s.Text())
if len(fields) == 0 {
continue
}
if fields[0] != "key" {
return nil, fmt.Errorf("unexpected token %q", fields[0])
}
if len(fields) < 2 {
return nil, fmt.Errorf("expected key name %q", s.Text())
}
key := strings.Trim(fields[1], "\"{")
if len(key) == 0 {
return nil, fmt.Errorf("expected key name %q", s.Text())
}
key = plugin.Name(key).Normalize()
if _, ok := secrets[key]; ok {
return nil, fmt.Errorf("key %q redefined", key)
}
key:
for s.Scan() {
fields := strings.Fields(s.Text())
if len(fields) == 0 {
continue
}
switch fields[0] {
case "algorithm":
continue
case "secret":
if len(fields) < 2 {
return nil, fmt.Errorf("expected secret key %q", s.Text())
}
secret := strings.Trim(fields[1], "\";")
if len(secret) == 0 {
return nil, fmt.Errorf("expected secret key %q", s.Text())
}
secrets[key] = secret
case "}":
fallthrough
case "};":
break key
default:
return nil, fmt.Errorf("unexpected token %q", fields[0])
}
}
if _, ok := secrets[key]; !ok {
return nil, fmt.Errorf("expected secret for key %q", key)
}
}
return secrets, nil
}
var defaultQTypes = qTypes{}
package tsig
import (
"context"
"encoding/binary"
"encoding/hex"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// TSIGServer verifies tsig status and adds tsig to responses
type TSIGServer struct {
Zones []string
secrets map[string]string // [key-name]secret
types qTypes
all bool
Next plugin.Handler
}
type qTypes map[uint16]struct{}
// Name implements plugin.Handler
func (t TSIGServer) Name() string { return pluginName }
// ServeDNS implements plugin.Handler
func (t *TSIGServer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
var err error
state := request.Request{Req: r, W: w}
if z := plugin.Zones(t.Zones).Matches(state.Name()); z == "" {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
var tsigRR = r.IsTsig()
rcode := dns.RcodeSuccess
if !t.tsigRequired(state.QType()) && tsigRR == nil {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
if tsigRR == nil {
log.Debugf("rejecting '%s' request without TSIG\n", dns.TypeToString[state.QType()])
rcode = dns.RcodeRefused
}
// wrap the response writer so the response will be TSIG signed.
w = &restoreTsigWriter{w, r, tsigRR}
tsigStatus := w.TsigStatus()
if tsigStatus != nil {
log.Debugf("TSIG validation failed: %v %v", dns.TypeToString[state.QType()], tsigStatus)
rcode = dns.RcodeNotAuth
switch tsigStatus {
case dns.ErrSecret:
tsigRR.Error = dns.RcodeBadKey
case dns.ErrTime:
tsigRR.Error = dns.RcodeBadTime
default:
tsigRR.Error = dns.RcodeBadSig
}
resp := new(dns.Msg).SetRcode(r, rcode)
w.WriteMsg(resp)
return dns.RcodeSuccess, nil
}
// strip the TSIG RR. Next, and subsequent plugins will not see the TSIG RRs.
// This violates forwarding cases (RFC 8945 5.5). See README.md Bugs
if len(r.Extra) > 1 {
r.Extra = r.Extra[0 : len(r.Extra)-1]
} else {
r.Extra = []dns.RR{}
}
if rcode == dns.RcodeSuccess {
rcode, err = plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
if err != nil {
log.Errorf("request handler returned an error: %v\n", err)
}
}
// If the plugin chain result was not an error, restore the TSIG and write the response.
if !plugin.ClientWrite(rcode) {
resp := new(dns.Msg).SetRcode(r, rcode)
w.WriteMsg(resp)
}
return dns.RcodeSuccess, nil
}
func (t *TSIGServer) tsigRequired(qtype uint16) bool {
if t.all {
return true
}
if _, ok := t.types[qtype]; ok {
return true
}
return false
}
// restoreTsigWriter Implement Response Writer, and adds a TSIG RR to a response
type restoreTsigWriter struct {
dns.ResponseWriter
req *dns.Msg // original request excluding TSIG if it has one
reqTSIG *dns.TSIG // original TSIG
}
// WriteMsg adds a TSIG RR to the response
func (r *restoreTsigWriter) WriteMsg(m *dns.Msg) error {
// Make sure the response has an EDNS OPT RR if the request had it.
// Otherwise ScrubWriter would append it *after* TSIG, making it a non-compliant DNS message.
state := request.Request{Req: r.req, W: r.ResponseWriter}
state.SizeAndDo(m)
repTSIG := m.IsTsig()
if r.reqTSIG != nil && repTSIG == nil {
repTSIG = new(dns.TSIG)
repTSIG.Hdr = dns.RR_Header{Name: r.reqTSIG.Hdr.Name, Rrtype: dns.TypeTSIG, Class: dns.ClassANY}
repTSIG.Algorithm = r.reqTSIG.Algorithm
repTSIG.OrigId = m.Id
repTSIG.Error = r.reqTSIG.Error
repTSIG.MAC = r.reqTSIG.MAC
repTSIG.MACSize = r.reqTSIG.MACSize
if repTSIG.Error == dns.RcodeBadTime {
// per RFC 8945 5.2.3. client time goes into TimeSigned, server time in OtherData, OtherLen = 6 ...
repTSIG.TimeSigned = r.reqTSIG.TimeSigned
b := make([]byte, 8)
// TimeSigned is network byte order.
binary.BigEndian.PutUint64(b, uint64(time.Now().Unix()))
// truncate to 48 least significant bits (network order 6 rightmost bytes)
repTSIG.OtherData = hex.EncodeToString(b[2:])
repTSIG.OtherLen = 6
}
m.Extra = append(m.Extra, repTSIG)
}
return r.ResponseWriter.WriteMsg(m)
}
const pluginName = "tsig"
package view
import (
"context"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/request"
)
// Metadata implements the metadata.Provider interface.
func (v *View) Metadata(ctx context.Context, state request.Request) context.Context {
metadata.SetValueFunc(ctx, "view/name", func() string {
return v.viewName
})
return ctx
}
package view
import (
"context"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/expression"
"github.com/expr-lang/expr"
)
func init() { plugin.Register("view", setup) }
func setup(c *caddy.Controller) error {
cond, err := parse(c)
if err != nil {
return plugin.Error("view", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
cond.Next = next
return cond
})
return nil
}
func parse(c *caddy.Controller) (*View, error) {
v := new(View)
i := 0
for c.Next() {
i++
if i > 1 {
return nil, plugin.ErrOnce
}
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
v.viewName = args[0]
for c.NextBlock() {
switch c.Val() {
case "expr":
args := c.RemainingArgs()
prog, err := expr.Compile(strings.Join(args, " "), expr.Env(expression.DefaultEnv(context.Background(), nil)), expr.DisableBuiltin("type"))
if err != nil {
return v, err
}
v.progs = append(v.progs, prog)
continue
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
}
return v, nil
}
package view
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/expression"
"github.com/coredns/coredns/request"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
"github.com/miekg/dns"
)
// View is a plugin that enables configuring expression based advanced routing
type View struct {
progs []*vm.Program
viewName string
Next plugin.Handler
}
// Filter implements dnsserver.Viewer. It returns true if all View rules evaluate to true for the given state.
func (v *View) Filter(ctx context.Context, state *request.Request) bool {
env := expression.DefaultEnv(ctx, state)
for _, prog := range v.progs {
result, err := expr.Run(prog, env)
if err != nil {
return false
}
if b, ok := result.(bool); ok && b {
continue
}
// anything other than a boolean true result is considered false
return false
}
return true
}
// ViewName implements dnsserver.Viewer. It returns the view name
func (v *View) ViewName() string { return v.viewName }
// Name implements the Handler interface
func (*View) Name() string { return "view" }
// ServeDNS implements the Handler interface.
func (v *View) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return plugin.NextOrFailure(v.Name(), v.Next, ctx, w, r)
}
//go:build gofuzz
package whoami
import (
"github.com/coredns/coredns/plugin/pkg/fuzz"
)
// Fuzz fuzzes cache.
func Fuzz(data []byte) int {
w := Whoami{}
return fuzz.Do(w, data)
}
package whoami
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("whoami", setup) }
func setup(c *caddy.Controller) error {
c.Next() // 'whoami'
if c.NextArg() {
return plugin.Error("whoami", c.ArgErr())
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Whoami{}
})
return nil
}
// Package whoami implements a plugin that returns details about the resolving
// querying it.
package whoami
import (
"context"
"net"
"strconv"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
const name = "whoami"
// Whoami is a plugin that returns your IP address, port and the protocol used for connecting
// to CoreDNS.
type Whoami struct{}
// ServeDNS implements the plugin.Handler interface.
func (wh Whoami) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
a := new(dns.Msg)
a.SetReply(r)
a.Authoritative = true
ip := state.IP()
var rr dns.RR
switch state.Family() {
case 1:
rr = new(dns.A)
rr.(*dns.A).Hdr = dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeA, Class: state.QClass()}
rr.(*dns.A).A = net.ParseIP(ip).To4()
case 2:
rr = new(dns.AAAA)
rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeAAAA, Class: state.QClass()}
rr.(*dns.AAAA).AAAA = net.ParseIP(ip)
}
srv := new(dns.SRV)
srv.Hdr = dns.RR_Header{Name: "_" + state.Proto() + "." + state.QName(), Rrtype: dns.TypeSRV, Class: state.QClass()}
if state.QName() == "." {
srv.Hdr.Name = "_" + state.Proto() + state.QName()
}
port, _ := strconv.ParseUint(state.Port(), 10, 16)
srv.Port = uint16(port)
srv.Target = "."
a.Extra = []dns.RR{rr, srv}
w.WriteMsg(a)
return 0, nil
}
// Name implements the Handler interface.
func (wh Whoami) Name() string { return name }
package request
import (
"github.com/coredns/coredns/plugin/pkg/edns"
"github.com/miekg/dns"
)
func supportedOptions(o []dns.EDNS0) []dns.EDNS0 {
var supported = make([]dns.EDNS0, 0, 3)
// For as long as possible try avoid looking up in the map, because that need an Rlock.
for _, opt := range o {
switch code := opt.Option(); code {
case dns.EDNS0NSID:
fallthrough
case dns.EDNS0EXPIRE:
fallthrough
case dns.EDNS0COOKIE:
fallthrough
case dns.EDNS0TCPKEEPALIVE:
fallthrough
case dns.EDNS0PADDING:
supported = append(supported, opt)
default:
if edns.SupportedOption(code) {
supported = append(supported, opt)
}
}
}
return supported
}
// Package request abstracts a client's request so that all plugins will handle them in an unified way.
package request
import (
"net"
"strings"
"github.com/coredns/coredns/plugin/pkg/edns"
"github.com/miekg/dns"
)
// Request contains some connection state and is useful in plugin.
type Request struct {
Req *dns.Msg
W dns.ResponseWriter
// Optional lowercased zone of this query.
Zone string
// Cache size after first call to Size or Do. If size is zero nothing has been cached yet.
// Both Size and Do set these values (and cache them).
size uint16 // UDP buffer size, or 64K in case of TCP.
do bool // DNSSEC OK value
// Caches
family int8 // transport's family.
name string // lowercase qname.
ip string // client's ip.
port string // client's port.
localPort string // server's port.
localIP string // server's ip.
}
// NewWithQuestion returns a new request based on the old, but with a new question
// section in the request.
func (r *Request) NewWithQuestion(name string, typ uint16) Request {
req1 := Request{W: r.W, Req: r.Req.Copy()}
req1.Req.Question[0] = dns.Question{Name: dns.Fqdn(name), Qclass: dns.ClassINET, Qtype: typ}
return req1
}
// IP gets the (remote) IP address of the client making the request.
func (r *Request) IP() string {
if r.ip != "" {
return r.ip
}
ip, _, err := net.SplitHostPort(r.W.RemoteAddr().String())
if err != nil {
r.ip = r.W.RemoteAddr().String()
return r.ip
}
r.ip = ip
return r.ip
}
// LocalIP gets the (local) IP address of server handling the request.
func (r *Request) LocalIP() string {
if r.localIP != "" {
return r.localIP
}
ip, _, err := net.SplitHostPort(r.W.LocalAddr().String())
if err != nil {
r.localIP = r.W.LocalAddr().String()
return r.localIP
}
r.localIP = ip
return r.localIP
}
// Port gets the (remote) port of the client making the request.
func (r *Request) Port() string {
if r.port != "" {
return r.port
}
_, port, err := net.SplitHostPort(r.W.RemoteAddr().String())
if err != nil {
r.port = "0"
return r.port
}
r.port = port
return r.port
}
// LocalPort gets the local port of the server handling the request.
func (r *Request) LocalPort() string {
if r.localPort != "" {
return r.localPort
}
_, port, err := net.SplitHostPort(r.W.LocalAddr().String())
if err != nil {
r.localPort = "0"
return r.localPort
}
r.localPort = port
return r.localPort
}
// RemoteAddr returns the net.Addr of the client that sent the current request.
func (r *Request) RemoteAddr() string { return r.W.RemoteAddr().String() }
// LocalAddr returns the net.Addr of the server handling the current request.
func (r *Request) LocalAddr() string { return r.W.LocalAddr().String() }
// Proto gets the protocol used as the transport. This will be udp or tcp.
func (r *Request) Proto() string {
if _, ok := r.W.RemoteAddr().(*net.UDPAddr); ok {
return "udp"
}
if _, ok := r.W.RemoteAddr().(*net.TCPAddr); ok {
return "tcp"
}
return "udp"
}
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
func (r *Request) Family() int {
if r.family != 0 {
return int(r.family)
}
var a net.IP
ip := r.W.RemoteAddr()
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
if i, ok := ip.(*net.TCPAddr); ok {
a = i.IP
}
if a.To4() != nil {
r.family = 1
return 1
}
r.family = 2
return 2
}
// Do returns true if the request has the DO (DNSSEC OK) bit set.
func (r *Request) Do() bool {
if r.size != 0 {
return r.do
}
r.Size()
return r.do
}
// Len returns the length in bytes in the request.
func (r *Request) Len() int { return r.Req.Len() }
// Size returns if buffer size *advertised* in the requests OPT record.
// Or when the request was over TCP, we return the maximum allowed size of 64K.
func (r *Request) Size() int {
if r.size != 0 {
return int(r.size)
}
size := uint16(0)
if o := r.Req.IsEdns0(); o != nil {
r.do = o.Do()
size = o.UDPSize()
}
// normalize size
size = edns.Size(r.Proto(), size)
r.size = size
return int(size)
}
// SizeAndDo adds an OPT record that the reflects the intent from request.
// The returned bool indicates if an record was found and normalised.
func (r *Request) SizeAndDo(m *dns.Msg) bool {
o := r.Req.IsEdns0()
if o == nil {
return false
}
if mo := m.IsEdns0(); mo != nil {
mo.Hdr.Name = "."
mo.Hdr.Rrtype = dns.TypeOPT
mo.SetVersion(0)
mo.SetUDPSize(o.UDPSize())
mo.Hdr.Ttl &= 0xff00 // clear flags
// Assume if the message m has options set, they are OK and represent what an upstream can do.
if o.Do() {
mo.SetDo()
}
return true
}
// Reuse the request's OPT record and tack it to m.
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetVersion(0)
o.Hdr.Ttl &= 0xff00 // clear flags
if len(o.Option) > 0 {
o.Option = supportedOptions(o.Option)
}
m.Extra = append(m.Extra, o)
return true
}
// Scrub scrubs the reply message so that it will fit the client's buffer. It will first
// check if the reply fits without compression and then *with* compression.
// Note, the TC bit will be set regardless of protocol, even TCP message will
// get the bit, the client should then retry with pigeons.
func (r *Request) Scrub(reply *dns.Msg) *dns.Msg {
reply.Truncate(r.Size())
if reply.Compress {
return reply
}
if r.Proto() == "udp" {
rl := reply.Len()
// Last ditch attempt to avoid fragmentation, if the size is bigger than the v4/v6 UDP fragmentation
// limit and sent via UDP compress it (in the hope we go under that limit). Limits taken from NSD:
//
// .., 1480 (EDNS/IPv4), 1220 (EDNS/IPv6), or the advertised EDNS buffer size if that is
// smaller than the EDNS default.
// See: https://open.nlnetlabs.nl/pipermail/nsd-users/2011-November/001278.html
if rl > 1480 && r.Family() == 1 {
reply.Compress = true
}
if rl > 1220 && r.Family() == 2 {
reply.Compress = true
}
}
return reply
}
// Type returns the type of the question as a string. If the request is malformed the empty string is returned.
func (r *Request) Type() string {
if r.Req == nil {
return ""
}
if len(r.Req.Question) == 0 {
return ""
}
return dns.Type(r.Req.Question[0].Qtype).String()
}
// QType returns the type of the question as an uint16. If the request is malformed
// 0 is returned.
func (r *Request) QType() uint16 {
if r.Req == nil {
return 0
}
if len(r.Req.Question) == 0 {
return 0
}
return r.Req.Question[0].Qtype
}
// Name returns the name of the question in the request. Note
// this name will always have a closing dot and will be lower cased. After a call Name
// the value will be cached. To clear this caching call Clear.
// If the request is malformed the root zone is returned.
func (r *Request) Name() string {
if r.name != "" {
return r.name
}
if r.Req == nil {
r.name = "."
return "."
}
if len(r.Req.Question) == 0 {
r.name = "."
return "."
}
r.name = strings.ToLower(dns.Name(r.Req.Question[0].Name).String())
return r.name
}
// QName returns the name of the question in the request.
// If the request is malformed the root zone is returned.
func (r *Request) QName() string {
if r.Req == nil {
return "."
}
if len(r.Req.Question) == 0 {
return "."
}
return dns.Name(r.Req.Question[0].Name).String()
}
// Class returns the class of the question in the request.
// If the request is malformed the empty string is returned.
func (r *Request) Class() string {
if r.Req == nil {
return ""
}
if len(r.Req.Question) == 0 {
return ""
}
return dns.Class(r.Req.Question[0].Qclass).String()
}
// QClass returns the class of the question in the request.
// If the request is malformed 0 returned.
func (r *Request) QClass() uint16 {
if r.Req == nil {
return 0
}
if len(r.Req.Question) == 0 {
return 0
}
return r.Req.Question[0].Qclass
}
// Clear clears all caching from Request s.
func (r *Request) Clear() {
r.name = ""
r.ip = ""
r.localIP = ""
r.port = ""
r.localPort = ""
r.family = 0
r.size = 0
r.do = false
}
// Match checks if the reply matches the qname and qtype from the request, it returns
// false when they don't match.
func (r *Request) Match(reply *dns.Msg) bool {
if len(reply.Question) != 1 {
return false
}
if !reply.Response {
return false
}
if strings.ToLower(reply.Question[0].Name) != r.Name() {
return false
}
if reply.Question[0].Qtype != r.QType() {
return false
}
return true
}
package request
import "github.com/miekg/dns"
// ScrubWriter will, when writing the message, call scrub to make it fit the client's buffer.
type ScrubWriter struct {
dns.ResponseWriter
req *dns.Msg // original request
}
// NewScrubWriter returns a new and initialized ScrubWriter.
func NewScrubWriter(req *dns.Msg, w dns.ResponseWriter) *ScrubWriter { return &ScrubWriter{w, req} }
// WriteMsg overrides the default implementation of the underlying dns.ResponseWriter and calls
// scrub on the message m and will then write it to the client.
func (s *ScrubWriter) WriteMsg(m *dns.Msg) error {
state := Request{Req: s.req, W: s.ResponseWriter}
state.SizeAndDo(m)
state.Scrub(m)
return s.ResponseWriter.WriteMsg(m)
}
//go:build gofuzz
package test
// Fuzz fuzzes a corefile.
func Fuzz(data []byte) int {
_, _, _, err := CoreDNSServerAndPorts(string(data))
if err != nil {
return 1
}
return 0
}
package test
import (
"sync"
"github.com/coredns/caddy"
_ "github.com/coredns/coredns/core" // Hook in CoreDNS.
"github.com/coredns/coredns/core/dnsserver"
_ "github.com/coredns/coredns/core/plugin" // Load all managed plugins in github.com/coredns/coredns.
)
var mu sync.Mutex
// CoreDNSServer returns a CoreDNS test server. It just takes a normal Corefile as input.
func CoreDNSServer(corefile string) (*caddy.Instance, error) {
mu.Lock()
defer mu.Unlock()
caddy.Quiet = true
dnsserver.Quiet = true
return caddy.Start(NewInput(corefile))
}
// CoreDNSServerStop stops a server.
func CoreDNSServerStop(i *caddy.Instance) { i.Stop() }
// CoreDNSServerPorts returns the ports the instance is listening on. The integer k indicates
// which ServerListener you want.
func CoreDNSServerPorts(i *caddy.Instance, k int) (udp, tcp string) {
srvs := i.Servers()
if len(srvs) < k+1 {
return "", ""
}
u := srvs[k].LocalAddr()
t := srvs[k].Addr()
if u != nil {
udp = u.String()
}
if t != nil {
tcp = t.String()
}
return
}
// CoreDNSServerAndPorts combines CoreDNSServer and CoreDNSServerPorts to start a CoreDNS
// server and returns the udp and tcp ports of the first instance.
func CoreDNSServerAndPorts(corefile string) (i *caddy.Instance, udp, tcp string, err error) {
i, err = CoreDNSServer(corefile)
if err != nil {
return nil, "", "", err
}
udp, tcp = CoreDNSServerPorts(i, 0)
return i, udp, tcp, nil
}
// Input implements the caddy.Input interface and acts as an easy way to use a string as a Corefile.
type Input struct {
corefile []byte
}
// NewInput returns a pointer to Input, containing the corefile string as input.
func NewInput(corefile string) *Input {
return &Input{corefile: []byte(corefile)}
}
// Body implements the Input interface.
func (i *Input) Body() []byte { return i.corefile }
// Path implements the Input interface.
func (i *Input) Path() string { return "Corefile" }
// ServerType implements the Input interface.
func (i *Input) ServerType() string { return "dns" }