package dnsserver
import (
"fmt"
"net"
"strings"
)
type zoneAddr struct {
Zone string
Port string
Transport string // dns, tls or grpc
Address string // used for bound zoneAddr - validation of overlapping
}
// String returns the string representation of z.
func (z zoneAddr) String() string {
s := z.Transport + "://" + z.Zone + ":" + z.Port
if z.Address != "" {
s += " on " + z.Address
}
return s
}
// SplitProtocolHostPort splits a full formed address like "dns://[::1]:53" into parts.
func SplitProtocolHostPort(address string) (protocol string, ip string, port string, err error) {
parts := strings.Split(address, "://")
switch len(parts) {
case 1:
ip, port, err := net.SplitHostPort(parts[0])
return "", ip, port, err
case 2:
ip, port, err := net.SplitHostPort(parts[1])
return parts[0], ip, port, err
default:
return "", "", "", fmt.Errorf("provided value is not in an address format : %s", address)
}
}
type zoneOverlap struct {
registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key
unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAddr is registered by its original key
}
func newOverlapZone() *zoneOverlap {
return &zoneOverlap{registeredAddr: make(map[zoneAddr]zoneAddr), unboundOverlap: make(map[zoneAddr]zoneAddr)}
}
// registerAndCheck adds a new zoneAddr for validation, it returns information about existing or overlapping with already registered
// we consider that an unbound address is overlapping all bound addresses for same zone, same port
func (zo *zoneOverlap) registerAndCheck(z zoneAddr) (existingZone *zoneAddr, overlappingZone *zoneAddr) {
existingZone, overlappingZone = zo.check(z)
if existingZone != nil || overlappingZone != nil {
return existingZone, overlappingZone
}
// there is no overlap, keep the current zoneAddr for future checks
zo.registeredAddr[z] = z
zo.unboundOverlap[z.unbound()] = z
return nil, nil
}
// check validates a zoneAddr for overlap without registering it
func (zo *zoneOverlap) check(z zoneAddr) (existingZone *zoneAddr, overlappingZone *zoneAddr) {
if exist, ok := zo.registeredAddr[z]; ok {
// exact same zone already registered
return &exist, nil
}
uz := z.unbound()
if already, ok := zo.unboundOverlap[uz]; ok {
if z.Address == "" {
// current is not bound to an address, but there is already another zone with a bind address registered
return nil, &already
}
if _, ok := zo.registeredAddr[uz]; ok {
// current zone is bound to an address, but there is already an overlapping zone+port with no bind address
return nil, &uz
}
}
// there is no overlap
return nil, nil
}
// unbound returns an unbound version of the zoneAddr
func (z zoneAddr) unbound() zoneAddr {
return zoneAddr{Zone: z.Zone, Address: "", Port: z.Port, Transport: z.Transport}
}
package dnsserver
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
)
// Config configuration for a single server.
type Config struct {
// The zone of the site.
Zone string
// one or several hostnames to bind the server to.
// defaults to a single empty string that denote the wildcard address
ListenHosts []string
// The port to listen on.
Port string
// The number of servers that will listen on one port.
// By default, one server will be running.
NumSockets int
// Root points to a base directory we find user defined "things".
// First consumer is the file plugin to looks for zone files in this place.
Root string
// Debug controls the panic/recover mechanism that is enabled by default.
Debug bool
// Stacktrace controls including stacktrace as part of log from recover mechanism, it is disabled by default.
Stacktrace bool
// The transport we implement, normally just "dns" over TCP/UDP, but could be
// DNS-over-TLS or DNS-over-gRPC.
Transport string
// If this function is not nil it will be used to inspect and validate
// HTTP requests. Although this isn't referenced in-tree, external plugins
// may depend on it.
HTTPRequestValidateFunc func(*http.Request) bool
// FilterFuncs is used to further filter access
// to this handler. E.g. to limit access to a reverse zone
// on a non-octet boundary, i.e. /17
FilterFuncs []FilterFunc
// ViewName is the name of the Viewer PLugin defined in the Config
ViewName string
// TLSConfig when listening for encrypted connections (gRPC, DNS-over-TLS).
TLSConfig *tls.Config
// MaxQUICStreams defines the maximum number of concurrent QUIC streams for a QUIC server.
// This is nil if not specified, allowing for a default to be used.
MaxQUICStreams *int
// MaxQUICWorkerPoolSize defines the size of the worker pool for processing QUIC streams.
// This is nil if not specified, allowing for a default to be used.
MaxQUICWorkerPoolSize *int
// Timeouts for TCP, TLS and HTTPS servers.
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
// TSIG secrets, [name]key.
TsigSecret map[string]string
// Plugin stack.
Plugin []plugin.Plugin
// Compiled plugin stack.
pluginChain plugin.Handler
// Plugin interested in announcing that they exist, so other plugin can call methods
// on them should register themselves here. The name should be the name as return by the
// Handler's Name method.
registry map[string]plugin.Handler
// firstConfigInBlock is used to reference the first config in a server block, for the
// purpose of sharing single instance of each plugin among all zones in a server block.
firstConfigInBlock *Config
// metaCollector references the first MetadataCollector plugin, if one exists
metaCollector MetadataCollector
}
// FilterFunc is a function that filters requests from the Config
type FilterFunc func(context.Context, *request.Request) bool
// keyForConfig builds a key for identifying the configs during setup time
func keyForConfig(blocIndex int, blocKeyIndex int) string {
return fmt.Sprintf("%d:%d", blocIndex, blocKeyIndex)
}
// GetConfig gets the Config that corresponds to c.
// If none exist nil is returned.
func GetConfig(c *caddy.Controller) *Config {
ctx := c.Context().(*dnsContext)
key := keyForConfig(c.ServerBlockIndex, c.ServerBlockKeyIndex)
if cfg, ok := ctx.keysToConfigs[key]; ok {
return cfg
}
// we should only get here during tests because directive
// actions typically skip the server blocks where we make
// the configs.
ctx.saveConfig(key, &Config{ListenHosts: []string{""}})
return GetConfig(c)
}
package dnsserver
import (
"net"
"net/http"
"github.com/miekg/dns"
)
// DoHWriter is a dns.ResponseWriter that adds more specific LocalAddr and RemoteAddr methods.
type DoHWriter struct {
// raddr is the remote's address. This can be optionally set.
raddr net.Addr
// laddr is our address. This can be optionally set.
laddr net.Addr
// request is the HTTP request we're currently handling.
request *http.Request
// Msg is a response to be written to the client.
Msg *dns.Msg
}
// WriteMsg stores the message to be written to the client.
func (d *DoHWriter) WriteMsg(m *dns.Msg) error {
d.Msg = m
return nil
}
// Write stores the message to be written to the client.
func (d *DoHWriter) Write(b []byte) (int, error) {
d.Msg = new(dns.Msg)
return len(b), d.Msg.Unpack(b)
}
// RemoteAddr returns the remote address.
func (d *DoHWriter) RemoteAddr() net.Addr {
return d.raddr
}
// LocalAddr returns the local address.
func (d *DoHWriter) LocalAddr() net.Addr {
return d.laddr
}
// Network no-op implementation.
func (d *DoHWriter) Network() string {
return ""
}
// Request returns the HTTP request.
func (d *DoHWriter) Request() *http.Request {
return d.request
}
// Close no-op implementation.
func (d *DoHWriter) Close() error {
return nil
}
// TsigStatus no-op implementation.
func (d *DoHWriter) TsigStatus() error {
return nil
}
// TsigTimersOnly no-op implementation.
func (d *DoHWriter) TsigTimersOnly(_ bool) {}
// Hijack no-op implementation.
func (d *DoHWriter) Hijack() {}
package dnsserver
import (
"fmt"
"regexp"
"sort"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
)
// checkZoneSyntax() checks whether the given string match 1035 Preferred Syntax or not.
// The root zone, and all reverse zones always return true even though they technically don't meet 1035 Preferred Syntax
func checkZoneSyntax(zone string) bool {
if zone == "." || dnsutil.IsReverse(zone) != 0 {
return true
}
regex1035PreferredSyntax, _ := regexp.MatchString(`^(([A-Za-z]([A-Za-z0-9-]*[A-Za-z0-9])?)\.)+$`, zone)
return regex1035PreferredSyntax
}
// startUpZones creates the text that we show when starting up:
// grpc://example.com.:1055
// example.com.:1053 on 127.0.0.1
func startUpZones(protocol, addr string, zones map[string][]*Config) string {
s := ""
keys := make([]string, len(zones))
i := 0
for k := range zones {
keys[i] = k
i++
}
sort.Strings(keys)
for _, zone := range keys {
if !checkZoneSyntax(zone) {
s += fmt.Sprintf("Warning: Domain %q does not follow RFC1035 preferred syntax\n", zone)
}
// split addr into protocol, IP and Port
_, ip, port, err := SplitProtocolHostPort(addr)
if err != nil {
// this should not happen, but we need to take care of it anyway
s += fmt.Sprintln(protocol + zone + ":" + addr)
continue
}
if ip == "" {
s += fmt.Sprintln(protocol + zone + ":" + port)
continue
}
// if the server is listening on a specific address let's make it visible in the log,
// so one can differentiate between all active listeners
s += fmt.Sprintln(protocol + zone + ":" + port + " on " + ip)
}
return s
}
package dnsserver
import (
"encoding/binary"
"errors"
"net"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
type DoQWriter struct {
localAddr net.Addr
remoteAddr net.Addr
stream *quic.Stream
Msg *dns.Msg
}
func (w *DoQWriter) Write(b []byte) (int, error) {
if w.stream == nil {
return 0, errors.New("stream is nil")
}
b = AddPrefix(b)
return w.stream.Write(b)
}
func (w *DoQWriter) WriteMsg(m *dns.Msg) error {
bytes, err := m.Pack()
if err != nil {
return err
}
_, err = w.Write(bytes)
if err != nil {
return err
}
return w.Close()
}
// Close sends the STREAM FIN signal.
// The server MUST send the response(s) on the same stream and MUST
// indicate, after the last response, through the STREAM FIN
// mechanism that no further data will be sent on that stream.
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.2-7
func (w *DoQWriter) Close() error {
if w.stream == nil {
return errors.New("stream is nil")
}
return w.stream.Close()
}
// AddPrefix adds a 2-byte prefix with the DNS message length.
func AddPrefix(b []byte) (m []byte) {
m = make([]byte, 2+len(b))
binary.BigEndian.PutUint16(m, uint16(len(b)))
copy(m[2:], b)
return m
}
// These methods implement the dns.ResponseWriter interface from Go DNS.
func (w *DoQWriter) TsigStatus() error { return nil }
func (w *DoQWriter) TsigTimersOnly(b bool) {}
func (w *DoQWriter) Hijack() {}
func (w *DoQWriter) LocalAddr() net.Addr { return w.localAddr }
func (w *DoQWriter) RemoteAddr() net.Addr { return w.remoteAddr }
func (w *DoQWriter) Network() string { return "" }
package dnsserver
import (
"fmt"
"net"
"time"
"github.com/coredns/caddy"
"github.com/coredns/caddy/caddyfile"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
const serverType = "dns"
func init() {
caddy.RegisterServerType(serverType, caddy.ServerType{
Directives: func() []string { return Directives },
DefaultInput: func() caddy.Input {
return caddy.CaddyfileInput{
Filepath: "Corefile",
Contents: []byte(".:" + Port + " {\nwhoami\nlog\n}\n"),
ServerTypeName: serverType,
}
},
NewContext: newContext,
})
}
func newContext(i *caddy.Instance) caddy.Context {
return &dnsContext{keysToConfigs: make(map[string]*Config)}
}
type dnsContext struct {
keysToConfigs map[string]*Config
// configs is the master list of all site configs.
configs []*Config
}
func (h *dnsContext) saveConfig(key string, cfg *Config) {
h.configs = append(h.configs, cfg)
h.keysToConfigs[key] = cfg
}
// Compile-time check to ensure dnsContext implements the caddy.Context interface
var _ caddy.Context = &dnsContext{}
// InspectServerBlocks make sure that everything checks out before
// executing directives and otherwise prepares the directives to
// be parsed and executed.
func (h *dnsContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
// Normalize and check all the zone names and check for duplicates
for ib, s := range serverBlocks {
// Walk the s.Keys and expand any reverse address in their proper DNS in-addr zones. If the expansions leads for
// more than one reverse zone, replace the current value and add the rest to s.Keys.
zoneAddrs := []zoneAddr{}
for ik, k := range s.Keys {
trans, k1 := parse.Transport(k) // get rid of any dns:// or other scheme.
hosts, port, err := plugin.SplitHostPort(k1)
// We need to make this a fully qualified domain name to catch all errors here and not later when
// plugin.Normalize is called again on these strings, with the prime difference being that the domain
// name is fully qualified. This was found by fuzzing where "ȶ" is deemed OK, but "ȶ." is not (might be a
// bug in miekg/dns actually). But here we were checking ȶ, which is OK, and later we barf in ȶ. leading to
// "index out of range".
for ih := range hosts {
_, _, err := plugin.SplitHostPort(dns.Fqdn(hosts[ih]))
if err != nil {
return nil, err
}
}
if err != nil {
return nil, err
}
if port == "" {
switch trans {
case transport.DNS:
port = Port
case transport.TLS:
port = transport.TLSPort
case transport.QUIC:
port = transport.QUICPort
case transport.GRPC:
port = transport.GRPCPort
case transport.HTTPS:
port = transport.HTTPSPort
}
}
if len(hosts) > 1 {
s.Keys[ik] = hosts[0] + ":" + port // replace for the first
for _, h := range hosts[1:] { // add the rest
s.Keys = append(s.Keys, h+":"+port)
}
}
for i := range hosts {
zoneAddrs = append(zoneAddrs, zoneAddr{Zone: dns.Fqdn(hosts[i]), Port: port, Transport: trans})
}
}
serverBlocks[ib].Keys = s.Keys // important to save back the new keys that are potentially created here.
var firstConfigInBlock *Config
for ik := range s.Keys {
za := zoneAddrs[ik]
s.Keys[ik] = za.String()
// Save the config to our master list, and key it for lookups.
cfg := &Config{
Zone: za.Zone,
ListenHosts: []string{""},
Port: za.Port,
Transport: za.Transport,
}
// Set reference to the first config in the current block.
// This is used later by MakeServers to share a single plugin list
// for all zones in a server block.
if ik == 0 {
firstConfigInBlock = cfg
}
cfg.firstConfigInBlock = firstConfigInBlock
keyConfig := keyForConfig(ib, ik)
h.saveConfig(keyConfig, cfg)
}
}
return serverBlocks, nil
}
// MakeServers uses the newly-created siteConfigs to create and return a list of server instances.
func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
// Copy parameters from first config in the block to all other config in the same block
propagateConfigParams(h.configs)
// we must map (group) each config to a bind address
groups, err := groupConfigsByListenAddr(h.configs)
if err != nil {
return nil, err
}
// then we create a server for each group
var servers []caddy.Server
for addr, group := range groups {
serversForGroup, err := makeServersForGroup(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, serversForGroup...)
}
// For each server config, check for View Filter plugins
for _, c := range h.configs {
// Add filters in the plugin.cfg order for consistent filter func evaluation order.
for _, d := range Directives {
if vf, ok := c.registry[d].(Viewer); ok {
if c.ViewName != "" {
return nil, fmt.Errorf("multiple views defined in server block")
}
c.ViewName = vf.ViewName()
c.FilterFuncs = append(c.FilterFuncs, vf.Filter)
}
}
}
// Verify that there is no overlap on the zones and listen addresses
// for unfiltered server configs
errValid := h.validateZonesAndListeningAddresses()
if errValid != nil {
return nil, errValid
}
return servers, nil
}
// AddPlugin adds a plugin to a site's plugin stack.
func (c *Config) AddPlugin(m plugin.Plugin) {
c.Plugin = append(c.Plugin, m)
}
// registerHandler adds a handler to a site's handler registration. Handlers
//
// use this to announce that they exist to other plugin.
func (c *Config) registerHandler(h plugin.Handler) {
if c.registry == nil {
c.registry = make(map[string]plugin.Handler)
}
// Just overwrite...
c.registry[h.Name()] = h
}
// Handler returns the plugin handler that has been added to the config under its name.
// This is useful to inspect if a certain plugin is active in this server.
// Note that this is order dependent and the order is defined in directives.go, i.e. if your plugin
// comes before the plugin you are checking; it will not be there (yet).
func (c *Config) Handler(name string) plugin.Handler {
if c.registry == nil {
return nil
}
if h, ok := c.registry[name]; ok {
return h
}
return nil
}
// Handlers returns a slice of plugins that have been registered. This can be used to
// inspect and interact with registered plugins but cannot be used to remove or add plugins.
// Note that this is order dependent and the order is defined in directives.go, i.e. if your plugin
// comes before the plugin you are checking; it will not be there (yet).
func (c *Config) Handlers() []plugin.Handler {
if c.registry == nil {
return nil
}
hs := make([]plugin.Handler, 0, len(c.registry))
for _, k := range Directives {
registry := c.Handler(k)
if registry != nil {
hs = append(hs, registry)
}
}
return hs
}
func (h *dnsContext) validateZonesAndListeningAddresses() error {
//Validate Zone and addresses
checker := newOverlapZone()
for _, conf := range h.configs {
for _, h := range conf.ListenHosts {
// Validate the overlapping of ZoneAddr
akey := zoneAddr{Transport: conf.Transport, Zone: conf.Zone, Address: h, Port: conf.Port}
var existZone, overlapZone *zoneAddr
if len(conf.FilterFuncs) > 0 {
// This config has filters. Check for overlap with other (unfiltered) configs.
existZone, overlapZone = checker.check(akey)
} else {
// This config has no filters. Check for overlap with other (unfiltered) configs,
// and register the zone to prevent subsequent zones from overlapping with it.
existZone, overlapZone = checker.registerAndCheck(akey)
}
if existZone != nil {
return fmt.Errorf("cannot serve %s - it is already defined", akey.String())
}
if overlapZone != nil {
return fmt.Errorf("cannot serve %s - zone overlap listener capacity with %v", akey.String(), overlapZone.String())
}
}
}
return nil
}
// propagateConfigParams copies the necessary parameters from first config in the block
// to all other config in the same block. Doing this results in zones
// sharing the same plugin instances and settings as other zones in
// the same block.
func propagateConfigParams(configs []*Config) {
for _, c := range configs {
c.Plugin = c.firstConfigInBlock.Plugin
c.ListenHosts = c.firstConfigInBlock.ListenHosts
c.Debug = c.firstConfigInBlock.Debug
c.Stacktrace = c.firstConfigInBlock.Stacktrace
c.NumSockets = c.firstConfigInBlock.NumSockets
// Fork TLSConfig for each encrypted connection
c.TLSConfig = c.firstConfigInBlock.TLSConfig.Clone()
c.ReadTimeout = c.firstConfigInBlock.ReadTimeout
c.WriteTimeout = c.firstConfigInBlock.WriteTimeout
c.IdleTimeout = c.firstConfigInBlock.IdleTimeout
c.TsigSecret = c.firstConfigInBlock.TsigSecret
}
}
// groupConfigsByListenAddr groups site configs by their listen
// (bind) address, so sites that use the same listener can be served
// on the same server instance. The return value maps the listen
// address (what you pass into net.Listen) to the list of site configs.
// This function does NOT vet the configs to ensure they are compatible.
func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
groups := make(map[string][]*Config)
for _, conf := range configs {
for _, h := range conf.ListenHosts {
addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(h, conf.Port))
if err != nil {
return nil, err
}
addrstr := conf.Transport + "://" + addr.String()
groups[addrstr] = append(groups[addrstr], conf)
}
}
return groups, nil
}
// makeServersForGroup creates servers for a specific transport and group.
// It creates as many servers as specified in the NumSockets configuration.
// If the NumSockets param is not specified, one server is created by default.
func makeServersForGroup(addr string, group []*Config) ([]caddy.Server, error) {
// that is impossible, but better to check
if len(group) == 0 {
return nil, fmt.Errorf("no configs for group defined")
}
// create one server by default if no NumSockets specified
numSockets := 1
if group[0].NumSockets > 0 {
numSockets = group[0].NumSockets
}
var servers []caddy.Server
for range numSockets {
// switch on addr
switch tr, _ := parse.Transport(addr); tr {
case transport.DNS:
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case transport.TLS:
s, err := NewServerTLS(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case transport.QUIC:
s, err := NewServerQUIC(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case transport.GRPC:
s, err := NewServergRPC(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case transport.HTTPS:
s, err := NewServerHTTPS(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
}
}
return servers, nil
}
// DefaultPort is the default port.
const DefaultPort = transport.Port
// These "soft defaults" are configurable by
// command line flags, etc.
var (
// Port is the port we listen on by default.
Port = DefaultPort
// GracefulTimeout is the maximum duration of a graceful shutdown.
GracefulTimeout time.Duration
)
// Package dnsserver implements all the interfaces from Caddy, so that CoreDNS can be a servertype plugin.
package dnsserver
import (
"context"
"fmt"
"net"
"runtime"
"runtime/debug"
"strings"
"sync"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
"github.com/coredns/coredns/plugin/pkg/edns"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/trace"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
)
// Server represents an instance of a server, which serves
// DNS requests at a particular address (host and port). A
// server is capable of serving numerous zones on
// the same address and the listener may be stopped for
// graceful termination (POSIX only).
type Server struct {
Addr string // Address we listen on
server [2]*dns.Server // 0 is a net.Listener, 1 is a net.PacketConn (a *UDPConn) in our case.
m sync.Mutex // protects the servers
zones map[string][]*Config // zones keyed by their address
dnsWg sync.WaitGroup // used to wait on outstanding connections
graceTimeout time.Duration // the maximum duration of a graceful shutdown
trace trace.Trace // the trace plugin for the server
debug bool // disable recover()
stacktrace bool // enable stacktrace in recover error log
classChaos bool // allow non-INET class queries
idleTimeout time.Duration // Idle timeout for TCP
readTimeout time.Duration // Read timeout for TCP
writeTimeout time.Duration // Write timeout for TCP
tsigSecret map[string]string
}
// MetadataCollector is a plugin that can retrieve metadata functions from all metadata providing plugins
type MetadataCollector interface {
Collect(context.Context, request.Request) context.Context
}
// NewServer returns a new CoreDNS server and compiles all plugins in to it. By default CH class
// queries are blocked unless queries from enableChaos are loaded.
func NewServer(addr string, group []*Config) (*Server, error) {
s := &Server{
Addr: addr,
zones: make(map[string][]*Config),
graceTimeout: 5 * time.Second,
idleTimeout: 10 * time.Second,
readTimeout: 3 * time.Second,
writeTimeout: 5 * time.Second,
tsigSecret: make(map[string]string),
}
// We have to bound our wg with one increment
// to prevent a "race condition" that is hard-coded
// into sync.WaitGroup.Wait() - basically, an add
// with a positive delta must be guaranteed to
// occur before Wait() is called on the wg.
// In a way, this kind of acts as a safety barrier.
s.dnsWg.Add(1)
for _, site := range group {
if site.Debug {
s.debug = true
log.D.Set()
}
s.stacktrace = site.Stacktrace
// append the config to the zone's configs
s.zones[site.Zone] = append(s.zones[site.Zone], site)
// set timeouts
if site.ReadTimeout != 0 {
s.readTimeout = site.ReadTimeout
}
if site.WriteTimeout != 0 {
s.writeTimeout = site.WriteTimeout
}
if site.IdleTimeout != 0 {
s.idleTimeout = site.IdleTimeout
}
// copy tsig secrets
for key, secret := range site.TsigSecret {
s.tsigSecret[key] = secret
}
// compile custom plugin for everything
var stack plugin.Handler
for i := len(site.Plugin) - 1; i >= 0; i-- {
stack = site.Plugin[i](stack)
// register the *handler* also
site.registerHandler(stack)
// If the current plugin is a MetadataCollector, bookmark it for later use. This loop traverses the plugin
// list backwards, so the first MetadataCollector plugin wins.
if mdc, ok := stack.(MetadataCollector); ok {
site.metaCollector = mdc
}
if s.trace == nil && stack.Name() == "trace" {
// we have to stash away the plugin, not the
// Tracer object, because the Tracer won't be initialized yet
if t, ok := stack.(trace.Trace); ok {
s.trace = t
}
}
// Unblock CH class queries when any of these plugins are loaded.
if _, ok := EnableChaos[stack.Name()]; ok {
s.classChaos = true
}
}
site.pluginChain = stack
}
if !s.debug {
// When reloading we need to explicitly disable debug logging if it is now disabled.
log.D.Clear()
}
return s, nil
}
// Compile-time check to ensure Server implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &Server{}
// Serve starts the server with an existing listener. It blocks until the server stops.
// This implements caddy.TCPServer interface.
func (s *Server) Serve(l net.Listener) error {
s.m.Lock()
s.server[tcp] = &dns.Server{Listener: l,
Net: "tcp",
TsigSecret: s.tsigSecret,
MaxTCPQueries: tcpMaxQueries,
ReadTimeout: s.readTimeout,
WriteTimeout: s.writeTimeout,
IdleTimeout: func() time.Duration {
return s.idleTimeout
},
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
return s.server[tcp].ActivateAndServe()
}
// ServePacket starts the server with an existing packetconn. It blocks until the server stops.
// This implements caddy.UDPServer interface.
func (s *Server) ServePacket(p net.PacketConn) error {
s.m.Lock()
s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
}), TsigSecret: s.tsigSecret}
s.m.Unlock()
return s.server[udp].ActivateAndServe()
}
// Listen implements caddy.TCPServer interface.
func (s *Server) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.DNS+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// WrapListener Listen implements caddy.GracefulServer interface.
func (s *Server) WrapListener(ln net.Listener) net.Listener {
return ln
}
// ListenPacket implements caddy.UDPServer interface.
func (s *Server) ListenPacket() (net.PacketConn, error) {
p, err := reuseport.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):])
if err != nil {
return nil, err
}
return p, nil
}
// Stop stops the server. It blocks until the server is
// totally stopped. On POSIX systems, it will wait for
// connections to close (up to a max timeout of a few
// seconds); on Windows it will close the listener
// immediately.
// This implements Caddy.Stopper interface.
func (s *Server) Stop() (err error) {
if runtime.GOOS != "windows" {
// force connections to close after timeout
done := make(chan struct{})
go func() {
s.dnsWg.Done() // decrement our initial increment used as a barrier
s.dnsWg.Wait()
close(done)
}()
// Wait for remaining connections to finish or
// force them all to close after timeout
select {
case <-time.After(s.graceTimeout):
case <-done:
}
}
// Close the listener now; this stops the server without delay
s.m.Lock()
for _, s1 := range s.server {
// We might not have started and initialized the full set of servers
if s1 != nil {
err = s1.Shutdown()
}
}
s.m.Unlock()
return
}
// Address together with Stop() implement caddy.GracefulServer.
func (s *Server) Address() string { return s.Addr }
// ServeDNS is the entry point for every request to the address that
// is bound to. It acts as a multiplexer for the requests zonename as
// defined in the request so that the correct zone
// (configuration and plugin stack) will handle the request.
func (s *Server) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
// The default dns.Mux checks the question section size, but we have our
// own mux here. Check if we have a question section. If not drop them here.
if r == nil || len(r.Question) == 0 {
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeServerFailure)
return
}
if !s.debug {
defer func() {
// In case the user doesn't enable error plugin, we still
// need to make sure that we stay alive up here
if rec := recover(); rec != nil {
if s.stacktrace {
log.Errorf("Recovered from panic in server: %q %v\n%s", s.Addr, rec, string(debug.Stack()))
} else {
log.Errorf("Recovered from panic in server: %q %v", s.Addr, rec)
}
vars.Panic.Inc()
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeServerFailure)
}
}()
}
if !s.classChaos && r.Question[0].Qclass != dns.ClassINET {
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
return
}
if m, err := edns.Version(r); err != nil { // Wrong EDNS version, return at once.
w.WriteMsg(m)
return
}
// Wrap the response writer in a ScrubWriter so we automatically make the reply fit in the client's buffer.
w = request.NewScrubWriter(r, w)
q := strings.ToLower(r.Question[0].Name)
var (
off int
end bool
dshandler *Config
)
for {
if z, ok := s.zones[q[off:]]; ok {
for _, h := range z {
if h.pluginChain == nil { // zone defined, but has not got any plugins
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
return
}
if h.metaCollector != nil {
// Collect metadata now, so it can be used before we send a request down the plugin chain.
ctx = h.metaCollector.Collect(ctx, request.Request{Req: r, W: w})
}
// If all filter funcs pass, use this config.
if passAllFilterFuncs(ctx, h.FilterFuncs, &request.Request{Req: r, W: w}) {
if h.ViewName != "" {
// if there was a view defined for this Config, set the view name in the context
ctx = context.WithValue(ctx, ViewKey{}, h.ViewName)
}
if r.Question[0].Qtype != dns.TypeDS {
rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
if !plugin.ClientWrite(rcode) {
errorFunc(s.Addr, w, r, rcode)
}
return
}
// The type is DS, keep the handler, but keep on searching as maybe we are serving
// the parent as well and the DS should be routed to it - this will probably *misroute* DS
// queries to a possibly grand parent, but there is no way for us to know at this point
// if there is an actual delegation from grandparent -> parent -> zone.
// In all fairness: direct DS queries should not be needed.
dshandler = h
}
}
}
off, end = dns.NextLabel(q, off)
if end {
break
}
}
if r.Question[0].Qtype == dns.TypeDS && dshandler != nil && dshandler.pluginChain != nil {
// DS request, and we found a zone, use the handler for the query.
rcode, _ := dshandler.pluginChain.ServeDNS(ctx, w, r)
if !plugin.ClientWrite(rcode) {
errorFunc(s.Addr, w, r, rcode)
}
return
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if z, ok := s.zones["."]; ok {
for _, h := range z {
if h.pluginChain == nil {
continue
}
if h.metaCollector != nil {
// Collect metadata now, so it can be used before we send a request down the plugin chain.
ctx = h.metaCollector.Collect(ctx, request.Request{Req: r, W: w})
}
// If all filter funcs pass, use this config.
if passAllFilterFuncs(ctx, h.FilterFuncs, &request.Request{Req: r, W: w}) {
if h.ViewName != "" {
// if there was a view defined for this Config, set the view name in the context
ctx = context.WithValue(ctx, ViewKey{}, h.ViewName)
}
rcode, _ := h.pluginChain.ServeDNS(ctx, w, r)
if !plugin.ClientWrite(rcode) {
errorFunc(s.Addr, w, r, rcode)
}
return
}
}
}
// Still here? Error out with REFUSED.
errorAndMetricsFunc(s.Addr, w, r, dns.RcodeRefused)
}
// passAllFilterFuncs returns true if all filter funcs evaluate to true for the given request
func passAllFilterFuncs(ctx context.Context, filterFuncs []FilterFunc, req *request.Request) bool {
for _, ff := range filterFuncs {
if !ff(ctx, req) {
return false
}
}
return true
}
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *Server) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones("", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Tracer returns the tracer in the server if defined.
func (s *Server) Tracer() ot.Tracer {
if s.trace == nil {
return nil
}
return s.trace.Tracer()
}
// errorFunc responds to an DNS request with an error.
func errorFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int) {
state := request.Request{W: w, Req: r}
answer := new(dns.Msg)
answer.SetRcode(r, rc)
state.SizeAndDo(answer)
w.WriteMsg(answer)
}
func errorAndMetricsFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int) {
state := request.Request{W: w, Req: r}
answer := new(dns.Msg)
answer.SetRcode(r, rc)
state.SizeAndDo(answer)
vars.Report(server, state, vars.Dropped, "", rcode.ToString(rc), "" /* plugin */, answer.Len(), time.Now())
w.WriteMsg(answer)
}
const (
tcp = 0
udp = 1
tcpMaxQueries = -1
)
type (
// Key is the context key for the current server added to the context.
Key struct{}
// LoopKey is the context key to detect server wide loops.
LoopKey struct{}
// ViewKey is the context key for the current view, if defined
ViewKey struct{}
)
// EnableChaos is a map with plugin names for which we should open CH class queries as we block these by default.
var EnableChaos = map[string]struct{}{
"chaos": {},
"forward": {},
"proxy": {},
}
// Quiet mode will not show any informative output on initialization.
var Quiet bool
package dnsserver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/coredns/caddy"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/miekg/dns"
"github.com/opentracing/opentracing-go"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)
// ServergRPC represents an instance of a DNS-over-gRPC server.
type ServergRPC struct {
*Server
*pb.UnimplementedDnsServiceServer
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
}
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
// http/2 is required when using gRPC. We need to specify it in next protos
// or the upgrade won't happen.
if tlsConfig != nil {
tlsConfig.NextProtos = []string{"h2"}
}
return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
}
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &ServergRPC{}
// Serve implements caddy.TCPServer interface.
func (s *ServergRPC) Serve(l net.Listener) error {
s.m.Lock()
s.listenAddr = l.Addr()
s.m.Unlock()
if s.Tracer() != nil {
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}) bool {
return parentSpanCtx != nil
}
intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept))
} else {
s.grpcServer = grpc.NewServer()
}
pb.RegisterDnsServiceServer(s.grpcServer, s)
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
return s.grpcServer.Serve(l)
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServergRPC) ListenPacket() (net.PacketConn, error) { return nil, nil }
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServergRPC) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Stop stops the server. It blocks until the server is
// totally stopped.
func (s *ServergRPC) Stop() (err error) {
s.m.Lock()
defer s.m.Unlock()
if s.grpcServer != nil {
s.grpcServer.GracefulStop()
}
return
}
// Query is the main entry-point into the gRPC server. From here we call ServeDNS like
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
// back to the client as a protobuf.
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
msg := new(dns.Msg)
err := msg.Unpack(in.GetMsg())
if err != nil {
return nil, err
}
p, ok := peer.FromContext(ctx)
if !ok {
return nil, errors.New("no peer in gRPC context")
}
a, ok := p.Addr.(*net.TCPAddr)
if !ok {
return nil, fmt.Errorf("no TCP peer in gRPC context: %v", p.Addr)
}
w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg}
dnsCtx := context.WithValue(ctx, Key{}, s.Server)
dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0)
s.ServeDNS(dnsCtx, w, msg)
packed, err := w.Msg.Pack()
if err != nil {
return nil, err
}
return &pb.DnsPacket{Msg: packed}, nil
}
// Shutdown stops the server (non gracefully).
func (s *ServergRPC) Shutdown() error {
if s.grpcServer != nil {
s.grpcServer.Stop()
}
return nil
}
type gRPCresponse struct {
localAddr net.Addr
remoteAddr net.Addr
Msg *dns.Msg
}
// Write is the hack that makes this work. It does not actually write the message
// but returns the bytes we need to write in r. We can then pick this up in Query
// and write a proper protobuf back to the client.
func (r *gRPCresponse) Write(b []byte) (int, error) {
r.Msg = new(dns.Msg)
return len(b), r.Msg.Unpack(b)
}
// These methods implement the dns.ResponseWriter interface from Go DNS.
func (r *gRPCresponse) Close() error { return nil }
func (r *gRPCresponse) TsigStatus() error { return nil }
func (r *gRPCresponse) TsigTimersOnly(b bool) {}
func (r *gRPCresponse) Hijack() {}
func (r *gRPCresponse) LocalAddr() net.Addr { return r.localAddr }
func (r *gRPCresponse) RemoteAddr() net.Addr { return r.remoteAddr }
func (r *gRPCresponse) Network() string { return "" }
func (r *gRPCresponse) WriteMsg(m *dns.Msg) error { r.Msg = m; return nil }
package dnsserver
import (
"context"
"crypto/tls"
"fmt"
stdlog "log"
"net"
"net/http"
"strconv"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/metrics/vars"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/doh"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
)
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
type ServerHTTPS struct {
*Server
httpsServer *http.Server
listenAddr net.Addr
tlsConfig *tls.Config
validRequest func(*http.Request) bool
}
// loggerAdapter is a simple adapter around CoreDNS logger made to implement io.Writer in order to log errors from HTTP server
type loggerAdapter struct {
}
func (l *loggerAdapter) Write(p []byte) (n int, err error) {
clog.Debug(string(p))
return len(p), nil
}
// HTTPRequestKey is the context key for the current processed HTTP request (if current processed request was done over DOH)
type HTTPRequestKey struct{}
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
// http/2 is recommended when using DoH. We need to specify it in next protos
// or the upgrade won't happen.
if tlsConfig != nil {
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
}
// Use a custom request validation func or use the standard DoH path check.
var validator func(*http.Request) bool
for _, z := range s.zones {
for _, conf := range z {
validator = conf.HTTPRequestValidateFunc
}
}
if validator == nil {
validator = func(r *http.Request) bool { return r.URL.Path == doh.Path }
}
srv := &http.Server{
ReadTimeout: s.readTimeout,
WriteTimeout: s.writeTimeout,
IdleTimeout: s.idleTimeout,
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
}
sh := &ServerHTTPS{
Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator,
}
sh.httpsServer.Handler = sh
return sh, nil
}
// Compile-time check to ensure ServerHTTPS implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &ServerHTTPS{}
// Serve implements caddy.TCPServer interface.
func (s *ServerHTTPS) Serve(l net.Listener) error {
s.m.Lock()
s.listenAddr = l.Addr()
s.m.Unlock()
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
return s.httpsServer.Serve(l)
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerHTTPS) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServerHTTPS) ListenPacket() (net.PacketConn, error) { return nil, nil }
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServerHTTPS) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Stop stops the server. It blocks until the server is totally stopped.
func (s *ServerHTTPS) Stop() error {
s.m.Lock()
defer s.m.Unlock()
if s.httpsServer != nil {
s.httpsServer.Shutdown(context.Background())
}
return nil
}
// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the plugin
// chain, converts it back and write it to the client.
func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !s.validRequest(r) {
http.Error(w, "", http.StatusNotFound)
s.countResponse(http.StatusNotFound)
return
}
msg, err := doh.RequestToMsg(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
s.countResponse(http.StatusBadRequest)
return
}
// Create a DoHWriter with the correct addresses in it.
h, p, _ := net.SplitHostPort(r.RemoteAddr)
port, _ := strconv.Atoi(p)
dw := &DoHWriter{
laddr: s.listenAddr,
raddr: &net.TCPAddr{IP: net.ParseIP(h), Port: port},
request: r,
}
// We just call the normal chain handler - all error handling is done there.
// We should expect a packet to be returned that we can send to the client.
ctx := context.WithValue(context.Background(), Key{}, s.Server)
ctx = context.WithValue(ctx, LoopKey{}, 0)
ctx = context.WithValue(ctx, HTTPRequestKey{}, r)
s.ServeDNS(ctx, dw, msg)
// See section 4.2.1 of RFC 8484.
// We are using code 500 to indicate an unexpected situation when the chain
// handler has not provided any response message.
if dw.Msg == nil {
http.Error(w, "No response", http.StatusInternalServerError)
s.countResponse(http.StatusInternalServerError)
return
}
buf, _ := dw.Msg.Pack()
mt, _ := response.Typify(dw.Msg, time.Now().UTC())
age := dnsutil.MinimalTTL(dw.Msg, mt)
w.Header().Set("Content-Type", doh.MimeType)
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", uint32(age.Seconds())))
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
w.WriteHeader(http.StatusOK)
s.countResponse(http.StatusOK)
w.Write(buf)
}
func (s *ServerHTTPS) countResponse(status int) {
vars.HTTPSResponsesCount.WithLabelValues(s.Addr, strconv.Itoa(status)).Inc()
}
// Shutdown stops the server (non gracefully).
func (s *ServerHTTPS) Shutdown() error {
if s.httpsServer != nil {
s.httpsServer.Shutdown(context.Background())
}
return nil
}
package dnsserver
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"github.com/coredns/coredns/plugin/metrics/vars"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
const (
// DoQCodeNoError is used when the connection or stream needs to be
// closed, but there is no error to signal.
DoQCodeNoError quic.ApplicationErrorCode = 0
// DoQCodeInternalError signals that the DoQ implementation encountered
// an internal error and is incapable of pursuing the transaction or the
// connection.
DoQCodeInternalError quic.ApplicationErrorCode = 1
// DoQCodeProtocolError signals that the DoQ implementation encountered
// a protocol error and is forcibly aborting the connection.
DoQCodeProtocolError quic.ApplicationErrorCode = 2
// DefaultMaxQUICStreams is the default maximum number of concurrent QUIC streams
// on a per-connection basis. RFC 9250 (DNS-over-QUIC) does not require a high
// concurrent-stream limit; normal stub or recursive resolvers open only a handful
// of streams in parallel. This default (256) is a safe upper bound.
DefaultMaxQUICStreams = 256
// DefaultQUICStreamWorkers is the default number of workers for processing QUIC streams.
DefaultQUICStreamWorkers = 1024
)
// ServerQUIC represents an instance of a DNS-over-QUIC server.
type ServerQUIC struct {
*Server
listenAddr net.Addr
tlsConfig *tls.Config
quicConfig *quic.Config
quicListener *quic.Listener
maxStreams int
streamProcessPool chan struct{}
}
// NewServerQUIC returns a new CoreDNS QUIC server and compiles all plugin in to it.
func NewServerQUIC(addr string, group []*Config) (*ServerQUIC, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
if tlsConfig != nil {
tlsConfig.NextProtos = []string{"doq"}
}
maxStreams := DefaultMaxQUICStreams
if len(group) > 0 && group[0] != nil && group[0].MaxQUICStreams != nil {
maxStreams = *group[0].MaxQUICStreams
}
streamProcessPoolSize := DefaultQUICStreamWorkers
if len(group) > 0 && group[0] != nil && group[0].MaxQUICWorkerPoolSize != nil {
streamProcessPoolSize = *group[0].MaxQUICWorkerPoolSize
}
var quicConfig = &quic.Config{
MaxIdleTimeout: s.idleTimeout,
MaxIncomingStreams: int64(maxStreams),
MaxIncomingUniStreams: int64(maxStreams),
// Enable 0-RTT by default for all connections on the server-side.
Allow0RTT: true,
}
return &ServerQUIC{
Server: s,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
maxStreams: maxStreams,
streamProcessPool: make(chan struct{}, streamProcessPoolSize),
}, nil
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServerQUIC) ServePacket(p net.PacketConn) error {
s.m.Lock()
s.listenAddr = s.quicListener.Addr()
s.m.Unlock()
return s.ServeQUIC()
}
// ServeQUIC listens for incoming QUIC packets.
func (s *ServerQUIC) ServeQUIC() error {
for {
conn, err := s.quicListener.Accept(context.Background())
if err != nil {
if s.isExpectedErr(err) {
s.closeQUICConn(conn, DoQCodeNoError)
return err
}
s.closeQUICConn(conn, DoQCodeInternalError)
return err
}
go s.serveQUICConnection(conn)
}
}
// serveQUICConnection handles a new QUIC connection. It waits for new streams
// and passes them to serveQUICStream.
func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) {
if conn == nil {
return
}
for {
// In DoQ, one query consumes one stream.
// The client MUST select the next available client-initiated bidirectional
// stream for each subsequent query on a QUIC connection.
stream, err := conn.AcceptStream(context.Background())
if err != nil {
if s.isExpectedErr(err) {
s.closeQUICConn(conn, DoQCodeNoError)
return
}
s.closeQUICConn(conn, DoQCodeInternalError)
return
}
// Use a bounded worker pool
s.streamProcessPool <- struct{}{} // Acquire a worker slot, may block
go func(st *quic.Stream, cn *quic.Conn) {
defer func() { <-s.streamProcessPool }() // Release worker slot
s.serveQUICStream(st, cn)
}(stream, conn)
}
}
func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) {
if conn == nil {
return
}
if stream == nil {
s.closeQUICConn(conn, DoQCodeInternalError)
return
}
buf, err := readDOQMessage(stream)
// io.EOF does not really mean that there's any error, it is just
// the STREAM FIN indicating that there will be no data to read
// anymore from this stream.
if err != nil && err != io.EOF {
s.closeQUICConn(conn, DoQCodeProtocolError)
return
}
req := &dns.Msg{}
err = req.Unpack(buf)
if err != nil {
clog.Debugf("unpacking quic packet: %s", err)
s.closeQUICConn(conn, DoQCodeProtocolError)
return
}
if !validRequest(req) {
// If a peer encounters such an error condition, it is considered a
// fatal error. It SHOULD forcibly abort the connection using QUIC's
// CONNECTION_CLOSE mechanism and SHOULD use the DoQ error code
// DOQ_PROTOCOL_ERROR.
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.3.3-3
s.closeQUICConn(conn, DoQCodeProtocolError)
return
}
w := &DoQWriter{
localAddr: conn.LocalAddr(),
remoteAddr: conn.RemoteAddr(),
stream: stream,
Msg: req,
}
dnsCtx := context.WithValue(stream.Context(), Key{}, s.Server)
dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0)
s.ServeDNS(dnsCtx, w, req)
s.countResponse(DoQCodeNoError)
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServerQUIC) ListenPacket() (net.PacketConn, error) {
p, err := reuseport.ListenPacket("udp", s.Addr[len(transport.QUIC+"://"):])
if err != nil {
return nil, err
}
s.m.Lock()
defer s.m.Unlock()
s.quicListener, err = quic.Listen(p, s.tlsConfig, s.quicConfig)
if err != nil {
return nil, err
}
return p, nil
}
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServerQUIC) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.QUIC+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Stop stops the server non-gracefully. It blocks until the server is totally stopped.
func (s *ServerQUIC) Stop() error {
s.m.Lock()
defer s.m.Unlock()
if s.quicListener != nil {
return s.quicListener.Close()
}
return nil
}
// Serve implements caddy.TCPServer interface.
func (s *ServerQUIC) Serve(l net.Listener) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerQUIC) Listen() (net.Listener, error) { return nil, nil }
// closeQUICConn quietly closes the QUIC connection.
func (s *ServerQUIC) closeQUICConn(conn *quic.Conn, code quic.ApplicationErrorCode) {
if conn == nil {
return
}
clog.Debugf("closing quic conn %s with code %d", conn.LocalAddr(), code)
err := conn.CloseWithError(code, "")
if err != nil {
clog.Debugf("failed to close quic connection with code %d: %s", code, err)
}
// DoQCodeNoError metrics are already registered after s.ServeDNS()
if code != DoQCodeNoError {
s.countResponse(code)
}
}
// validRequest checks for protocol errors in the unpacked DNS message.
// See https://www.rfc-editor.org/rfc/rfc9250.html#name-protocol-errors
func validRequest(req *dns.Msg) (ok bool) {
// 1. a client or server receives a message with a non-zero Message ID.
if req.Id != 0 {
return false
}
// 2. an implementation receives a message containing the edns-tcp-keepalive
// EDNS(0) Option [RFC7828].
if opt := req.IsEdns0(); opt != nil {
for _, option := range opt.Option {
if option.Option() == dns.EDNS0TCPKEEPALIVE {
clog.Debug("client sent EDNS0 TCP keepalive option")
return false
}
}
}
// 3. the client or server does not indicate the expected STREAM FIN after
// sending requests or responses.
//
// This is quite problematic to validate this case since this would imply
// we have to wait until STREAM FIN is arrived before we start processing
// the message. So we're consciously ignoring this case in this
// implementation.
// 4. a server receives a "replayable" transaction in 0-RTT data
//
// The information necessary to validate this is not exposed by quic-go.
return true
}
// readDOQMessage reads a DNS over QUIC (DOQ) message from the given stream
// and returns the message bytes.
// Drafts of the RFC9250 did not require the 2-byte prefixed message length.
// Thus, we are only supporting the official version (DoQ v1).
func readDOQMessage(r io.Reader) ([]byte, error) {
// All DNS messages (queries and responses) sent over DoQ connections MUST
// be encoded as a 2-octet length field followed by the message content as
// specified in [RFC1035].
// See https://www.rfc-editor.org/rfc/rfc9250.html#section-4.2-4
sizeBuf := make([]byte, 2)
_, err := io.ReadFull(r, sizeBuf)
if err != nil {
return nil, err
}
size := binary.BigEndian.Uint16(sizeBuf)
if size == 0 {
return nil, fmt.Errorf("message size is 0: probably unsupported DoQ version")
}
buf := make([]byte, size)
_, err = io.ReadFull(r, buf)
// A client or server receives a STREAM FIN before receiving all the bytes
// for a message indicated in the 2-octet length field.
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.3.3-2.2
if size != uint16(len(buf)) {
return nil, fmt.Errorf("message size does not match 2-byte prefix")
}
return buf, err
}
// isExpectedErr returns true if err is an expected error, likely related to
// the current implementation.
func (s *ServerQUIC) isExpectedErr(err error) bool {
if err == nil {
return false
}
// This error is returned when the QUIC listener was closed by us. As
// graceful shutdown is not implemented, the connection will be abruptly
// closed but there is no error to signal.
if errors.Is(err, quic.ErrServerClosed) {
return true
}
// This error happens when the connection was closed due to a DoQ
// protocol error but there's still something to read in the closed stream.
// For example, when the message was sent without the prefixed length.
var qAppErr *quic.ApplicationError
if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 2 {
return true
}
// When a connection hits the idle timeout, quic.AcceptStream() returns
// an IdleTimeoutError. In this, case, we should just drop the connection
// with DoQCodeNoError.
var qIdleErr *quic.IdleTimeoutError
return errors.As(err, &qIdleErr)
}
func (s *ServerQUIC) countResponse(code quic.ApplicationErrorCode) {
switch code {
case DoQCodeNoError:
vars.QUICResponsesCount.WithLabelValues(s.Addr, "0x0").Inc()
case DoQCodeInternalError:
vars.QUICResponsesCount.WithLabelValues(s.Addr, "0x1").Inc()
case DoQCodeProtocolError:
vars.QUICResponsesCount.WithLabelValues(s.Addr, "0x2").Inc()
}
}
package dnsserver
import (
"context"
"crypto/tls"
"fmt"
"net"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
// ServerTLS represents an instance of a TLS-over-DNS-server.
type ServerTLS struct {
*Server
tlsConfig *tls.Config
}
// NewServerTLS returns a new CoreDNS TLS server and compiles all plugin in to it.
func NewServerTLS(addr string, group []*Config) (*ServerTLS, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
return &ServerTLS{Server: s, tlsConfig: tlsConfig}, nil
}
// Compile-time check to ensure ServerTLS implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &ServerTLS{}
// Serve implements caddy.TCPServer interface.
func (s *ServerTLS) Serve(l net.Listener) error {
s.m.Lock()
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
// Only fill out the TCP server for this one.
s.server[tcp] = &dns.Server{Listener: l,
Net: "tcp-tls",
MaxTCPQueries: tlsMaxQueries,
ReadTimeout: s.readTimeout,
WriteTimeout: s.writeTimeout,
IdleTimeout: func() time.Duration {
return s.idleTimeout
},
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s.Server)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}
s.m.Unlock()
return s.server[tcp].ActivateAndServe()
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerTLS) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.TLS+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServerTLS) ListenPacket() (net.PacketConn, error) { return nil, nil }
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServerTLS) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.TLS+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
const (
tlsMaxQueries = -1
)
// Package coremain contains the functions for starting CoreDNS.
package coremain
import (
"flag"
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"go.uber.org/automaxprocs/maxprocs"
)
func init() {
caddy.DefaultConfigFile = "Corefile"
caddy.Quiet = true // don't show init stuff from caddy
setVersion()
flag.StringVar(&conf, "conf", "", "Corefile to load (default \""+caddy.DefaultConfigFile+"\")")
flag.BoolVar(&plugins, "plugins", false, "List installed plugins")
flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file")
flag.BoolVar(&version, "version", false, "Show version")
flag.BoolVar(&dnsserver.Quiet, "quiet", false, "Quiet mode (no initialization output)")
caddy.RegisterCaddyfileLoader("flag", caddy.LoaderFunc(confLoader))
caddy.SetDefaultCaddyfileLoader("default", caddy.LoaderFunc(defaultLoader))
flag.StringVar(&dnsserver.Port, serverType+".port", dnsserver.DefaultPort, "Default port")
flag.StringVar(&dnsserver.Port, "p", dnsserver.DefaultPort, "Default port")
caddy.AppName = CoreName
caddy.AppVersion = CoreVersion
}
// Run is CoreDNS's main() function.
func Run() {
caddy.TrapSignals()
flag.Parse()
if len(flag.Args()) > 0 {
mustLogFatal(fmt.Errorf("extra command line arguments: %s", flag.Args()))
}
log.SetOutput(os.Stdout)
log.SetFlags(LogFlags)
if version {
showVersion()
os.Exit(0)
}
if plugins {
fmt.Println(caddy.DescribePlugins())
os.Exit(0)
}
_, err := maxprocs.Set(maxprocs.Logger(log.Printf))
if err != nil {
log.Println("[WARNING] Failed to set GOMAXPROCS:", err)
}
// Get Corefile input
corefile, err := caddy.LoadCaddyfile(serverType)
if err != nil {
mustLogFatal(err)
}
// Start your engines
instance, err := caddy.Start(corefile)
if err != nil {
mustLogFatal(err)
}
if !dnsserver.Quiet {
showVersion()
}
// Twiddle your thumbs
instance.Wait()
}
// mustLogFatal wraps log.Fatal() in a way that ensures the
// output is always printed to stderr so the user can see it
// if the user is still there, even if the process log was not
// enabled. If this process is an upgrade, however, and the user
// might not be there anymore, this just logs to the process
// log and exits.
func mustLogFatal(args ...interface{}) {
if !caddy.IsUpgrade() {
log.SetOutput(os.Stderr)
}
log.Fatal(args...)
}
// confLoader loads the Caddyfile using the -conf flag.
func confLoader(serverType string) (caddy.Input, error) {
if conf == "" {
return nil, nil
}
if conf == "stdin" {
return caddy.CaddyfileFromPipe(os.Stdin, serverType)
}
contents, err := os.ReadFile(filepath.Clean(conf))
if err != nil {
return nil, err
}
return caddy.CaddyfileInput{
Contents: contents,
Filepath: conf,
ServerTypeName: serverType,
}, nil
}
// defaultLoader loads the Corefile from the current working directory.
func defaultLoader(serverType string) (caddy.Input, error) {
contents, err := os.ReadFile(caddy.DefaultConfigFile)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
return caddy.CaddyfileInput{
Contents: contents,
Filepath: caddy.DefaultConfigFile,
ServerTypeName: serverType,
}, nil
}
// showVersion prints the version that is starting.
func showVersion() {
fmt.Print(versionString())
fmt.Print(releaseString())
if devBuild && gitShortStat != "" {
fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified)
}
}
// versionString returns the CoreDNS version as a string.
func versionString() string {
return fmt.Sprintf("%s-%s\n", caddy.AppName, caddy.AppVersion)
}
// releaseString returns the release information related to CoreDNS version:
// <OS>/<ARCH>, <go version>, <commit>
// e.g.,
// linux/amd64, go1.8.3, a6d2d7b5
func releaseString() string {
return fmt.Sprintf("%s/%s, %s, %s\n", runtime.GOOS, runtime.GOARCH, runtime.Version(), GitCommit)
}
// setVersion figures out the version information
// based on variables set by -ldflags.
func setVersion() {
// A development build is one that's not at a tag or has uncommitted changes
devBuild = gitTag == "" || gitShortStat != ""
// Only set the appVersion if -ldflags was used
if gitNearestTag != "" || gitTag != "" {
if devBuild && gitNearestTag != "" {
appVersion = fmt.Sprintf("%s (+%s %s)", strings.TrimPrefix(gitNearestTag, "v"), GitCommit, buildDate)
} else if gitTag != "" {
appVersion = strings.TrimPrefix(gitTag, "v")
}
}
}
// Flags that control program flow or startup
var (
conf string
version bool
plugins bool
// LogFlags are initially set to 0 for no extra output
LogFlags int
)
// Build information obtained with the help of -ldflags
var (
appVersion = "(untracked dev build)" // inferred at startup
devBuild = true // inferred at startup
buildDate string // date -u
gitTag string // git describe --exact-match HEAD 2> /dev/null
gitNearestTag string // git describe --abbrev=0 --tags HEAD
gitShortStat string // git diff-index --shortstat
gitFilesModified string // git diff-index --name-only HEAD
// Gitcommit contains the commit where we built CoreDNS from.
GitCommit string
)
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.27.1
// protoc v3.19.4
// source: dns.proto
package pb
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type DnsPacket struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Msg []byte `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"`
}
func (x *DnsPacket) Reset() {
*x = DnsPacket{}
if protoimpl.UnsafeEnabled {
mi := &file_dns_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DnsPacket) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DnsPacket) ProtoMessage() {}
func (x *DnsPacket) ProtoReflect() protoreflect.Message {
mi := &file_dns_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DnsPacket.ProtoReflect.Descriptor instead.
func (*DnsPacket) Descriptor() ([]byte, []int) {
return file_dns_proto_rawDescGZIP(), []int{0}
}
func (x *DnsPacket) GetMsg() []byte {
if x != nil {
return x.Msg
}
return nil
}
var File_dns_proto protoreflect.FileDescriptor
var file_dns_proto_rawDesc = []byte{
0x0a, 0x09, 0x64, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0b, 0x63, 0x6f, 0x72,
0x65, 0x64, 0x6e, 0x73, 0x2e, 0x64, 0x6e, 0x73, 0x22, 0x1d, 0x0a, 0x09, 0x44, 0x6e, 0x73, 0x50,
0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x18, 0x01, 0x20, 0x01,
0x28, 0x0c, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x32, 0x45, 0x0a, 0x0a, 0x44, 0x6e, 0x73, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x37, 0x0a, 0x05, 0x51, 0x75, 0x65, 0x72, 0x79, 0x12, 0x16,
0x2e, 0x63, 0x6f, 0x72, 0x65, 0x64, 0x6e, 0x73, 0x2e, 0x64, 0x6e, 0x73, 0x2e, 0x44, 0x6e, 0x73,
0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x64, 0x6e, 0x73,
0x2e, 0x64, 0x6e, 0x73, 0x2e, 0x44, 0x6e, 0x73, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x42, 0x06,
0x5a, 0x04, 0x2e, 0x3b, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_dns_proto_rawDescOnce sync.Once
file_dns_proto_rawDescData = file_dns_proto_rawDesc
)
func file_dns_proto_rawDescGZIP() []byte {
file_dns_proto_rawDescOnce.Do(func() {
file_dns_proto_rawDescData = protoimpl.X.CompressGZIP(file_dns_proto_rawDescData)
})
return file_dns_proto_rawDescData
}
var file_dns_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_dns_proto_goTypes = []interface{}{
(*DnsPacket)(nil), // 0: coredns.dns.DnsPacket
}
var file_dns_proto_depIdxs = []int32{
0, // 0: coredns.dns.DnsService.Query:input_type -> coredns.dns.DnsPacket
0, // 1: coredns.dns.DnsService.Query:output_type -> coredns.dns.DnsPacket
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_dns_proto_init() }
func file_dns_proto_init() {
if File_dns_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_dns_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DnsPacket); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_dns_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_dns_proto_goTypes,
DependencyIndexes: file_dns_proto_depIdxs,
MessageInfos: file_dns_proto_msgTypes,
}.Build()
File_dns_proto = out.File
file_dns_proto_rawDesc = nil
file_dns_proto_goTypes = nil
file_dns_proto_depIdxs = nil
}
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.19.4
// source: dns.proto
package pb
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// DnsServiceClient is the client API for DnsService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type DnsServiceClient interface {
Query(ctx context.Context, in *DnsPacket, opts ...grpc.CallOption) (*DnsPacket, error)
}
type dnsServiceClient struct {
cc grpc.ClientConnInterface
}
func NewDnsServiceClient(cc grpc.ClientConnInterface) DnsServiceClient {
return &dnsServiceClient{cc}
}
func (c *dnsServiceClient) Query(ctx context.Context, in *DnsPacket, opts ...grpc.CallOption) (*DnsPacket, error) {
out := new(DnsPacket)
err := c.cc.Invoke(ctx, "/coredns.dns.DnsService/Query", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DnsServiceServer is the server API for DnsService service.
// All implementations must embed UnimplementedDnsServiceServer
// for forward compatibility
type DnsServiceServer interface {
Query(context.Context, *DnsPacket) (*DnsPacket, error)
mustEmbedUnimplementedDnsServiceServer()
}
// UnimplementedDnsServiceServer must be embedded to have forward compatible implementations.
type UnimplementedDnsServiceServer struct {
}
func (UnimplementedDnsServiceServer) Query(context.Context, *DnsPacket) (*DnsPacket, error) {
return nil, status.Errorf(codes.Unimplemented, "method Query not implemented")
}
func (UnimplementedDnsServiceServer) mustEmbedUnimplementedDnsServiceServer() {}
// UnsafeDnsServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to DnsServiceServer will
// result in compilation errors.
type UnsafeDnsServiceServer interface {
mustEmbedUnimplementedDnsServiceServer()
}
func RegisterDnsServiceServer(s grpc.ServiceRegistrar, srv DnsServiceServer) {
s.RegisterService(&DnsService_ServiceDesc, srv)
}
func _DnsService_Query_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DnsPacket)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DnsServiceServer).Query(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/coredns.dns.DnsService/Query",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DnsServiceServer).Query(ctx, req.(*DnsPacket))
}
return interceptor(ctx, in, info, handler)
}
// DnsService_ServiceDesc is the grpc.ServiceDesc for DnsService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var DnsService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "coredns.dns.DnsService",
HandlerType: (*DnsServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Query",
Handler: _DnsService_Query_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "dns.proto",
}
package plugin
import (
"context"
"fmt"
"math"
"net"
"github.com/coredns/coredns/plugin/etcd/msg"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
const maxCnameChainLength = 10
// A returns A records from Backend or an error.
func A(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) {
services, err := checkForApex(ctx, b, zone, state, opt)
if err != nil {
return nil, false, err
}
dup := make(map[string]struct{})
for _, serv := range services {
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) {
// x CNAME x is a direct loop, don't add those
// in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results
continue
}
newRecord := serv.NewCNAME(state.QName(), serv.Host)
if len(previousRecords) > maxCnameChainLength {
// don't add it, and just continue
continue
}
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) {
state1 := state.NewWithQuestion(serv.Host, state.QType())
state1.Zone = zone
nextRecords, tc, err := A(ctx, b, zone, state1, append(previousRecords, newRecord), opt)
if err == nil {
// Not only have we found something we should add the CNAME and the IP addresses.
if len(nextRecords) > 0 {
records = append(records, newRecord)
records = append(records, nextRecords...)
}
}
if tc {
truncated = true
}
continue
}
// This means we can not complete the CNAME, try to look else where.
target := newRecord.Target
// Lookup
m1, e1 := b.Lookup(ctx, state, target, state.QType())
if e1 != nil {
continue
}
if m1.Truncated {
truncated = true
}
// Len(m1.Answer) > 0 here is well?
records = append(records, newRecord)
records = append(records, m1.Answer...)
continue
case dns.TypeA:
if _, ok := dup[serv.Host]; !ok {
dup[serv.Host] = struct{}{}
records = append(records, serv.NewA(state.QName(), ip))
}
case dns.TypeAAAA:
// nada
}
}
return records, truncated, nil
}
// AAAA returns AAAA records from Backend or an error.
func AAAA(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) {
services, err := checkForApex(ctx, b, zone, state, opt)
if err != nil {
return nil, false, err
}
dup := make(map[string]struct{})
for _, serv := range services {
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
// Try to resolve as CNAME if it's not an IP, but only if we don't create loops.
if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) {
// x CNAME x is a direct loop, don't add those
// in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results
continue
}
newRecord := serv.NewCNAME(state.QName(), serv.Host)
if len(previousRecords) > maxCnameChainLength {
// don't add it, and just continue
continue
}
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) {
state1 := state.NewWithQuestion(serv.Host, state.QType())
state1.Zone = zone
nextRecords, tc, err := AAAA(ctx, b, zone, state1, append(previousRecords, newRecord), opt)
if err == nil {
// Not only have we found something we should add the CNAME and the IP addresses.
if len(nextRecords) > 0 {
records = append(records, newRecord)
records = append(records, nextRecords...)
}
}
if tc {
truncated = true
}
continue
}
// This means we can not complete the CNAME, try to look else where.
target := newRecord.Target
m1, e1 := b.Lookup(ctx, state, target, state.QType())
if e1 != nil {
continue
}
if m1.Truncated {
truncated = true
}
// Len(m1.Answer) > 0 here is well?
records = append(records, newRecord)
records = append(records, m1.Answer...)
continue
// both here again
case dns.TypeA:
// nada
case dns.TypeAAAA:
if _, ok := dup[serv.Host]; !ok {
dup[serv.Host] = struct{}{}
records = append(records, serv.NewAAAA(state.QName(), ip))
}
}
}
return records, truncated, nil
}
// SRV returns SRV records from the Backend.
// If the Target is not a name but an IP address, a name is created on the fly.
func SRV(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
services, err := b.Services(ctx, state, false, opt)
if err != nil {
return nil, nil, err
}
dup := make(map[item]struct{})
lookup := make(map[string]struct{})
// Looping twice to get the right weight vs priority. This might break because we may drop duplicate SRV records latter on.
w := make(map[int]int)
for _, serv := range services {
weight := 100
if serv.Weight != 0 {
weight = serv.Weight
}
if _, ok := w[serv.Priority]; !ok {
w[serv.Priority] = weight
continue
}
w[serv.Priority] += weight
}
for _, serv := range services {
// Don't add the entry if the port is -1 (invalid). The kubernetes plugin uses port -1 when a service/endpoint
// does not have any declared ports.
if serv.Port == -1 {
continue
}
w1 := 100.0 / float64(w[serv.Priority])
if serv.Weight == 0 {
w1 *= 100
} else {
w1 *= float64(serv.Weight)
}
weight := uint16(math.Floor(w1))
// weight should be at least 1
if weight == 0 {
weight = 1
}
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
srv := serv.NewSRV(state.QName(), weight)
records = append(records, srv)
if _, ok := lookup[srv.Target]; ok {
break
}
lookup[srv.Target] = struct{}{}
if !dns.IsSubDomain(zone, srv.Target) {
m1, e1 := b.Lookup(ctx, state, srv.Target, dns.TypeA)
if e1 == nil {
extra = append(extra, m1.Answer...)
}
m1, e1 = b.Lookup(ctx, state, srv.Target, dns.TypeAAAA)
if e1 == nil {
// If we have seen CNAME's we *assume* that they are already added.
for _, a := range m1.Answer {
if _, ok := a.(*dns.CNAME); !ok {
extra = append(extra, a)
}
}
}
break
}
// Internal name, we should have some info on them, either v4 or v6
// Clients expect a complete answer, because we are a recursor in their view.
state1 := state.NewWithQuestion(srv.Target, dns.TypeA)
addr, _, e1 := A(ctx, b, zone, state1, nil, opt)
if e1 == nil {
extra = append(extra, addr...)
}
// TODO(miek): AAAA as well here.
case dns.TypeA, dns.TypeAAAA:
addr := serv.Host
serv.Host = msg.Domain(serv.Key)
srv := serv.NewSRV(state.QName(), weight)
if ok := isDuplicate(dup, srv.Target, "", srv.Port); !ok {
records = append(records, srv)
}
if ok := isDuplicate(dup, srv.Target, addr, 0); !ok {
extra = append(extra, newAddress(serv, srv.Target, ip, what))
}
}
}
return records, extra, nil
}
// MX returns MX records from the Backend. If the Target is not a name but an IP address, a name is created on the fly.
func MX(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
services, err := b.Services(ctx, state, false, opt)
if err != nil {
return nil, nil, err
}
dup := make(map[item]struct{})
lookup := make(map[string]struct{})
for _, serv := range services {
if !serv.Mail {
continue
}
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
mx := serv.NewMX(state.QName())
records = append(records, mx)
if _, ok := lookup[mx.Mx]; ok {
break
}
lookup[mx.Mx] = struct{}{}
if !dns.IsSubDomain(zone, mx.Mx) {
m1, e1 := b.Lookup(ctx, state, mx.Mx, dns.TypeA)
if e1 == nil {
extra = append(extra, m1.Answer...)
}
m1, e1 = b.Lookup(ctx, state, mx.Mx, dns.TypeAAAA)
if e1 == nil {
// If we have seen CNAME's we *assume* that they are already added.
for _, a := range m1.Answer {
if _, ok := a.(*dns.CNAME); !ok {
extra = append(extra, a)
}
}
}
break
}
// Internal name
state1 := state.NewWithQuestion(mx.Mx, dns.TypeA)
addr, _, e1 := A(ctx, b, zone, state1, nil, opt)
if e1 == nil {
extra = append(extra, addr...)
}
// TODO(miek): AAAA as well here.
case dns.TypeA, dns.TypeAAAA:
addr := serv.Host
serv.Host = msg.Domain(serv.Key)
mx := serv.NewMX(state.QName())
if ok := isDuplicate(dup, mx.Mx, "", mx.Preference); !ok {
records = append(records, mx)
}
// Fake port to be 0 for address...
if ok := isDuplicate(dup, serv.Host, addr, 0); !ok {
extra = append(extra, newAddress(serv, serv.Host, ip, what))
}
}
}
return records, extra, nil
}
// CNAME returns CNAME records from the backend or an error.
func CNAME(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) {
services, err := b.Services(ctx, state, true, opt)
if err != nil {
return nil, err
}
if len(services) > 0 {
serv := services[0]
if ip := net.ParseIP(serv.Host); ip == nil {
records = append(records, serv.NewCNAME(state.QName(), serv.Host))
}
}
return records, nil
}
// TXT returns TXT records from Backend or an error.
func TXT(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) {
services, err := b.Services(ctx, state, false, opt)
if err != nil {
return nil, false, err
}
dup := make(map[string]struct{})
for _, serv := range services {
what, _ := serv.HostType()
switch what {
case dns.TypeCNAME:
if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) {
// x CNAME x is a direct loop, don't add those
// in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results
continue
}
newRecord := serv.NewCNAME(state.QName(), serv.Host)
if len(previousRecords) > maxCnameChainLength {
// don't add it, and just continue
continue
}
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) {
state1 := state.NewWithQuestion(serv.Host, state.QType())
state1.Zone = zone
nextRecords, tc, err := TXT(ctx, b, zone, state1, append(previousRecords, newRecord), opt)
if tc {
truncated = true
}
if err == nil {
// Not only have we found something we should add the CNAME and the IP addresses.
if len(nextRecords) > 0 {
records = append(records, newRecord)
records = append(records, nextRecords...)
}
}
continue
}
// This means we can not complete the CNAME, try to look else where.
target := newRecord.Target
// Lookup
m1, e1 := b.Lookup(ctx, state, target, state.QType())
if e1 != nil {
continue
}
// Len(m1.Answer) > 0 here is well?
records = append(records, newRecord)
records = append(records, m1.Answer...)
continue
case dns.TypeTXT:
if _, ok := dup[serv.Text]; !ok {
dup[serv.Text] = struct{}{}
records = append(records, serv.NewTXT(state.QName()))
}
}
}
return records, truncated, nil
}
// PTR returns the PTR records from the backend, only services that have a domain name as host are included.
func PTR(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) {
services, err := b.Reverse(ctx, state, true, opt)
if err != nil {
return nil, err
}
dup := make(map[string]struct{})
for _, serv := range services {
if ip := net.ParseIP(serv.Host); ip == nil {
if _, ok := dup[serv.Host]; !ok {
dup[serv.Host] = struct{}{}
records = append(records, serv.NewPTR(state.QName(), serv.Host))
}
}
}
return records, nil
}
// NS returns NS records from the backend
func NS(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) {
// NS record for this zone lives in a special place, ns.dns.<zone>. Fake our lookup.
// Only a tad bit fishy...
old := state.QName()
state.Clear()
state.Req.Question[0].Name = dnsutil.Join("ns.dns.", zone)
services, err := b.Services(ctx, state, false, opt)
// reset the query name to the original
state.Req.Question[0].Name = old
if err != nil {
return nil, nil, err
}
seen := map[string]bool{}
for _, serv := range services {
what, ip := serv.HostType()
switch what {
case dns.TypeCNAME:
return nil, nil, fmt.Errorf("NS record must be an IP address: %s", serv.Host)
case dns.TypeA, dns.TypeAAAA:
serv.Host = msg.Domain(serv.Key)
ns := serv.NewNS(state.QName())
extra = append(extra, newAddress(serv, ns.Ns, ip, what))
if _, ok := seen[ns.Ns]; ok {
continue
}
seen[ns.Ns] = true
records = append(records, ns)
}
}
return records, extra, nil
}
// SOA returns a SOA record from the backend.
func SOA(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) ([]dns.RR, error) {
minTTL := b.MinTTL(state)
ttl := uint32(300)
if minTTL < ttl {
ttl = minTTL
}
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: ttl, Class: dns.ClassINET}
Mbox := dnsutil.Join(hostmaster, zone)
Ns := dnsutil.Join("ns.dns", zone)
soa := &dns.SOA{Hdr: header,
Mbox: Mbox,
Ns: Ns,
Serial: b.Serial(state),
Refresh: 7200,
Retry: 1800,
Expire: 86400,
Minttl: minTTL,
}
return []dns.RR{soa}, nil
}
// BackendError writes an error response to the client.
func BackendError(ctx context.Context, b ServiceBackend, zone string, rcode int, state request.Request, err error, opt Options) (int, error) {
m := new(dns.Msg)
m.SetRcode(state.Req, rcode)
m.Authoritative = true
m.Ns, _ = SOA(ctx, b, zone, state, opt)
state.W.WriteMsg(m)
// Return success as the rcode to signal we have written to the client.
return dns.RcodeSuccess, err
}
func newAddress(s msg.Service, name string, ip net.IP, what uint16) dns.RR {
hdr := dns.RR_Header{Name: name, Rrtype: what, Class: dns.ClassINET, Ttl: s.TTL}
if what == dns.TypeA {
return &dns.A{Hdr: hdr, A: ip}
}
// Should always be dns.TypeAAAA
return &dns.AAAA{Hdr: hdr, AAAA: ip}
}
// checkForApex checks the special apex.dns directory for records that will be returned as A or AAAA.
func checkForApex(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) ([]msg.Service, error) {
if state.Name() != zone {
return b.Services(ctx, state, false, opt)
}
// If the zone name itself is queried we fake the query to search for a special entry
// this is equivalent to the NS search code.
old := state.QName()
state.Clear()
state.Req.Question[0].Name = dnsutil.Join("apex.dns", zone)
services, err := b.Services(ctx, state, false, opt)
if err == nil {
state.Req.Question[0].Name = old
return services, err
}
state.Req.Question[0].Name = old
return b.Services(ctx, state, false, opt)
}
// item holds records.
type item struct {
name string // name of the record (either owner or something else unique).
port uint16 // port of the record (used for address records, A and AAAA).
addr string // address of the record (A and AAAA).
}
// isDuplicate uses m to see if the combo (name, addr, port) already exists. If it does
// not exist already IsDuplicate will also add the record to the map.
func isDuplicate(m map[item]struct{}, name, addr string, port uint16) bool {
if addr != "" {
_, ok := m[item{name, 0, addr}]
if !ok {
m[item{name, 0, addr}] = struct{}{}
}
return ok
}
_, ok := m[item{name, port, ""}]
if !ok {
m[item{name, port, ""}] = struct{}{}
}
return ok
}
const hostmaster = "hostmaster"
// Package cache implements a cache.
package cache
import (
"hash/fnv"
"net"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/cache"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Cache is a plugin that looks up responses in a cache and caches replies.
// It has a success and a denial of existence cache.
type Cache struct {
Next plugin.Handler
Zones []string
zonesMetricLabel string
viewMetricLabel string
ncache *cache.Cache
ncap int
nttl time.Duration
minnttl time.Duration
pcache *cache.Cache
pcap int
pttl time.Duration
minpttl time.Duration
failttl time.Duration // TTL for caching SERVFAIL responses
// Prefetch.
prefetch int
duration time.Duration
percentage int
// Stale serve
staleUpTo time.Duration
verifyStale bool
// Positive/negative zone exceptions
pexcept []string
nexcept []string
// Keep ttl option
keepttl bool
// Testing.
now func() time.Time
}
// New returns an initialized Cache with default settings. It's up to the
// caller to set the Next handler.
func New() *Cache {
return &Cache{
Zones: []string{"."},
pcap: defaultCap,
pcache: cache.New(defaultCap),
pttl: maxTTL,
minpttl: minTTL,
ncap: defaultCap,
ncache: cache.New(defaultCap),
nttl: maxNTTL,
minnttl: minNTTL,
failttl: minNTTL,
prefetch: 0,
duration: 1 * time.Minute,
percentage: 10,
now: time.Now,
}
}
// key returns key under which we store the item, -1 will be returned if we don't store the message.
// Currently we do not cache Truncated, errors zone transfers or dynamic update messages.
// qname holds the already lowercased qname.
func key(qname string, m *dns.Msg, t response.Type, do, cd bool) (bool, uint64) {
// We don't store truncated responses.
if m.Truncated {
return false, 0
}
// Nor errors or Meta or Update.
if t == response.OtherError || t == response.Meta || t == response.Update {
return false, 0
}
return true, hash(qname, m.Question[0].Qtype, do, cd)
}
var one = []byte("1")
var zero = []byte("0")
func hash(qname string, qtype uint16, do, cd bool) uint64 {
h := fnv.New64()
if do {
h.Write(one)
} else {
h.Write(zero)
}
if cd {
h.Write(one)
} else {
h.Write(zero)
}
h.Write([]byte{byte(qtype >> 8)})
h.Write([]byte{byte(qtype)})
h.Write([]byte(qname))
return h.Sum64()
}
func computeTTL(msgTTL, minTTL, maxTTL time.Duration) time.Duration {
ttl := msgTTL
if ttl < minTTL {
ttl = minTTL
}
if ttl > maxTTL {
ttl = maxTTL
}
return ttl
}
// ResponseWriter is a response writer that caches the reply message.
type ResponseWriter struct {
dns.ResponseWriter
*Cache
state request.Request
server string // Server handling the request.
do bool // When true the original request had the DO bit set.
cd bool // When true the original request had the CD bit set.
ad bool // When true the original request had the AD bit set.
prefetch bool // When true write nothing back to the client.
remoteAddr net.Addr
wildcardFunc func() string // function to retrieve wildcard name that synthesized the result.
pexcept []string // positive zone exceptions
nexcept []string // negative zone exceptions
}
// newPrefetchResponseWriter returns a Cache ResponseWriter to be used in
// prefetch requests. It ensures RemoteAddr() can be called even after the
// original connection has already been closed.
func newPrefetchResponseWriter(server string, state request.Request, c *Cache) *ResponseWriter {
// Resolve the address now, the connection might be already closed when the
// actual prefetch request is made.
addr := state.W.RemoteAddr()
// The protocol of the client triggering a cache prefetch doesn't matter.
// The address type is used by request.Proto to determine the response size,
// and using TCP ensures the message isn't unnecessarily truncated.
if u, ok := addr.(*net.UDPAddr); ok {
addr = &net.TCPAddr{IP: u.IP, Port: u.Port, Zone: u.Zone}
}
return &ResponseWriter{
ResponseWriter: state.W,
Cache: c,
state: state,
server: server,
do: state.Do(),
cd: state.Req.CheckingDisabled,
prefetch: true,
remoteAddr: addr,
}
}
// RemoteAddr implements the dns.ResponseWriter interface.
func (w *ResponseWriter) RemoteAddr() net.Addr {
if w.remoteAddr != nil {
return w.remoteAddr
}
return w.ResponseWriter.RemoteAddr()
}
// WriteMsg implements the dns.ResponseWriter interface.
func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
res = res.Copy()
mt, _ := response.Typify(res, w.now().UTC())
// key returns empty string for anything we don't want to cache.
hasKey, key := key(w.state.Name(), res, mt, w.do, w.cd)
msgTTL := dnsutil.MinimalTTL(res, mt)
var duration time.Duration
switch mt {
case response.NameError, response.NoData:
duration = computeTTL(msgTTL, w.minnttl, w.nttl)
case response.ServerError:
duration = w.failttl
default:
duration = computeTTL(msgTTL, w.minpttl, w.pttl)
}
// Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.)
ttl := uint32(duration.Seconds())
res.Answer = filterRRSlice(res.Answer, ttl, false)
res.Ns = filterRRSlice(res.Ns, ttl, false)
res.Extra = filterRRSlice(res.Extra, ttl, false)
if !w.do && !w.ad {
// unset AD bit if requester is not OK with DNSSEC
// But retain AD bit if requester set the AD bit in the request, per RFC6840 5.7-5.8
res.AuthenticatedData = false
}
if hasKey && duration > 0 {
if w.state.Match(res) {
w.set(res, key, mt, duration)
cacheSize.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.pcache.Len()))
cacheSize.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.ncache.Len()))
} else {
// Don't log it, but increment counter
cacheDrops.WithLabelValues(w.server, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
}
if w.prefetch {
return nil
}
return w.ResponseWriter.WriteMsg(res)
}
func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) {
// duration is expected > 0
// and key is valid
switch mt {
case response.NoError, response.Delegation:
if plugin.Zones(w.pexcept).Matches(m.Question[0].Name) != "" {
// zone is in exception list, do not cache
return
}
i := newItem(m, w.now(), duration)
if w.wildcardFunc != nil {
i.wildcard = w.wildcardFunc()
}
if w.pcache.Add(key, i) {
evictions.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
// when pre-fetching, remove the negative cache entry if it exists
if w.prefetch {
w.ncache.Remove(key)
}
case response.NameError, response.NoData, response.ServerError:
if plugin.Zones(w.nexcept).Matches(m.Question[0].Name) != "" {
// zone is in exception list, do not cache
return
}
i := newItem(m, w.now(), duration)
if w.wildcardFunc != nil {
i.wildcard = w.wildcardFunc()
}
if w.ncache.Add(key, i) {
evictions.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Inc()
}
case response.OtherError:
// don't cache these
default:
log.Warningf("Caching called with unknown classification: %d", mt)
}
}
// Write implements the dns.ResponseWriter interface.
func (w *ResponseWriter) Write(buf []byte) (int, error) {
log.Warning("Caching called with Write: not caching reply")
if w.prefetch {
return 0, nil
}
n, err := w.ResponseWriter.Write(buf)
return n, err
}
// verifyStaleResponseWriter is a response writer that only writes messages if they should replace a
// stale cache entry, and otherwise discards them.
type verifyStaleResponseWriter struct {
*ResponseWriter
refreshed bool // set to true if the last WriteMsg wrote to ResponseWriter, false otherwise.
}
// newVerifyStaleResponseWriter returns a ResponseWriter to be used when verifying stale cache
// entries. It only forward writes if an entry was successfully refreshed according to RFC8767,
// section 4 (response is NoError or NXDomain), and ignores any other response.
func newVerifyStaleResponseWriter(w *ResponseWriter) *verifyStaleResponseWriter {
return &verifyStaleResponseWriter{
w,
false,
}
}
// WriteMsg implements the dns.ResponseWriter interface.
func (w *verifyStaleResponseWriter) WriteMsg(res *dns.Msg) error {
w.refreshed = false
if res.Rcode == dns.RcodeSuccess || res.Rcode == dns.RcodeNameError {
w.refreshed = true
return w.ResponseWriter.WriteMsg(res) // stores to the cache and send to client
}
return nil // else discard
}
const (
maxTTL = dnsutil.MaximumDefaulTTL
minTTL = dnsutil.MinimalDefaultTTL
maxNTTL = dnsutil.MaximumDefaulTTL / 2
minNTTL = dnsutil.MinimalDefaultTTL
defaultCap = 10000 // default capacity of the cache.
// Success is the class for caching positive caching.
Success = "success"
// Denial is the class defined for negative caching.
Denial = "denial"
)
package cache
import "github.com/miekg/dns"
// filterRRSlice filters out OPT RRs, and sets all RR TTLs to ttl.
// If dup is true the RRs in rrs are _copied_ before adjusting their
// TTL and the slice of copied RRs is returned.
func filterRRSlice(rrs []dns.RR, ttl uint32, dup bool) []dns.RR {
j := 0
rs := make([]dns.RR, len(rrs))
for _, r := range rrs {
if r.Header().Rrtype == dns.TypeOPT {
continue
}
if dup {
rs[j] = dns.Copy(r)
} else {
rs[j] = r
}
rs[j].Header().Ttl = ttl
j++
}
return rs[:j]
}
// Package freq keeps track of last X seen events. The events themselves are not stored
// here. So the Freq type should be added next to the thing it is tracking.
package freq
import (
"sync"
"time"
)
// Freq tracks the frequencies of things.
type Freq struct {
// Last time we saw a query for this element.
last time.Time
// Number of this in the last time slice.
hits int
sync.RWMutex
}
// New returns a new initialized Freq.
func New(t time.Time) *Freq {
return &Freq{last: t, hits: 0}
}
// Update updates the number of hits. Last time seen will be set to now.
// If the last time we've seen this entity is within now - d, we increment hits, otherwise
// we reset hits to 1. It returns the number of hits.
func (f *Freq) Update(d time.Duration, now time.Time) int {
earliest := now.Add(-1 * d)
f.Lock()
defer f.Unlock()
if f.last.Before(earliest) {
f.last = now
f.hits = 1
return f.hits
}
f.last = now
f.hits++
return f.hits
}
// Hits returns the number of hits that we have seen, according to the updates we have done to f.
func (f *Freq) Hits() int {
f.RLock()
defer f.RUnlock()
return f.hits
}
// Reset resets f to time t and hits to hits.
func (f *Freq) Reset(t time.Time, hits int) {
f.Lock()
defer f.Unlock()
f.last = t
f.hits = hits
}
//go:build gofuzz
package cache
import (
"github.com/coredns/coredns/plugin/pkg/fuzz"
)
// Fuzz fuzzes cache.
func Fuzz(data []byte) int {
return fuzz.Do(New(), data)
}
package cache
import (
"context"
"math"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// ServeDNS implements the plugin.Handler interface.
func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rc := r.Copy() // We potentially modify r, to prevent other plugins from seeing this (r is a pointer), copy r into rc.
state := request.Request{W: w, Req: rc}
do := state.Do()
cd := r.CheckingDisabled
ad := r.AuthenticatedData
zone := plugin.Zones(c.Zones).Matches(state.Name())
if zone == "" {
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, rc)
}
now := c.now().UTC()
server := metrics.WithServer(ctx)
// On cache refresh, we will just use the DO bit from the incoming query for the refresh since we key our cache
// with the query DO bit. That means two separate cache items for the query DO bit true or false. In the situation
// in which upstream doesn't support DNSSEC, the two cache items will effectively be the same. Regardless, any
// DNSSEC RRs in the response are written to cache with the response.
i := c.getIgnoreTTL(now, state, server)
if i == nil {
crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do, ad: ad, cd: cd,
nexcept: c.nexcept, pexcept: c.pexcept, wildcardFunc: wildcardFunc(ctx)}
return c.doRefresh(ctx, state, crr)
}
ttl := i.ttl(now)
if ttl < 0 {
// serve stale behavior
if c.verifyStale {
crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do, cd: cd}
cw := newVerifyStaleResponseWriter(crr)
ret, err := c.doRefresh(ctx, state, cw)
if cw.refreshed {
return ret, err
}
}
// Adjust the time to get a 0 TTL in the reply built from a stale item.
now = now.Add(time.Duration(ttl) * time.Second)
if !c.verifyStale {
cw := newPrefetchResponseWriter(server, state, c)
go c.doPrefetch(ctx, state, cw, i, now)
}
servedStale.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc()
} else if c.shouldPrefetch(i, now) {
cw := newPrefetchResponseWriter(server, state, c)
go c.doPrefetch(ctx, state, cw, i, now)
}
if i.wildcard != "" {
// Set wildcard source record name to metadata
metadata.SetValueFunc(ctx, "zone/wildcard", func() string {
return i.wildcard
})
}
if c.keepttl {
// If keepttl is enabled we fake the current time to the stored
// one so that we always get the original TTL
now = i.stored
}
resp := i.toMsg(r, now, do, ad)
w.WriteMsg(resp)
return dns.RcodeSuccess, nil
}
func wildcardFunc(ctx context.Context) func() string {
return func() string {
// Get wildcard source record name from metadata
if f := metadata.ValueFunc(ctx, "zone/wildcard"); f != nil {
return f()
}
return ""
}
}
func (c *Cache) doPrefetch(ctx context.Context, state request.Request, cw *ResponseWriter, i *item, now time.Time) {
cachePrefetches.WithLabelValues(cw.server, c.zonesMetricLabel, c.viewMetricLabel).Inc()
c.doRefresh(ctx, state, cw)
// When prefetching we loose the item i, and with it the frequency
// that we've gathered sofar. See we copy the frequencies info back
// into the new item that was stored in the cache.
if i1 := c.exists(state); i1 != nil {
i1.Reset(now, i.Hits())
}
}
func (c *Cache) doRefresh(ctx context.Context, state request.Request, cw dns.ResponseWriter) (int, error) {
return plugin.NextOrFailure(c.Name(), c.Next, ctx, cw, state.Req)
}
func (c *Cache) shouldPrefetch(i *item, now time.Time) bool {
if c.prefetch <= 0 {
return false
}
i.Update(c.duration, now)
threshold := int(math.Ceil(float64(c.percentage) / 100 * float64(i.origTTL)))
return i.Hits() >= c.prefetch && i.ttl(now) <= threshold
}
// Name implements the Handler interface.
func (c *Cache) Name() string { return "cache" }
// getIgnoreTTL unconditionally returns an item if it exists in the cache.
func (c *Cache) getIgnoreTTL(now time.Time, state request.Request, server string) *item {
k := hash(state.Name(), state.QType(), state.Do(), state.Req.CheckingDisabled)
cacheRequests.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc()
if i, ok := c.ncache.Get(k); ok {
itm := i.(*item)
ttl := itm.ttl(now)
if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) {
cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel, c.viewMetricLabel).Inc()
return i.(*item)
}
}
if i, ok := c.pcache.Get(k); ok {
itm := i.(*item)
ttl := itm.ttl(now)
if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) {
cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel, c.viewMetricLabel).Inc()
return i.(*item)
}
}
cacheMisses.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc()
return nil
}
func (c *Cache) exists(state request.Request) *item {
k := hash(state.Name(), state.QType(), state.Do(), state.Req.CheckingDisabled)
if i, ok := c.ncache.Get(k); ok {
return i.(*item)
}
if i, ok := c.pcache.Get(k); ok {
return i.(*item)
}
return nil
}
package cache
import (
"strings"
"time"
"github.com/coredns/coredns/plugin/cache/freq"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type item struct {
Name string
QType uint16
Rcode int
AuthenticatedData bool
RecursionAvailable bool
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
wildcard string
origTTL uint32
stored time.Time
*freq.Freq
}
func newItem(m *dns.Msg, now time.Time, d time.Duration) *item {
i := new(item)
if len(m.Question) != 0 {
i.Name = m.Question[0].Name
i.QType = m.Question[0].Qtype
}
i.Rcode = m.Rcode
i.AuthenticatedData = m.AuthenticatedData
i.RecursionAvailable = m.RecursionAvailable
i.Answer = m.Answer
i.Ns = m.Ns
i.Extra = make([]dns.RR, len(m.Extra))
// Don't copy OPT records as these are hop-by-hop.
j := 0
for _, e := range m.Extra {
if e.Header().Rrtype == dns.TypeOPT {
continue
}
i.Extra[j] = e
j++
}
i.Extra = i.Extra[:j]
i.origTTL = uint32(d.Seconds())
i.stored = now.UTC()
i.Freq = new(freq.Freq)
return i
}
// toMsg turns i into a message, it tailors the reply to m.
// The Authoritative bit should be set to 0, but some client stub resolver implementations, most notably,
// on some legacy systems(e.g. ubuntu 14.04 with glib version 2.20), low-level glibc function `getaddrinfo`
// useb by Python/Ruby/etc.. will discard answers that do not have this bit set.
// So we're forced to always set this to 1; regardless if the answer came from the cache or not.
// On newer systems(e.g. ubuntu 16.04 with glib version 2.23), this issue is resolved.
// So we may set this bit back to 0 in the future ?
func (i *item) toMsg(m *dns.Msg, now time.Time, do bool, ad bool) *dns.Msg {
m1 := new(dns.Msg)
m1.SetReply(m)
// Set this to true as some DNS clients discard the *entire* packet when it's non-authoritative.
// This is probably not according to spec, but the bit itself is not super useful as this point, so
// just set it to true.
m1.Authoritative = true
m1.AuthenticatedData = i.AuthenticatedData
if !do && !ad {
// When DNSSEC was not wanted, it can't be authenticated data.
// However, retain the AD bit if the requester set the AD bit, per RFC6840 5.7-5.8
m1.AuthenticatedData = false
}
m1.RecursionAvailable = i.RecursionAvailable
m1.Rcode = i.Rcode
m1.Answer = make([]dns.RR, len(i.Answer))
m1.Ns = make([]dns.RR, len(i.Ns))
m1.Extra = make([]dns.RR, len(i.Extra))
ttl := uint32(i.ttl(now))
m1.Answer = filterRRSlice(i.Answer, ttl, true)
m1.Ns = filterRRSlice(i.Ns, ttl, true)
m1.Extra = filterRRSlice(i.Extra, ttl, true)
return m1
}
func (i *item) ttl(now time.Time) int {
ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
return ttl
}
func (i *item) matches(state request.Request) bool {
if state.QType() == i.QType && strings.EqualFold(state.QName(), i.Name) {
return true
}
return false
}
package cache
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/cache"
clog "github.com/coredns/coredns/plugin/pkg/log"
)
var log = clog.NewWithPlugin("cache")
func init() { plugin.Register("cache", setup) }
func setup(c *caddy.Controller) error {
ca, err := cacheParse(c)
if err != nil {
return plugin.Error("cache", err)
}
c.OnStartup(func() error {
ca.viewMetricLabel = dnsserver.GetConfig(c).ViewName
return nil
})
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
ca.Next = next
return ca
})
return nil
}
func cacheParse(c *caddy.Controller) (*Cache, error) {
ca := New()
j := 0
for c.Next() {
if j > 0 {
return nil, plugin.ErrOnce
}
j++
// cache [ttl] [zones..]
args := c.RemainingArgs()
if len(args) > 0 {
// first args may be just a number, then it is the ttl, if not it is a zone
ttl, err := strconv.Atoi(args[0])
if err == nil {
// Reserve 0 (and smaller for future things)
if ttl <= 0 {
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", ttl)
}
ca.pttl = time.Duration(ttl) * time.Second
ca.nttl = time.Duration(ttl) * time.Second
args = args[1:]
}
}
origins := plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys)
// Refinements? In an extra block.
for c.NextBlock() {
switch c.Val() {
// first number is cap, second is an new ttl
case Success:
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
pcap, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
ca.pcap = pcap
if len(args) > 1 {
pttl, err := strconv.Atoi(args[1])
if err != nil {
return nil, err
}
// Reserve 0 (and smaller for future things)
if pttl <= 0 {
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", pttl)
}
ca.pttl = time.Duration(pttl) * time.Second
if len(args) > 2 {
minpttl, err := strconv.Atoi(args[2])
if err != nil {
return nil, err
}
// Reserve < 0
if minpttl < 0 {
return nil, fmt.Errorf("cache min TTL can not be negative: %d", minpttl)
}
ca.minpttl = time.Duration(minpttl) * time.Second
}
}
case Denial:
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
ncap, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
ca.ncap = ncap
if len(args) > 1 {
nttl, err := strconv.Atoi(args[1])
if err != nil {
return nil, err
}
// Reserve 0 (and smaller for future things)
if nttl <= 0 {
return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", nttl)
}
ca.nttl = time.Duration(nttl) * time.Second
if len(args) > 2 {
minnttl, err := strconv.Atoi(args[2])
if err != nil {
return nil, err
}
// Reserve < 0
if minnttl < 0 {
return nil, fmt.Errorf("cache min TTL can not be negative: %d", minnttl)
}
ca.minnttl = time.Duration(minnttl) * time.Second
}
}
case "prefetch":
args := c.RemainingArgs()
if len(args) == 0 || len(args) > 3 {
return nil, c.ArgErr()
}
amount, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
if amount < 0 {
return nil, fmt.Errorf("prefetch amount should be positive: %d", amount)
}
ca.prefetch = amount
if len(args) > 1 {
dur, err := time.ParseDuration(args[1])
if err != nil {
return nil, err
}
ca.duration = dur
}
if len(args) > 2 {
pct := args[2]
if x := pct[len(pct)-1]; x != '%' {
return nil, fmt.Errorf("last character of percentage should be `%%`, but is: %q", x)
}
pct = pct[:len(pct)-1]
num, err := strconv.Atoi(pct)
if err != nil {
return nil, err
}
if num < 10 || num > 90 {
return nil, fmt.Errorf("percentage should fall in range [10, 90]: %d", num)
}
ca.percentage = num
}
case "serve_stale":
args := c.RemainingArgs()
if len(args) > 2 {
return nil, c.ArgErr()
}
ca.staleUpTo = 1 * time.Hour
if len(args) > 0 {
d, err := time.ParseDuration(args[0])
if err != nil {
return nil, err
}
if d < 0 {
return nil, errors.New("invalid negative duration for serve_stale")
}
ca.staleUpTo = d
}
ca.verifyStale = false
if len(args) > 1 {
mode := strings.ToLower(args[1])
if mode != "immediate" && mode != "verify" {
return nil, fmt.Errorf("invalid value for serve_stale refresh mode: %s", mode)
}
ca.verifyStale = mode == "verify"
}
case "servfail":
args := c.RemainingArgs()
if len(args) != 1 {
return nil, c.ArgErr()
}
d, err := time.ParseDuration(args[0])
if err != nil {
return nil, err
}
if d < 0 {
return nil, errors.New("invalid negative ttl for servfail")
}
if d > 5*time.Minute {
// RFC 2308 prohibits caching SERVFAIL longer than 5 minutes
return nil, errors.New("caching SERVFAIL responses over 5 minutes is not permitted")
}
ca.failttl = d
case "disable":
// disable [success|denial] [zones]...
args := c.RemainingArgs()
if len(args) < 1 {
return nil, c.ArgErr()
}
var zones []string
if len(args) > 1 {
for _, z := range args[1:] { // args[1:] define the list of zones to disable
nz := plugin.Name(z).Normalize()
if nz == "" {
return nil, fmt.Errorf("invalid disabled zone: %s", z)
}
zones = append(zones, nz)
}
} else {
// if no zones specified, default to root
zones = []string{"."}
}
switch args[0] { // args[0] defines which cache to disable
case Denial:
ca.nexcept = zones
case Success:
ca.pexcept = zones
default:
return nil, fmt.Errorf("cache type for disable must be %q or %q", Success, Denial)
}
case "keepttl":
args := c.RemainingArgs()
if len(args) != 0 {
return nil, c.ArgErr()
}
ca.keepttl = true
default:
return nil, c.ArgErr()
}
}
ca.Zones = origins
ca.zonesMetricLabel = strings.Join(origins, ",")
ca.pcache = cache.New(ca.pcap)
ca.ncache = cache.New(ca.ncap)
}
return ca, nil
}
// Package chaos implements a plugin that answer to 'CH version.bind TXT' type queries.
package chaos
import (
"context"
"math/rand"
"os"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Chaos allows CoreDNS to reply to CH TXT queries and return author or
// version information.
type Chaos struct {
Next plugin.Handler
Version string
Authors []string
}
// ServeDNS implements the plugin.Handler interface.
func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT {
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r)
}
m := new(dns.Msg)
m.SetReply(r)
hdr := dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0}
switch state.Name() {
default:
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r)
case "authors.bind.":
rnd := rand.New(rand.NewSource(time.Now().Unix()))
for _, i := range rnd.Perm(len(c.Authors)) {
m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{c.Authors[i]}})
}
case "version.bind.", "version.server.":
m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{c.Version}}}
case "hostname.bind.", "id.server.":
hostname, err := os.Hostname()
if err != nil {
hostname = "localhost"
}
m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{trim(hostname)}}}
}
w.WriteMsg(m)
return 0, nil
}
// Name implements the Handler interface.
func (c Chaos) Name() string { return "chaos" }
//go:build gofuzz
package chaos
import (
"github.com/coredns/coredns/plugin/pkg/fuzz"
)
// Fuzz fuzzes cache.
func Fuzz(data []byte) int {
c := Chaos{}
return fuzz.Do(c, data)
}
//go:generate go run owners_generate.go
package chaos
import (
"sort"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("chaos", setup) }
func setup(c *caddy.Controller) error {
version, authors, err := parse(c)
if err != nil {
return plugin.Error("chaos", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Chaos{Next: next, Version: version, Authors: authors}
})
return nil
}
func parse(c *caddy.Controller) (string, []string, error) {
// Set here so we pick up AppName and AppVersion that get set in coremain's init().
chaosVersion = caddy.AppName + "-" + caddy.AppVersion
version := ""
if c.Next() {
args := c.RemainingArgs()
if len(args) == 0 {
return trim(chaosVersion), Owners, nil
}
if len(args) == 1 {
return trim(args[0]), Owners, nil
}
version = args[0]
authors := make(map[string]struct{})
for _, a := range args[1:] {
authors[a] = struct{}{}
}
list := []string{}
for k := range authors {
k = trim(k) // limit size to 255 chars
list = append(list, k)
}
sort.Strings(list)
return version, list, nil
}
return version, Owners, nil
}
func trim(s string) string {
if len(s) < 256 {
return s
}
return s[:255]
}
var chaosVersion string
package debug
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("debug", setup) }
func setup(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
for c.Next() {
if c.NextArg() {
return plugin.Error("debug", c.ArgErr())
}
config.Debug = true
}
return nil
}
package debug
import (
"bytes"
"fmt"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/miekg/dns"
)
// Hexdump converts the dns message m to a hex dump Wireshark can import.
// See https://www.wireshark.org/docs/man-pages/text2pcap.html.
// This output looks like this:
//
// 00000 dc bd 01 00 00 01 00 00 00 00 00 01 07 65 78 61
// 000010 6d 70 6c 65 05 6c 6f 63 61 6c 00 00 01 00 01 00
// 000020 00 29 10 00 00 00 80 00 00 00
// 00002a
//
// Hexdump will use log.Debug to write the dump to the log, each line
// is prefixed with 'debug: ' so the data can be easily extracted.
//
// msg will prefix the pcap dump.
func Hexdump(m *dns.Msg, v ...interface{}) {
if !log.D.Value() {
return
}
buf, _ := m.Pack()
if len(buf) == 0 {
return
}
out := "\n" + string(hexdump(buf))
v = append(v, out)
log.Debug(v...)
}
// Hexdumpf dumps a DNS message as Hexdump, but allows a format string.
func Hexdumpf(m *dns.Msg, format string, v ...interface{}) {
if !log.D.Value() {
return
}
buf, _ := m.Pack()
if len(buf) == 0 {
return
}
format += "\n%s"
v = append(v, hexdump(buf))
log.Debugf(format, v...)
}
func hexdump(data []byte) []byte {
b := new(bytes.Buffer)
newline := ""
for i := range data {
if i%16 == 0 {
fmt.Fprintf(b, "%s%s%06x", newline, prefix, i)
newline = "\n"
}
fmt.Fprintf(b, " %02x", data[i])
}
fmt.Fprintf(b, "\n%s%06x", prefix, len(data))
return b.Bytes()
}
const prefix = "debug: "
package dnstap
import (
"io"
"time"
tap "github.com/dnstap/golang-dnstap"
fs "github.com/farsightsec/golang-framestream"
"google.golang.org/protobuf/proto"
)
// encoder wraps a golang-framestream.Encoder.
type encoder struct {
fs *fs.Encoder
}
func newEncoder(w io.Writer, timeout time.Duration) (*encoder, error) {
fs, err := fs.NewEncoder(w, &fs.EncoderOptions{
ContentType: []byte("protobuf:dnstap.Dnstap"),
Bidirectional: true,
Timeout: timeout,
})
if err != nil {
return nil, err
}
return &encoder{fs}, nil
}
func (e *encoder) writeMsg(msg *tap.Dnstap) error {
buf, err := proto.Marshal(msg)
if err != nil {
return err
}
_, err = e.fs.Write(buf) // n < len(buf) should return an error?
return err
}
func (e *encoder) flush() error { return e.fs.Flush() }
func (e *encoder) close() error { return e.fs.Close() }
package dnstap
import (
"context"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/dnstap/msg"
"github.com/coredns/coredns/plugin/pkg/replacer"
"github.com/coredns/coredns/request"
tap "github.com/dnstap/golang-dnstap"
"github.com/miekg/dns"
)
// Dnstap is the dnstap handler.
type Dnstap struct {
Next plugin.Handler
io tapper
repl replacer.Replacer
// IncludeRawMessage will include the raw DNS message into the dnstap messages if true.
IncludeRawMessage bool
Identity []byte
Version []byte
ExtraFormat string
MultipleTcpWriteBuf int // *Mb
MultipleQueue int // *10000
}
// TapMessage sends the message m to the dnstap interface, without populating "Extra" field.
func (h *Dnstap) TapMessage(m *tap.Message) {
if h.ExtraFormat == "" {
h.tapWithExtra(m, nil)
} else {
h.tapWithExtra(m, []byte(h.ExtraFormat))
}
}
// TapMessageWithMetadata sends the message m to the dnstap interface, with "Extra" field being populated.
func (h *Dnstap) TapMessageWithMetadata(ctx context.Context, m *tap.Message, state request.Request) {
if h.ExtraFormat == "" {
h.tapWithExtra(m, nil)
return
}
extraStr := h.repl.Replace(ctx, state, nil, h.ExtraFormat)
h.tapWithExtra(m, []byte(extraStr))
}
func (h *Dnstap) tapWithExtra(m *tap.Message, extra []byte) {
t := tap.Dnstap_MESSAGE
h.io.Dnstap(&tap.Dnstap{Type: &t, Message: m, Identity: h.Identity, Version: h.Version, Extra: extra})
}
func (h *Dnstap) tapQuery(ctx context.Context, w dns.ResponseWriter, query *dns.Msg, queryTime time.Time) {
q := new(tap.Message)
msg.SetQueryTime(q, queryTime)
msg.SetQueryAddress(q, w.RemoteAddr())
if h.IncludeRawMessage {
buf, _ := query.Pack()
q.QueryMessage = buf
}
msg.SetType(q, tap.Message_CLIENT_QUERY)
state := request.Request{W: w, Req: query}
h.TapMessageWithMetadata(ctx, q, state)
}
// ServeDNS logs the client query and response to dnstap and passes the dnstap Context.
func (h *Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rw := &ResponseWriter{
ResponseWriter: w,
Dnstap: h,
query: r,
ctx: ctx,
queryTime: time.Now(),
}
// The query tap message should be sent before sending the query to the
// forwarder. Otherwise, the tap messages will come out out of order.
h.tapQuery(ctx, w, r, rw.queryTime)
return plugin.NextOrFailure(h.Name(), h.Next, ctx, rw, r)
}
// Name implements the plugin.Plugin interface.
func (h *Dnstap) Name() string { return "dnstap" }
package dnstap
import (
"crypto/tls"
"net"
"sync/atomic"
"time"
tap "github.com/dnstap/golang-dnstap"
)
const (
tcpWriteBufSize = 1024 * 1024 // there is no good explanation for why this number has this value.
queueSize = 10000 // idem.
tcpTimeout = 4 * time.Second
flushTimeout = 1 * time.Second
skipVerify = false // by default, every tls connection is verified to be secure
)
// tapper interface is used in testing to mock the Dnstap method.
type tapper interface {
Dnstap(*tap.Dnstap)
}
// dio implements the Tapper interface.
type dio struct {
endpoint string
proto string
enc *encoder
queue chan *tap.Dnstap
dropped uint32
quit chan struct{}
flushTimeout time.Duration
tcpTimeout time.Duration
skipVerify bool
tcpWriteBufSize int
}
// newIO returns a new and initialized pointer to a dio.
func newIO(proto, endpoint string, multipleQueue int, multipleTcpWriteBuf int) *dio {
return &dio{
endpoint: endpoint,
proto: proto,
queue: make(chan *tap.Dnstap, multipleQueue*queueSize),
quit: make(chan struct{}),
flushTimeout: flushTimeout,
tcpTimeout: tcpTimeout,
skipVerify: skipVerify,
tcpWriteBufSize: multipleTcpWriteBuf * tcpWriteBufSize,
}
}
func (d *dio) dial() error {
var conn net.Conn
var err error
if d.proto == "tls" {
config := &tls.Config{
InsecureSkipVerify: d.skipVerify,
}
dialer := &net.Dialer{
Timeout: d.tcpTimeout,
}
conn, err = tls.DialWithDialer(dialer, "tcp", d.endpoint, config)
if err != nil {
return err
}
} else {
conn, err = net.DialTimeout(d.proto, d.endpoint, d.tcpTimeout)
if err != nil {
return err
}
}
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetWriteBuffer(d.tcpWriteBufSize)
tcpConn.SetNoDelay(false)
}
d.enc, err = newEncoder(conn, d.tcpTimeout)
return err
}
// Connect connects to the dnstap endpoint.
func (d *dio) connect() error {
err := d.dial()
go d.serve()
return err
}
// Dnstap enqueues the payload for log.
func (d *dio) Dnstap(payload *tap.Dnstap) {
select {
case d.queue <- payload:
default:
atomic.AddUint32(&d.dropped, 1)
}
}
// close waits until the I/O routine is finished to return.
func (d *dio) close() { close(d.quit) }
func (d *dio) write(payload *tap.Dnstap) error {
if d.enc == nil {
atomic.AddUint32(&d.dropped, 1)
return nil
}
if err := d.enc.writeMsg(payload); err != nil {
atomic.AddUint32(&d.dropped, 1)
return err
}
return nil
}
func (d *dio) serve() {
timeout := time.NewTimer(d.flushTimeout)
defer timeout.Stop()
for {
timeout.Reset(d.flushTimeout)
select {
case <-d.quit:
if d.enc == nil {
return
}
d.enc.flush()
d.enc.close()
return
case payload := <-d.queue:
if err := d.write(payload); err != nil {
d.dial()
}
case <-timeout.C:
if dropped := atomic.SwapUint32(&d.dropped, 0); dropped > 0 {
log.Warningf("Dropped dnstap messages: %d", dropped)
}
if d.enc == nil {
d.dial()
} else {
d.enc.flush()
}
}
}
}
package msg
import (
"fmt"
"net"
"time"
tap "github.com/dnstap/golang-dnstap"
)
var (
protoUDP = tap.SocketProtocol_UDP
protoTCP = tap.SocketProtocol_TCP
familyINET = tap.SocketFamily_INET
familyINET6 = tap.SocketFamily_INET6
)
// SetQueryAddress adds the query address to the message. This also sets the SocketFamily and SocketProtocol.
func SetQueryAddress(t *tap.Message, addr net.Addr) error {
t.SocketFamily = &familyINET
switch a := addr.(type) {
case *net.TCPAddr:
t.SocketProtocol = &protoTCP
t.QueryAddress = a.IP
p := uint32(a.Port)
t.QueryPort = &p
if a.IP.To4() == nil {
t.SocketFamily = &familyINET6
}
return nil
case *net.UDPAddr:
t.SocketProtocol = &protoUDP
t.QueryAddress = a.IP
p := uint32(a.Port)
t.QueryPort = &p
if a.IP.To4() == nil {
t.SocketFamily = &familyINET6
}
return nil
default:
return fmt.Errorf("unknown address type: %T", a)
}
}
// SetResponseAddress the response address to the message. This also sets the SocketFamily and SocketProtocol.
func SetResponseAddress(t *tap.Message, addr net.Addr) error {
t.SocketFamily = &familyINET
switch a := addr.(type) {
case *net.TCPAddr:
t.SocketProtocol = &protoTCP
t.ResponseAddress = a.IP
p := uint32(a.Port)
t.ResponsePort = &p
if a.IP.To4() == nil {
t.SocketFamily = &familyINET6
}
return nil
case *net.UDPAddr:
t.SocketProtocol = &protoUDP
t.ResponseAddress = a.IP
p := uint32(a.Port)
t.ResponsePort = &p
if a.IP.To4() == nil {
t.SocketFamily = &familyINET6
}
return nil
default:
return fmt.Errorf("unknown address type: %T", a)
}
}
// SetQueryTime sets the time of the query in t.
func SetQueryTime(t *tap.Message, ti time.Time) {
qts := uint64(ti.Unix())
qtn := uint32(ti.Nanosecond())
t.QueryTimeSec = &qts
t.QueryTimeNsec = &qtn
}
// SetResponseTime sets the time of the response in t.
func SetResponseTime(t *tap.Message, ti time.Time) {
rts := uint64(ti.Unix())
rtn := uint32(ti.Nanosecond())
t.ResponseTimeSec = &rts
t.ResponseTimeNsec = &rtn
}
// SetType sets the type in t.
func SetType(t *tap.Message, typ tap.Message_Type) { t.Type = &typ }
package dnstap
import (
"net/url"
"os"
"strconv"
"strings"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/replacer"
)
var log = clog.NewWithPlugin("dnstap")
func init() { plugin.Register("dnstap", setup) }
func parseConfig(c *caddy.Controller) ([]*Dnstap, error) {
dnstaps := []*Dnstap{}
for c.Next() { // directive name
d := Dnstap{
MultipleTcpWriteBuf: 1,
MultipleQueue: 1,
}
d.repl = replacer.New()
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
endpoint := args[0]
if len(args) >= 3 {
d.MultipleTcpWriteBuf, _ = strconv.Atoi(args[2])
}
if len(args) >= 4 {
d.MultipleQueue, _ = strconv.Atoi(args[3])
}
var dio *dio
if strings.HasPrefix(endpoint, "tls://") {
// remote network endpoint
endpointURL, err := url.Parse(endpoint)
if err != nil {
return nil, c.ArgErr()
}
dio = newIO("tls", endpointURL.Host, d.MultipleQueue, d.MultipleTcpWriteBuf)
d.io = dio
} else if strings.HasPrefix(endpoint, "tcp://") {
// remote network endpoint
endpointURL, err := url.Parse(endpoint)
if err != nil {
return nil, c.ArgErr()
}
dio = newIO("tcp", endpointURL.Host, d.MultipleQueue, d.MultipleTcpWriteBuf)
d.io = dio
} else {
endpoint = strings.TrimPrefix(endpoint, "unix://")
dio = newIO("unix", endpoint, d.MultipleQueue, d.MultipleTcpWriteBuf)
d.io = dio
}
d.IncludeRawMessage = len(args) >= 2 && args[1] == "full"
hostname, _ := os.Hostname()
d.Identity = []byte(hostname)
d.Version = []byte(caddy.AppName + "-" + caddy.AppVersion)
for c.NextBlock() {
switch c.Val() {
case "skipverify":
{
dio.skipVerify = true
}
case "identity":
{
if !c.NextArg() {
return nil, c.ArgErr()
}
d.Identity = []byte(c.Val())
}
case "version":
{
if !c.NextArg() {
return nil, c.ArgErr()
}
d.Version = []byte(c.Val())
}
case "extra":
{
if !c.NextArg() {
return nil, c.ArgErr()
}
d.ExtraFormat = c.Val()
}
}
}
dnstaps = append(dnstaps, &d)
}
return dnstaps, nil
}
func setup(c *caddy.Controller) error {
dnstaps, err := parseConfig(c)
if err != nil {
return plugin.Error("dnstap", err)
}
for i := range dnstaps {
dnstap := dnstaps[i]
c.OnStartup(func() error {
if err := dnstap.io.(*dio).connect(); err != nil {
log.Errorf("No connection to dnstap endpoint: %s", err)
}
return nil
})
c.OnRestart(func() error {
dnstap.io.(*dio).close()
return nil
})
c.OnFinalShutdown(func() error {
dnstap.io.(*dio).close()
return nil
})
if i == len(dnstaps)-1 {
// last dnstap plugin in block: point next to next plugin
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
dnstap.Next = next
return dnstap
})
} else {
// not last dnstap plugin in block: point next to next dnstap
nextDnstap := dnstaps[i+1]
dnsserver.GetConfig(c).AddPlugin(func(plugin.Handler) plugin.Handler {
dnstap.Next = nextDnstap
return dnstap
})
}
}
return nil
}
package dnstap
import (
"context"
"time"
"github.com/coredns/coredns/plugin/dnstap/msg"
"github.com/coredns/coredns/request"
tap "github.com/dnstap/golang-dnstap"
"github.com/miekg/dns"
)
// ResponseWriter captures the client response and logs the query to dnstap.
type ResponseWriter struct {
queryTime time.Time
query *dns.Msg
ctx context.Context
dns.ResponseWriter
*Dnstap
}
// WriteMsg writes back the response to the client and THEN works on logging the request and response to dnstap.
func (w *ResponseWriter) WriteMsg(resp *dns.Msg) error {
err := w.ResponseWriter.WriteMsg(resp)
if err != nil {
return err
}
r := new(tap.Message)
msg.SetQueryTime(r, w.queryTime)
msg.SetResponseTime(r, time.Now())
msg.SetQueryAddress(r, w.RemoteAddr())
if w.IncludeRawMessage {
buf, _ := resp.Pack()
r.ResponseMessage = buf
}
msg.SetType(r, tap.Message_CLIENT_RESPONSE)
state := request.Request{W: w.ResponseWriter, Req: w.query}
w.TapMessageWithMetadata(w.ctx, r, state)
return nil
}
package plugin
import "context"
// Done is a non-blocking function that returns true if the context has been canceled.
func Done(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}
package msg
import (
"path"
"strings"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/miekg/dns"
)
// Path converts a domainname to an etcd path. If s looks like service.staging.skydns.local.,
// the resulting key will be /skydns/local/skydns/staging/service .
func Path(s, prefix string) string {
l := dns.SplitDomainName(s)
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
return path.Join(append([]string{"/" + prefix + "/"}, l...)...)
}
// Domain is the opposite of Path.
func Domain(s string) string {
l := strings.Split(s, "/")
if l[len(l)-1] == "" {
l = l[:len(l)-1]
}
// start with 1, to strip /skydns
for i, j := 1, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
return dnsutil.Join(l[1 : len(l)-1]...)
}
// PathWithWildcard acts as Path, but if a name contains wildcards (* or any), the name will be
// chopped of before the (first) wildcard, and we do a higher level search and
// later find the matching names. So service.*.skydns.local, will look for all
// services under skydns.local and will later check for names that match
// service.*.skydns.local. If a wildcard is found the returned bool is true.
func PathWithWildcard(s, prefix string) (string, bool) {
l := dns.SplitDomainName(s)
for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 {
l[i], l[j] = l[j], l[i]
}
for i, k := range l {
if k == "*" || k == "any" {
return path.Join(append([]string{"/" + prefix + "/"}, l[:i]...)...), true
}
}
return path.Join(append([]string{"/" + prefix + "/"}, l...)...), false
}
// Package msg defines the Service structure which is used for service discovery.
package msg
import (
"net"
"strings"
"github.com/miekg/dns"
)
// Service defines a discoverable service in etcd. It is the rdata from a SRV
// record, but with a twist. Host (Target in SRV) must be a domain name, but
// if it looks like an IP address (4/6), we will treat it like an IP address.
type Service struct {
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
Priority int `json:"priority,omitempty"`
Weight int `json:"weight,omitempty"`
Text string `json:"text,omitempty"`
Mail bool `json:"mail,omitempty"` // Be an MX record. Priority becomes Preference.
TTL uint32 `json:"ttl,omitempty"`
// When a SRV record with a "Host: IP-address" is added, we synthesize
// a srv.Target domain name. Normally we convert the full Key where
// the record lives to a DNS name and use this as the srv.Target. When
// TargetStrip > 0 we strip the left most TargetStrip labels from the
// DNS name.
TargetStrip int `json:"targetstrip,omitempty"`
// Group is used to group (or *not* to group) different services
// together. Services with an identical Group are returned in the same
// answer.
Group string `json:"group,omitempty"`
// Etcd key where we found this service and ignored from json un-/marshalling
Key string `json:"-"`
}
// NewSRV returns a new SRV record based on the Service.
func (s *Service) NewSRV(name string, weight uint16) *dns.SRV {
host := dns.Fqdn(s.Host)
if s.TargetStrip > 0 {
host = targetStrip(host, s.TargetStrip)
}
return &dns.SRV{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: s.TTL},
Priority: uint16(s.Priority), Weight: weight, Port: uint16(s.Port), Target: host}
}
// NewMX returns a new MX record based on the Service.
func (s *Service) NewMX(name string) *dns.MX {
host := dns.Fqdn(s.Host)
if s.TargetStrip > 0 {
host = targetStrip(host, s.TargetStrip)
}
return &dns.MX{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: s.TTL},
Preference: uint16(s.Priority), Mx: host}
}
// NewA returns a new A record based on the Service.
func (s *Service) NewA(name string, ip net.IP) *dns.A {
return &dns.A{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: s.TTL}, A: ip}
}
// NewAAAA returns a new AAAA record based on the Service.
func (s *Service) NewAAAA(name string, ip net.IP) *dns.AAAA {
return &dns.AAAA{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: s.TTL}, AAAA: ip}
}
// NewCNAME returns a new CNAME record based on the Service.
func (s *Service) NewCNAME(name string, target string) *dns.CNAME {
return &dns.CNAME{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: s.TTL}, Target: dns.Fqdn(target)}
}
// NewTXT returns a new TXT record based on the Service.
func (s *Service) NewTXT(name string) *dns.TXT {
return &dns.TXT{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: s.TTL}, Txt: split255(s.Text)}
}
// NewPTR returns a new PTR record based on the Service.
func (s *Service) NewPTR(name string, target string) *dns.PTR {
return &dns.PTR{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: s.TTL}, Ptr: dns.Fqdn(target)}
}
// NewNS returns a new NS record based on the Service.
func (s *Service) NewNS(name string) *dns.NS {
host := dns.Fqdn(s.Host)
if s.TargetStrip > 0 {
host = targetStrip(host, s.TargetStrip)
}
return &dns.NS{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: s.TTL}, Ns: host}
}
// Group checks the services in sx, it looks for a Group attribute on the shortest
// keys. If there are multiple shortest keys *and* the group attribute disagrees (and
// is not empty), we don't consider it a group.
// If a group is found, only services with *that* group (or no group) will be returned.
func Group(sx []Service) []Service {
if len(sx) == 0 {
return sx
}
// Shortest key with group attribute sets the group for this set.
group := sx[0].Group
slashes := strings.Count(sx[0].Key, "/")
length := make([]int, len(sx))
for i, s := range sx {
x := strings.Count(s.Key, "/")
length[i] = x
if x < slashes {
if s.Group == "" {
break
}
slashes = x
group = s.Group
}
}
if group == "" {
return sx
}
ret := []Service{} // with slice-tricks in sx we can prolly save this allocation (TODO)
for i, s := range sx {
if s.Group == "" {
ret = append(ret, s)
continue
}
// Disagreement on the same level
if length[i] == slashes && s.Group != group {
return sx
}
if s.Group == group {
ret = append(ret, s)
}
}
return ret
}
// Split255 splits a string into 255 byte chunks.
func split255(s string) []string {
if len(s) < 255 {
return []string{s}
}
sx := []string{}
p, i := 0, 255
for {
if i > len(s) {
sx = append(sx, s[p:])
break
}
sx = append(sx, s[p:i])
p, i = p+255, i+255
}
return sx
}
// targetStrip strips "targetstrip" labels from the left side of the fully qualified name.
func targetStrip(name string, targetStrip int) string {
offset, end := 0, false
for range targetStrip {
offset, end = dns.NextLabel(name, offset)
}
if end {
// We overshot the name, use the original one.
offset = 0
}
name = name[offset:]
return name
}
package msg
import (
"net"
"github.com/miekg/dns"
)
// HostType returns the DNS type of what is encoded in the Service Host field. We're reusing
// dns.TypeXXX to not reinvent a new set of identifiers.
//
// dns.TypeA: the service's Host field contains an A record.
// dns.TypeAAAA: the service's Host field contains an AAAA record.
// dns.TypeCNAME: the service's Host field contains a name.
//
// Note that a service can double/triple as a TXT record or MX record.
func (s *Service) HostType() (what uint16, normalized net.IP) {
ip := net.ParseIP(s.Host)
switch {
case ip == nil:
if len(s.Text) == 0 {
return dns.TypeCNAME, nil
}
return dns.TypeTXT, nil
case ip.To4() != nil:
return dns.TypeA, ip.To4()
case ip.To4() == nil:
return dns.TypeAAAA, ip.To16()
}
// This should never be reached.
return dns.TypeNone, nil
}
package file
import (
"github.com/coredns/coredns/plugin/file/tree"
"github.com/miekg/dns"
)
// ClosestEncloser returns the closest encloser for qname.
func (z *Zone) ClosestEncloser(qname string) (*tree.Elem, bool) {
offset, end := dns.NextLabel(qname, 0)
for !end {
elem, _ := z.Search(qname)
if elem != nil {
return elem, true
}
qname = qname[offset:]
offset, end = dns.NextLabel(qname, 0)
}
return z.Search(z.origin)
}
package file
import (
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/miekg/dns"
)
// substituteDNAME performs the DNAME substitution defined by RFC 6672,
// assuming the QTYPE of the query is not DNAME. It returns an empty
// string if there is no match.
func substituteDNAME(qname, owner, target string) string {
if dns.IsSubDomain(owner, qname) && qname != owner {
labels := dns.SplitDomainName(qname)
labels = append(labels[0:len(labels)-dns.CountLabel(owner)], dns.SplitDomainName(target)...)
return dnsutil.Join(labels...)
}
return ""
}
// synthesizeCNAME returns a CNAME RR pointing to the resulting name of
// the DNAME substitution. The owner name of the CNAME is the QNAME of
// the query and the TTL is the same as the corresponding DNAME RR.
//
// It returns nil if the DNAME substitution has no match.
func synthesizeCNAME(qname string, d *dns.DNAME) *dns.CNAME {
target := substituteDNAME(qname, d.Header().Name, d.Target)
if target == "" {
return nil
}
r := new(dns.CNAME)
r.Hdr = dns.RR_Header{
Name: qname,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: d.Header().Ttl,
}
r.Target = target
return r
}
// Package file implements a file backend.
package file
import (
"context"
"fmt"
"io"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/fall"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/transfer"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("file")
type (
// File is the plugin that reads zone data from disk.
File struct {
Next plugin.Handler
Zones
transfer *transfer.Transfer
Fall fall.F
}
// Zones maps zone names to a *Zone.
Zones struct {
Z map[string]*Zone // A map mapping zone (origin) to the Zone's data
Names []string // All the keys from the map Z as a string slice.
}
)
// ServeDNS implements the plugin.Handle interface.
func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.Name()
// TODO(miek): match the qname better in the map
zone := plugin.Zones(f.Zones.Names).Matches(qname)
if zone == "" {
// If no next plugin is configured, it's more correct to return REFUSED as file acts as an authoritative server
if f.Next == nil {
return dns.RcodeRefused, nil
}
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
z, ok := f.Z[zone]
if !ok || z == nil {
return dns.RcodeServerFailure, nil
}
// If transfer is not loaded, we'll see these, answer with refused (no transfer allowed).
if state.QType() == dns.TypeAXFR || state.QType() == dns.TypeIXFR {
return dns.RcodeRefused, nil
}
// This is only for when we are a secondary zones.
if r.Opcode == dns.OpcodeNotify {
if z.isNotify(state) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
w.WriteMsg(m)
log.Infof("Notify from %s for %s: checking transfer", state.IP(), zone)
ok, err := z.shouldTransfer()
if ok {
z.TransferIn()
} else {
log.Infof("Notify from %s for %s: no SOA serial increase seen", state.IP(), zone)
}
if err != nil {
log.Warningf("Notify from %s for %s: failed primary check: %s", state.IP(), zone, err)
}
return dns.RcodeSuccess, nil
}
log.Infof("Dropping notify from %s for %s", state.IP(), zone)
return dns.RcodeSuccess, nil
}
z.RLock()
exp := z.Expired
z.RUnlock()
if exp {
log.Errorf("Zone %s is expired", zone)
return dns.RcodeServerFailure, nil
}
answer, ns, extra, result := z.Lookup(ctx, state, qname)
// Only on NXDOMAIN we will fallthrough.
// `z.Lookup` can also return NOERROR for NXDOMAIN see comment see comment "Hacky way to get around empty-non-terminals" inside `Zone.Lookup`.
// It's safe to fallthrough with `result` Sucess (NOERROR) since all other return points in Lookup with Success have answer(s).
if len(answer) == 0 && (result == NameError || result == Success) && f.Fall.Through(qname) {
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
m.Answer, m.Ns, m.Extra = answer, ns, extra
switch result {
case Success:
case NoData:
case NameError:
m.Rcode = dns.RcodeNameError
case Delegation:
m.Authoritative = false
case ServerFailure:
// If the result is SERVFAIL and the answer is non-empty, then the SERVFAIL came from an
// external CNAME lookup and the answer contains the CNAME with no target record. We should
// write the CNAME record to the client instead of sending an empty SERVFAIL response.
if len(m.Answer) == 0 {
return dns.RcodeServerFailure, nil
}
// The rcode in the response should be the rcode received from the target lookup. RFC 6604 section 3
m.Rcode = dns.RcodeServerFailure
}
w.WriteMsg(m)
return dns.RcodeSuccess, nil
}
// Name implements the Handler interface.
func (f File) Name() string { return "file" }
type serialErr struct {
err string
zone string
origin string
serial int64
}
func (s *serialErr) Error() string {
return fmt.Sprintf("%s for origin %s in file %s, with %d SOA serial", s.err, s.origin, s.zone, s.serial)
}
// Parse parses the zone in filename and returns a new Zone or an error.
// If serial >= 0 it will reload the zone, if the SOA hasn't changed
// it returns an error indicating nothing was read.
func Parse(f io.Reader, origin, fileName string, serial int64) (*Zone, error) {
zp := dns.NewZoneParser(f, dns.Fqdn(origin), fileName)
zp.SetIncludeAllowed(true)
z := NewZone(origin, fileName)
seenSOA := false
for rr, ok := zp.Next(); ok; rr, ok = zp.Next() {
if !seenSOA {
if s, ok := rr.(*dns.SOA); ok {
seenSOA = true
// -1 is valid serial is we failed to load the file on startup.
if serial >= 0 && s.Serial == uint32(serial) { // same serial
return nil, &serialErr{err: "no change in SOA serial", origin: origin, zone: fileName, serial: serial}
}
}
}
if err := z.Insert(rr); err != nil {
return nil, err
}
}
if !seenSOA {
return nil, fmt.Errorf("file %q has no SOA record for origin %s", fileName, origin)
}
if zp.Err() != nil {
return nil, fmt.Errorf("failed to parse file %q for origin %s with error %v", fileName, origin, zp.Err())
}
if err := zp.Err(); err != nil {
return nil, err
}
return z, nil
}
//go:build gofuzz
package file
import (
"strings"
"github.com/coredns/coredns/plugin/pkg/fuzz"
"github.com/coredns/coredns/plugin/test"
)
// Fuzz fuzzes file.
func Fuzz(data []byte) int {
name := "miek.nl."
zone, _ := Parse(strings.NewReader(fuzzMiekNL), name, "stdin", 0)
f := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{name: zone}, Names: []string{name}}}
return fuzz.Do(f, data)
}
const fuzzMiekNL = `
$TTL 30M
$ORIGIN miek.nl.
@ IN SOA linode.atoom.net. miek.miek.nl. (
1282630057 ; Serial
4H ; Refresh
1H ; Retry
7D ; Expire
4H ) ; Negative Cache TTL
IN NS linode.atoom.net.
IN NS ns-ext.nlnetlabs.nl.
IN NS omval.tednet.nl.
IN NS ext.ns.whyscream.net.
IN MX 1 aspmx.l.google.com.
IN MX 5 alt1.aspmx.l.google.com.
IN MX 5 alt2.aspmx.l.google.com.
IN MX 10 aspmx2.googlemail.com.
IN MX 10 aspmx3.googlemail.com.
IN A 139.162.196.78
IN AAAA 2a01:7e00::f03c:91ff:fef1:6735
a IN A 139.162.196.78
IN AAAA 2a01:7e00::f03c:91ff:fef1:6735
www IN CNAME a
archive IN CNAME a
srv IN SRV 10 10 8080 a.miek.nl.
mx IN MX 10 a.miek.nl.`
package file
import (
"context"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin/file/rrutil"
"github.com/coredns/coredns/plugin/file/tree"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Result is the result of a Lookup
type Result int
const (
// Success is a successful lookup.
Success Result = iota
// NameError indicates a nameerror
NameError
// Delegation indicates the lookup resulted in a delegation.
Delegation
// NoData indicates the lookup resulted in a NODATA.
NoData
// ServerFailure indicates a server failure during the lookup.
ServerFailure
)
// Lookup looks up qname and qtype in the zone. When do is true DNSSEC records are included.
// Three sets of records are returned, one for the answer, one for authority and one for the additional section.
func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) ([]dns.RR, []dns.RR, []dns.RR, Result) {
qtype := state.QType()
do := state.Do()
// If z is a secondary zone we might not have transferred it, meaning we have
// all zone context setup, except the actual record. This means (for one thing) the apex
// is empty and we don't have a SOA record.
z.RLock()
ap := z.Apex
tr := z.Tree
z.RUnlock()
if ap.SOA == nil {
return nil, nil, nil, ServerFailure
}
if qname == z.origin {
switch qtype {
case dns.TypeSOA:
return ap.soa(do), ap.ns(do), nil, Success
case dns.TypeNS:
nsrrs := ap.ns(do)
glue := tr.Glue(nsrrs, do) // technically this isn't glue
return nsrrs, nil, glue, Success
}
}
var (
found, shot bool
parts string
i int
elem, wildElem *tree.Elem
)
loop, _ := ctx.Value(dnsserver.LoopKey{}).(int)
if loop > 8 {
// We're back here for the 9th time; we have a loop and need to bail out.
// Note the answer we're returning will be incomplete (more cnames to be followed) or
// illegal (wildcard cname with multiple identical records). For now it's more important
// to protect ourselves then to give the client a valid answer. We return with an error
// to let the server handle what to do.
return nil, nil, nil, ServerFailure
}
// Lookup:
// * Per label from the right, look if it exists. We do this to find potential
// delegation records.
// * If the per-label search finds nothing, we will look for the wildcard at the
// level. If found we keep it around. If we don't find the complete name we will
// use the wildcard.
//
// Main for-loop handles delegation and finding or not finding the qname.
// If found we check if it is a CNAME/DNAME and do CNAME processing
// We also check if we have type and do a nodata response.
//
// If not found, we check the potential wildcard, and use that for further processing.
// If not found and no wildcard we will process this as an NXDOMAIN response.
for {
parts, shot = z.nameFromRight(qname, i)
// We overshot the name, break and check if we previously found something.
if shot {
break
}
elem, found = tr.Search(parts)
if !found {
// Apex will always be found, when we are here we can search for a wildcard
// and save the result of that search. So when nothing match, but we have a
// wildcard we should expand the wildcard.
wildcard := replaceWithAsteriskLabel(parts)
if wild, found := tr.Search(wildcard); found {
wildElem = wild
}
// Keep on searching, because maybe we hit an empty-non-terminal (which aren't
// stored in the tree. Only when we have match the full qname (and possible wildcard
// we can be confident that we didn't find anything.
i++
continue
}
// If we see DNAME records, we should return those.
if dnamerrs := elem.Type(dns.TypeDNAME); dnamerrs != nil {
// Only one DNAME is allowed per name. We just pick the first one to synthesize from.
dname := dnamerrs[0]
if cname := synthesizeCNAME(state.Name(), dname.(*dns.DNAME)); cname != nil {
var (
answer, ns, extra []dns.RR
rcode Result
)
// We don't need to chase CNAME chain for synthesized CNAME
if qtype == dns.TypeCNAME {
answer = []dns.RR{cname}
ns = ap.ns(do)
extra = nil
rcode = Success
} else {
ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1)
answer, ns, extra, rcode = z.externalLookup(ctx, state, elem, []dns.RR{cname})
}
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeDNAME)
dnamerrs = append(dnamerrs, sigs...)
}
// The relevant DNAME RR should be included in the answer section,
// if the DNAME is being employed as a substitution instruction.
answer = append(dnamerrs, answer...)
return answer, ns, extra, rcode
}
// The domain name that owns a DNAME record is allowed to have other RR types
// at that domain name, except those have restrictions on what they can coexist
// with (e.g. another DNAME). So there is nothing special left here.
}
// If we see NS records, it means the name as been delegated, and we should return the delegation.
if nsrrs := elem.Type(dns.TypeNS); nsrrs != nil {
// If the query is specifically for DS and the qname matches the delegated name, we should
// return the DS in the answer section and leave the rest empty, i.e. just continue the loop
// and continue searching.
if qtype == dns.TypeDS && elem.Name() == qname {
i++
continue
}
glue := tr.Glue(nsrrs, do)
if do {
dss := typeFromElem(elem, dns.TypeDS, do)
nsrrs = append(nsrrs, dss...)
}
return nil, nsrrs, glue, Delegation
}
i++
}
// What does found and !shot mean - do we ever hit it?
if found && !shot {
return nil, nil, nil, ServerFailure
}
// Found entire name.
if found && shot {
if rrs := elem.Type(dns.TypeCNAME); len(rrs) > 0 && qtype != dns.TypeCNAME {
ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1)
return z.externalLookup(ctx, state, elem, rrs)
}
rrs := elem.Type(qtype)
// NODATA
if len(rrs) == 0 {
ret := ap.soa(do)
if do {
nsec := typeFromElem(elem, dns.TypeNSEC, do)
ret = append(ret, nsec...)
}
return nil, ret, nil, NoData
}
// Additional section processing for MX, SRV. Check response and see if any of the names are in bailiwick -
// if so add IP addresses to the additional section.
additional := z.additionalProcessing(rrs, do)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, qtype)
rrs = append(rrs, sigs...)
}
return rrs, ap.ns(do), additional, Success
}
// Haven't found the original name.
// Found wildcard.
if wildElem != nil {
// set metadata value for the wildcard record that synthesized the result
metadata.SetValueFunc(ctx, "zone/wildcard", func() string {
return wildElem.Name()
})
if rrs := wildElem.TypeForWildcard(dns.TypeCNAME, qname); len(rrs) > 0 && qtype != dns.TypeCNAME {
ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1)
return z.externalLookup(ctx, state, wildElem, rrs)
}
rrs := wildElem.TypeForWildcard(qtype, qname)
// NODATA response.
if len(rrs) == 0 {
ret := ap.soa(do)
if do {
nsec := typeFromElem(wildElem, dns.TypeNSEC, do)
ret = append(ret, nsec...)
}
return nil, ret, nil, NoData
}
auth := ap.ns(do)
if do {
// An NSEC is needed to say no longer name exists under this wildcard.
if deny, found := tr.Prev(qname); found {
nsec := typeFromElem(deny, dns.TypeNSEC, do)
auth = append(auth, nsec...)
}
sigs := wildElem.TypeForWildcard(dns.TypeRRSIG, qname)
sigs = rrutil.SubTypeSignature(sigs, qtype)
rrs = append(rrs, sigs...)
}
return rrs, auth, nil, Success
}
rcode := NameError
// Hacky way to get around empty-non-terminals. If a longer name does exist, but this qname, does not, it
// must be an empty-non-terminal. If so, we do the proper NXDOMAIN handling, but set the rcode to be success.
if x, found := tr.Next(qname); found {
if dns.IsSubDomain(qname, x.Name()) {
rcode = Success
}
}
ret := ap.soa(do)
if do {
deny, found := tr.Prev(qname)
if !found {
goto Out
}
nsec := typeFromElem(deny, dns.TypeNSEC, do)
ret = append(ret, nsec...)
if rcode != NameError {
goto Out
}
ce, found := z.ClosestEncloser(qname)
// wildcard denial only for NXDOMAIN
if found {
// wildcard denial
wildcard := "*." + ce.Name()
if ss, found := tr.Prev(wildcard); found {
// Only add this nsec if it is different than the one already added
if ss.Name() != deny.Name() {
nsec := typeFromElem(ss, dns.TypeNSEC, do)
ret = append(ret, nsec...)
}
}
}
}
Out:
return nil, ret, nil, rcode
}
// typeFromElem returns the type tp from e and adds signatures (if they exist) and do is true.
func typeFromElem(elem *tree.Elem, tp uint16, do bool) []dns.RR {
rrs := elem.Type(tp)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, tp)
rrs = append(rrs, sigs...)
}
return rrs
}
func (a Apex) soa(do bool) []dns.RR {
if do {
ret := append([]dns.RR{a.SOA}, a.SIGSOA...)
return ret
}
return []dns.RR{a.SOA}
}
func (a Apex) ns(do bool) []dns.RR {
if do {
ret := append(a.NS, a.SIGNS...)
return ret
}
return a.NS
}
// externalLookup adds signatures and tries to resolve CNAMEs that point to external names.
func (z *Zone) externalLookup(ctx context.Context, state request.Request, elem *tree.Elem, rrs []dns.RR) ([]dns.RR, []dns.RR, []dns.RR, Result) {
qtype := state.QType()
do := state.Do()
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeCNAME)
rrs = append(rrs, sigs...)
}
targetName := rrs[0].(*dns.CNAME).Target
elem, _ = z.Search(targetName)
if elem == nil {
lookupRRs, result := z.doLookup(ctx, state, targetName, qtype)
rrs = append(rrs, lookupRRs...)
return rrs, z.ns(do), nil, result
}
i := 0
Redo:
cname := elem.Type(dns.TypeCNAME)
if len(cname) > 0 {
rrs = append(rrs, cname...)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeCNAME)
rrs = append(rrs, sigs...)
}
targetName := cname[0].(*dns.CNAME).Target
elem, _ = z.Search(targetName)
if elem == nil {
lookupRRs, result := z.doLookup(ctx, state, targetName, qtype)
rrs = append(rrs, lookupRRs...)
return rrs, z.ns(do), nil, result
}
i++
if i > 8 {
return rrs, z.ns(do), nil, Success
}
goto Redo
}
targets := elem.Type(qtype)
if len(targets) > 0 {
rrs = append(rrs, targets...)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, qtype)
rrs = append(rrs, sigs...)
}
}
return rrs, z.ns(do), nil, Success
}
func (z *Zone) doLookup(ctx context.Context, state request.Request, target string, qtype uint16) ([]dns.RR, Result) {
m, e := z.Upstream.Lookup(ctx, state, target, qtype)
if e != nil {
return nil, ServerFailure
}
if m == nil {
return nil, Success
}
if m.Rcode == dns.RcodeNameError {
return m.Answer, NameError
}
if m.Rcode == dns.RcodeServerFailure {
return m.Answer, ServerFailure
}
if m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0 {
return m.Answer, NoData
}
return m.Answer, Success
}
// additionalProcessing checks the current answer section and retrieves A or AAAA records
// (and possible SIGs) to need to be put in the additional section.
func (z *Zone) additionalProcessing(answer []dns.RR, do bool) (extra []dns.RR) {
for _, rr := range answer {
name := ""
switch x := rr.(type) {
case *dns.SRV:
name = x.Target
case *dns.MX:
name = x.Mx
}
if len(name) == 0 || !dns.IsSubDomain(z.origin, name) {
continue
}
elem, _ := z.Search(name)
if elem == nil {
continue
}
sigs := elem.Type(dns.TypeRRSIG)
for _, addr := range []uint16{dns.TypeA, dns.TypeAAAA} {
if a := elem.Type(addr); a != nil {
extra = append(extra, a...)
if do {
sig := rrutil.SubTypeSignature(sigs, addr)
extra = append(extra, sig...)
}
}
}
}
return extra
}
package file
import (
"net"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// isNotify checks if state is a notify message and if so, will *also* check if it
// is from one of the configured masters. If not it will not be a valid notify
// message. If the zone z is not a secondary zone the message will also be ignored.
func (z *Zone) isNotify(state request.Request) bool {
if state.Req.Opcode != dns.OpcodeNotify {
return false
}
if len(z.TransferFrom) == 0 {
return false
}
// If remote IP matches we accept.
remote := state.IP()
for _, f := range z.TransferFrom {
from, _, err := net.SplitHostPort(f)
if err != nil {
continue
}
if from == remote {
return true
}
}
return false
}
package file
import (
"os"
"path/filepath"
"time"
"github.com/coredns/coredns/plugin/transfer"
)
// Reload reloads a zone when it is changed on disk. If z.ReloadInterval is zero, no reloading will be done.
func (z *Zone) Reload(t *transfer.Transfer) error {
if z.ReloadInterval == 0 {
return nil
}
tick := time.NewTicker(z.ReloadInterval)
go func() {
for {
select {
case <-tick.C:
zFile := z.File()
reader, err := os.Open(filepath.Clean(zFile))
if err != nil {
log.Errorf("Failed to open zone %q in %q: %v", z.origin, zFile, err)
continue
}
serial := z.SOASerialIfDefined()
zone, err := Parse(reader, z.origin, zFile, serial)
reader.Close()
if err != nil {
if _, ok := err.(*serialErr); !ok {
log.Errorf("Parsing zone %q: %v", z.origin, err)
}
continue
}
// copy elements we need
z.Lock()
z.Apex = zone.Apex
z.Tree = zone.Tree
z.Unlock()
log.Infof("Successfully reloaded zone %q in %q with %d SOA serial", z.origin, zFile, z.SOA.Serial)
if t != nil {
if err := t.Notify(z.origin); err != nil {
log.Warningf("Failed sending notifies: %s", err)
}
}
case <-z.reloadShutdown:
tick.Stop()
return
}
}
}()
return nil
}
// SOASerialIfDefined returns the SOA's serial if the zone has a SOA record in the Apex, or -1 otherwise.
func (z *Zone) SOASerialIfDefined() int64 {
z.RLock()
defer z.RUnlock()
if z.SOA != nil {
return int64(z.SOA.Serial)
}
return -1
}
// Package rrutil provides function to find certain RRs in slices.
package rrutil
import "github.com/miekg/dns"
// SubTypeSignature returns the RRSIG for the subtype.
func SubTypeSignature(rrs []dns.RR, subtype uint16) []dns.RR {
sigs := []dns.RR{}
// there may be multiple keys that have signed this subtype
for _, sig := range rrs {
if s, ok := sig.(*dns.RRSIG); ok {
if s.TypeCovered == subtype {
sigs = append(sigs, s)
}
}
}
return sigs
}
package file
import (
"math/rand"
"time"
"github.com/miekg/dns"
)
// TransferIn retrieves the zone from the masters, parses it and sets it live.
func (z *Zone) TransferIn() error {
if len(z.TransferFrom) == 0 {
return nil
}
m := new(dns.Msg)
m.SetAxfr(z.origin)
z1 := z.CopyWithoutApex()
var (
Err error
tr string
)
Transfer:
for _, tr = range z.TransferFrom {
t := new(dns.Transfer)
c, err := t.In(m, tr)
if err != nil {
log.Errorf("Failed to setup transfer `%s' with `%q': %v", z.origin, tr, err)
Err = err
continue Transfer
}
for env := range c {
if env.Error != nil {
log.Errorf("Failed to transfer `%s' from %q: %v", z.origin, tr, env.Error)
Err = env.Error
continue Transfer
}
for _, rr := range env.RR {
if err := z1.Insert(rr); err != nil {
log.Errorf("Failed to parse transfer `%s' from: %q: %v", z.origin, tr, err)
Err = err
continue Transfer
}
}
}
Err = nil
break
}
if Err != nil {
return Err
}
z.Lock()
z.Tree = z1.Tree
z.Apex = z1.Apex
z.Expired = false
z.Unlock()
log.Infof("Transferred: %s from %s", z.origin, tr)
return nil
}
// shouldTransfer checks the primaries of zone, retrieves the SOA record, checks the current serial
// and the remote serial and will return true if the remote one is higher than the locally configured one.
func (z *Zone) shouldTransfer() (bool, error) {
c := new(dns.Client)
c.Net = "tcp" // do this query over TCP to minimize spoofing
m := new(dns.Msg)
m.SetQuestion(z.origin, dns.TypeSOA)
var Err error
serial := -1
Transfer:
for _, tr := range z.TransferFrom {
Err = nil
ret, _, err := c.Exchange(m, tr)
if err != nil || ret.Rcode != dns.RcodeSuccess {
Err = err
continue
}
for _, a := range ret.Answer {
if a.Header().Rrtype == dns.TypeSOA {
serial = int(a.(*dns.SOA).Serial)
break Transfer
}
}
}
if serial == -1 {
return false, Err
}
if z.SOA == nil {
return true, Err
}
return less(z.SOA.Serial, uint32(serial)), Err
}
// less returns true of a is smaller than b when taking RFC 1982 serial arithmetic into account.
func less(a, b uint32) bool {
if a < b {
return (b - a) <= MaxSerialIncrement
}
return (a - b) > MaxSerialIncrement
}
// Update updates the secondary zone according to its SOA. It will run for the life time of the server
// and uses the SOA parameters. Every refresh it will check for a new SOA number. If that fails (for all
// server) it will retry every retry interval. If the zone failed to transfer before the expire, the zone
// will be marked expired.
func (z *Zone) Update() error {
// If we don't have a SOA, we don't have a zone, wait for it to appear.
for z.SOA == nil {
time.Sleep(1 * time.Second)
}
retryActive := false
Restart:
refresh := time.Second * time.Duration(z.SOA.Refresh)
retry := time.Second * time.Duration(z.SOA.Retry)
expire := time.Second * time.Duration(z.SOA.Expire)
refreshTicker := time.NewTicker(refresh)
retryTicker := time.NewTicker(retry)
expireTicker := time.NewTicker(expire)
for {
select {
case <-expireTicker.C:
if !retryActive {
break
}
z.Expired = true
case <-retryTicker.C:
if !retryActive {
break
}
time.Sleep(jitter(2000)) // 2s randomize
ok, err := z.shouldTransfer()
if err != nil {
log.Warningf("Failed retry check %s", err)
continue
}
if ok {
if err := z.TransferIn(); err != nil {
// transfer failed, leave retryActive true
break
}
}
// no errors, stop timers and restart
retryActive = false
refreshTicker.Stop()
retryTicker.Stop()
expireTicker.Stop()
goto Restart
case <-refreshTicker.C:
time.Sleep(jitter(5000)) // 5s randomize
ok, err := z.shouldTransfer()
if err != nil {
log.Warningf("Failed refresh check %s", err)
retryActive = true
continue
}
if ok {
if err := z.TransferIn(); err != nil {
// transfer failed
retryActive = true
break
}
}
// no errors, stop timers and restart
retryActive = false
refreshTicker.Stop()
retryTicker.Stop()
expireTicker.Stop()
goto Restart
}
}
}
// jitter returns a random duration between [0,n) * time.Millisecond
func jitter(n int) time.Duration {
r := rand.Intn(n)
return time.Duration(r) * time.Millisecond
}
// MaxSerialIncrement is the maximum difference between two serial numbers. If the difference between
// two serials is greater than this number, the smaller one is considered greater.
const MaxSerialIncrement uint32 = 2147483647
package file
import (
"errors"
"os"
"path/filepath"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/plugin/transfer"
)
func init() { plugin.Register("file", setup) }
func setup(c *caddy.Controller) error {
zones, fall, err := fileParse(c)
if err != nil {
return plugin.Error("file", err)
}
f := File{Zones: zones, Fall: fall}
// get the transfer plugin, so we can send notifies and send notifies on startup as well.
c.OnStartup(func() error {
t := dnsserver.GetConfig(c).Handler("transfer")
if t == nil {
return nil
}
f.transfer = t.(*transfer.Transfer) // if found this must be OK.
go func() {
for _, n := range zones.Names {
f.transfer.Notify(n)
}
}()
return nil
})
c.OnRestartFailed(func() error {
t := dnsserver.GetConfig(c).Handler("transfer")
if t == nil {
return nil
}
go func() {
for _, n := range zones.Names {
f.transfer.Notify(n)
}
}()
return nil
})
for _, n := range zones.Names {
z := zones.Z[n]
c.OnShutdown(z.OnShutdown)
c.OnStartup(func() error {
z.StartupOnce.Do(func() { z.Reload(f.transfer) })
return nil
})
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
f.Next = next
return f
})
return nil
}
func fileParse(c *caddy.Controller) (Zones, fall.F, error) {
z := make(map[string]*Zone)
names := []string{}
fall := fall.F{}
config := dnsserver.GetConfig(c)
var openErr error
reload := 1 * time.Minute
for c.Next() {
// file db.file [zones...]
if !c.NextArg() {
return Zones{}, fall, c.ArgErr()
}
fileName := c.Val()
origins := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
if !filepath.IsAbs(fileName) && config.Root != "" {
fileName = filepath.Join(config.Root, fileName)
}
reader, err := os.Open(filepath.Clean(fileName))
if err != nil {
openErr = err
}
err = func() error {
defer reader.Close()
for i := range origins {
z[origins[i]] = NewZone(origins[i], fileName)
if openErr == nil {
reader.Seek(0, 0)
zone, err := Parse(reader, origins[i], fileName, 0)
if err != nil {
return err
}
z[origins[i]] = zone
}
names = append(names, origins[i])
}
return nil
}()
if err != nil {
return Zones{}, fall, err
}
for c.NextBlock() {
switch c.Val() {
case "fallthrough":
fall.SetZonesFromArgs(c.RemainingArgs())
case "reload":
t := c.RemainingArgs()
if len(t) < 1 {
return Zones{}, fall, errors.New("reload duration value is expected")
}
d, err := time.ParseDuration(t[0])
if err != nil {
return Zones{}, fall, plugin.Error("file", err)
}
reload = d
case "upstream":
// remove soon
c.RemainingArgs()
default:
return Zones{}, fall, c.Errf("unknown property '%s'", c.Val())
}
}
for i := range origins {
z[origins[i]].ReloadInterval = reload
z[origins[i]].Upstream = upstream.New()
}
}
if openErr != nil {
if reload == 0 {
// reload hasn't been set make this a fatal error
return Zones{}, fall, plugin.Error("file", openErr)
}
log.Warningf("Failed to open %q: trying again in %s", openErr, reload)
}
return Zones{Z: z, Names: names}, fall, nil
}
package file
// OnShutdown shuts down any running go-routines for this zone.
func (z *Zone) OnShutdown() error {
if 0 < z.ReloadInterval {
z.reloadShutdown <- true
}
return nil
}
package tree
// All traverses tree and returns all elements.
func (t *Tree) All() []*Elem {
if t.Root == nil {
return nil
}
found := t.Root.all(nil)
return found
}
func (n *Node) all(found []*Elem) []*Elem {
if n.Left != nil {
found = n.Left.all(found)
}
found = append(found, n.Elem)
if n.Right != nil {
found = n.Right.all(found)
}
return found
}
package tree
import (
"github.com/miekg/dns"
)
// AuthWalk performs fn on all authoritative values stored in the tree in
// pre-order depth first. If a non-nil error is returned the AuthWalk was interrupted
// by an fn returning that error. If fn alters stored values' sort
// relationships, future tree operation behaviors are undefined.
//
// The fn function will be called with 3 arguments, the current element, a map containing all
// the RRs for this element and a boolean if this name is considered authoritative.
func (t *Tree) AuthWalk(fn func(*Elem, map[uint16][]dns.RR, bool) error) error {
if t.Root == nil {
return nil
}
return t.Root.authwalk(make(map[string]struct{}), fn)
}
func (n *Node) authwalk(ns map[string]struct{}, fn func(*Elem, map[uint16][]dns.RR, bool) error) error {
if n.Left != nil {
if err := n.Left.authwalk(ns, fn); err != nil {
return err
}
}
// Check if the current name is a subdomain of *any* of the delegated names we've seen, if so, skip this name.
// The ordering of the tree and how we walk if guarantees we see parents first.
if n.Elem.Type(dns.TypeNS) != nil {
ns[n.Elem.Name()] = struct{}{}
}
auth := true
i := 0
for {
j, end := dns.NextLabel(n.Elem.Name(), i)
if end {
break
}
if _, ok := ns[n.Elem.Name()[j:]]; ok {
auth = false
break
}
i++
}
if err := fn(n.Elem, n.Elem.m, auth); err != nil {
return err
}
if n.Right != nil {
if err := n.Right.authwalk(ns, fn); err != nil {
return err
}
}
return nil
}
package tree
import "github.com/miekg/dns"
// Elem is an element in the tree.
type Elem struct {
m map[uint16][]dns.RR
name string // owner name
}
// newElem returns a new elem.
func newElem(rr dns.RR) *Elem {
e := Elem{m: make(map[uint16][]dns.RR)}
e.m[rr.Header().Rrtype] = []dns.RR{rr}
return &e
}
// Types returns the types of the records in e. The returned list is not sorted.
func (e *Elem) Types() []uint16 {
t := make([]uint16, len(e.m))
i := 0
for ty := range e.m {
t[i] = ty
i++
}
return t
}
// Type returns the RRs with type qtype from e.
func (e *Elem) Type(qtype uint16) []dns.RR { return e.m[qtype] }
// TypeForWildcard returns the RRs with type qtype from e. The ownername returned is set to qname.
func (e *Elem) TypeForWildcard(qtype uint16, qname string) []dns.RR {
rrs := e.m[qtype]
if rrs == nil {
return nil
}
copied := make([]dns.RR, len(rrs))
for i := range rrs {
copied[i] = dns.Copy(rrs[i])
copied[i].Header().Name = qname
}
return copied
}
// All returns all RRs from e, regardless of type.
func (e *Elem) All() []dns.RR {
list := []dns.RR{}
for _, rrs := range e.m {
list = append(list, rrs...)
}
return list
}
// Name returns the name for this node.
func (e *Elem) Name() string {
if e.name != "" {
return e.name
}
for _, rrs := range e.m {
e.name = rrs[0].Header().Name
return e.name
}
return ""
}
// Empty returns true is e does not contain any RRs, i.e. is an empty-non-terminal.
func (e *Elem) Empty() bool { return len(e.m) == 0 }
// Insert inserts rr into e. If rr is equal to existing RRs, the RR will be added anyway.
func (e *Elem) Insert(rr dns.RR) {
t := rr.Header().Rrtype
if e.m == nil {
e.m = make(map[uint16][]dns.RR)
e.m[t] = []dns.RR{rr}
return
}
rrs, ok := e.m[t]
if !ok {
e.m[t] = []dns.RR{rr}
return
}
rrs = append(rrs, rr)
e.m[t] = rrs
}
// Delete removes all RRs of type rr.Header().Rrtype from e.
func (e *Elem) Delete(rr dns.RR) {
if e.m == nil {
return
}
t := rr.Header().Rrtype
delete(e.m, t)
}
// Less is a tree helper function that calls less.
func Less(a *Elem, name string) int { return less(name, a.Name()) }
package tree
import (
"github.com/coredns/coredns/plugin/file/rrutil"
"github.com/miekg/dns"
)
// Glue returns any potential glue records for nsrrs.
func (t *Tree) Glue(nsrrs []dns.RR, do bool) []dns.RR {
glue := []dns.RR{}
for _, rr := range nsrrs {
if ns, ok := rr.(*dns.NS); ok && dns.IsSubDomain(ns.Header().Name, ns.Ns) {
glue = append(glue, t.searchGlue(ns.Ns, do)...)
}
}
return glue
}
// searchGlue looks up A and AAAA for name.
func (t *Tree) searchGlue(name string, do bool) []dns.RR {
glue := []dns.RR{}
// A
if elem, found := t.Search(name); found {
glue = append(glue, elem.Type(dns.TypeA)...)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeA)
glue = append(glue, sigs...)
}
}
// AAAA
if elem, found := t.Search(name); found {
glue = append(glue, elem.Type(dns.TypeAAAA)...)
if do {
sigs := elem.Type(dns.TypeRRSIG)
sigs = rrutil.SubTypeSignature(sigs, dns.TypeAAAA)
glue = append(glue, sigs...)
}
}
return glue
}
package tree
import (
"bytes"
"strings"
"github.com/miekg/dns"
)
// less returns <0 when a is less than b, 0 when they are equal and
// >0 when a is larger than b.
// The function orders names in DNSSEC canonical order: RFC 4034s section-6.1
//
// See https://bert-hubert.blogspot.co.uk/2015/10/how-to-do-fast-canonical-ordering-of.html
// for a blog article on this implementation, although here we still go label by label.
//
// The values of a and b are *not* lowercased before the comparison!
func less(a, b string) int {
i := 1
aj := len(a)
bj := len(b)
for {
ai, oka := dns.PrevLabel(a, i)
bi, okb := dns.PrevLabel(b, i)
if oka && okb {
return 0
}
// sadly this []byte will allocate... TODO(miek): check if this is needed
// for a name, otherwise compare the strings.
ab := []byte(strings.ToLower(a[ai:aj]))
bb := []byte(strings.ToLower(b[bi:bj]))
doDDD(ab)
doDDD(bb)
res := bytes.Compare(ab, bb)
if res != 0 {
return res
}
i++
aj, bj = ai, bi
}
}
func doDDD(b []byte) {
lb := len(b)
for i := 0; i < lb; i++ {
if i+3 < lb && b[i] == '\\' && isDigit(b[i+1]) && isDigit(b[i+2]) && isDigit(b[i+3]) {
b[i] = dddToByte(b[i:])
for j := i + 1; j < lb-3; j++ {
b[j] = b[j+3]
}
lb -= 3
}
}
}
func isDigit(b byte) bool { return b >= '0' && b <= '9' }
func dddToByte(s []byte) byte { return (s[1]-'0')*100 + (s[2]-'0')*10 + (s[3] - '0') }
package tree
import "fmt"
// Print prints a Tree. Main use is to aid in debugging.
func (t *Tree) Print() {
if t.Root == nil {
fmt.Println("<nil>")
}
t.Root.print()
}
func (n *Node) print() {
q := newQueue()
q.push(n)
nodesInCurrentLevel := 1
nodesInNextLevel := 0
for !q.empty() {
do := q.pop()
nodesInCurrentLevel--
if do != nil {
fmt.Print(do.Elem.Name(), " ")
q.push(do.Left)
q.push(do.Right)
nodesInNextLevel += 2
}
if nodesInCurrentLevel == 0 {
fmt.Println()
nodesInCurrentLevel = nodesInNextLevel
nodesInNextLevel = 0
}
}
fmt.Println()
}
type queue []*Node
// newQueue returns a new queue.
func newQueue() queue {
q := queue([]*Node{})
return q
}
// push pushes n to the end of the queue.
func (q *queue) push(n *Node) {
*q = append(*q, n)
}
// pop pops the first element off the queue.
func (q *queue) pop() *Node {
n := (*q)[0]
*q = (*q)[1:]
return n
}
// empty returns true when the queue contains zero nodes.
func (q *queue) empty() bool {
return len(*q) == 0
}
// Copyright ©2012 The bíogo Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found at the end of this file.
// Package tree implements Left-Leaning Red Black trees as described by Robert Sedgewick.
//
// More details relating to the implementation are available at the following locations:
//
// http://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf
// http://www.cs.princeton.edu/~rs/talks/LLRB/Java/RedBlackBST.java
// http://www.teachsolaisgames.com/articles/balanced_left_leaning.html
//
// Heavily modified by Miek Gieben for use in DNS zones.
package tree
import "github.com/miekg/dns"
const (
td234 = iota
bu23
)
// Operation mode of the LLRB tree.
const mode = bu23
func init() {
if mode != td234 && mode != bu23 {
panic("tree: unknown mode")
}
}
// A Color represents the color of a Node.
type Color bool
const (
// Red as false give us the defined behaviour that new nodes are red. Although this
// is incorrect for the root node, that is resolved on the first insertion.
red Color = false
black Color = true
)
// A Node represents a node in the LLRB tree.
type Node struct {
Elem *Elem
Left, Right *Node
Color Color
}
// A Tree manages the root node of an LLRB tree. Public methods are exposed through this type.
type Tree struct {
Root *Node // Root node of the tree.
Count int // Number of elements stored.
}
// Helper methods
// color returns the effect color of a Node. A nil node returns black.
func (n *Node) color() Color {
if n == nil {
return black
}
return n.Color
}
// (a,c)b -rotL-> ((a,)b,)c
func (n *Node) rotateLeft() (root *Node) {
// Assumes: n has two children.
root = n.Right
n.Right = root.Left
root.Left = n
root.Color = n.Color
n.Color = red
return
}
// (a,c)b -rotR-> (,(,c)b)a
func (n *Node) rotateRight() (root *Node) {
// Assumes: n has two children.
root = n.Left
n.Left = root.Right
root.Right = n
root.Color = n.Color
n.Color = red
return
}
// (aR,cR)bB -flipC-> (aB,cB)bR | (aB,cB)bR -flipC-> (aR,cR)bB
func (n *Node) flipColors() {
// Assumes: n has two children.
n.Color = !n.Color
n.Left.Color = !n.Left.Color
n.Right.Color = !n.Right.Color
}
// fixUp ensures that black link balance is correct, that red nodes lean left,
// and that 4 nodes are split in the case of BU23 and properly balanced in TD234.
func (n *Node) fixUp() *Node {
if n.Right.color() == red {
if mode == td234 && n.Right.Left.color() == red {
n.Right = n.Right.rotateRight()
}
n = n.rotateLeft()
}
if n.Left.color() == red && n.Left.Left.color() == red {
n = n.rotateRight()
}
if mode == bu23 && n.Left.color() == red && n.Right.color() == red {
n.flipColors()
}
return n
}
func (n *Node) moveRedLeft() *Node {
n.flipColors()
if n.Right.Left.color() == red {
n.Right = n.Right.rotateRight()
n = n.rotateLeft()
n.flipColors()
if mode == td234 && n.Right.Right.color() == red {
n.Right = n.Right.rotateLeft()
}
}
return n
}
func (n *Node) moveRedRight() *Node {
n.flipColors()
if n.Left.Left.color() == red {
n = n.rotateRight()
n.flipColors()
}
return n
}
// Len returns the number of elements stored in the Tree.
func (t *Tree) Len() int {
return t.Count
}
// Search returns the first match of qname in the Tree.
func (t *Tree) Search(qname string) (*Elem, bool) {
if t.Root == nil {
return nil, false
}
n, res := t.Root.search(qname)
if n == nil {
return nil, res
}
return n.Elem, res
}
// search searches the tree for qname and type.
func (n *Node) search(qname string) (*Node, bool) {
for n != nil {
switch c := Less(n.Elem, qname); {
case c == 0:
return n, true
case c < 0:
n = n.Left
default:
n = n.Right
}
}
return n, false
}
// Insert inserts rr into the Tree at the first match found
// with e or when a nil node is reached.
func (t *Tree) Insert(rr dns.RR) {
var d int
t.Root, d = t.Root.insert(rr)
t.Count += d
t.Root.Color = black
}
// insert inserts rr in to the tree.
func (n *Node) insert(rr dns.RR) (root *Node, d int) {
if n == nil {
return &Node{Elem: newElem(rr)}, 1
} else if n.Elem == nil {
n.Elem = newElem(rr)
return n, 1
}
if mode == td234 {
if n.Left.color() == red && n.Right.color() == red {
n.flipColors()
}
}
switch c := Less(n.Elem, rr.Header().Name); {
case c == 0:
n.Elem.Insert(rr)
case c < 0:
n.Left, d = n.Left.insert(rr)
default:
n.Right, d = n.Right.insert(rr)
}
if n.Right.color() == red && n.Left.color() == black {
n = n.rotateLeft()
}
if n.Left.color() == red && n.Left.Left.color() == red {
n = n.rotateRight()
}
if mode == bu23 {
if n.Left.color() == red && n.Right.color() == red {
n.flipColors()
}
}
root = n
return
}
// DeleteMin deletes the node with the minimum value in the tree.
func (t *Tree) DeleteMin() {
if t.Root == nil {
return
}
var d int
t.Root, d = t.Root.deleteMin()
t.Count += d
if t.Root == nil {
return
}
t.Root.Color = black
}
func (n *Node) deleteMin() (root *Node, d int) {
if n.Left == nil {
return nil, -1
}
if n.Left.color() == black && n.Left.Left.color() == black {
n = n.moveRedLeft()
}
n.Left, d = n.Left.deleteMin()
root = n.fixUp()
return
}
// DeleteMax deletes the node with the maximum value in the tree.
func (t *Tree) DeleteMax() {
if t.Root == nil {
return
}
var d int
t.Root, d = t.Root.deleteMax()
t.Count += d
if t.Root == nil {
return
}
t.Root.Color = black
}
func (n *Node) deleteMax() (root *Node, d int) {
if n.Left != nil && n.Left.color() == red {
n = n.rotateRight()
}
if n.Right == nil {
return nil, -1
}
if n.Right.color() == black && n.Right.Left.color() == black {
n = n.moveRedRight()
}
n.Right, d = n.Right.deleteMax()
root = n.fixUp()
return
}
// Delete removes all RRs of type rr.Header().Rrtype from e. If after the deletion of rr the node is empty the
// entire node is deleted.
func (t *Tree) Delete(rr dns.RR) {
if t.Root == nil {
return
}
el, _ := t.Search(rr.Header().Name)
if el == nil {
return
}
el.Delete(rr)
if el.Empty() {
t.deleteNode(rr)
}
}
// DeleteNode deletes the node that matches rr according to Less().
func (t *Tree) deleteNode(rr dns.RR) {
if t.Root == nil {
return
}
var d int
t.Root, d = t.Root.delete(rr)
t.Count += d
if t.Root == nil {
return
}
t.Root.Color = black
}
func (n *Node) delete(rr dns.RR) (root *Node, d int) {
if Less(n.Elem, rr.Header().Name) < 0 {
if n.Left != nil {
if n.Left.color() == black && n.Left.Left.color() == black {
n = n.moveRedLeft()
}
n.Left, d = n.Left.delete(rr)
}
} else {
if n.Left.color() == red {
n = n.rotateRight()
}
if n.Right == nil && Less(n.Elem, rr.Header().Name) == 0 {
return nil, -1
}
if n.Right != nil {
if n.Right.color() == black && n.Right.Left.color() == black {
n = n.moveRedRight()
}
if Less(n.Elem, rr.Header().Name) == 0 {
n.Elem = n.Right.min().Elem
n.Right, d = n.Right.deleteMin()
} else {
n.Right, d = n.Right.delete(rr)
}
}
}
root = n.fixUp()
return
}
// Min returns the minimum value stored in the tree.
func (t *Tree) Min() *Elem {
if t.Root == nil {
return nil
}
return t.Root.min().Elem
}
func (n *Node) min() *Node {
for ; n.Left != nil; n = n.Left {
}
return n
}
// Max returns the maximum value stored in the tree.
func (t *Tree) Max() *Elem {
if t.Root == nil {
return nil
}
return t.Root.max().Elem
}
func (n *Node) max() *Node {
for ; n.Right != nil; n = n.Right {
}
return n
}
// Prev returns the greatest value equal to or less than the qname according to Less().
func (t *Tree) Prev(qname string) (*Elem, bool) {
if t.Root == nil {
return nil, false
}
n := t.Root.floor(qname)
if n == nil {
return nil, false
}
return n.Elem, true
}
func (n *Node) floor(qname string) *Node {
if n == nil {
return nil
}
switch c := Less(n.Elem, qname); {
case c == 0:
return n
case c <= 0:
return n.Left.floor(qname)
default:
if r := n.Right.floor(qname); r != nil {
return r
}
}
return n
}
// Next returns the smallest value equal to or greater than the qname according to Less().
func (t *Tree) Next(qname string) (*Elem, bool) {
if t.Root == nil {
return nil, false
}
n := t.Root.ceil(qname)
if n == nil {
return nil, false
}
return n.Elem, true
}
func (n *Node) ceil(qname string) *Node {
if n == nil {
return nil
}
switch c := Less(n.Elem, qname); {
case c == 0:
return n
case c > 0:
return n.Right.ceil(qname)
default:
if l := n.Left.ceil(qname); l != nil {
return l
}
}
return n
}
/*
Copyright ©2012 The bíogo Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the bíogo project nor the names of its authors and
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package tree
import "github.com/miekg/dns"
// Walk performs fn on all authoritative values stored in the tree in
// in-order depth first. If a non-nil error is returned the Walk was interrupted
// by an fn returning that error. If fn alters stored values' sort
// relationships, future tree operation behaviors are undefined.
func (t *Tree) Walk(fn func(*Elem, map[uint16][]dns.RR) error) error {
if t.Root == nil {
return nil
}
return t.Root.walk(fn)
}
func (n *Node) walk(fn func(*Elem, map[uint16][]dns.RR) error) error {
if n.Left != nil {
if err := n.Left.walk(fn); err != nil {
return err
}
}
if err := fn(n.Elem, n.Elem.m); err != nil {
return err
}
if n.Right != nil {
if err := n.Right.walk(fn); err != nil {
return err
}
}
return nil
}
package file
import "github.com/miekg/dns"
// replaceWithAsteriskLabel replaces the left most label with '*'.
func replaceWithAsteriskLabel(qname string) (wildcard string) {
i, shot := dns.NextLabel(qname, 0)
if shot {
return ""
}
return "*." + qname[i:]
}
package file
import (
"github.com/coredns/coredns/plugin/file/tree"
"github.com/coredns/coredns/plugin/transfer"
"github.com/miekg/dns"
)
// Transfer implements the transfer.Transfer interface.
func (f File) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) {
z, ok := f.Z[zone]
if !ok || z == nil {
return nil, transfer.ErrNotAuthoritative
}
return z.Transfer(serial)
}
// Transfer transfers a zone with serial in the returned channel and implements IXFR fallback, by just
// sending a single SOA record.
func (z *Zone) Transfer(serial uint32) (<-chan []dns.RR, error) {
// get soa and apex
apex, err := z.ApexIfDefined()
if err != nil {
return nil, err
}
ch := make(chan []dns.RR)
go func() {
if serial != 0 && apex[0].(*dns.SOA).Serial == serial { // ixfr fallback, only send SOA
ch <- []dns.RR{apex[0]}
close(ch)
return
}
ch <- apex
z.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error { ch <- e.All(); return nil })
ch <- []dns.RR{apex[0]}
close(ch)
}()
return ch, nil
}
package file
import (
"fmt"
"path/filepath"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin/file/tree"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/miekg/dns"
)
// Zone is a structure that contains all data related to a DNS zone.
type Zone struct {
origin string
origLen int
file string
*tree.Tree
Apex
Expired bool
sync.RWMutex
StartupOnce sync.Once
TransferFrom []string
ReloadInterval time.Duration
reloadShutdown chan bool
Upstream *upstream.Upstream // Upstream for looking up external names during the resolution process.
}
// Apex contains the apex records of a zone: SOA, NS and their potential signatures.
type Apex struct {
SOA *dns.SOA
NS []dns.RR
SIGSOA []dns.RR
SIGNS []dns.RR
}
// NewZone returns a new zone.
func NewZone(name, file string) *Zone {
return &Zone{
origin: dns.Fqdn(name),
origLen: dns.CountLabel(dns.Fqdn(name)),
file: filepath.Clean(file),
Tree: &tree.Tree{},
reloadShutdown: make(chan bool),
}
}
// Copy copies a zone.
func (z *Zone) Copy() *Zone {
z1 := NewZone(z.origin, z.file)
z1.TransferFrom = z.TransferFrom
z1.Expired = z.Expired
z1.Apex = z.Apex
return z1
}
// CopyWithoutApex copies zone z without the Apex records.
func (z *Zone) CopyWithoutApex() *Zone {
z1 := NewZone(z.origin, z.file)
z1.TransferFrom = z.TransferFrom
z1.Expired = z.Expired
return z1
}
// Insert inserts r into z.
func (z *Zone) Insert(r dns.RR) error {
// r.Header().Name = strings.ToLower(r.Header().Name)
if r.Header().Rrtype != dns.TypeSRV {
r.Header().Name = strings.ToLower(r.Header().Name)
}
switch h := r.Header().Rrtype; h {
case dns.TypeNS:
r.(*dns.NS).Ns = strings.ToLower(r.(*dns.NS).Ns)
if r.Header().Name == z.origin {
z.NS = append(z.NS, r)
return nil
}
case dns.TypeSOA:
r.(*dns.SOA).Ns = strings.ToLower(r.(*dns.SOA).Ns)
r.(*dns.SOA).Mbox = strings.ToLower(r.(*dns.SOA).Mbox)
z.SOA = r.(*dns.SOA)
return nil
case dns.TypeNSEC3, dns.TypeNSEC3PARAM:
return fmt.Errorf("NSEC3 zone is not supported, dropping RR: %s for zone: %s", r.Header().Name, z.origin)
case dns.TypeRRSIG:
x := r.(*dns.RRSIG)
switch x.TypeCovered {
case dns.TypeSOA:
z.SIGSOA = append(z.SIGSOA, x)
return nil
case dns.TypeNS:
if r.Header().Name == z.origin {
z.SIGNS = append(z.SIGNS, x)
return nil
}
}
case dns.TypeCNAME:
r.(*dns.CNAME).Target = strings.ToLower(r.(*dns.CNAME).Target)
case dns.TypeMX:
r.(*dns.MX).Mx = strings.ToLower(r.(*dns.MX).Mx)
case dns.TypeSRV:
// r.(*dns.SRV).Target = strings.ToLower(r.(*dns.SRV).Target)
}
z.Tree.Insert(r)
return nil
}
// File retrieves the file path in a safe way.
func (z *Zone) File() string {
z.RLock()
defer z.RUnlock()
return z.file
}
// SetFile updates the file path in a safe way.
func (z *Zone) SetFile(path string) {
z.Lock()
z.file = path
z.Unlock()
}
// ApexIfDefined returns the apex nodes from z. The SOA record is the first record, if it does not exist, an error is returned.
func (z *Zone) ApexIfDefined() ([]dns.RR, error) {
z.RLock()
defer z.RUnlock()
if z.SOA == nil {
return nil, fmt.Errorf("no SOA")
}
rrs := []dns.RR{z.SOA}
if len(z.SIGSOA) > 0 {
rrs = append(rrs, z.SIGSOA...)
}
if len(z.NS) > 0 {
rrs = append(rrs, z.NS...)
}
if len(z.SIGNS) > 0 {
rrs = append(rrs, z.SIGNS...)
}
return rrs, nil
}
// NameFromRight returns the labels from the right, staring with the
// origin and then i labels extra. When we are overshooting the name
// the returned boolean is set to true.
func (z *Zone) nameFromRight(qname string, i int) (string, bool) {
if i <= 0 {
return z.origin, false
}
for j := 1; j <= z.origLen; j++ {
if _, shot := dns.PrevLabel(qname, j); shot {
return qname, shot
}
}
k := 0
var shot bool
for j := 1; j <= i; j++ {
k, shot = dns.PrevLabel(qname, j+z.origLen)
if shot {
return qname, shot
}
}
return qname[k:], false
}
package forward
import (
"context"
"net"
"net/netip"
"time"
"github.com/coredns/coredns/plugin/dnstap/msg"
"github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/coredns/coredns/request"
tap "github.com/dnstap/golang-dnstap"
"github.com/miekg/dns"
)
// toDnstap will send the forward and received message to the dnstap plugin.
func toDnstap(ctx context.Context, f *Forward, host string, state request.Request, opts proxy.Options, reply *dns.Msg, start time.Time) {
ap, _ := netip.ParseAddrPort(host) // this is preparsed and can't err here
ip := net.IP(ap.Addr().AsSlice())
port := int(ap.Port())
var ta net.Addr = &net.UDPAddr{
IP: ip,
Port: port,
}
t := state.Proto()
switch {
case opts.ForceTCP:
t = "tcp"
case opts.PreferUDP:
t = "udp"
}
if t == "tcp" {
ta = &net.TCPAddr{IP: ip, Port: port}
}
for _, t := range f.tapPlugins {
// Query
q := new(tap.Message)
msg.SetQueryTime(q, start)
// Forwarder dnstap messages are from the perspective of the downstream server
// (upstream is the forward server)
msg.SetQueryAddress(q, state.W.RemoteAddr())
msg.SetResponseAddress(q, ta)
if t.IncludeRawMessage {
buf, _ := state.Req.Pack()
q.QueryMessage = buf
}
msg.SetType(q, tap.Message_FORWARDER_QUERY)
t.TapMessageWithMetadata(ctx, q, state)
// Response
if reply != nil {
r := new(tap.Message)
if t.IncludeRawMessage {
buf, _ := reply.Pack()
r.ResponseMessage = buf
}
msg.SetQueryTime(r, start)
msg.SetQueryAddress(r, state.W.RemoteAddr())
msg.SetResponseAddress(r, ta)
msg.SetResponseTime(r, time.Now())
msg.SetType(r, tap.Message_FORWARDER_RESPONSE)
t.TapMessageWithMetadata(ctx, r, state)
}
}
}
// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses
// inband healthchecking.
package forward
import (
"context"
"crypto/tls"
"errors"
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/debug"
"github.com/coredns/coredns/plugin/dnstap"
"github.com/coredns/coredns/plugin/metadata"
clog "github.com/coredns/coredns/plugin/pkg/log"
proxyPkg "github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
otext "github.com/opentracing/opentracing-go/ext"
)
var log = clog.NewWithPlugin("forward")
const (
defaultExpire = 10 * time.Second
hcInterval = 500 * time.Millisecond
)
// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list
// of proxies each representing one upstream proxy.
type Forward struct {
concurrent int64 // atomic counters need to be first in struct for proper alignment
proxies []*proxyPkg.Proxy
p Policy
hcInterval time.Duration
from string
ignored []string
nextAlternateRcodes []int
tlsConfig *tls.Config
tlsServerName string
maxfails uint32
expire time.Duration
maxConcurrent int64
failfastUnhealthyUpstreams bool
opts proxyPkg.Options // also here for testing
// ErrLimitExceeded indicates that a query was rejected because the number of concurrent queries has exceeded
// the maximum allowed (maxConcurrent)
ErrLimitExceeded error
tapPlugins []*dnstap.Dnstap // when dnstap plugins are loaded, we use to this to send messages out.
Next plugin.Handler
}
// New returns a new Forward.
func New() *Forward {
f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, opts: proxyPkg.Options{ForceTCP: false, PreferUDP: false, HCRecursionDesired: true, HCDomain: "."}}
return f
}
// SetProxy appends p to the proxy list and starts healthchecking.
func (f *Forward) SetProxy(p *proxyPkg.Proxy) {
f.proxies = append(f.proxies, p)
p.Start(f.hcInterval)
}
// SetProxyOptions setup proxy options
func (f *Forward) SetProxyOptions(opts proxyPkg.Options) {
f.opts = opts
}
// SetTapPlugin appends one or more dnstap plugins to the tap plugin list.
func (f *Forward) SetTapPlugin(tapPlugin *dnstap.Dnstap) {
f.tapPlugins = append(f.tapPlugins, tapPlugin)
if nextPlugin, ok := tapPlugin.Next.(*dnstap.Dnstap); ok {
f.SetTapPlugin(nextPlugin)
}
}
// Len returns the number of configured proxies.
func (f *Forward) Len() int { return len(f.proxies) }
// Name implements plugin.Handler.
func (f *Forward) Name() string { return "forward" }
// ServeDNS implements plugin.Handler.
func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if !f.match(state) {
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
if f.maxConcurrent > 0 {
count := atomic.AddInt64(&(f.concurrent), 1)
defer atomic.AddInt64(&(f.concurrent), -1)
if count > f.maxConcurrent {
maxConcurrentRejectCount.Add(1)
return dns.RcodeRefused, f.ErrLimitExceeded
}
}
fails := 0
var span, child ot.Span
var upstreamErr error
span = ot.SpanFromContext(ctx)
i := 0
list := f.List()
deadline := time.Now().Add(defaultTimeout)
start := time.Now()
for time.Now().Before(deadline) && ctx.Err() == nil {
if i >= len(list) {
// reached the end of list, reset to begin
i = 0
fails = 0
}
proxy := list[i]
i++
if proxy.Down(f.maxfails) {
fails++
if fails < len(f.proxies) {
continue
}
healthcheckBrokenCount.Add(1)
// All upstreams are dead, return servfail if all upstreams are down
if f.failfastUnhealthyUpstreams {
break
}
// assume healthcheck is completely broken and randomly
// select an upstream to connect to.
r := new(random)
proxy = r.List(f.proxies)[0]
}
if span != nil {
child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context()))
otext.PeerAddress.Set(child, proxy.Addr())
ctx = ot.ContextWithSpan(ctx, child)
}
metadata.SetValueFunc(ctx, "forward/upstream", func() string {
return proxy.Addr()
})
var (
ret *dns.Msg
err error
)
opts := f.opts
for {
ret, err = proxy.Connect(ctx, state, opts)
if err == proxyPkg.ErrCachedClosed { // Remote side closed conn, can only happen with TCP.
continue
}
// Retry with TCP if truncated and prefer_udp configured.
if ret != nil && ret.Truncated && !opts.ForceTCP && opts.PreferUDP {
opts.ForceTCP = true
continue
}
break
}
if child != nil {
child.Finish()
}
if len(f.tapPlugins) != 0 {
toDnstap(ctx, f, proxy.Addr(), state, opts, ret, start)
}
upstreamErr = err
if err != nil {
// Kick off health check to see if *our* upstream is broken.
if f.maxfails != 0 {
proxy.Healthcheck()
}
if fails < len(f.proxies) {
continue
}
break
}
// Check if the reply is correct; if not return FormErr.
if !state.Match(ret) {
debug.Hexdumpf(ret, "Wrong reply for id: %d, %s %d", ret.Id, state.QName(), state.QType())
formerr := new(dns.Msg)
formerr.SetRcode(state.Req, dns.RcodeFormatError)
w.WriteMsg(formerr)
return 0, nil
}
// Check if we have an alternate Rcode defined, check if we match on the code
for _, alternateRcode := range f.nextAlternateRcodes {
if alternateRcode == ret.Rcode && f.Next != nil { // In case we do not have a Next handler, just continue normally
if _, ok := f.Next.(*Forward); ok { // Only continue if the next forwarder is also a Forworder
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
}
}
w.WriteMsg(ret)
return 0, nil
}
if upstreamErr != nil {
return dns.RcodeServerFailure, upstreamErr
}
return dns.RcodeServerFailure, ErrNoHealthy
}
func (f *Forward) match(state request.Request) bool {
if !plugin.Name(f.from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) {
return false
}
return true
}
func (f *Forward) isAllowedDomain(name string) bool {
if dns.Name(name) == dns.Name(f.from) {
return true
}
for _, ignore := range f.ignored {
if plugin.Name(ignore).Matches(name) {
return false
}
}
return true
}
// ForceTCP returns if TCP is forced to be used even when the request comes in over UDP.
func (f *Forward) ForceTCP() bool { return f.opts.ForceTCP }
// PreferUDP returns if UDP is preferred to be used even when the request comes in over TCP.
func (f *Forward) PreferUDP() bool { return f.opts.PreferUDP }
// List returns a set of proxies to be used for this client depending on the policy in f.
func (f *Forward) List() []*proxyPkg.Proxy { return f.p.List(f.proxies) }
var (
// ErrNoHealthy means no healthy proxies left.
ErrNoHealthy = errors.New("no healthy proxies")
// ErrNoForward means no forwarder defined.
ErrNoForward = errors.New("no forwarder defined")
// ErrCachedClosed means cached connection was closed by peer.
ErrCachedClosed = errors.New("cached connection was closed by peer")
)
// Options holds various Options that can be set.
type Options struct {
// ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag
ForceTCP bool
// PreferUDP use UDP protocol for upstream DNS request.
PreferUDP bool
// HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests
HCRecursionDesired bool
// HCDomain sets domain for Proxy healthcheck requests
HCDomain string
}
var defaultTimeout = 5 * time.Second
//go:build gofuzz
package forward
import (
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/fuzz"
"github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/miekg/dns"
)
var f *Forward
// abuse init to setup an environment to test against. This start another server to that will
// reflect responses.
func init() {
f = New()
s := dnstest.NewServer(r{}.reflectHandler)
f.SetProxy(proxy.NewProxy("FuzzForwardPlugin1", s.Addr, "tcp"))
f.SetProxy(proxy.NewProxy("FuzzForwardPlugin2", s.Addr, "udp"))
}
// Fuzz fuzzes forward.
func Fuzz(data []byte) int {
return fuzz.Do(f, data)
}
type r struct{}
func (r r) reflectHandler(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetReply(req)
w.WriteMsg(m)
}
package forward
import (
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/pkg/proxy"
"github.com/coredns/coredns/plugin/pkg/rand"
)
// Policy defines a policy we use for selecting upstreams.
type Policy interface {
List([]*proxy.Proxy) []*proxy.Proxy
String() string
}
// random is a policy that implements random upstream selection.
type random struct{}
func (r *random) String() string { return "random" }
func (r *random) List(p []*proxy.Proxy) []*proxy.Proxy {
switch len(p) {
case 1:
return p
case 2:
if rn.Int()%2 == 0 {
return []*proxy.Proxy{p[1], p[0]} // swap
}
return p
}
perms := rn.Perm(len(p))
rnd := make([]*proxy.Proxy, len(p))
for i, p1 := range perms {
rnd[i] = p[p1]
}
return rnd
}
// roundRobin is a policy that selects hosts based on round robin ordering.
type roundRobin struct {
robin uint32
}
func (r *roundRobin) String() string { return "round_robin" }
func (r *roundRobin) List(p []*proxy.Proxy) []*proxy.Proxy {
poolLen := uint32(len(p))
i := atomic.AddUint32(&r.robin, 1) % poolLen
robin := []*proxy.Proxy{p[i]}
robin = append(robin, p[:i]...)
robin = append(robin, p[i+1:]...)
return robin
}
// sequential is a policy that selects hosts based on sequential ordering.
type sequential struct{}
func (r *sequential) String() string { return "sequential" }
func (r *sequential) List(p []*proxy.Proxy) []*proxy.Proxy {
return p
}
var rn = rand.New(time.Now().UnixNano())
package forward
import (
"crypto/tls"
"errors"
"fmt"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/dnstap"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/coredns/coredns/plugin/pkg/proxy"
pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
func init() {
plugin.Register("forward", setup)
}
func setup(c *caddy.Controller) error {
fs, err := parseForward(c)
if err != nil {
return plugin.Error("forward", err)
}
for i := range fs {
f := fs[i]
if f.Len() > max {
return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len()))
}
if i == len(fs)-1 {
// last forward: point next to next plugin
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
f.Next = next
return f
})
} else {
// middle forward: point next to next forward
nextForward := fs[i+1]
dnsserver.GetConfig(c).AddPlugin(func(plugin.Handler) plugin.Handler {
f.Next = nextForward
return f
})
}
c.OnStartup(func() error {
return f.OnStartup()
})
c.OnStartup(func() error {
if taph := dnsserver.GetConfig(c).Handler("dnstap"); taph != nil {
f.SetTapPlugin(taph.(*dnstap.Dnstap))
}
return nil
})
c.OnShutdown(func() error {
return f.OnShutdown()
})
}
return nil
}
// OnStartup starts a goroutines for all proxies.
func (f *Forward) OnStartup() (err error) {
for _, p := range f.proxies {
p.Start(f.hcInterval)
}
return nil
}
// OnShutdown stops all configured proxies.
func (f *Forward) OnShutdown() error {
for _, p := range f.proxies {
p.Stop()
}
return nil
}
func parseForward(c *caddy.Controller) ([]*Forward, error) {
var fs = []*Forward{}
for c.Next() {
f, err := parseStanza(c)
if err != nil {
return nil, err
}
fs = append(fs, f)
}
return fs, nil
}
func parseStanza(c *caddy.Controller) (*Forward, error) {
f := New()
if !c.Args(&f.from) {
return f, c.ArgErr()
}
origFrom := f.from
zones := plugin.Host(f.from).NormalizeExact()
if len(zones) == 0 {
return f, fmt.Errorf("unable to normalize '%s'", f.from)
}
f.from = zones[0] // there can only be one here, won't work with non-octet reverse
if len(zones) > 1 {
log.Warningf("Unsupported CIDR notation: '%s' expands to multiple zones. Using only '%s'.", origFrom, f.from)
}
to := c.RemainingArgs()
if len(to) == 0 {
return f, c.ArgErr()
}
toHosts, err := parse.HostPortOrFile(to...)
if err != nil {
return f, err
}
transports := make([]string, len(toHosts))
allowedTrans := map[string]bool{"dns": true, "tls": true}
for i, host := range toHosts {
trans, h := parse.Transport(host)
if !allowedTrans[trans] {
return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host)
}
p := proxy.NewProxy("forward", h, trans)
f.proxies = append(f.proxies, p)
transports[i] = trans
}
for c.NextBlock() {
if err := parseBlock(c, f); err != nil {
return f, err
}
}
if f.tlsServerName != "" {
f.tlsConfig.ServerName = f.tlsServerName
}
// Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake
// in upcoming connections to the same TLS server.
f.tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(len(f.proxies))
for i := range f.proxies {
// Only set this for proxies that need it.
if transports[i] == transport.TLS {
f.proxies[i].SetTLSConfig(f.tlsConfig)
}
f.proxies[i].SetExpire(f.expire)
f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired)
// when TLS is used, checks are set to tcp-tls
if f.opts.ForceTCP && transports[i] != transport.TLS {
f.proxies[i].GetHealthchecker().SetTCPTransport()
}
f.proxies[i].GetHealthchecker().SetDomain(f.opts.HCDomain)
}
return f, nil
}
func parseBlock(c *caddy.Controller, f *Forward) error {
config := dnsserver.GetConfig(c)
switch c.Val() {
case "except":
ignore := c.RemainingArgs()
if len(ignore) == 0 {
return c.ArgErr()
}
for i := range ignore {
f.ignored = append(f.ignored, plugin.Host(ignore[i]).NormalizeExact()...)
}
case "max_fails":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.ParseUint(c.Val(), 10, 32)
if err != nil {
return err
}
f.maxfails = uint32(n)
case "health_check":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
if dur < 0 {
return fmt.Errorf("health_check can't be negative: %d", dur)
}
f.hcInterval = dur
f.opts.HCDomain = "."
for c.NextArg() {
switch hcOpts := c.Val(); hcOpts {
case "no_rec":
f.opts.HCRecursionDesired = false
case "domain":
if !c.NextArg() {
return c.ArgErr()
}
hcDomain := c.Val()
if _, ok := dns.IsDomainName(hcDomain); !ok {
return fmt.Errorf("health_check: invalid domain name %s", hcDomain)
}
f.opts.HCDomain = plugin.Name(hcDomain).Normalize()
default:
return fmt.Errorf("health_check: unknown option %s", hcOpts)
}
}
case "force_tcp":
if c.NextArg() {
return c.ArgErr()
}
f.opts.ForceTCP = true
case "prefer_udp":
if c.NextArg() {
return c.ArgErr()
}
f.opts.PreferUDP = true
case "tls":
args := c.RemainingArgs()
if len(args) > 3 {
return c.ArgErr()
}
for i := range args {
if !filepath.IsAbs(args[i]) && config.Root != "" {
args[i] = filepath.Join(config.Root, args[i])
}
}
tlsConfig, err := pkgtls.NewTLSConfigFromArgs(args...)
if err != nil {
return err
}
f.tlsConfig = tlsConfig
case "tls_servername":
if !c.NextArg() {
return c.ArgErr()
}
f.tlsServerName = c.Val()
case "expire":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
if dur < 0 {
return fmt.Errorf("expire can't be negative: %s", dur)
}
f.expire = dur
case "policy":
if !c.NextArg() {
return c.ArgErr()
}
switch x := c.Val(); x {
case "random":
f.p = &random{}
case "round_robin":
f.p = &roundRobin{}
case "sequential":
f.p = &sequential{}
default:
return c.Errf("unknown policy '%s'", x)
}
case "max_concurrent":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
if n < 0 {
return fmt.Errorf("max_concurrent can't be negative: %d", n)
}
f.ErrLimitExceeded = errors.New("concurrent queries exceeded maximum " + c.Val())
f.maxConcurrent = int64(n)
case "next":
args := c.RemainingArgs()
if len(args) == 0 {
return c.ArgErr()
}
for _, rcode := range args {
var rc int
var ok bool
if rc, ok = dns.StringToRcode[strings.ToUpper(rcode)]; !ok {
return fmt.Errorf("%s is not a valid rcode", rcode)
}
f.nextAlternateRcodes = append(f.nextAlternateRcodes, rc)
}
case "failfast_all_unhealthy_upstreams":
args := c.RemainingArgs()
if len(args) != 0 {
return c.ArgErr()
}
f.failfastUnhealthyUpstreams = true
default:
return c.Errf("unknown property '%s'", c.Val())
}
return nil
}
const max = 15 // Maximum number of upstreams.
package metadata
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Metadata implements collecting metadata information from all plugins that
// implement the Provider interface.
type Metadata struct {
Zones []string
Providers []Provider
Next plugin.Handler
}
// Name implements the Handler interface.
func (m *Metadata) Name() string { return "metadata" }
// ContextWithMetadata is exported for use by provider tests
func ContextWithMetadata(ctx context.Context) context.Context {
return context.WithValue(ctx, key{}, md{})
}
// ServeDNS implements the plugin.Handler interface.
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rcode, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, w, r)
return rcode, err
}
// Collect will retrieve metadata functions from each metadata provider and update the context
func (m *Metadata) Collect(ctx context.Context, state request.Request) context.Context {
ctx = ContextWithMetadata(ctx)
if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
// Go through all Providers and collect metadata.
for _, p := range m.Providers {
ctx = p.Metadata(ctx, state)
}
}
return ctx
}
// Package metadata provides an API that allows plugins to add metadata to the context.
// Each metadata is stored under a label that has the form <plugin>/<name>. Each metadata
// is returned as a Func. When Func is called the metadata is returned. If Func is expensive to
// execute it is its responsibility to provide some form of caching. During the handling of a
// query it is expected the metadata stays constant.
//
// Basic example:
//
// Implement the Provider interface for a plugin p:
//
// func (p P) Metadata(ctx context.Context, state request.Request) context.Context {
// metadata.SetValueFunc(ctx, "test/something", func() string { return "myvalue" })
// return ctx
// }
//
// Basic example with caching:
//
// func (p P) Metadata(ctx context.Context, state request.Request) context.Context {
// cached := ""
// f := func() string {
// if cached != "" {
// return cached
// }
// cached = expensiveFunc()
// return cached
// }
// metadata.SetValueFunc(ctx, "test/something", f)
// return ctx
// }
//
// If you need access to this metadata from another plugin:
//
// // ...
// valueFunc := metadata.ValueFunc(ctx, "test/something")
// value := valueFunc()
// // use 'value'
package metadata
import (
"context"
"strings"
"github.com/coredns/coredns/request"
)
// Provider interface needs to be implemented by each plugin willing to provide
// metadata information for other plugins.
type Provider interface {
// Metadata adds metadata to the context and returns a (potentially) new context.
// Note: this method should work quickly, because it is called for every request
// from the metadata plugin.
Metadata(ctx context.Context, state request.Request) context.Context
}
// Func is the type of function in the metadata, when called they return the value of the label.
type Func func() string
// IsLabel checks that the provided name is a valid label name, i.e. two or more words separated by a slash.
func IsLabel(label string) bool {
p := strings.Index(label, "/")
if p <= 0 || p >= len(label)-1 {
// cannot accept namespace empty nor label empty
return false
}
return true
}
// Labels returns all metadata keys stored in the context. These label names should be named
// as: plugin/NAME, where NAME is something descriptive.
func Labels(ctx context.Context) []string {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
return keys(m)
}
}
return nil
}
// ValueFuncs returns the map[string]Func from the context, or nil if it does not exist.
func ValueFuncs(ctx context.Context) map[string]Func {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
return m
}
}
return nil
}
// ValueFunc returns the value function of label. If none can be found nil is returned. Calling the
// function returns the value of the label.
func ValueFunc(ctx context.Context, label string) Func {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
return m[label]
}
}
return nil
}
// SetValueFunc set the metadata label to the value function. If no metadata can be found this is a noop and
// false is returned. Any existing value is overwritten.
func SetValueFunc(ctx context.Context, label string, f Func) bool {
if metadata := ctx.Value(key{}); metadata != nil {
if m, ok := metadata.(md); ok {
m[label] = f
return true
}
}
return false
}
// md is metadata information storage.
type md map[string]Func
// key defines the type of key that is used to save metadata into the context.
type key struct{}
func keys(m map[string]Func) []string {
s := make([]string, len(m))
i := 0
for k := range m {
s[i] = k
i++
}
return s
}
package metadata
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("metadata", setup) }
func setup(c *caddy.Controller) error {
m, err := metadataParse(c)
if err != nil {
return err
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
m.Next = next
return m
})
c.OnStartup(func() error {
plugins := dnsserver.GetConfig(c).Handlers()
for _, p := range plugins {
if met, ok := p.(Provider); ok {
m.Providers = append(m.Providers, met)
}
}
return nil
})
return nil
}
func metadataParse(c *caddy.Controller) (*Metadata, error) {
m := &Metadata{}
c.Next()
m.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
if c.NextBlock() || c.Next() {
return nil, plugin.Error("metadata", c.ArgErr())
}
return m, nil
}
package metrics
import (
"context"
"github.com/coredns/coredns/core/dnsserver"
)
// WithServer returns the current server handling the request. It returns the
// server listening address: <scheme>://[<bind>]:<port> Normally this is
// something like "dns://:53", but if the bind plugin is used, i.e. "bind
// 127.0.0.53", it will be "dns://127.0.0.53:53", etc. If not address is found
// the empty string is returned.
//
// Basic usage with a metric:
//
// <metric>.WithLabelValues(metrics.WithServer(ctx), labels..).Add(1)
func WithServer(ctx context.Context) string {
srv := ctx.Value(dnsserver.Key{})
if srv == nil {
return ""
}
return srv.(*dnsserver.Server).Addr
}
// WithView returns the name of the view currently handling the request, if a view is defined.
//
// Basic usage with a metric:
//
// <metric>.WithLabelValues(metrics.WithView(ctx), labels..).Add(1)
func WithView(ctx context.Context) string {
v := ctx.Value(dnsserver.ViewKey{})
if v == nil {
return ""
}
return v.(string)
}
package metrics
import (
"context"
"path/filepath"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// ServeDNS implements the Handler interface.
func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
// Capture the original request size before any plugins modify it
originalSize := r.Len()
qname := state.QName()
zone := plugin.Zones(m.ZoneNames()).Matches(qname)
if zone == "" {
zone = "."
}
// Record response to get status code and size of the reply.
rw := NewRecorder(w)
status, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, rw, r)
rc := rw.Rcode
if !plugin.ClientWrite(status) {
// when no response was written, fallback to status returned from next plugin as this status
// is actually used as rcode of DNS response
// see https://github.com/coredns/coredns/blob/master/core/dnsserver/server.go#L318
rc = status
}
plugin := m.authoritativePlugin(rw.Caller)
// Pass the original request size to vars.Report
vars.Report(WithServer(ctx), state, zone, WithView(ctx), rcode.ToString(rc), plugin,
rw.Len, rw.Start, vars.WithOriginalReqSize(originalSize))
return status, err
}
// Name implements the Handler interface.
func (m *Metrics) Name() string { return "prometheus" }
// authoritativePlugin returns which of made the write, if none is found the empty string is returned.
func (m *Metrics) authoritativePlugin(caller [3]string) string {
// a b and c contain the full path of the caller, the plugin name 2nd last elements
// .../coredns/plugin/whoami/whoami.go --> whoami
// this is likely FS specific, so use filepath.
for _, c := range caller {
plug := filepath.Base(filepath.Dir(c))
if _, ok := m.plugins[plug]; ok {
return plug
}
}
return ""
}
// Package metrics implement a handler and plugin that provides Prometheus metrics.
package metrics
import (
"context"
"net"
"net/http"
"sync"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics .
type Metrics struct {
Next plugin.Handler
Addr string
Reg *prometheus.Registry
ln net.Listener
lnSetup bool
mux *http.ServeMux
srv *http.Server
zoneNames []string
zoneMap map[string]struct{}
zoneMu sync.RWMutex
plugins map[string]struct{} // all available plugins, used to determine which plugin made the client write
}
// New returns a new instance of Metrics with the given address.
func New(addr string) *Metrics {
met := &Metrics{
Addr: addr,
Reg: prometheus.DefaultRegisterer.(*prometheus.Registry),
zoneMap: make(map[string]struct{}),
plugins: pluginList(caddy.ListPlugins()),
}
return met
}
// MustRegister wraps m.Reg.MustRegister.
func (m *Metrics) MustRegister(c prometheus.Collector) {
err := m.Reg.Register(c)
if err != nil {
// ignore any duplicate error, but fatal on any other kind of error
if _, ok := err.(prometheus.AlreadyRegisteredError); !ok {
log.Fatalf("Cannot register metrics collector: %s", err)
}
}
}
// AddZone adds zone z to m.
func (m *Metrics) AddZone(z string) {
m.zoneMu.Lock()
m.zoneMap[z] = struct{}{}
m.zoneNames = keys(m.zoneMap)
m.zoneMu.Unlock()
}
// RemoveZone remove zone z from m.
func (m *Metrics) RemoveZone(z string) {
m.zoneMu.Lock()
delete(m.zoneMap, z)
m.zoneNames = keys(m.zoneMap)
m.zoneMu.Unlock()
}
// ZoneNames returns the zones of m.
func (m *Metrics) ZoneNames() []string {
m.zoneMu.RLock()
s := m.zoneNames
m.zoneMu.RUnlock()
return s
}
// OnStartup sets up the metrics on startup.
func (m *Metrics) OnStartup() error {
ln, err := reuseport.Listen("tcp", m.Addr)
if err != nil {
log.Errorf("Failed to start metrics handler: %s", err)
return err
}
m.ln = ln
m.lnSetup = true
m.mux = http.NewServeMux()
m.mux.Handle("/metrics", promhttp.HandlerFor(m.Reg, promhttp.HandlerOpts{}))
// creating some helper variables to avoid data races on m.srv and m.ln
server := &http.Server{Handler: m.mux}
m.srv = server
go func() {
server.Serve(ln)
}()
ListenAddr = ln.Addr().String() // For tests.
return nil
}
// OnRestart stops the listener on reload.
func (m *Metrics) OnRestart() error {
if !m.lnSetup {
return nil
}
u.Unset(m.Addr)
return m.stopServer()
}
func (m *Metrics) stopServer() error {
if !m.lnSetup {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
if err := m.srv.Shutdown(ctx); err != nil {
log.Infof("Failed to stop prometheus http server: %s", err)
return err
}
m.lnSetup = false
m.ln.Close()
return nil
}
// OnFinalShutdown tears down the metrics listener on shutdown and restart.
func (m *Metrics) OnFinalShutdown() error { return m.stopServer() }
func keys(m map[string]struct{}) []string {
sx := []string{}
for k := range m {
sx = append(sx, k)
}
return sx
}
// pluginList iterates over the returned plugin map from caddy and removes the "dns." prefix from them.
func pluginList(m map[string][]string) map[string]struct{} {
pm := map[string]struct{}{}
for _, p := range m["others"] {
// only add 'dns.' plugins
if len(p) > 3 {
pm[p[4:]] = struct{}{}
continue
}
}
return pm
}
// ListenAddr is assigned the address of the prometheus listener. Its use is mainly in tests where
// we listen on "localhost:0" and need to retrieve the actual address.
var ListenAddr string
// shutdownTimeout is the maximum amount of time the metrics plugin will wait
// before erroring when it tries to close the metrics server
const shutdownTimeout time.Duration = time.Second * 5
var buildInfo = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: plugin.Namespace,
Name: "build_info",
Help: "A metric with a constant '1' value labeled by version, revision, and goversion from which CoreDNS was built.",
}, []string{"version", "revision", "goversion"})
package metrics
import (
"runtime"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/miekg/dns"
)
// Recorder is a dnstest.Recorder specific to the metrics plugin.
type Recorder struct {
*dnstest.Recorder
// CallerN holds the string return value of the call to runtime.Caller(N+1)
Caller [3]string
}
// NewRecorder makes and returns a new Recorder.
func NewRecorder(w dns.ResponseWriter) *Recorder { return &Recorder{Recorder: dnstest.NewRecorder(w)} }
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *Recorder) WriteMsg(res *dns.Msg) error {
_, r.Caller[0], _, _ = runtime.Caller(1)
_, r.Caller[1], _, _ = runtime.Caller(2)
_, r.Caller[2], _, _ = runtime.Caller(3)
return r.Recorder.WriteMsg(res)
}
package metrics
import (
"sync"
"github.com/prometheus/client_golang/prometheus"
)
type reg struct {
sync.RWMutex
r map[string]*prometheus.Registry
}
func newReg() *reg { return ®{r: make(map[string]*prometheus.Registry)} }
// update sets the registry if not already there and returns the input. Or it returns
// a previous set value.
func (r *reg) getOrSet(addr string, pr *prometheus.Registry) *prometheus.Registry {
r.Lock()
defer r.Unlock()
if v, ok := r.r[addr]; ok {
return v
}
r.r[addr] = pr
return pr
}
package metrics
import (
"net"
"runtime"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/coremain"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/uniq"
)
var (
log = clog.NewWithPlugin("prometheus")
u = uniq.New()
registry = newReg()
)
func init() { plugin.Register("prometheus", setup) }
func setup(c *caddy.Controller) error {
m, err := parse(c)
if err != nil {
return plugin.Error("prometheus", err)
}
m.Reg = registry.getOrSet(m.Addr, m.Reg)
c.OnStartup(func() error { m.Reg = registry.getOrSet(m.Addr, m.Reg); u.Set(m.Addr, m.OnStartup); return nil })
c.OnRestartFailed(func() error { m.Reg = registry.getOrSet(m.Addr, m.Reg); u.Set(m.Addr, m.OnStartup); return nil })
c.OnStartup(func() error { return u.ForEach() })
c.OnRestartFailed(func() error { return u.ForEach() })
c.OnStartup(func() error {
conf := dnsserver.GetConfig(c)
for _, h := range conf.ListenHosts {
addrstr := conf.Transport + "://" + net.JoinHostPort(h, conf.Port)
for _, p := range conf.Handlers() {
vars.PluginEnabled.WithLabelValues(addrstr, conf.Zone, conf.ViewName, p.Name()).Set(1)
}
}
return nil
})
c.OnRestartFailed(func() error {
conf := dnsserver.GetConfig(c)
for _, h := range conf.ListenHosts {
addrstr := conf.Transport + "://" + net.JoinHostPort(h, conf.Port)
for _, p := range conf.Handlers() {
vars.PluginEnabled.WithLabelValues(addrstr, conf.Zone, conf.ViewName, p.Name()).Set(1)
}
}
return nil
})
c.OnRestart(m.OnRestart)
c.OnRestart(func() error { vars.PluginEnabled.Reset(); return nil })
c.OnFinalShutdown(m.OnFinalShutdown)
// Initialize metrics.
buildInfo.WithLabelValues(coremain.CoreVersion, coremain.GitCommit, runtime.Version()).Set(1)
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
m.Next = next
return m
})
return nil
}
func parse(c *caddy.Controller) (*Metrics, error) {
met := New(defaultAddr)
i := 0
for c.Next() {
if i > 0 {
return nil, plugin.ErrOnce
}
i++
zones := plugin.OriginsFromArgsOrServerBlock(nil /* args */, c.ServerBlockKeys)
for _, z := range zones {
met.AddZone(z)
}
args := c.RemainingArgs()
switch len(args) {
case 0:
case 1:
met.Addr = args[0]
_, _, e := net.SplitHostPort(met.Addr)
if e != nil {
return met, e
}
default:
return met, c.ArgErr()
}
}
return met, nil
}
// defaultAddr is the address the where the metrics are exported by default.
const defaultAddr = "localhost:9153"
package vars
import (
"github.com/miekg/dns"
)
var monitorType = map[uint16]struct{}{
dns.TypeAAAA: {},
dns.TypeA: {},
dns.TypeCNAME: {},
dns.TypeDNSKEY: {},
dns.TypeDS: {},
dns.TypeMX: {},
dns.TypeNSEC3: {},
dns.TypeNSEC: {},
dns.TypeNS: {},
dns.TypePTR: {},
dns.TypeRRSIG: {},
dns.TypeSOA: {},
dns.TypeSRV: {},
dns.TypeTXT: {},
dns.TypeHTTPS: {},
// Meta Qtypes
dns.TypeIXFR: {},
dns.TypeAXFR: {},
dns.TypeANY: {},
}
// qTypeString returns the RR type based on monitorType. It returns the text representation
// of those types. RR types not in that list will have "other" returned.
func qTypeString(qtype uint16) string {
if _, known := monitorType[qtype]; known {
return dns.Type(qtype).String()
}
return "other"
}
package vars
import (
"time"
"github.com/coredns/coredns/request"
)
// ReportOptions is a struct that contains available options for the Report function.
type ReportOptions struct {
OriginalReqSize int
}
// ReportOption defines a function that modifies ReportOptions
type ReportOption func(*ReportOptions)
// WithOriginalReqSize returns an option to set the original request size
func WithOriginalReqSize(size int) ReportOption {
return func(opts *ReportOptions) {
opts.OriginalReqSize = size
}
}
// Report reports the metrics data associated with request. This function is exported because it is also
// called from core/dnsserver to report requests hitting the server that should not be handled and are thus
// not sent down the plugin chain.
func Report(server string, req request.Request, zone, view, rcode, plugin string,
size int, start time.Time, opts ...ReportOption) {
options := ReportOptions{
OriginalReqSize: 0,
}
for _, opt := range opts {
opt(&options)
}
// Proto and Family.
net := req.Proto()
fam := "1"
if req.Family() == 2 {
fam = "2"
}
if req.Do() {
RequestDo.WithLabelValues(server, zone, view).Inc()
}
qType := qTypeString(req.QType())
RequestCount.WithLabelValues(server, zone, view, net, fam, qType).Inc()
RequestDuration.WithLabelValues(server, zone, view).Observe(time.Since(start).Seconds())
ResponseSize.WithLabelValues(server, zone, view, net).Observe(float64(size))
reqSize := req.Len()
if options.OriginalReqSize > 0 {
reqSize = options.OriginalReqSize
}
RequestSize.WithLabelValues(server, zone, view, net).Observe(float64(reqSize))
ResponseRcode.WithLabelValues(server, zone, view, rcode, plugin).Inc()
}
package plugin
import (
"fmt"
"net"
"runtime"
"strconv"
"strings"
"github.com/coredns/coredns/plugin/pkg/cidr"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/miekg/dns"
)
// See core/dnsserver/address.go - we should unify these two impls.
// Zones represents a lists of zone names.
type Zones []string
// Matches checks if qname is a subdomain of any of the zones in z. The match
// will return the most specific zones that matches. The empty string
// signals a not found condition.
func (z Zones) Matches(qname string) string {
zone := ""
for _, zname := range z {
if dns.IsSubDomain(zname, qname) {
// We want the *longest* matching zone, otherwise we may end up in a parent
if len(zname) > len(zone) {
zone = zname
}
}
}
return zone
}
// Normalize fully qualifies all zones in z. The zones in Z must be domain names, without
// a port or protocol prefix.
func (z Zones) Normalize() {
for i := range z {
z[i] = Name(z[i]).Normalize()
}
}
// Name represents a domain name.
type Name string
// Matches checks to see if other is a subdomain (or the same domain) of n.
// This method assures that names can be easily and consistently matched.
func (n Name) Matches(child string) bool {
if dns.Name(n) == dns.Name(child) {
return true
}
return dns.IsSubDomain(string(n), child)
}
// Normalize lowercases and makes n fully qualified.
func (n Name) Normalize() string { return strings.ToLower(dns.Fqdn(string(n))) }
type (
// Host represents a host from the Corefile, may contain port.
Host string
)
// Normalize will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased.
// An empty string is returned on failure
// Deprecated: use OriginsFromArgsOrServerBlock or NormalizeExact
func (h Host) Normalize() string {
var caller string
if _, file, line, ok := runtime.Caller(1); ok {
caller = fmt.Sprintf("(%v line %d) ", file, line)
}
log.Warning("An external plugin " + caller + "is using the deprecated function Normalize. " +
"This will be removed in a future versions of CoreDNS. The plugin should be updated to use " +
"OriginsFromArgsOrServerBlock or NormalizeExact instead.")
s := string(h)
_, s = parse.Transport(s)
// The error can be ignored here, because this function is called after the corefile has already been vetted.
hosts, _, err := SplitHostPort(s)
if err != nil {
return ""
}
return Name(hosts[0]).Normalize()
}
// MustNormalize will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased.
// An error is returned on error
// Deprecated: use OriginsFromArgsOrServerBlock or NormalizeExact
func (h Host) MustNormalize() (string, error) {
var caller string
if _, file, line, ok := runtime.Caller(1); ok {
caller = fmt.Sprintf("(%v line %d) ", file, line)
}
log.Warning("An external plugin " + caller + "is using the deprecated function MustNormalize. " +
"This will be removed in a future versions of CoreDNS. The plugin should be updated to use " +
"OriginsFromArgsOrServerBlock or NormalizeExact instead.")
s := string(h)
_, s = parse.Transport(s)
// The error can be ignored here, because this function is called after the corefile has already been vetted.
hosts, _, err := SplitHostPort(s)
if err != nil {
return "", err
}
return Name(hosts[0]).Normalize(), nil
}
// NormalizeExact will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased.
// An empty slice is returned on failure
func (h Host) NormalizeExact() []string {
// The error can be ignored here, because this function should only be called after the corefile has already been vetted.
s := string(h)
_, s = parse.Transport(s)
hosts, _, err := SplitHostPort(s)
if err != nil {
return nil
}
for i := range hosts {
hosts[i] = Name(hosts[i]).Normalize()
}
return hosts
}
// SplitHostPort splits s up in a host(s) and port portion, taking reverse address notation into account.
// String the string s should *not* be prefixed with any protocols, i.e. dns://. SplitHostPort can return
// multiple hosts when a reverse notation on a non-octet boundary is given.
func SplitHostPort(s string) (hosts []string, port string, err error) {
// If there is: :[0-9]+ on the end we assume this is the port. This works for (ascii) domain
// names and our reverse syntax, which always needs a /mask *before* the port.
// So from the back, find first colon, and then check if it's a number.
colon := strings.LastIndex(s, ":")
if colon == len(s)-1 {
return nil, "", fmt.Errorf("expecting data after last colon: %q", s)
}
if colon != -1 {
if p, err := strconv.Atoi(s[colon+1:]); err == nil {
port = strconv.Itoa(p)
s = s[:colon]
}
}
// TODO(miek): this should take escaping into account.
if len(s) > 255 {
return nil, "", fmt.Errorf("specified zone is too long: %d > 255", len(s))
}
if _, ok := dns.IsDomainName(s); !ok {
return nil, "", fmt.Errorf("zone is not a valid domain name: %s", s)
}
// Check if it parses as a reverse zone, if so we use that. Must be fully specified IP and mask.
_, n, err := net.ParseCIDR(s)
if err != nil {
return []string{s}, port, nil
}
if s[0] == ':' || (s[0] == '0' && strings.Contains(s, ":")) {
return nil, "", fmt.Errorf("invalid CIDR %s", s)
}
// now check if multiple hosts must be returned.
nets := cidr.Split(n)
hosts = cidr.Reverse(nets)
return hosts, port, nil
}
// OriginsFromArgsOrServerBlock returns the normalized args if that slice
// is not empty, otherwise the serverblock slice is returned (in a newly copied slice).
func OriginsFromArgsOrServerBlock(args, serverblock []string) []string {
if len(args) == 0 {
s := make([]string, len(serverblock))
copy(s, serverblock)
for i := range s {
s[i] = Host(s[i]).NormalizeExact()[0] // expansion of these already happened in dnsserver/register.go
}
return s
}
s := []string{}
for i := range args {
sx := Host(args[i]).NormalizeExact()
if len(sx) == 0 {
continue // silently ignores errors.
}
s = append(s, sx...)
}
return s
}
// Package cache implements a cache. The cache hold 256 shards, each shard
// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it
// just randomly evicts elements when it gets full.
package cache
import (
"hash/fnv"
"sync"
)
// Hash returns the FNV hash of what.
func Hash(what []byte) uint64 {
h := fnv.New64()
h.Write(what)
return h.Sum64()
}
// Cache is cache.
type Cache struct {
shards [shardSize]*shard
}
// shard is a cache with random eviction.
type shard struct {
items map[uint64]interface{}
size int
sync.RWMutex
}
// New returns a new cache.
func New(size int) *Cache {
ssize := size / shardSize
if ssize < 4 {
ssize = 4
}
c := &Cache{}
// Initialize all the shards
for i := range shardSize {
c.shards[i] = newShard(ssize)
}
return c
}
// Add adds a new element to the cache. If the element already exists it is overwritten.
// Returns true if an existing element was evicted to make room for this element.
func (c *Cache) Add(key uint64, el interface{}) bool {
shard := key & (shardSize - 1)
return c.shards[shard].Add(key, el)
}
// Get looks up element index under key.
func (c *Cache) Get(key uint64) (interface{}, bool) {
shard := key & (shardSize - 1)
return c.shards[shard].Get(key)
}
// Remove removes the element indexed with key.
func (c *Cache) Remove(key uint64) {
shard := key & (shardSize - 1)
c.shards[shard].Remove(key)
}
// Len returns the number of elements in the cache.
func (c *Cache) Len() int {
l := 0
for _, s := range &c.shards {
l += s.Len()
}
return l
}
// Walk walks each shard in the cache.
func (c *Cache) Walk(f func(map[uint64]interface{}, uint64) bool) {
for _, s := range &c.shards {
s.Walk(f)
}
}
// newShard returns a new shard with size.
func newShard(size int) *shard { return &shard{items: make(map[uint64]interface{}), size: size} }
// Add adds element indexed by key into the cache. Any existing element is overwritten
// Returns true if an existing element was evicted to make room for this element.
func (s *shard) Add(key uint64, el interface{}) bool {
eviction := false
s.Lock()
if len(s.items) >= s.size {
if _, ok := s.items[key]; !ok {
for k := range s.items {
delete(s.items, k)
eviction = true
break
}
}
}
s.items[key] = el
s.Unlock()
return eviction
}
// Remove removes the element indexed by key from the cache.
func (s *shard) Remove(key uint64) {
s.Lock()
delete(s.items, key)
s.Unlock()
}
// Evict removes a random element from the cache.
func (s *shard) Evict() {
s.Lock()
for k := range s.items {
delete(s.items, k)
break
}
s.Unlock()
}
// Get looks up the element indexed under key.
func (s *shard) Get(key uint64) (interface{}, bool) {
s.RLock()
el, found := s.items[key]
s.RUnlock()
return el, found
}
// Len returns the current length of the cache.
func (s *shard) Len() int {
s.RLock()
l := len(s.items)
s.RUnlock()
return l
}
// Walk walks the shard for each element the function f is executed while holding a write lock.
func (s *shard) Walk(f func(map[uint64]interface{}, uint64) bool) {
s.RLock()
items := make([]uint64, len(s.items))
i := 0
for k := range s.items {
items[i] = k
i++
}
s.RUnlock()
for _, k := range items {
s.Lock()
ok := f(s.items, k)
s.Unlock()
if !ok {
return
}
}
}
const shardSize = 256
// Package cidr contains functions that deal with classless reverse zones in the DNS.
package cidr
import (
"math"
"net"
"strings"
"github.com/apparentlymart/go-cidr/cidr"
"github.com/miekg/dns"
)
// Split returns a slice of non-overlapping subnets that in union equal the subnet n,
// and where each subnet falls on a reverse name segment boundary.
// for ipv4 this is any multiple of 8 bits (/8, /16, /24 or /32)
// for ipv6 this is any multiple of 4 bits
func Split(n *net.IPNet) []string {
boundary := 8
nstr := n.String()
if strings.Contains(nstr, ":") {
boundary = 4
}
ones, _ := n.Mask.Size()
if ones%boundary == 0 {
return []string{n.String()}
}
mask := int(math.Ceil(float64(ones)/float64(boundary))) * boundary
networks := nets(n, mask)
cidrs := make([]string, len(networks))
for i := range networks {
cidrs[i] = networks[i].String()
}
return cidrs
}
// nets return a slice of prefixes with the desired mask subnetted from original network.
func nets(network *net.IPNet, newPrefixLen int) []*net.IPNet {
prefixLen, _ := network.Mask.Size()
maxSubnets := int(math.Exp2(float64(newPrefixLen)) / math.Exp2(float64(prefixLen)))
nets := []*net.IPNet{{IP: network.IP, Mask: net.CIDRMask(newPrefixLen, 8*len(network.IP))}}
for i := 1; i < maxSubnets; i++ {
next, exceeds := cidr.NextSubnet(nets[len(nets)-1], newPrefixLen)
nets = append(nets, next)
if exceeds {
break
}
}
return nets
}
// Reverse return the reverse zones that are authoritative for each net in ns.
func Reverse(nets []string) []string {
rev := make([]string, len(nets))
for i := range nets {
ip, n, _ := net.ParseCIDR(nets[i])
r, err1 := dns.ReverseAddr(ip.String())
if err1 != nil {
continue
}
ones, bits := n.Mask.Size()
// get the size, in bits, of each portion of hostname defined in the reverse address. (8 for IPv4, 4 for IPv6)
sizeDigit := 8
if len(n.IP) == net.IPv6len {
sizeDigit = 4
}
// Get the first lower octet boundary to see what encompassing zone we should be authoritative for.
mod := (bits - ones) % sizeDigit
nearest := (bits - ones) + mod
offset := 0
var end bool
for range nearest / sizeDigit {
offset, end = dns.NextLabel(r, offset)
if end {
break
}
}
rev[i] = r[offset:]
}
return rev
}
package dnstest
import (
"time"
"github.com/miekg/dns"
)
// MultiRecorder is a type of ResponseWriter that captures all messages written to it.
type MultiRecorder struct {
Len int
Msgs []*dns.Msg
Start time.Time
dns.ResponseWriter
}
// NewMultiRecorder makes and returns a new MultiRecorder.
func NewMultiRecorder(w dns.ResponseWriter) *MultiRecorder {
return &MultiRecorder{
ResponseWriter: w,
Msgs: make([]*dns.Msg, 0),
Start: time.Now(),
}
}
// WriteMsg records the message and its length written to it and call the
// underlying ResponseWriter's WriteMsg method.
func (r *MultiRecorder) WriteMsg(res *dns.Msg) error {
r.Len += res.Len()
r.Msgs = append(r.Msgs, res)
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the length of the messages that get written to it.
func (r *MultiRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.Len += n
}
return n, err
}
// Package dnstest allows for easy testing of DNS client against a test server.
package dnstest
import (
"time"
"github.com/miekg/dns"
)
// Recorder is a type of ResponseWriter that captures
// the rcode code written to it and also the size of the message
// written in the response. A rcode code does not have
// to be written, however, in which case 0 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
type Recorder struct {
dns.ResponseWriter
Rcode int
Len int
Msg *dns.Msg
Start time.Time
}
// NewRecorder makes and returns a new Recorder,
// which captures the DNS rcode from the ResponseWriter
// and also the length of the response message written through it.
func NewRecorder(w dns.ResponseWriter) *Recorder {
return &Recorder{
ResponseWriter: w,
Rcode: 0,
Msg: nil,
Start: time.Now(),
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *Recorder) WriteMsg(res *dns.Msg) error {
r.Rcode = res.Rcode
// We may get called multiple times (axfr for instance).
// Save the last message, but add the sizes.
r.Len += res.Len()
r.Msg = res
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the length of the message that gets written.
func (r *Recorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.Len += n
}
return n, err
}
package dnstest
import (
"net"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/miekg/dns"
)
// A Server is an DNS server listening on a system-chosen port on the local
// loopback interface, for use in end-to-end DNS tests.
type Server struct {
Addr string // Address where the server listening.
s1 *dns.Server // udp
s2 *dns.Server // tcp
}
// NewServer starts and returns a new Server. The caller should call Close when
// finished, to shut it down.
func NewServer(f dns.HandlerFunc) *Server {
dns.HandleFunc(".", f)
ch1 := make(chan bool)
ch2 := make(chan bool)
s1 := &dns.Server{} // udp
s2 := &dns.Server{} // tcp
for range 5 { // 5 attempts
s2.Listener, _ = reuseport.Listen("tcp", ":0")
if s2.Listener == nil {
continue
}
s1.PacketConn, _ = net.ListenPacket("udp", s2.Listener.Addr().String())
if s1.PacketConn != nil {
break
}
// perhaps UPD port is in use, try again
s2.Listener.Close()
s2.Listener = nil
}
if s2.Listener == nil {
panic("dnstest.NewServer(): failed to create new server")
}
s1.NotifyStartedFunc = func() { close(ch1) }
s2.NotifyStartedFunc = func() { close(ch2) }
go s1.ActivateAndServe()
go s2.ActivateAndServe()
<-ch1
<-ch2
return &Server{s1: s1, s2: s2, Addr: s2.Listener.Addr().String()}
}
// Close shuts down the server.
func (s *Server) Close() {
s.s1.Shutdown()
s.s2.Shutdown()
}
package dnsutil
import "github.com/miekg/dns"
// DuplicateCNAME returns true if r already exists in records.
func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}
package dnsutil
import (
"strings"
"github.com/miekg/dns"
)
// Join joins labels to form a fully qualified domain name. If the last label is
// the root label it is ignored. Not other syntax checks are performed.
func Join(labels ...string) string {
ll := len(labels)
if labels[ll-1] == "." {
return strings.Join(labels[:ll-1], ".") + "."
}
return dns.Fqdn(strings.Join(labels, "."))
}
package dnsutil
import (
"net"
"strings"
)
// ExtractAddressFromReverse turns a standard PTR reverse record name
// into an IP address. This works for ipv4 or ipv6.
//
// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion
// fails the empty string is returned.
func ExtractAddressFromReverse(reverseName string) string {
f := reverse
var search string
switch {
case strings.HasSuffix(reverseName, IP4arpa):
search = strings.TrimSuffix(reverseName, IP4arpa)
case strings.HasSuffix(reverseName, IP6arpa):
search = strings.TrimSuffix(reverseName, IP6arpa)
f = reverse6
default:
return ""
}
// Reverse the segments and then combine them.
return f(strings.Split(search, "."))
}
// IsReverse returns 0 is name is not in a reverse zone. Anything > 0 indicates
// name is in a reverse zone. The returned integer will be 1 for in-addr.arpa. (IPv4)
// and 2 for ip6.arpa. (IPv6).
func IsReverse(name string) int {
if strings.HasSuffix(name, IP4arpa) {
return 1
}
if strings.HasSuffix(name, IP6arpa) {
return 2
}
return 0
}
func reverse(slice []string) string {
for i := range len(slice) / 2 {
j := len(slice) - i - 1
slice[i], slice[j] = slice[j], slice[i]
}
ip := net.ParseIP(strings.Join(slice, ".")).To4()
if ip == nil {
return ""
}
return ip.String()
}
// reverse6 reverse the segments and combine them according to RFC3596:
// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2
// is reversed to 2001:db8::567:89ab
func reverse6(slice []string) string {
for i := range len(slice) / 2 {
j := len(slice) - i - 1
slice[i], slice[j] = slice[j], slice[i]
}
slice6 := []string{}
for i := range len(slice) / 4 {
slice6 = append(slice6, strings.Join(slice[i*4:i*4+4], ""))
}
ip := net.ParseIP(strings.Join(slice6, ":")).To16()
if ip == nil {
return ""
}
return ip.String()
}
const (
// IP4arpa is the reverse tree suffix for v4 IP addresses.
IP4arpa = ".in-addr.arpa."
// IP6arpa is the reverse tree suffix for v6 IP addresses.
IP6arpa = ".ip6.arpa."
)
package dnsutil
import (
"time"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/miekg/dns"
)
// MinimalTTL scans the message returns the lowest TTL found taking into the response.Type of the message.
func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration {
if mt != response.NoError && mt != response.NameError && mt != response.NoData {
return MinimalDefaultTTL
}
// No records or OPT is the only record, return a short ttl as a fail safe.
if len(m.Answer)+len(m.Ns) == 0 &&
(len(m.Extra) == 0 || (len(m.Extra) == 1 && m.Extra[0].Header().Rrtype == dns.TypeOPT)) {
return MinimalDefaultTTL
}
minTTL := MaximumDefaulTTL
for _, r := range m.Answer {
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
for _, r := range m.Ns {
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
for _, r := range m.Extra {
if r.Header().Rrtype == dns.TypeOPT {
// OPT records use TTL field for extended rcode and flags
continue
}
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
return minTTL
}
const (
// MinimalDefaultTTL is the absolute lowest TTL we use in CoreDNS.
MinimalDefaultTTL = 5 * time.Second
// MaximumDefaulTTL is the maximum TTL was use on RRsets in CoreDNS.
// TODO: rename as MaximumDefaultTTL
MaximumDefaulTTL = 1 * time.Hour
)
package dnsutil
import (
"errors"
"github.com/miekg/dns"
)
// TrimZone removes the zone component from q. It returns the trimmed
// name or an error is zone is longer then qname. The trimmed name will be returned
// without a trailing dot.
func TrimZone(q string, z string) (string, error) {
zl := dns.CountLabel(z)
i, ok := dns.PrevLabel(q, zl)
if ok || i-1 < 0 {
return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z)
}
// This includes the '.', remove on return
return q[:i-1], nil
}
package doh
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"github.com/miekg/dns"
)
// MimeType is the DoH mimetype that should be used.
const MimeType = "application/dns-message"
// Path is the URL path that should be used.
const Path = "/dns-query"
// NewRequest returns a new DoH request given a HTTP method, URL and dns.Msg.
//
// The URL should not have a path, so please exclude /dns-query. The URL will
// be prefixed with https:// by default, unless it's already prefixed with
// either http:// or https://.
func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) {
buf, err := m.Pack()
if err != nil {
return nil, err
}
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
url = fmt.Sprintf("https://%s", url)
}
switch method {
case http.MethodGet:
b64 := base64.RawURLEncoding.EncodeToString(buf)
req, err := http.NewRequest(
http.MethodGet,
fmt.Sprintf("%s%s?dns=%s", url, Path, b64),
nil,
)
if err != nil {
return req, err
}
req.Header.Set("Content-Type", MimeType)
req.Header.Set("Accept", MimeType)
return req, nil
case http.MethodPost:
req, err := http.NewRequest(
http.MethodPost,
fmt.Sprintf("%s%s", url, Path),
bytes.NewReader(buf),
)
if err != nil {
return req, err
}
req.Header.Set("Content-Type", MimeType)
req.Header.Set("Accept", MimeType)
return req, nil
default:
return nil, fmt.Errorf("method not allowed: %s", method)
}
}
// ResponseToMsg converts a http.Response to a dns message.
func ResponseToMsg(resp *http.Response) (*dns.Msg, error) {
defer resp.Body.Close()
return toMsg(resp.Body)
}
// RequestToMsg converts a http.Request to a dns message.
func RequestToMsg(req *http.Request) (*dns.Msg, error) {
switch req.Method {
case http.MethodGet:
return requestToMsgGet(req)
case http.MethodPost:
return requestToMsgPost(req)
default:
return nil, fmt.Errorf("method not allowed: %s", req.Method)
}
}
// requestToMsgPost extracts the dns message from the request body.
func requestToMsgPost(req *http.Request) (*dns.Msg, error) {
defer req.Body.Close()
return toMsg(req.Body)
}
// requestToMsgGet extract the dns message from the GET request.
func requestToMsgGet(req *http.Request) (*dns.Msg, error) {
values := req.URL.Query()
b64, ok := values["dns"]
if !ok {
return nil, fmt.Errorf("no 'dns' query parameter found")
}
if len(b64) != 1 {
return nil, fmt.Errorf("multiple 'dns' query values found")
}
return base64ToMsg(b64[0])
}
func toMsg(r io.ReadCloser) (*dns.Msg, error) {
buf, err := io.ReadAll(http.MaxBytesReader(nil, r, 65536))
if err != nil {
return nil, err
}
m := new(dns.Msg)
err = m.Unpack(buf)
return m, err
}
func base64ToMsg(b64 string) (*dns.Msg, error) {
buf, err := b64Enc.DecodeString(b64)
if err != nil {
return nil, err
}
m := new(dns.Msg)
err = m.Unpack(buf)
return m, err
}
var b64Enc = base64.RawURLEncoding
// Package edns provides function useful for adding/inspecting OPT records to/in messages.
package edns
import (
"errors"
"sync"
"github.com/miekg/dns"
)
var sup = &supported{m: make(map[uint16]struct{})}
type supported struct {
m map[uint16]struct{}
sync.RWMutex
}
// SetSupportedOption adds a new supported option the set of EDNS0 options that we support. Plugins typically call
// this in their setup code to signal support for a new option.
// By default we support:
// dns.EDNS0NSID, dns.EDNS0EXPIRE, dns.EDNS0COOKIE, dns.EDNS0TCPKEEPALIVE, dns.EDNS0PADDING. These
// values are not in this map and checked directly in the server.
func SetSupportedOption(option uint16) {
sup.Lock()
sup.m[option] = struct{}{}
sup.Unlock()
}
// SupportedOption returns true if the option code is supported as an extra EDNS0 option.
func SupportedOption(option uint16) bool {
sup.RLock()
_, ok := sup.m[option]
sup.RUnlock()
return ok
}
// Version checks the EDNS version in the request. If error
// is nil everything is OK and we can invoke the plugin. If non-nil, the
// returned Msg is valid to be returned to the client (and should).
func Version(req *dns.Msg) (*dns.Msg, error) {
opt := req.IsEdns0()
if opt == nil {
return nil, nil
}
if opt.Version() == 0 {
return nil, nil
}
m := new(dns.Msg)
m.SetReply(req)
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetVersion(0)
m.Rcode = dns.RcodeBadVers
o.SetExtendedRcode(dns.RcodeBadVers)
m.Extra = []dns.RR{o}
return m, errors.New("EDNS0 BADVERS")
}
// Size returns a normalized size based on proto.
func Size(proto string, size uint16) uint16 {
if proto == "tcp" {
return dns.MaxMsgSize
}
if size < dns.MinMsgSize {
return dns.MinMsgSize
}
return size
}
// Package fall handles the fallthrough logic used in plugins that support it. Be careful when including this
// functionality in your plugin. Why? In the DNS only 1 source is authoritative for a set of names. Fallthrough
// breaks this convention by allowing a plugin to query multiple sources, depending on the replies it got sofar.
//
// This may cause issues in downstream caches, where different answers for the same query can potentially confuse clients.
// On the other hand this is a powerful feature that can aid in migration or other edge cases.
//
// The take away: be mindful of this and don't blindly assume it's a good feature to have in your plugin.
//
// See https://github.com/coredns/coredns/issues/2723 for some discussion on this, which includes this quote:
//
// TL;DR: `fallthrough` is indeed risky and hackish, but still a good feature of CoreDNS as it allows to quickly answer boring edge cases.
package fall
import (
"slices"
"github.com/coredns/coredns/plugin"
)
// F can be nil to allow for no fallthrough, empty allow all zones to fallthrough or
// contain a zone list that is checked.
type F struct {
Zones []string
}
// Through will check if we should fallthrough for qname. Note that we've named the
// variable in each plugin "Fall", so this then reads Fall.Through().
func (f F) Through(qname string) bool {
return plugin.Zones(f.Zones).Matches(qname) != ""
}
// setZones will set zones in f.
func (f *F) setZones(zones []string) {
z := []string{}
for i := range zones {
z = append(z, plugin.Host(zones[i]).NormalizeExact()...)
}
f.Zones = z
}
// SetZonesFromArgs sets zones in f to the passed value or to "." if the slice is empty.
func (f *F) SetZonesFromArgs(zones []string) {
if len(zones) == 0 {
f.setZones(Root.Zones)
return
}
f.setZones(zones)
}
// Equal returns true if f and g are equal.
func (f *F) Equal(g F) bool {
return slices.Equal(f.Zones, g.Zones)
}
// Zero returns a zero valued F.
var Zero = func() F {
return F{[]string{}}
}()
// Root returns F set to only ".".
var Root = func() F {
return F{[]string{"."}}
}()
// Package fuzz contains functions that enable fuzzing of plugins.
package fuzz
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
// Do will fuzz p - used by gofuzz. See Makefile.fuzz for comments and context.
func Do(p plugin.Handler, data []byte) int {
ctx := context.TODO()
r := new(dns.Msg)
if err := r.Unpack(data); err != nil {
return 0 // plugin will never be called when this happens.
}
// If the data unpack into a dns msg, but does not have a proper question section discard it.
// The server parts make sure this is true before calling the plugins; mimic this behavior.
if len(r.Question) == 0 {
return 0
}
if _, err := p.ServeDNS(ctx, &test.ResponseWriter{}, r); err != nil {
return 1
}
return 0
}
package log
import (
"sync"
)
// Listener listens for all log prints of plugin loggers aka loggers with plugin name.
// When a plugin logger gets called, it should first call the same method in the Listener object.
// A usage example is, the external plugin k8s_event will replicate log prints to Kubernetes events.
type Listener interface {
Name() string
Debug(plugin string, v ...interface{})
Debugf(plugin string, format string, v ...interface{})
Info(plugin string, v ...interface{})
Infof(plugin string, format string, v ...interface{})
Warning(plugin string, v ...interface{})
Warningf(plugin string, format string, v ...interface{})
Error(plugin string, v ...interface{})
Errorf(plugin string, format string, v ...interface{})
Fatal(plugin string, v ...interface{})
Fatalf(plugin string, format string, v ...interface{})
}
type listeners struct {
listeners []Listener
sync.RWMutex
}
var ls *listeners
func init() {
ls = &listeners{}
ls.listeners = make([]Listener, 0)
}
// RegisterListener register a listener object.
func RegisterListener(new Listener) error {
ls.Lock()
defer ls.Unlock()
for k, l := range ls.listeners {
if l.Name() == new.Name() {
ls.listeners[k] = new
return nil
}
}
ls.listeners = append(ls.listeners, new)
return nil
}
// DeregisterListener deregister a listener object.
func DeregisterListener(old Listener) error {
ls.Lock()
defer ls.Unlock()
for k, l := range ls.listeners {
if l.Name() == old.Name() {
ls.listeners = append(ls.listeners[:k], ls.listeners[k+1:]...)
return nil
}
}
return nil
}
func (ls *listeners) debug(plugin string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Debug(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) debugf(plugin string, format string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Debugf(plugin, format, v...)
}
ls.RUnlock()
}
func (ls *listeners) info(plugin string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Info(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) infof(plugin string, format string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Infof(plugin, format, v...)
}
ls.RUnlock()
}
func (ls *listeners) warning(plugin string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Warning(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) warningf(plugin string, format string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Warningf(plugin, format, v...)
}
ls.RUnlock()
}
func (ls *listeners) error(plugin string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Error(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) errorf(plugin string, format string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Errorf(plugin, format, v...)
}
ls.RUnlock()
}
func (ls *listeners) fatal(plugin string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Fatal(plugin, v...)
}
ls.RUnlock()
}
func (ls *listeners) fatalf(plugin string, format string, v ...interface{}) {
ls.RLock()
for _, l := range ls.listeners {
l.Fatalf(plugin, format, v...)
}
ls.RUnlock()
}
// Package log implements a small wrapper around the std lib log package. It
// implements log levels by prefixing the logs with [INFO], [DEBUG], [WARNING]
// or [ERROR]. Debug logging is available and enabled if the *debug* plugin is
// used.
//
// log.Info("this is some logging"), will log on the Info level.
//
// log.Debug("this is debug output"), will log in the Debug level, etc.
package log
import (
"fmt"
"io"
golog "log"
"os"
"sync/atomic"
)
// D controls whether we should output debug logs. If true, we do, once set
// it can not be unset.
var D = &d{}
type d struct {
on atomic.Bool
}
// Set enables debug logging.
func (d *d) Set() {
d.on.Store(true)
}
// Clear disables debug logging.
func (d *d) Clear() {
d.on.Store(false)
}
// Value returns if debug logging is enabled.
func (d *d) Value() bool {
return d.on.Load()
}
// logf calls log.Printf prefixed with level.
func logf(level, format string, v ...interface{}) {
golog.Print(level, fmt.Sprintf(format, v...))
}
// log calls log.Print prefixed with level.
func log(level string, v ...interface{}) {
golog.Print(level, fmt.Sprint(v...))
}
// Debug is equivalent to log.Print(), but prefixed with "[DEBUG] ". It only outputs something
// if D is true.
func Debug(v ...interface{}) {
if !D.Value() {
return
}
log(debug, v...)
}
// Debugf is equivalent to log.Printf(), but prefixed with "[DEBUG] ". It only outputs something
// if D is true.
func Debugf(format string, v ...interface{}) {
if !D.Value() {
return
}
logf(debug, format, v...)
}
// Info is equivalent to log.Print, but prefixed with "[INFO] ".
func Info(v ...interface{}) { log(info, v...) }
// Infof is equivalent to log.Printf, but prefixed with "[INFO] ".
func Infof(format string, v ...interface{}) { logf(info, format, v...) }
// Warning is equivalent to log.Print, but prefixed with "[WARNING] ".
func Warning(v ...interface{}) { log(warning, v...) }
// Warningf is equivalent to log.Printf, but prefixed with "[WARNING] ".
func Warningf(format string, v ...interface{}) { logf(warning, format, v...) }
// Error is equivalent to log.Print, but prefixed with "[ERROR] ".
func Error(v ...interface{}) { log(err, v...) }
// Errorf is equivalent to log.Printf, but prefixed with "[ERROR] ".
func Errorf(format string, v ...interface{}) { logf(err, format, v...) }
// Fatal is equivalent to log.Print, but prefixed with "[FATAL] ", and calling
// os.Exit(1).
func Fatal(v ...interface{}) { log(fatal, v...); os.Exit(1) }
// Fatalf is equivalent to log.Printf, but prefixed with "[FATAL] ", and calling
// os.Exit(1)
func Fatalf(format string, v ...interface{}) { logf(fatal, format, v...); os.Exit(1) }
// Discard sets the log output to /dev/null.
func Discard() { golog.SetOutput(io.Discard) }
const (
debug = "[DEBUG] "
err = "[ERROR] "
fatal = "[FATAL] "
info = "[INFO] "
warning = "[WARNING] "
)
package log
import (
"fmt"
"os"
)
// P is a logger that includes the plugin doing the logging.
type P struct {
plugin string
}
// NewWithPlugin returns a logger that includes "plugin/name: " in the log message.
// I.e [INFO] plugin/<name>: message.
func NewWithPlugin(name string) P { return P{"plugin/" + name + ": "} }
func (p P) logf(level, format string, v ...interface{}) {
log(level, p.plugin, fmt.Sprintf(format, v...))
}
func (p P) log(level string, v ...interface{}) {
log(level+p.plugin, v...)
}
// Debug logs as log.Debug.
func (p P) Debug(v ...interface{}) {
if !D.Value() {
return
}
ls.debug(p.plugin, v...)
p.log(debug, v...)
}
// Debugf logs as log.Debugf.
func (p P) Debugf(format string, v ...interface{}) {
if !D.Value() {
return
}
ls.debugf(p.plugin, format, v...)
p.logf(debug, format, v...)
}
// Info logs as log.Info.
func (p P) Info(v ...interface{}) {
ls.info(p.plugin, v...)
p.log(info, v...)
}
// Infof logs as log.Infof.
func (p P) Infof(format string, v ...interface{}) {
ls.infof(p.plugin, format, v...)
p.logf(info, format, v...)
}
// Warning logs as log.Warning.
func (p P) Warning(v ...interface{}) {
ls.warning(p.plugin, v...)
p.log(warning, v...)
}
// Warningf logs as log.Warningf.
func (p P) Warningf(format string, v ...interface{}) {
ls.warningf(p.plugin, format, v...)
p.logf(warning, format, v...)
}
// Error logs as log.Error.
func (p P) Error(v ...interface{}) {
ls.error(p.plugin, v...)
p.log(err, v...)
}
// Errorf logs as log.Errorf.
func (p P) Errorf(format string, v ...interface{}) {
ls.errorf(p.plugin, format, v...)
p.logf(err, format, v...)
}
// Fatal logs as log.Fatal and calls os.Exit(1).
func (p P) Fatal(v ...interface{}) {
ls.fatal(p.plugin, v...)
p.log(fatal, v...)
os.Exit(1)
}
// Fatalf logs as log.Fatalf and calls os.Exit(1).
func (p P) Fatalf(format string, v ...interface{}) {
ls.fatalf(p.plugin, format, v...)
p.logf(fatal, format, v...)
os.Exit(1)
}
// Package nonwriter implements a dns.ResponseWriter that never writes, but captures the dns.Msg being written.
package nonwriter
import (
"github.com/miekg/dns"
)
// Writer is a type of ResponseWriter that captures the message, but never writes to the client.
type Writer struct {
dns.ResponseWriter
Msg *dns.Msg
}
// New makes and returns a new NonWriter.
func New(w dns.ResponseWriter) *Writer { return &Writer{ResponseWriter: w} }
// WriteMsg records the message, but doesn't write it itself.
func (w *Writer) WriteMsg(res *dns.Msg) error {
w.Msg = res
return nil
}
package parse
import (
"errors"
"fmt"
"net"
"os"
"strings"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
// ErrNoNameservers is returned by HostPortOrFile if no servers can be parsed.
var ErrNoNameservers = errors.New("no nameservers found")
// Strips the zone, but preserves any port that comes after the zone
func stripZone(host string) string {
if strings.Contains(host, "%") {
lastPercent := strings.LastIndex(host, "%")
newHost := host[:lastPercent]
return newHost
}
return host
}
// HostPortOrFile parses the strings in s, each string can either be a
// address, [scheme://]address:port or a filename. The address part is checked
// and in case of filename a resolv.conf like file is (assumed) and parsed and
// the nameservers found are returned.
func HostPortOrFile(s ...string) ([]string, error) {
var servers []string
for _, h := range s {
trans, host := Transport(h)
if len(host) == 0 {
return servers, fmt.Errorf("invalid address: %q", h)
}
if trans == transport.UNIX {
servers = append(servers, trans+"://"+host)
continue
}
addr, _, err := net.SplitHostPort(host)
if err != nil {
// Parse didn't work, it is not a addr:port combo
hostNoZone := stripZone(host)
if net.ParseIP(hostNoZone) == nil {
ss, err := tryFile(host)
if err == nil {
servers = append(servers, ss...)
continue
}
return servers, fmt.Errorf("not an IP address or file: %q", host)
}
var ss string
switch trans {
case transport.DNS:
ss = net.JoinHostPort(host, transport.Port)
case transport.TLS:
ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort)
case transport.QUIC:
ss = transport.QUIC + "://" + net.JoinHostPort(host, transport.QUICPort)
case transport.GRPC:
ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort)
case transport.HTTPS:
ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort)
}
servers = append(servers, ss)
continue
}
if net.ParseIP(stripZone(addr)) == nil {
ss, err := tryFile(host)
if err == nil {
servers = append(servers, ss...)
continue
}
return servers, fmt.Errorf("not an IP address or file: %q", host)
}
servers = append(servers, h)
}
if len(servers) == 0 {
return servers, ErrNoNameservers
}
return servers, nil
}
// Try to open this is a file first.
func tryFile(s string) ([]string, error) {
c, err := dns.ClientConfigFromFile(s)
if err == os.ErrNotExist {
return nil, fmt.Errorf("failed to open file %q: %q", s, err)
} else if err != nil {
return nil, err
}
servers := []string{}
for _, s := range c.Servers {
servers = append(servers, net.JoinHostPort(stripZone(s), c.Port))
}
return servers, nil
}
// HostPort will check if the host part is a valid IP address, if the
// IP address is valid, but no port is found, defaultPort is added.
func HostPort(s, defaultPort string) (string, error) {
addr, port, err := net.SplitHostPort(s)
if port == "" {
port = defaultPort
}
if err != nil {
if net.ParseIP(s) == nil {
return "", fmt.Errorf("must specify an IP address: `%s'", s)
}
return net.JoinHostPort(s, port), nil
}
if net.ParseIP(addr) == nil {
return "", fmt.Errorf("must specify an IP address: `%s'", addr)
}
return net.JoinHostPort(addr, port), nil
}
// Package parse contains functions that can be used in the setup code for plugins.
package parse
import (
"fmt"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/pkg/transport"
)
// TransferIn parses transfer statements: 'transfer from [address...]'.
func TransferIn(c *caddy.Controller) (froms []string, err error) {
if !c.NextArg() {
return nil, c.ArgErr()
}
value := c.Val()
switch value {
default:
return nil, c.Errf("unknown property %s", value)
case "from":
froms = c.RemainingArgs()
if len(froms) == 0 {
return nil, c.ArgErr()
}
for i := range froms {
if froms[i] == "*" {
return nil, fmt.Errorf("can't use '*' in transfer from")
}
normalized, err := HostPort(froms[i], transport.Port)
if err != nil {
return nil, err
}
froms[i] = normalized
}
}
return froms, nil
}
package parse
import (
"strings"
"github.com/coredns/coredns/plugin/pkg/transport"
)
// Transport returns the transport defined in s and a string where the
// transport prefix is removed (if there was any). If no transport is defined
// we default to TransportDNS
func Transport(s string) (trans string, addr string) {
switch {
case strings.HasPrefix(s, transport.TLS+"://"):
s = s[len(transport.TLS+"://"):]
return transport.TLS, s
case strings.HasPrefix(s, transport.DNS+"://"):
s = s[len(transport.DNS+"://"):]
return transport.DNS, s
case strings.HasPrefix(s, transport.QUIC+"://"):
s = s[len(transport.QUIC+"://"):]
return transport.QUIC, s
case strings.HasPrefix(s, transport.GRPC+"://"):
s = s[len(transport.GRPC+"://"):]
return transport.GRPC, s
case strings.HasPrefix(s, transport.HTTPS+"://"):
s = s[len(transport.HTTPS+"://"):]
return transport.HTTPS, s
case strings.HasPrefix(s, transport.UNIX+"://"):
s = s[len(transport.UNIX+"://"):]
return transport.UNIX, s
}
return transport.DNS, s
}
// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses
// inband healthchecking.
package proxy
import (
"context"
"errors"
"io"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
const (
ErrTransportStopped = "proxy: transport stopped"
ErrTransportStoppedDuringDial = "proxy: transport stopped during dial"
ErrTransportStoppedRetClosed = "proxy: transport stopped, ret channel closed"
ErrTransportStoppedDuringRetWait = "proxy: transport stopped during ret wait"
)
// limitTimeout is a utility function to auto-tune timeout values
// average observed time is moved towards the last observed delay moderated by a weight
// next timeout to use will be the double of the computed average, limited by min and max frame.
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
rt := time.Duration(atomic.LoadInt64(currentAvg))
if rt < minValue {
return minValue
}
if rt < maxValue/2 {
return 2 * rt
}
return maxValue
}
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
dt := time.Duration(atomic.LoadInt64(currentAvg))
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
}
func (t *Transport) dialTimeout() time.Duration {
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
}
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
}
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
// If tls has been configured; use it.
if t.tlsConfig != nil {
proto = "tcp-tls"
}
// Check if transport is stopped before attempting to dial
select {
case <-t.stop:
return nil, false, errors.New(ErrTransportStopped)
default:
}
// Use select to avoid blocking if connManager has stopped
select {
case t.dial <- proto:
// Successfully sent dial request
case <-t.stop:
return nil, false, errors.New(ErrTransportStoppedDuringDial)
}
// Receive response with stop awareness
select {
case pc, ok := <-t.ret:
if !ok {
// ret channel was closed by connManager during stop
return nil, false, errors.New(ErrTransportStoppedRetClosed)
}
if pc != nil {
connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1)
return pc, true, nil
}
connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1)
reqTime := time.Now()
timeout := t.dialTimeout()
if proto == "tcp-tls" {
conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
t.updateDialTimeout(time.Since(reqTime))
return &persistConn{c: conn}, false, err
}
conn, err := dns.DialTimeout(proto, t.addr, timeout)
t.updateDialTimeout(time.Since(reqTime))
return &persistConn{c: conn}, false, err
case <-t.stop:
return nil, false, errors.New(ErrTransportStoppedDuringRetWait)
}
}
// Connect selects an upstream, sends the request and waits for a response.
func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
start := time.Now()
var proto string
switch {
case opts.ForceTCP: // TCP flag has precedence over UDP flag
proto = "tcp"
case opts.PreferUDP:
proto = "udp"
default:
proto = state.Proto()
}
pc, cached, err := p.transport.Dial(proto)
if err != nil {
return nil, err
}
// Set buffer size correctly for this client.
pc.c.UDPSize = uint16(state.Size())
if pc.c.UDPSize < 512 {
pc.c.UDPSize = 512
}
pc.c.SetWriteDeadline(time.Now().Add(maxTimeout))
// records the origin Id before upstream.
originId := state.Req.Id
state.Req.Id = dns.Id()
defer func() {
state.Req.Id = originId
}()
if err := pc.c.WriteMsg(state.Req); err != nil {
pc.c.Close() // not giving it back
if err == io.EOF && cached {
return nil, ErrCachedClosed
}
return nil, err
}
var ret *dns.Msg
pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
for {
ret, err = pc.c.ReadMsg()
if err != nil {
if ret != nil && (state.Req.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) {
// For UDP, if the error is an overflow, we probably have an upstream misbehaving in some way.
// (e.g. sending >512 byte responses without an eDNS0 OPT RR).
// Instead of returning an error, return an empty response with TC bit set. This will make the
// client retry over TCP (if that's supported) or at least receive a clean
// error. The connection is still good so we break before the close.
// Truncate the response.
ret = truncateResponse(ret)
break
}
pc.c.Close() // not giving it back
if err == io.EOF && cached {
return nil, ErrCachedClosed
}
// recovery the origin Id after upstream.
if ret != nil {
ret.Id = originId
}
return ret, err
}
// drop out-of-order responses
if state.Req.Id == ret.Id {
break
}
}
// recovery the origin Id after upstream.
ret.Id = originId
p.transport.Yield(pc)
rc, ok := dns.RcodeToString[ret.Rcode]
if !ok {
rc = strconv.Itoa(ret.Rcode)
}
requestDuration.WithLabelValues(p.proxyName, p.addr, rc).Observe(time.Since(start).Seconds())
return ret, nil
}
const cumulativeAvgWeight = 4
// Function to determine if a response should be truncated.
func shouldTruncateResponse(err error) bool {
// This is to handle a scenario in which upstream sets the TC bit, but doesn't truncate the response
// and we get ErrBuf instead of overflow.
if _, isDNSErr := err.(*dns.Error); isDNSErr && errors.Is(err, dns.ErrBuf) {
return true
} else if strings.Contains(err.Error(), "overflow") {
return true
}
return false
}
// Function to return an empty response with TC (truncated) bit set.
func truncateResponse(response *dns.Msg) *dns.Msg {
// Clear out Answer, Extra, and Ns sections
response.Answer = nil
response.Extra = nil
response.Ns = nil
// Set TC bit to indicate truncation.
response.Truncated = true
return response
}
package proxy
import (
"crypto/tls"
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
// HealthChecker checks the upstream health.
type HealthChecker interface {
Check(*Proxy) error
SetTLSConfig(*tls.Config)
GetTLSConfig() *tls.Config
SetRecursionDesired(bool)
GetRecursionDesired() bool
SetDomain(domain string)
GetDomain() string
SetTCPTransport()
GetReadTimeout() time.Duration
SetReadTimeout(time.Duration)
GetWriteTimeout() time.Duration
SetWriteTimeout(time.Duration)
}
// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
type dnsHc struct {
c *dns.Client
recursionDesired bool
domain string
proxyName string
}
// NewHealthChecker returns a new HealthChecker based on transport.
func NewHealthChecker(proxyName, trans string, recursionDesired bool, domain string) HealthChecker {
switch trans {
case transport.DNS, transport.TLS:
c := new(dns.Client)
c.Net = "udp"
c.ReadTimeout = 1 * time.Second
c.WriteTimeout = 1 * time.Second
return &dnsHc{
c: c,
recursionDesired: recursionDesired,
domain: domain,
proxyName: proxyName,
}
}
log.Warningf("No healthchecker for transport %q", trans)
return nil
}
func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
h.c.Net = "tcp-tls"
h.c.TLSConfig = cfg
}
func (h *dnsHc) GetTLSConfig() *tls.Config {
return h.c.TLSConfig
}
func (h *dnsHc) SetRecursionDesired(recursionDesired bool) {
h.recursionDesired = recursionDesired
}
func (h *dnsHc) GetRecursionDesired() bool {
return h.recursionDesired
}
func (h *dnsHc) SetDomain(domain string) {
h.domain = domain
}
func (h *dnsHc) GetDomain() string {
return h.domain
}
func (h *dnsHc) SetTCPTransport() {
h.c.Net = "tcp"
}
func (h *dnsHc) GetReadTimeout() time.Duration {
return h.c.ReadTimeout
}
func (h *dnsHc) SetReadTimeout(t time.Duration) {
h.c.ReadTimeout = t
}
func (h *dnsHc) GetWriteTimeout() time.Duration {
return h.c.WriteTimeout
}
func (h *dnsHc) SetWriteTimeout(t time.Duration) {
h.c.WriteTimeout = t
}
// For HC, we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty
// replies are considered fails, basically anything else constitutes a healthy upstream.
// Check is used as the up.Func in the up.Probe.
func (h *dnsHc) Check(p *Proxy) error {
err := h.send(p.addr)
if err != nil {
healthcheckFailureCount.WithLabelValues(p.proxyName, p.addr).Add(1)
p.incrementFails()
return err
}
atomic.StoreUint32(&p.fails, 0)
return nil
}
func (h *dnsHc) send(addr string) error {
ping := new(dns.Msg)
ping.SetQuestion(h.domain, dns.TypeNS)
ping.RecursionDesired = h.recursionDesired
m, _, err := h.c.Exchange(ping, addr)
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
if err != nil && m != nil {
// Silly check, something sane came back.
if m.Response || m.Opcode == dns.OpcodeQuery {
err = nil
}
}
return err
}
package proxy
import (
"crypto/tls"
"sort"
"time"
"github.com/miekg/dns"
)
// a persistConn hold the dns.Conn and the last used time.
type persistConn struct {
c *dns.Conn
used time.Time
}
// Transport hold the persistent cache.
type Transport struct {
avgDialTime int64 // kind of average time of dial time
conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls.
expire time.Duration // After this duration a connection is expired.
addr string
tlsConfig *tls.Config
proxyName string
dial chan string
yield chan *persistConn
ret chan *persistConn
stop chan bool
}
func newTransport(proxyName, addr string) *Transport {
t := &Transport{
avgDialTime: int64(maxDialTimeout / 2),
conns: [typeTotalCount][]*persistConn{},
expire: defaultExpire,
addr: addr,
dial: make(chan string),
yield: make(chan *persistConn),
ret: make(chan *persistConn),
stop: make(chan bool),
proxyName: proxyName,
}
return t
}
// connManager manages the persistent connection cache for UDP and TCP.
func (t *Transport) connManager() {
ticker := time.NewTicker(defaultExpire)
defer ticker.Stop()
Wait:
for {
select {
case proto := <-t.dial:
transtype := stringToTransportType(proto)
// take the last used conn - complexity O(1)
if stack := t.conns[transtype]; len(stack) > 0 {
pc := stack[len(stack)-1]
if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn.
t.conns[transtype] = stack[:len(stack)-1]
t.ret <- pc
continue Wait
}
// clear entire cache if the last conn is expired
t.conns[transtype] = nil
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack)
}
t.ret <- nil
case pc := <-t.yield:
transtype := t.transportTypeFromConn(pc)
t.conns[transtype] = append(t.conns[transtype], pc)
case <-ticker.C:
t.cleanup(false)
case <-t.stop:
t.cleanup(true)
close(t.ret)
return
}
}
}
// closeConns closes connections.
func closeConns(conns []*persistConn) {
for _, pc := range conns {
pc.c.Close()
}
}
// cleanup removes connections from cache.
func (t *Transport) cleanup(all bool) {
staleTime := time.Now().Add(-t.expire)
for transtype, stack := range t.conns {
if len(stack) == 0 {
continue
}
if all {
t.conns[transtype] = nil
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack)
continue
}
if stack[0].used.After(staleTime) {
continue
}
// connections in stack are sorted by "used"
good := sort.Search(len(stack), func(i int) bool {
return stack[i].used.After(staleTime)
})
t.conns[transtype] = stack[good:]
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack[:good])
}
}
// It is hard to pin a value to this, the import thing is to no block forever, losing at cached connection is not terrible.
const yieldTimeout = 25 * time.Millisecond
// Yield returns the connection to transport for reuse.
func (t *Transport) Yield(pc *persistConn) {
pc.used = time.Now() // update used time
// Make this non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This
// blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning
// these connection is an optimization anyway.
select {
case t.yield <- pc:
return
case <-time.After(yieldTimeout):
return
}
}
// Start starts the transport's connection manager.
func (t *Transport) Start() { go t.connManager() }
// Stop stops the transport's connection manager.
func (t *Transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport.
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport.
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
const (
defaultExpire = 10 * time.Second
minDialTimeout = 1 * time.Second
maxDialTimeout = 30 * time.Second
)
package proxy
import (
"crypto/tls"
"runtime"
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/up"
)
// Proxy defines an upstream host.
type Proxy struct {
fails uint32
addr string
proxyName string
transport *Transport
readTimeout time.Duration
// health checking
probe *up.Probe
health HealthChecker
}
// NewProxy returns a new proxy.
func NewProxy(proxyName, addr, trans string) *Proxy {
p := &Proxy{
addr: addr,
fails: 0,
probe: up.New(),
readTimeout: 2 * time.Second,
transport: newTransport(proxyName, addr),
health: NewHealthChecker(proxyName, trans, true, "."),
proxyName: proxyName,
}
runtime.SetFinalizer(p, (*Proxy).finalizer)
return p
}
func (p *Proxy) Addr() string { return p.addr }
// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client.
func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
p.transport.SetTLSConfig(cfg)
p.health.SetTLSConfig(cfg)
}
// SetExpire sets the expire duration in the lower p.transport.
func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
func (p *Proxy) GetHealthchecker() HealthChecker {
return p.health
}
func (p *Proxy) Fails() uint32 {
return atomic.LoadUint32(&p.fails)
}
// Healthcheck kicks of a round of health checks for this proxy.
func (p *Proxy) Healthcheck() {
if p.health == nil {
log.Warning("No healthchecker")
return
}
p.probe.Do(func() error {
return p.health.Check(p)
})
}
// Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
func (p *Proxy) Down(maxfails uint32) bool {
if maxfails == 0 {
return false
}
fails := atomic.LoadUint32(&p.fails)
return fails > maxfails
}
// Stop close stops the health checking goroutine.
func (p *Proxy) Stop() { p.probe.Stop() }
func (p *Proxy) finalizer() { p.transport.Stop() }
// Start starts the proxy's healthchecking.
func (p *Proxy) Start(duration time.Duration) {
p.probe.Start(duration)
p.transport.Start()
}
func (p *Proxy) SetReadTimeout(duration time.Duration) {
p.readTimeout = duration
}
// incrementFails increments the number of fails safely.
func (p *Proxy) incrementFails() {
curVal := atomic.LoadUint32(&p.fails)
if curVal > curVal+1 {
// overflow occurred, do not update the counter again
return
}
atomic.AddUint32(&p.fails, 1)
}
const (
maxTimeout = 2 * time.Second
)
package proxy
import (
"net"
)
type transportType int
const (
typeUDP transportType = iota
typeTCP
typeTLS
typeTotalCount // keep this last
)
func stringToTransportType(s string) transportType {
switch s {
case "udp":
return typeUDP
case "tcp":
return typeTCP
case "tcp-tls":
return typeTLS
}
return typeUDP
}
func (t *Transport) transportTypeFromConn(pc *persistConn) transportType {
if _, ok := pc.c.Conn.(*net.UDPConn); ok {
return typeUDP
}
if t.tlsConfig == nil {
return typeTCP
}
return typeTLS
}
// Package rand is used for concurrency safe random number generator.
// This package provides a thread-safe wrapper around math/rand for use in
// load balancing and server selection. It is NOT suitable for cryptographic
// purposes and should not be used for security-sensitive operations.
package rand
import (
"math/rand"
"sync"
)
// Rand is used for concurrency safe random number generator.
type Rand struct {
m sync.Mutex
r *rand.Rand
}
// New returns a new Rand from seed.
func New(seed int64) *Rand {
return &Rand{r: rand.New(rand.NewSource(seed))}
}
// Int returns a non-negative pseudo-random int from the Source in Rand.r.
func (r *Rand) Int() int {
r.m.Lock()
v := r.r.Int()
r.m.Unlock()
return v
}
// Perm returns, as a slice of n ints, a pseudo-random permutation of the
// integers in the half-open interval [0,n) from the Source in Rand.r.
func (r *Rand) Perm(n int) []int {
r.m.Lock()
v := r.r.Perm(n)
r.m.Unlock()
return v
}
package rcode
import (
"strconv"
"github.com/miekg/dns"
)
// ToString convert the rcode to the official DNS string, or to "RCODE"+value if the RCODE value is unknown.
func ToString(rcode int) string {
if str, ok := dns.RcodeToString[rcode]; ok {
return str
}
return "RCODE" + strconv.Itoa(rcode)
}
package replacer
import (
"context"
"strconv"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Replacer replaces labels for values in strings.
type Replacer struct{}
// New makes a new replacer. This only needs to be called once in the setup and
// then call Replace for each incoming message. A replacer is safe for concurrent use.
func New() Replacer {
return Replacer{}
}
// Replace performs a replacement of values on s and returns the string with the replaced values.
func (r Replacer) Replace(ctx context.Context, state request.Request, rr *dnstest.Recorder, s string) string {
return loadFormat(s).Replace(ctx, state, rr)
}
const (
headerReplacer = "{>"
// EmptyValue is the default empty value.
EmptyValue = "-"
)
// labels are all supported labels that can be used in the default Replacer.
var labels = map[string]struct{}{
"{type}": {},
"{name}": {},
"{class}": {},
"{proto}": {},
"{size}": {},
"{remote}": {},
"{port}": {},
"{local}": {},
// Header values.
headerReplacer + "id}": {},
headerReplacer + "opcode}": {},
headerReplacer + "do}": {},
headerReplacer + "bufsize}": {},
// Recorded replacements.
"{rcode}": {},
"{rsize}": {},
"{duration}": {},
headerReplacer + "rflags}": {},
}
// appendValue appends the current value of label.
func appendValue(b []byte, state request.Request, rr *dnstest.Recorder, label string) []byte {
switch label {
// Recorded replacements.
case "{rcode}":
if rr == nil || rr.Msg == nil {
return append(b, EmptyValue...)
}
if rcode := dns.RcodeToString[rr.Rcode]; rcode != "" {
return append(b, rcode...)
}
return strconv.AppendInt(b, int64(rr.Rcode), 10)
case "{rsize}":
if rr == nil {
return append(b, EmptyValue...)
}
return strconv.AppendInt(b, int64(rr.Len), 10)
case "{duration}":
if rr == nil {
return append(b, EmptyValue...)
}
secs := time.Since(rr.Start).Seconds()
return append(strconv.AppendFloat(b, secs, 'f', -1, 64), 's')
case headerReplacer + "rflags}":
if rr != nil && rr.Msg != nil {
return appendFlags(b, rr.Msg.MsgHdr)
}
return append(b, EmptyValue...)
}
if (request.Request{}) == state {
return append(b, EmptyValue...)
}
switch label {
case "{type}":
return append(b, state.Type()...)
case "{name}":
return append(b, state.Name()...)
case "{class}":
return append(b, state.Class()...)
case "{proto}":
return append(b, state.Proto()...)
case "{size}":
return strconv.AppendInt(b, int64(state.Req.Len()), 10)
case "{remote}":
return appendAddrToRFC3986(b, state.IP())
case "{port}":
return append(b, state.Port()...)
case "{local}":
return appendAddrToRFC3986(b, state.LocalIP())
// Header placeholders (case-insensitive).
case headerReplacer + "id}":
return strconv.AppendInt(b, int64(state.Req.Id), 10)
case headerReplacer + "opcode}":
return strconv.AppendInt(b, int64(state.Req.Opcode), 10)
case headerReplacer + "do}":
return strconv.AppendBool(b, state.Do())
case headerReplacer + "bufsize}":
return strconv.AppendInt(b, int64(state.Size()), 10)
default:
return append(b, EmptyValue...)
}
}
// appendFlags checks all header flags and appends those
// that are set as a string separated with commas
func appendFlags(b []byte, h dns.MsgHdr) []byte {
origLen := len(b)
if h.Response {
b = append(b, "qr,"...)
}
if h.Authoritative {
b = append(b, "aa,"...)
}
if h.Truncated {
b = append(b, "tc,"...)
}
if h.RecursionDesired {
b = append(b, "rd,"...)
}
if h.RecursionAvailable {
b = append(b, "ra,"...)
}
if h.Zero {
b = append(b, "z,"...)
}
if h.AuthenticatedData {
b = append(b, "ad,"...)
}
if h.CheckingDisabled {
b = append(b, "cd,"...)
}
if n := len(b); n > origLen {
return b[:n-1] // trim trailing ','
}
return b
}
// appendAddrToRFC3986 will add brackets to the address if it is an IPv6 address.
func appendAddrToRFC3986(b []byte, addr string) []byte {
if strings.IndexByte(addr, ':') != -1 {
b = append(b, '[')
b = append(b, addr...)
b = append(b, ']')
} else {
b = append(b, addr...)
}
return b
}
type nodeType int
const (
typeLabel nodeType = iota // "{type}"
typeLiteral // "foo"
typeMetadata // "{/metadata}"
)
// A node represents a segment of a parsed format. For example: "A {type}"
// contains two nodes: "A " (literal); and "{type}" (label).
type node struct {
value string // Literal value, label or metadata label
typ nodeType
}
// A replacer is an ordered list of all the nodes in a format.
type replacer []node
func parseFormat(s string) replacer {
// Assume there is a literal between each label - its cheaper to over
// allocate once than allocate twice.
rep := make(replacer, 0, strings.Count(s, "{")*2)
for {
// We find the right bracket then backtrack to find the left bracket.
// This allows us to handle formats like: "{ {foo} }".
j := strings.IndexByte(s, '}')
if j < 0 {
break
}
i := strings.LastIndexByte(s[:j], '{')
if i < 0 {
// Handle: "A } {foo}" by treating "A }" as a literal
rep = append(rep, node{
value: s[:j+1],
typ: typeLiteral,
})
s = s[j+1:]
continue
}
val := s[i : j+1]
var typ nodeType
switch _, ok := labels[val]; {
case ok:
typ = typeLabel
case strings.HasPrefix(val, "{/"):
// Strip "{/}" from metadata labels
val = val[2 : len(val)-1]
typ = typeMetadata
default:
// Given: "A {X}" val is "{X}" expand it to the whole literal.
val = s[:j+1]
typ = typeLiteral
}
// Append any leading literal. Given "A {type}" the literal is "A "
if i != 0 && typ != typeLiteral {
rep = append(rep, node{
value: s[:i],
typ: typeLiteral,
})
}
rep = append(rep, node{
value: val,
typ: typ,
})
s = s[j+1:]
}
if len(s) != 0 {
rep = append(rep, node{
value: s,
typ: typeLiteral,
})
}
return rep
}
var replacerCache sync.Map // map[string]replacer
func loadFormat(s string) replacer {
if v, ok := replacerCache.Load(s); ok {
return v.(replacer)
}
v, _ := replacerCache.LoadOrStore(s, parseFormat(s))
return v.(replacer)
}
// bufPool stores pointers to scratch buffers.
var bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 0, 256)
},
}
func (r replacer) Replace(ctx context.Context, state request.Request, rr *dnstest.Recorder) string {
b := bufPool.Get().([]byte)
for _, s := range r {
switch s.typ {
case typeLabel:
b = appendValue(b, state, rr, s.value)
case typeLiteral:
b = append(b, s.value...)
case typeMetadata:
if fm := metadata.ValueFunc(ctx, s.value); fm != nil {
b = append(b, fm()...)
} else {
b = append(b, EmptyValue...)
}
}
}
s := string(b)
//nolint:staticcheck
bufPool.Put(b[:0])
return s
}
package response
import "fmt"
// Class holds sets of Types
type Class int
const (
// All is a meta class encompassing all the classes.
All Class = iota
// Success is a class for a successful response.
Success
// Denial is a class for denying existence (NXDOMAIN, or a nodata: type does not exist)
Denial
// Error is a class for errors, right now defined as not Success and not Denial
Error
)
func (c Class) String() string {
switch c {
case All:
return "all"
case Success:
return "success"
case Denial:
return "denial"
case Error:
return "error"
}
return ""
}
// ClassFromString returns the class from the string s. If not class matches
// the All class and an error are returned
func ClassFromString(s string) (Class, error) {
switch s {
case "all":
return All, nil
case "success":
return Success, nil
case "denial":
return Denial, nil
case "error":
return Error, nil
}
return All, fmt.Errorf("invalid Class: %s", s)
}
// Classify classifies the Type t, it returns its Class.
func Classify(t Type) Class {
switch t {
case NoError, Delegation:
return Success
case NameError, NoData:
return Denial
case OtherError:
fallthrough
default:
return Error
}
}
package response
import (
"fmt"
"time"
"github.com/miekg/dns"
)
// Type is the type of the message.
type Type int
const (
// NoError indicates a positive reply
NoError Type = iota
// NameError is a NXDOMAIN in header, SOA in auth.
NameError
// ServerError is a set of errors we want to cache, for now it contains SERVFAIL and NOTIMPL.
ServerError
// NoData indicates name found, but not the type: NOERROR in header, SOA in auth.
NoData
// Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked).
Delegation
// Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR.
Meta
// Update is an dynamic update message.
Update
// OtherError indicates any other error: don't cache these.
OtherError
)
var toString = map[Type]string{
NoError: "NOERROR",
NameError: "NXDOMAIN",
ServerError: "SERVERERROR",
NoData: "NODATA",
Delegation: "DELEGATION",
Meta: "META",
Update: "UPDATE",
OtherError: "OTHERERROR",
}
func (t Type) String() string { return toString[t] }
// TypeFromString returns the type from the string s. If not type matches
// the OtherError type and an error are returned.
func TypeFromString(s string) (Type, error) {
for t, str := range toString {
if s == str {
return t, nil
}
}
return NoError, fmt.Errorf("invalid Type: %s", s)
}
// Typify classifies a message, it returns the Type.
func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) {
if m == nil {
return OtherError, nil
}
opt := m.IsEdns0()
do := false
if opt != nil {
do = opt.Do()
}
if m.Opcode == dns.OpcodeUpdate {
return Update, opt
}
// Check transfer and update first
if m.Opcode == dns.OpcodeNotify {
return Meta, opt
}
if len(m.Question) > 0 {
if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR {
return Meta, opt
}
}
// If our message contains any expired sigs and we care about that, we should return expired
if do {
if expired := typifyExpired(m, t); expired {
return OtherError, opt
}
}
if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {
return NoError, opt
}
soa := false
ns := 0
for _, r := range m.Ns {
if r.Header().Rrtype == dns.TypeSOA {
soa = true
continue
}
if r.Header().Rrtype == dns.TypeNS {
ns++
}
}
if soa && m.Rcode == dns.RcodeSuccess {
return NoData, opt
}
if soa && m.Rcode == dns.RcodeNameError {
return NameError, opt
}
if m.Rcode == dns.RcodeServerFailure || m.Rcode == dns.RcodeNotImplemented {
return ServerError, opt
}
if ns > 0 && m.Rcode == dns.RcodeSuccess {
return Delegation, opt
}
if m.Rcode == dns.RcodeSuccess {
return NoError, opt
}
return OtherError, opt
}
func typifyExpired(m *dns.Msg, t time.Time) bool {
if expired := typifyExpiredRRSIG(m.Answer, t); expired {
return true
}
if expired := typifyExpiredRRSIG(m.Ns, t); expired {
return true
}
if expired := typifyExpiredRRSIG(m.Extra, t); expired {
return true
}
return false
}
func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool {
for _, r := range rrs {
if r.Header().Rrtype != dns.TypeRRSIG {
continue
}
ok := r.(*dns.RRSIG).ValidityPeriod(t)
if !ok {
return true
}
}
return false
}
//go:build go1.11 && (aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd)
package reuseport
import (
"context"
"net"
"syscall"
"github.com/coredns/coredns/plugin/pkg/log"
"golang.org/x/sys/unix"
)
func control(network, address string, c syscall.RawConn) error {
c.Control(func(fd uintptr) {
if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
log.Warningf("Failed to set SO_REUSEPORT on socket: %s", err)
}
})
return nil
}
// Listen announces on the local network address. See net.Listen for more information.
// If SO_REUSEPORT is available it will be set on the socket.
func Listen(network, addr string) (net.Listener, error) {
lc := net.ListenConfig{Control: control}
return lc.Listen(context.Background(), network, addr)
}
// ListenPacket announces on the local network address. See net.ListenPacket for more information.
// If SO_REUSEPORT is available it will be set on the socket.
func ListenPacket(network, addr string) (net.PacketConn, error) {
lc := net.ListenConfig{Control: control}
return lc.ListenPacket(context.Background(), network, addr)
}
package tls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"time"
)
func setTLSDefaults(ctls *tls.Config) {
ctls.MinVersion = tls.VersionTLS12
ctls.MaxVersion = tls.VersionTLS13
ctls.CipherSuites = []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}
}
// NewTLSConfigFromArgs returns a TLS config based upon the passed
// in list of arguments. Typically these come straight from the
// Corefile.
// no args
// - creates a Config with no cert and using system CAs
// - use for a client that talks to a server with a public signed cert (CA installed in system)
// - the client will not be authenticated by the server since there is no cert
//
// one arg: the path to CA PEM file
// - creates a Config with no cert using a specific CA
// - use for a client that talks to a server with a private signed cert (CA not installed in system)
// - the client will not be authenticated by the server since there is no cert
//
// two args: path to cert PEM file, the path to private key PEM file
// - creates a Config with a cert, using system CAs to validate the other end
// - use for:
// - a server; or,
// - a client that talks to a server with a public cert and needs certificate-based authentication
// - the other end will authenticate this end via the provided cert
// - the cert of the other end will be verified via system CAs
//
// three args: path to cert PEM file, path to client private key PEM file, path to CA PEM file
// - creates a Config with the cert, using specified CA to validate the other end
// - use for:
// - a server; or,
// - a client that talks to a server with a privately signed cert and needs certificate-based
// authentication
// - the other end will authenticate this end via the provided cert
// - this end will verify the other end's cert using the specified CA
func NewTLSConfigFromArgs(args ...string) (*tls.Config, error) {
var err error
var c *tls.Config
switch len(args) {
case 0:
// No client cert, use system CA
c, err = NewTLSClientConfig("")
case 1:
// No client cert, use specified CA
c, err = NewTLSClientConfig(args[0])
case 2:
// Client cert, use system CA
c, err = NewTLSConfig(args[0], args[1], "")
case 3:
// Client cert, use specified CA
c, err = NewTLSConfig(args[0], args[1], args[2])
default:
err = fmt.Errorf("maximum of three arguments allowed for TLS config, found %d", len(args))
}
if err != nil {
return nil, err
}
return c, nil
}
// NewTLSConfig returns a TLS config that includes a certificate
// Use for server TLS config or when using a client certificate
// If caPath is empty, system CAs will be used
func NewTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("could not load TLS cert: %s", err)
}
roots, err := loadRoots(caPath)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: roots}
setTLSDefaults(tlsConfig)
return tlsConfig, nil
}
// NewTLSClientConfig returns a TLS config for a client connection
// If caPath is empty, system CAs will be used
func NewTLSClientConfig(caPath string) (*tls.Config, error) {
roots, err := loadRoots(caPath)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{RootCAs: roots}
setTLSDefaults(tlsConfig)
return tlsConfig, nil
}
func loadRoots(caPath string) (*x509.CertPool, error) {
if caPath == "" {
return nil, nil
}
roots := x509.NewCertPool()
pem, err := os.ReadFile(filepath.Clean(caPath))
if err != nil {
return nil, fmt.Errorf("error reading %s: %s", caPath, err)
}
ok := roots.AppendCertsFromPEM(pem)
if !ok {
return nil, fmt.Errorf("could not read root certs: %s", err)
}
return roots, nil
}
// NewHTTPSTransport returns an HTTP transport configured using tls.Config
func NewHTTPSTransport(cc *tls.Config) *http.Transport {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: cc,
MaxIdleConnsPerHost: 25,
}
return tr
}
// Package uniq keeps track of "thing" that are either "todo" or "done". Multiple
// identical events will only be processed once.
package uniq
// U keeps track of item to be done.
type U struct {
u map[string]item
}
type item struct {
state int // either todo or done
f func() error // function to be executed.
}
// New returns a new initialized U.
func New() U { return U{u: make(map[string]item)} }
// Set sets function f in U under key. If the key already exists it is not overwritten.
func (u U) Set(key string, f func() error) {
if _, ok := u.u[key]; ok {
return
}
u.u[key] = item{todo, f}
}
// Unset removes the key.
func (u U) Unset(key string) {
delete(u.u, key)
}
// ForEach iterates over u and executes f for each element that is 'todo' and sets it to 'done'.
func (u U) ForEach() error {
for k, v := range u.u {
if v.state == todo {
v.f()
}
v.state = done
u.u[k] = v
}
return nil
}
const (
todo = 1
done = 2
)
// Package up is used to run a function for some duration. If a new function is added while a previous run is
// still ongoing, nothing new will be executed.
package up
import (
"sync"
"time"
)
// Probe is used to run a single Func until it returns true (indicating a target is healthy). If an Func
// is already in progress no new one will be added, i.e. there is always a maximum of 1 checks in flight.
//
// There is a tradeoff to be made in figuring out quickly that an upstream is healthy and not doing much work
// (sending queries) to find that out. Having some kind of exp. backoff here won't help much, because you don't want
// to backoff too much. You then also need random queries to be performed every so often to quickly detect a working
// upstream. In the end we just send a query every 0.5 second to check the upstream. This hopefully strikes a balance
// between getting information about the upstream state quickly and not doing too much work. Note that 0.5s is still an
// eternity in DNS, so we may actually want to shorten it.
type Probe struct {
sync.Mutex
inprogress int
interval time.Duration
}
// Func is used to determine if a target is alive. If so this function must return nil.
type Func func() error
// New returns a pointer to an initialized Probe.
func New() *Probe { return &Probe{} }
// Do will probe target, if a probe is already in progress this is a noop.
func (p *Probe) Do(f Func) {
p.Lock()
if p.inprogress != idle {
p.Unlock()
return
}
p.inprogress = active
interval := p.interval
p.Unlock()
// Passed the lock. Now run f for as long it returns false. If a true is returned
// we return from the goroutine and we can accept another Func to run.
go func() {
i := 1
for {
if err := f(); err == nil {
break
}
time.Sleep(interval)
p.Lock()
if p.inprogress == stop {
p.Unlock()
return
}
p.Unlock()
i++
}
p.Lock()
p.inprogress = idle
p.Unlock()
}()
}
// Stop stops the probing.
func (p *Probe) Stop() {
p.Lock()
p.inprogress = stop
p.Unlock()
}
// Start will initialize the probe manager, after which probes can be initiated with Do.
func (p *Probe) Start(interval time.Duration) {
p.Lock()
p.interval = interval
p.Unlock()
}
const (
idle = iota
active
stop
)
// Package upstream abstracts a upstream lookups so that plugins can handle them in an unified way.
package upstream
import (
"context"
"fmt"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin/pkg/nonwriter"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Upstream is used to resolve CNAME or other external targets via CoreDNS itself.
type Upstream struct{}
// New creates a new Upstream to resolve names using the coredns process.
func New() *Upstream { return &Upstream{} }
// Lookup routes lookups to our selves to make it follow the plugin chain *again*, but with a (possibly) new query. As
// we are doing the query against ourselves again, there is no actual new hop, as such RFC 6891 does not apply and we
// need the EDNS0 option present in the *original* query to be present here too.
func (u *Upstream) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) {
server, ok := ctx.Value(dnsserver.Key{}).(*dnsserver.Server)
if !ok {
return nil, fmt.Errorf("no full server is running")
}
req := state.NewWithQuestion(name, typ)
nw := nonwriter.New(state.W)
server.ServeDNS(ctx, nw, req.Req)
return nw.Msg, nil
}
// Package plugin provides some types and functions common among plugin.
package plugin
import (
"context"
"errors"
"fmt"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
)
type (
// Plugin is a middle layer which represents the traditional
// idea of plugin: it chains one Handler to the next by being
// passed the next Handler in the chain.
Plugin func(Handler) Handler
// Handler is like dns.Handler except ServeDNS may return an rcode
// and/or error.
//
// If ServeDNS writes to the response body, it should return a status
// code. CoreDNS assumes *no* reply has yet been written if the status
// code is one of the following:
//
// * SERVFAIL (dns.RcodeServerFailure)
//
// * REFUSED (dns.RecodeRefused)
//
// * FORMERR (dns.RcodeFormatError)
//
// * NOTIMP (dns.RcodeNotImplemented)
//
// All other response codes signal other handlers above it that the
// response message is already written, and that they should not write
// to it also.
//
// If ServeDNS encounters an error, it should return the error value
// so it can be logged by designated error-handling plugin.
//
// If writing a response after calling another ServeDNS method, the
// returned rcode SHOULD be used when writing the response.
//
// If handling errors after calling another ServeDNS method, the
// returned error value SHOULD be logged or handled accordingly.
//
// Otherwise, return values should be propagated down the plugin
// chain by returning them unchanged.
Handler interface {
ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
Name() string
}
// HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error. See Handler
// documentation for more information.
HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
)
// ServeDNS implements the Handler interface.
func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(ctx, w, r)
}
// Name implements the Handler interface.
func (f HandlerFunc) Name() string { return "handlerfunc" }
// Error returns err with 'plugin/name: ' prefixed to it.
func Error(name string, err error) error { return fmt.Errorf("%s/%s: %s", "plugin", name, err) }
// NextOrFailure calls next.ServeDNS when next is not nil, otherwise it will return, a ServerFailure and a `no next plugin found` error.
func NextOrFailure(name string, next Handler, ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if next != nil {
if span := ot.SpanFromContext(ctx); span != nil {
child := span.Tracer().StartSpan(next.Name(), ot.ChildOf(span.Context()))
defer child.Finish()
ctx = ot.ContextWithSpan(ctx, child)
}
return next.ServeDNS(ctx, w, r)
}
return dns.RcodeServerFailure, Error(name, errors.New("no next plugin found"))
}
// ClientWrite returns true if the response has been written to the client.
// Each plugin to adhere to this protocol.
func ClientWrite(rcode int) bool {
switch rcode {
case dns.RcodeServerFailure:
fallthrough
case dns.RcodeRefused:
fallthrough
case dns.RcodeFormatError:
fallthrough
case dns.RcodeNotImplemented:
return false
}
return true
}
// Namespace is the namespace used for the metrics.
const Namespace = "coredns"
// TimeBuckets is based on Prometheus client_golang prometheus.DefBuckets
var TimeBuckets = prometheus.ExponentialBuckets(0.00025, 2, 16) // from 0.25ms to 8 seconds
// SlimTimeBuckets is low cardinality set of duration buckets.
var SlimTimeBuckets = prometheus.ExponentialBuckets(0.00025, 10, 5) // from 0.25ms to 2.5 seconds
// NativeHistogramBucketFactor controls the resolution of Prometheus native histogram buckets.
// See: https://pkg.go.dev/github.com/prometheus/client_golang@v1.19.0/prometheus#section-readme
var NativeHistogramBucketFactor = 1.05
// ErrOnce is returned when a plugin doesn't support multiple setups per server.
var ErrOnce = errors.New("this plugin can only be used once per Server Block")
package plugin
import "github.com/coredns/caddy"
// Register registers your plugin with CoreDNS and allows it to be called when the server is running.
func Register(name string, action caddy.SetupFunc) {
caddy.RegisterPlugin(name, caddy.Plugin{
ServerType: "dns",
Action: action,
})
}
package rewrite
import (
"context"
"fmt"
"strings"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type classRule struct {
fromClass uint16
toClass uint16
NextAction string
}
// newClassRule creates a class matching rule
func newClassRule(nextAction string, args ...string) (Rule, error) {
var from, to uint16
var ok bool
if from, ok = dns.StringToClass[strings.ToUpper(args[0])]; !ok {
return nil, fmt.Errorf("invalid class %q", strings.ToUpper(args[0]))
}
if to, ok = dns.StringToClass[strings.ToUpper(args[1])]; !ok {
return nil, fmt.Errorf("invalid class %q", strings.ToUpper(args[1]))
}
return &classRule{from, to, nextAction}, nil
}
// Rewrite rewrites the current request.
func (rule *classRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if rule.fromClass > 0 && rule.toClass > 0 {
if state.Req.Question[0].Qclass == rule.fromClass {
state.Req.Question[0].Qclass = rule.toClass
return nil, RewriteDone
}
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *classRule) Mode() string { return rule.NextAction }
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// UpstreamInt wraps the Upstream API for dependency injection during testing
type UpstreamInt interface {
Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error)
}
// cnameTargetRule is cname target rewrite rule.
type cnameTargetRule struct {
rewriteType string
paramFromTarget string
paramToTarget string
nextAction string
Upstream UpstreamInt // Upstream for looking up external names during the resolution process.
}
// cnameTargetRuleWithReqState is cname target rewrite rule state
type cnameTargetRuleWithReqState struct {
rule cnameTargetRule
state request.Request
ctx context.Context
}
func (r *cnameTargetRule) getFromAndToTarget(inputCName string) (from string, to string) {
switch r.rewriteType {
case ExactMatch:
return r.paramFromTarget, r.paramToTarget
case PrefixMatch:
if strings.HasPrefix(inputCName, r.paramFromTarget) {
return inputCName, r.paramToTarget + strings.TrimPrefix(inputCName, r.paramFromTarget)
}
case SuffixMatch:
if strings.HasSuffix(inputCName, r.paramFromTarget) {
return inputCName, strings.TrimSuffix(inputCName, r.paramFromTarget) + r.paramToTarget
}
case SubstringMatch:
if strings.Contains(inputCName, r.paramFromTarget) {
return inputCName, strings.ReplaceAll(inputCName, r.paramFromTarget, r.paramToTarget)
}
case RegexMatch:
pattern := regexp.MustCompile(r.paramFromTarget)
regexGroups := pattern.FindStringSubmatch(inputCName)
if len(regexGroups) == 0 {
return "", ""
}
substitution := r.paramToTarget
for groupIndex, groupValue := range regexGroups {
groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}"
substitution = strings.ReplaceAll(substitution, groupIndexStr, groupValue)
}
return inputCName, substitution
}
return "", ""
}
func (r *cnameTargetRuleWithReqState) RewriteResponse(res *dns.Msg, rr dns.RR) {
// logic to rewrite the cname target of dns response
switch rr.Header().Rrtype {
case dns.TypeCNAME:
// rename the target of the cname response
if cname, ok := rr.(*dns.CNAME); ok {
fromTarget, toTarget := r.rule.getFromAndToTarget(cname.Target)
if cname.Target == fromTarget {
// create upstream request with the new target with the same qtype
r.state.Req.Question[0].Name = toTarget
upRes, err := r.rule.Upstream.Lookup(r.ctx, r.state, toTarget, r.state.Req.Question[0].Qtype)
if err != nil {
log.Errorf("Error upstream request %v", err)
}
var newAnswer []dns.RR
// iterate over first upstram response
// add the cname record to the new answer
for _, rr := range res.Answer {
if cname, ok := rr.(*dns.CNAME); ok {
// change the target name in the response
cname.Target = toTarget
newAnswer = append(newAnswer, rr)
}
}
// iterate over upstream response received
for _, rr := range upRes.Answer {
if rr.Header().Name == toTarget {
newAnswer = append(newAnswer, rr)
}
}
res.Answer = newAnswer
// if not propagated, the truncated response might get cached,
// and it will be impossible to resolve the full response
res.Truncated = upRes.Truncated
}
}
}
}
func newCNAMERule(nextAction string, args ...string) (Rule, error) {
var rewriteType string
var paramFromTarget, paramToTarget string
if len(args) == 3 {
rewriteType = (strings.ToLower(args[0]))
switch rewriteType {
case ExactMatch:
case PrefixMatch:
case SuffixMatch:
case SubstringMatch:
case RegexMatch:
default:
return nil, fmt.Errorf("unknown cname rewrite type: %s", rewriteType)
}
paramFromTarget, paramToTarget = strings.ToLower(args[1]), strings.ToLower(args[2])
} else if len(args) == 2 {
rewriteType = ExactMatch
paramFromTarget, paramToTarget = strings.ToLower(args[0]), strings.ToLower(args[1])
} else {
return nil, fmt.Errorf("too few (%d) arguments for a cname rule", len(args))
}
rule := cnameTargetRule{
rewriteType: rewriteType,
paramFromTarget: paramFromTarget,
paramToTarget: paramToTarget,
nextAction: nextAction,
Upstream: upstream.New(),
}
return &rule, nil
}
// Rewrite rewrites the current request.
func (r *cnameTargetRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if r != nil && len(r.rewriteType) > 0 && len(r.paramFromTarget) > 0 && len(r.paramToTarget) > 0 {
return ResponseRules{&cnameTargetRuleWithReqState{
rule: *r,
state: state,
ctx: ctx,
}}, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (r *cnameTargetRule) Mode() string { return r.nextAction }
// Package rewrite is a plugin for rewriting requests internally to something different.
package rewrite
import (
"context"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/pkg/edns"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// edns0LocalRule is a rewrite rule for EDNS0_LOCAL options.
type edns0LocalRule struct {
mode string
action string
code uint16
data []byte
revert bool
}
// edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable.
type edns0VariableRule struct {
mode string
action string
code uint16
variable string
revert bool
}
// ends0NsidRule is a rewrite rule for EDNS0_NSID options.
type edns0NsidRule struct {
mode string
action string
revert bool
}
type edns0SetResponseRule struct {
code uint16
}
func (r *edns0SetResponseRule) RewriteResponse(res *dns.Msg, _ dns.RR) {
ednsOpt := res.IsEdns0()
for idx, opt := range ednsOpt.Option {
if opt.Option() == r.code {
ednsOpt.Option = append(ednsOpt.Option[:idx], ednsOpt.Option[idx+1:]...)
return
}
}
}
type edns0ReplaceResponseRule[T dns.EDNS0] struct {
code uint16
source T
}
func (r *edns0ReplaceResponseRule[T]) RewriteResponse(res *dns.Msg, _ dns.RR) {
ednsOpt := res.IsEdns0()
for idx, opt := range ednsOpt.Option {
if opt.Option() == r.code {
ednsOpt.Option[idx] = r.source
return
}
}
}
// setupEdns0Opt will retrieve the EDNS0 OPT or create it if it does not exist.
func setupEdns0Opt(r *dns.Msg) *dns.OPT {
o := r.IsEdns0()
if o == nil {
r.SetEdns0(4096, false)
o = r.IsEdns0()
}
return o
}
func unsetEdns0Option(opt *dns.OPT, code uint16) {
var newOpts []dns.EDNS0
for _, o := range opt.Option {
if o.Option() != code {
newOpts = append(newOpts, o)
}
}
opt.Option = newOpts
}
// Rewrite will alter the request EDNS0 NSID option
func (rule *edns0NsidRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
o := setupEdns0Opt(state.Req)
if rule.action == Unset {
unsetEdns0Option(o, dns.EDNS0NSID)
return nil, RewriteDone
}
var resp ResponseRules
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_NSID); ok {
if rule.action == Replace || rule.action == Set {
if rule.revert {
old := *e
resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_NSID]{code: e.Code, source: &old})
}
e.Nsid = "" // make sure it is empty for request
return resp, RewriteDone
}
}
}
// add option if not found
if rule.action == Append || rule.action == Set {
o.Option = append(o.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""})
if rule.revert {
resp = append(resp, &edns0SetResponseRule{code: dns.EDNS0NSID})
}
return resp, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *edns0NsidRule) Mode() string { return rule.mode }
// Rewrite will alter the request EDNS0 local options.
func (rule *edns0LocalRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
o := setupEdns0Opt(state.Req)
if rule.action == Unset {
unsetEdns0Option(o, rule.code)
return nil, RewriteDone
}
var resp ResponseRules
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_LOCAL); ok {
if rule.code == e.Code {
if rule.action == Replace || rule.action == Set {
if rule.revert {
old := *e
resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_LOCAL]{code: rule.code, source: &old})
}
e.Data = rule.data
return resp, RewriteDone
}
}
}
}
// add option if not found
if rule.action == Append || rule.action == Set {
o.Option = append(o.Option, &dns.EDNS0_LOCAL{Code: rule.code, Data: rule.data})
if rule.revert {
resp = append(resp, &edns0SetResponseRule{code: rule.code})
}
return resp, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *edns0LocalRule) Mode() string { return rule.mode }
// newEdns0Rule creates an EDNS0 rule of the appropriate type based on the args
func newEdns0Rule(mode string, args ...string) (Rule, error) {
if len(args) < 2 {
return nil, fmt.Errorf("too few arguments for an EDNS0 rule")
}
ruleType := strings.ToLower(args[0])
action := strings.ToLower(args[1])
switch action {
case Append:
case Replace:
case Set:
case Unset:
return newEdns0UnsetRule(mode, action, ruleType, args...)
default:
return nil, fmt.Errorf("invalid action: %q", action)
}
// Extract "revert" parameter.
var revert bool
if args[len(args)-1] == "revert" {
revert = true
args = args[:len(args)-1]
}
switch ruleType {
case "local":
if len(args) != 4 {
return nil, fmt.Errorf("EDNS0 local rules require three or four args")
}
// Check for variable option.
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
return newEdns0VariableRule(mode, action, args[2], args[3], revert)
}
return newEdns0LocalRule(mode, action, args[2], args[3], revert)
case "nsid":
if len(args) != 2 {
return nil, fmt.Errorf("EDNS0 NSID rules can accept no more than one arg")
}
return &edns0NsidRule{mode: mode, action: action, revert: revert}, nil
case "subnet":
if len(args) != 4 {
return nil, fmt.Errorf("EDNS0 subnet rules require three or four args")
}
return newEdns0SubnetRule(mode, action, args[2], args[3], revert)
default:
return nil, fmt.Errorf("invalid rule type %q", ruleType)
}
}
func newEdns0UnsetRule(mode string, action string, ruleType string, args ...string) (Rule, error) {
switch ruleType {
case "local":
if len(args) != 3 {
return nil, fmt.Errorf("local unset action requires exactly two arguments")
}
return newEdns0LocalRule(mode, action, args[2], "", false)
case "nsid":
if len(args) != 2 {
return nil, fmt.Errorf("nsid unset action requires exactly one argument")
}
return &edns0NsidRule{mode, action, false}, nil
case "subnet":
if len(args) != 2 {
return nil, fmt.Errorf("subnet unset action requires exactly one argument")
}
return &edns0SubnetRule{mode, 0, 0, action, false}, nil
default:
return nil, fmt.Errorf("invalid rule type %q", ruleType)
}
}
func newEdns0LocalRule(mode, action, code, data string, revert bool) (*edns0LocalRule, error) {
c, err := strconv.ParseUint(code, 0, 16)
if err != nil {
return nil, err
}
decoded := []byte(data)
if strings.HasPrefix(data, "0x") {
decoded, err = hex.DecodeString(data[2:])
if err != nil {
return nil, err
}
}
// Add this code to the ones the server supports.
edns.SetSupportedOption(uint16(c))
return &edns0LocalRule{mode: mode, action: action, code: uint16(c), data: decoded, revert: revert}, nil
}
// newEdns0VariableRule creates an EDNS0 rule that handles variable substitution
func newEdns0VariableRule(mode, action, code, variable string, revert bool) (*edns0VariableRule, error) {
c, err := strconv.ParseUint(code, 0, 16)
if err != nil {
return nil, err
}
//Validate
if !isValidVariable(variable) {
return nil, fmt.Errorf("unsupported variable name %q", variable)
}
// Add this code to the ones the server supports.
edns.SetSupportedOption(uint16(c))
return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable, revert: revert}, nil
}
// ruleData returns the data specified by the variable.
func (rule *edns0VariableRule) ruleData(ctx context.Context, state request.Request) ([]byte, error) {
switch rule.variable {
case queryName:
return []byte(state.QName()), nil
case queryType:
return uint16ToWire(state.QType()), nil
case clientIP:
return ipToWire(state.Family(), state.IP())
case serverIP:
return ipToWire(state.Family(), state.LocalIP())
case clientPort:
return portToWire(state.Port())
case serverPort:
return portToWire(state.LocalPort())
case protocol:
return []byte(state.Proto()), nil
}
fetcher := metadata.ValueFunc(ctx, rule.variable[1:len(rule.variable)-1])
if fetcher != nil {
value := fetcher()
if len(value) > 0 {
return []byte(value), nil
}
}
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
}
// Rewrite will alter the request EDNS0 local options with specified variables.
func (rule *edns0VariableRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
data, err := rule.ruleData(ctx, state)
if err != nil || data == nil {
return nil, RewriteIgnored
}
var resp ResponseRules
o := setupEdns0Opt(state.Req)
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_LOCAL); ok {
if rule.code == e.Code {
if rule.action == Replace || rule.action == Set {
if rule.revert {
old := *e
resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_LOCAL]{code: rule.code, source: &old})
}
e.Data = data
return resp, RewriteDone
}
return nil, RewriteIgnored
}
}
}
// add option if not found
if rule.action == Append || rule.action == Set {
o.Option = append(o.Option, &dns.EDNS0_LOCAL{Code: rule.code, Data: data})
if rule.revert {
resp = append(resp, &edns0SetResponseRule{code: rule.code})
}
return resp, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *edns0VariableRule) Mode() string { return rule.mode }
func isValidVariable(variable string) bool {
switch variable {
case
queryName,
queryType,
clientIP,
clientPort,
protocol,
serverIP,
serverPort:
return true
}
// we cannot validate the labels of metadata - but we can verify it has the syntax of a label
if strings.HasPrefix(variable, "{") && strings.HasSuffix(variable, "}") && metadata.IsLabel(variable[1:len(variable)-1]) {
return true
}
return false
}
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
type edns0SubnetRule struct {
mode string
v4BitMaskLen uint8
v6BitMaskLen uint8
action string
revert bool
}
func newEdns0SubnetRule(mode, action, v4BitMaskLen, v6BitMaskLen string, revert bool) (*edns0SubnetRule, error) {
v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16)
if err != nil {
return nil, err
}
// validate V4 length
if v4Len > net.IPv4len*8 {
return nil, fmt.Errorf("invalid IPv4 bit mask length %d", v4Len)
}
v6Len, err := strconv.ParseUint(v6BitMaskLen, 0, 16)
if err != nil {
return nil, err
}
// validate V6 length
if v6Len > net.IPv6len*8 {
return nil, fmt.Errorf("invalid IPv6 bit mask length %d", v6Len)
}
return &edns0SubnetRule{mode: mode, action: action,
v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len), revert: revert}, nil
}
// fillEcsData sets the subnet data into the ecs option
func (rule *edns0SubnetRule) fillEcsData(state request.Request, ecs *dns.EDNS0_SUBNET) error {
family := state.Family()
if (family != 1) && (family != 2) {
return fmt.Errorf("unable to fill data for EDNS0 subnet due to invalid IP family")
}
ecs.Family = uint16(family)
ecs.SourceScope = 0
ipAddr := state.IP()
switch family {
case 1:
ipv4Mask := net.CIDRMask(int(rule.v4BitMaskLen), 32)
ipv4Addr := net.ParseIP(ipAddr)
ecs.SourceNetmask = rule.v4BitMaskLen
ecs.Address = ipv4Addr.Mask(ipv4Mask).To4()
case 2:
ipv6Mask := net.CIDRMask(int(rule.v6BitMaskLen), 128)
ipv6Addr := net.ParseIP(ipAddr)
ecs.SourceNetmask = rule.v6BitMaskLen
ecs.Address = ipv6Addr.Mask(ipv6Mask).To16()
}
return nil
}
// Rewrite will alter the request EDNS0 subnet option.
func (rule *edns0SubnetRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
o := setupEdns0Opt(state.Req)
if rule.action == Unset {
unsetEdns0Option(o, dns.EDNS0SUBNET)
return nil, RewriteDone
}
var resp ResponseRules
for _, s := range o.Option {
if e, ok := s.(*dns.EDNS0_SUBNET); ok {
if rule.action == Replace || rule.action == Set {
if rule.revert {
old := *e
resp = append(resp, &edns0ReplaceResponseRule[*dns.EDNS0_SUBNET]{code: e.Code, source: &old})
}
if rule.fillEcsData(state, e) == nil {
return resp, RewriteDone
}
}
return nil, RewriteIgnored
}
}
// add option if not found
if rule.action == Append || rule.action == Set {
opt := &dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET}
if rule.fillEcsData(state, opt) == nil {
o.Option = append(o.Option, opt)
if rule.revert {
resp = append(resp, &edns0SetResponseRule{code: dns.EDNS0SUBNET})
}
return resp, RewriteDone
}
}
return nil, RewriteIgnored
}
// Mode returns the processing mode
func (rule *edns0SubnetRule) Mode() string { return rule.mode }
// These are all defined actions.
const (
Replace = "replace"
Set = "set"
Append = "append"
Unset = "unset"
)
// Supported local EDNS0 variables
const (
queryName = "{qname}"
queryType = "{qtype}"
clientIP = "{client_ip}"
clientPort = "{client_port}"
protocol = "{protocol}"
serverIP = "{server_ip}"
serverPort = "{server_port}"
)
//go:build gofuzz
package rewrite
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/pkg/fuzz"
)
// Fuzz fuzzes rewrite.
func Fuzz(data []byte) int {
c := caddy.NewTestController("dns", "rewrite edns0 subnet set 24 56")
rules, err := rewriteParse(c)
if err != nil {
return 0
}
r := Rewrite{Rules: rules}
return fuzz.Do(r, data)
}
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// stringRewriter rewrites a string
type stringRewriter interface {
rewriteString(src string) string
}
// regexStringRewriter can be used to rewrite strings by regex pattern.
// it contains all the information required to detect and execute a rewrite
// on a string.
type regexStringRewriter struct {
pattern *regexp.Regexp
replacement string
}
var _ stringRewriter = ®exStringRewriter{}
func newStringRewriter(pattern *regexp.Regexp, replacement string) stringRewriter {
return ®exStringRewriter{pattern, replacement}
}
func (r *regexStringRewriter) rewriteString(src string) string {
regexGroups := r.pattern.FindStringSubmatch(src)
if len(regexGroups) == 0 {
return src
}
s := r.replacement
for groupIndex, groupValue := range regexGroups {
groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}"
s = strings.ReplaceAll(s, groupIndexStr, groupValue)
}
return s
}
// remapStringRewriter maps a dedicated string to another string
// it also maps a the domain of a sub domain.
type remapStringRewriter struct {
orig string
replacement string
}
var _ stringRewriter = &remapStringRewriter{}
func newRemapStringRewriter(orig, replacement string) stringRewriter {
return &remapStringRewriter{orig, replacement}
}
func (r *remapStringRewriter) rewriteString(src string) string {
if src == r.orig {
return r.replacement
}
if strings.HasSuffix(src, "."+r.orig) {
return src[0:len(src)-len(r.orig)] + r.replacement
}
return src
}
// suffixStringRewriter maps a dedicated suffix string to another string
type suffixStringRewriter struct {
suffix string
replacement string
}
var _ stringRewriter = &suffixStringRewriter{}
func newSuffixStringRewriter(orig, replacement string) stringRewriter {
return &suffixStringRewriter{orig, replacement}
}
func (r *suffixStringRewriter) rewriteString(src string) string {
if strings.HasSuffix(src, r.suffix) {
return strings.TrimSuffix(src, r.suffix) + r.replacement
}
return src
}
// nameRewriterResponseRule maps a record name according to a stringRewriter.
type nameRewriterResponseRule struct {
stringRewriter
}
func (r *nameRewriterResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
rr.Header().Name = r.rewriteString(rr.Header().Name)
}
// valueRewriterResponseRule maps a record value according to a stringRewriter.
type valueRewriterResponseRule struct {
stringRewriter
}
func (r *valueRewriterResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
value := getRecordValueForRewrite(rr)
if value != "" {
new := r.rewriteString(value)
if new != value {
setRewrittenRecordValue(rr, new)
}
}
}
const (
// ExactMatch matches only on exact match of the name in the question section of a request
ExactMatch = "exact"
// PrefixMatch matches when the name begins with the matching string
PrefixMatch = "prefix"
// SuffixMatch matches when the name ends with the matching string
SuffixMatch = "suffix"
// SubstringMatch matches on partial match of the name in the question section of a request
SubstringMatch = "substring"
// RegexMatch matches when the name in the question section of a request matches a regular expression
RegexMatch = "regex"
// AnswerMatch matches an answer rewrite
AnswerMatch = "answer"
// AutoMatch matches the auto name answer rewrite
AutoMatch = "auto"
// NameMatch matches the name answer rewrite
NameMatch = "name"
// ValueMatch matches the value answer rewrite
ValueMatch = "value"
)
type nameRuleBase struct {
nextAction string
auto bool
replacement string
static ResponseRules
}
func newNameRuleBase(nextAction string, auto bool, replacement string, staticResponses ResponseRules) nameRuleBase {
return nameRuleBase{
nextAction: nextAction,
auto: auto,
replacement: replacement,
static: staticResponses,
}
}
// responseRuleFor create for auto mode dynamically response rewriters for name and value
// reverting the mapping done by the name rewrite rule, which can be found in the state.
func (rule *nameRuleBase) responseRuleFor(state request.Request) (ResponseRules, Result) {
if !rule.auto {
return rule.static, RewriteDone
}
rewriter := newRemapStringRewriter(state.Req.Question[0].Name, state.Name())
rules := ResponseRules{
&nameRewriterResponseRule{rewriter},
&valueRewriterResponseRule{rewriter},
}
return append(rules, rule.static...), RewriteDone
}
// Mode returns the processing nextAction
func (rule *nameRuleBase) Mode() string { return rule.nextAction }
// exactNameRule rewrites the current request based upon exact match of the name
// in the question section of the request.
type exactNameRule struct {
nameRuleBase
from string
}
func newExactNameRule(nextAction string, orig, replacement string, answers ResponseRules) Rule {
return &exactNameRule{
newNameRuleBase(nextAction, true, replacement, answers),
orig,
}
}
func (rule *exactNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if rule.from == state.Name() {
state.Req.Question[0].Name = rule.replacement
return rule.responseRuleFor(state)
}
return nil, RewriteIgnored
}
// prefixNameRule rewrites the current request when the name begins with the matching string.
type prefixNameRule struct {
nameRuleBase
prefix string
}
func newPrefixNameRule(nextAction string, auto bool, prefix, replacement string, answers ResponseRules) Rule {
return &prefixNameRule{
newNameRuleBase(nextAction, auto, replacement, answers),
prefix,
}
}
func (rule *prefixNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if strings.HasPrefix(state.Name(), rule.prefix) {
state.Req.Question[0].Name = rule.replacement + strings.TrimPrefix(state.Name(), rule.prefix)
return rule.responseRuleFor(state)
}
return nil, RewriteIgnored
}
// suffixNameRule rewrites the current request when the name ends with the matching string.
type suffixNameRule struct {
nameRuleBase
suffix string
}
func newSuffixNameRule(nextAction string, auto bool, suffix, replacement string, answers ResponseRules) Rule {
var rules ResponseRules
if auto {
// for a suffix rewriter better standard response rewrites can be done
// just by using the original suffix/replacement in the opposite order
rewriter := newSuffixStringRewriter(replacement, suffix)
rules = ResponseRules{
&nameRewriterResponseRule{rewriter},
&valueRewriterResponseRule{rewriter},
}
}
return &suffixNameRule{
newNameRuleBase(nextAction, false, replacement, append(rules, answers...)),
suffix,
}
}
func (rule *suffixNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if strings.HasSuffix(state.Name(), rule.suffix) {
state.Req.Question[0].Name = strings.TrimSuffix(state.Name(), rule.suffix) + rule.replacement
return rule.responseRuleFor(state)
}
return nil, RewriteIgnored
}
// substringNameRule rewrites the current request based upon partial match of the
// name in the question section of the request.
type substringNameRule struct {
nameRuleBase
substring string
}
func newSubstringNameRule(nextAction string, auto bool, substring, replacement string, answers ResponseRules) Rule {
return &substringNameRule{
newNameRuleBase(nextAction, auto, replacement, answers),
substring,
}
}
func (rule *substringNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if strings.Contains(state.Name(), rule.substring) {
state.Req.Question[0].Name = strings.ReplaceAll(state.Name(), rule.substring, rule.replacement)
return rule.responseRuleFor(state)
}
return nil, RewriteIgnored
}
// regexNameRule rewrites the current request when the name in the question
// section of the request matches a regular expression.
type regexNameRule struct {
nameRuleBase
pattern *regexp.Regexp
}
func newRegexNameRule(nextAction string, auto bool, pattern *regexp.Regexp, replacement string, answers ResponseRules) Rule {
return ®exNameRule{
newNameRuleBase(nextAction, auto, replacement, answers),
pattern,
}
}
func (rule *regexNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
regexGroups := rule.pattern.FindStringSubmatch(state.Name())
if len(regexGroups) == 0 {
return nil, RewriteIgnored
}
s := rule.replacement
for groupIndex, groupValue := range regexGroups {
groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}"
s = strings.ReplaceAll(s, groupIndexStr, groupValue)
}
state.Req.Question[0].Name = s
return rule.responseRuleFor(state)
}
// newNameRule creates a name matching rule based on exact, partial, or regex match
func newNameRule(nextAction string, args ...string) (Rule, error) {
var matchType, rewriteQuestionFrom, rewriteQuestionTo string
if len(args) < 2 {
return nil, fmt.Errorf("too few arguments for a name rule")
}
if len(args) == 2 {
matchType = ExactMatch
rewriteQuestionFrom = plugin.Name(args[0]).Normalize()
rewriteQuestionTo = plugin.Name(args[1]).Normalize()
}
if len(args) >= 3 {
matchType = strings.ToLower(args[0])
if matchType == RegexMatch {
rewriteQuestionFrom = args[1]
rewriteQuestionTo = args[2]
} else {
rewriteQuestionFrom = plugin.Name(args[1]).Normalize()
rewriteQuestionTo = plugin.Name(args[2]).Normalize()
}
}
if matchType == ExactMatch || matchType == SuffixMatch {
if !hasClosingDot(rewriteQuestionFrom) {
rewriteQuestionFrom = rewriteQuestionFrom + "."
}
if !hasClosingDot(rewriteQuestionTo) {
rewriteQuestionTo = rewriteQuestionTo + "."
}
}
var err error
var answers ResponseRules
auto := false
if len(args) > 3 {
auto, answers, err = parseAnswerRules(matchType, args[3:])
if err != nil {
return nil, err
}
}
switch matchType {
case ExactMatch:
if _, err := isValidRegexPattern(rewriteQuestionTo, rewriteQuestionFrom); err != nil {
return nil, err
}
return newExactNameRule(nextAction, rewriteQuestionFrom, rewriteQuestionTo, answers), nil
case PrefixMatch:
return newPrefixNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil
case SuffixMatch:
return newSuffixNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil
case SubstringMatch:
return newSubstringNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil
case RegexMatch:
rewriteQuestionFromPattern, err := isValidRegexPattern(rewriteQuestionFrom, rewriteQuestionTo)
if err != nil {
return nil, err
}
rewriteQuestionTo := plugin.Name(args[2]).Normalize()
return newRegexNameRule(nextAction, auto, rewriteQuestionFromPattern, rewriteQuestionTo, answers), nil
default:
return nil, fmt.Errorf("name rule supports only exact, prefix, suffix, substring, and regex name matching, received: %s", matchType)
}
}
func parseAnswerRules(name string, args []string) (auto bool, rules ResponseRules, err error) {
auto = false
arg := 0
nameRules := 0
last := ""
if len(args) < 2 {
return false, nil, fmt.Errorf("invalid arguments for %s rule", name)
}
for arg < len(args) {
if last == "" && args[arg] != AnswerMatch {
if last == "" {
return false, nil, fmt.Errorf("exceeded the number of arguments for a non-answer rule argument for %s rule", name)
}
return false, nil, fmt.Errorf("exceeded the number of arguments for %s answer rule for %s rule", last, name)
}
if args[arg] == AnswerMatch {
arg++
}
if len(args)-arg == 0 {
return false, nil, fmt.Errorf("type missing for answer rule for %s rule", name)
}
last = args[arg]
arg++
switch last {
case AutoMatch:
auto = true
continue
case NameMatch:
if len(args)-arg < 2 {
return false, nil, fmt.Errorf("%s answer rule for %s rule: 2 arguments required", last, name)
}
rewriteAnswerFrom := args[arg]
rewriteAnswerTo := args[arg+1]
rewriteAnswerFromPattern, err := isValidRegexPattern(rewriteAnswerFrom, rewriteAnswerTo)
rewriteAnswerTo = plugin.Name(rewriteAnswerTo).Normalize()
if err != nil {
return false, nil, fmt.Errorf("%s answer rule for %s rule: %s", last, name, err)
}
rules = append(rules, &nameRewriterResponseRule{newStringRewriter(rewriteAnswerFromPattern, rewriteAnswerTo)})
arg += 2
nameRules++
case ValueMatch:
if len(args)-arg < 2 {
return false, nil, fmt.Errorf("%s answer rule for %s rule: 2 arguments required", last, name)
}
rewriteAnswerFrom := args[arg]
rewriteAnswerTo := args[arg+1]
rewriteAnswerFromPattern, err := isValidRegexPattern(rewriteAnswerFrom, rewriteAnswerTo)
rewriteAnswerTo = plugin.Name(rewriteAnswerTo).Normalize()
if err != nil {
return false, nil, fmt.Errorf("%s answer rule for %s rule: %s", last, name, err)
}
rules = append(rules, &valueRewriterResponseRule{newStringRewriter(rewriteAnswerFromPattern, rewriteAnswerTo)})
arg += 2
default:
return false, nil, fmt.Errorf("invalid type %q for answer rule for %s rule", last, name)
}
}
if auto && nameRules > 0 {
return false, nil, fmt.Errorf("auto name answer rule cannot be combined with explicit name anwer rules")
}
return
}
// hasClosingDot returns true if s has a closing dot at the end.
func hasClosingDot(s string) bool {
return strings.HasSuffix(s, ".")
}
// getSubExprUsage returns the number of subexpressions used in s.
func getSubExprUsage(s string) int {
subExprUsage := 0
for i := range 101 {
if strings.Contains(s, "{"+strconv.Itoa(i)+"}") {
subExprUsage++
}
}
return subExprUsage
}
// isValidRegexPattern returns a regular expression for pattern matching or errors, if any.
func isValidRegexPattern(rewriteFrom, rewriteTo string) (*regexp.Regexp, error) {
rewriteFromPattern, err := regexp.Compile(rewriteFrom)
if err != nil {
return nil, fmt.Errorf("invalid regex matching pattern: %s", rewriteFrom)
}
if getSubExprUsage(rewriteTo) > rewriteFromPattern.NumSubexp() {
return nil, fmt.Errorf("the rewrite regex pattern (%s) uses more subexpressions than its corresponding matching regex pattern (%s)", rewriteTo, rewriteFrom)
}
return rewriteFromPattern, nil
}
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type rcodeResponseRule struct {
old int
new int
}
func (r *rcodeResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
if r.old == res.Rcode {
res.Rcode = r.new
}
}
type rcodeRuleBase struct {
nextAction string
response rcodeResponseRule
}
func newRCodeRuleBase(nextAction string, old, new int) rcodeRuleBase {
return rcodeRuleBase{
nextAction: nextAction,
response: rcodeResponseRule{old: old, new: new},
}
}
func (rule *rcodeRuleBase) responseRule(match bool) (ResponseRules, Result) {
if match {
return ResponseRules{&rule.response}, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing nextAction
func (rule *rcodeRuleBase) Mode() string { return rule.nextAction }
type exactRCodeRule struct {
rcodeRuleBase
From string
}
type prefixRCodeRule struct {
rcodeRuleBase
Prefix string
}
type suffixRCodeRule struct {
rcodeRuleBase
Suffix string
}
type substringRCodeRule struct {
rcodeRuleBase
Substring string
}
type regexRCodeRule struct {
rcodeRuleBase
Pattern *regexp.Regexp
}
// Rewrite rewrites the current request based upon exact match of the name
// in the question section of the request.
func (rule *exactRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(rule.From == state.Name())
}
// Rewrite rewrites the current request when the name begins with the matching string.
func (rule *prefixRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.HasPrefix(state.Name(), rule.Prefix))
}
// Rewrite rewrites the current request when the name ends with the matching string.
func (rule *suffixRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.HasSuffix(state.Name(), rule.Suffix))
}
// Rewrite rewrites the current request based upon partial match of the
// name in the question section of the request.
func (rule *substringRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.Contains(state.Name(), rule.Substring))
}
// Rewrite rewrites the current request when the name in the question
// section of the request matches a regular expression.
func (rule *regexRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(len(rule.Pattern.FindStringSubmatch(state.Name())) != 0)
}
// newRCodeRule creates a name matching rule based on exact, partial, or regex match
func newRCodeRule(nextAction string, args ...string) (Rule, error) {
if len(args) < 3 {
return nil, fmt.Errorf("too few (%d) arguments for a rcode rule", len(args))
}
var oldStr, newStr string
if len(args) == 3 {
oldStr, newStr = args[1], args[2]
}
if len(args) == 4 {
oldStr, newStr = args[2], args[3]
}
old, valid := isValidRCode(oldStr)
if !valid {
return nil, fmt.Errorf("invalid matching RCODE '%s' for a rcode rule", oldStr)
}
new, valid := isValidRCode(newStr)
if !valid {
return nil, fmt.Errorf("invalid replacement RCODE '%s' for a rcode rule", newStr)
}
if len(args) == 4 {
switch strings.ToLower(args[0]) {
case ExactMatch:
return &exactRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[1]).Normalize(),
}, nil
case PrefixMatch:
return &prefixRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[1]).Normalize(),
}, nil
case SuffixMatch:
return &suffixRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[1]).Normalize(),
}, nil
case SubstringMatch:
return &substringRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[1]).Normalize(),
}, nil
case RegexMatch:
regexPattern, err := regexp.Compile(args[1])
if err != nil {
return nil, fmt.Errorf("invalid regex pattern in a rcode rule: %s", args[1])
}
return ®exRCodeRule{
newRCodeRuleBase(nextAction, old, new),
regexPattern,
}, nil
default:
return nil, fmt.Errorf("rcode rule supports only exact, prefix, suffix, substring, and regex name matching")
}
}
if len(args) > 4 {
return nil, fmt.Errorf("many few arguments for a rcode rule")
}
return &exactRCodeRule{
newRCodeRuleBase(nextAction, old, new),
plugin.Name(args[0]).Normalize(),
}, nil
}
// validRCode returns true if v is valid RCode value.
func isValidRCode(v string) (int, bool) {
i, err := strconv.ParseUint(v, 10, 32)
// try parsing integer based rcode
if err == nil && i <= 23 {
return int(i), true
}
if RCodeInt, ok := dns.StringToRcode[strings.ToUpper(v)]; ok {
return RCodeInt, true
}
return 0, false
}
package rewrite
import (
"github.com/miekg/dns"
)
// RevertPolicy controls the overall reverting process
type RevertPolicy interface {
DoRevert() bool
DoQuestionRestore() bool
}
type revertPolicy struct {
noRevert bool
noRestore bool
}
func (p revertPolicy) DoRevert() bool {
return !p.noRevert
}
func (p revertPolicy) DoQuestionRestore() bool {
return !p.noRestore
}
// NoRevertPolicy disables all response rewrite rules
func NoRevertPolicy() RevertPolicy {
return revertPolicy{true, false}
}
// NoRestorePolicy disables the question restoration during the response rewrite
func NoRestorePolicy() RevertPolicy {
return revertPolicy{false, true}
}
// NewRevertPolicy creates a new reverter policy by dynamically specifying all
// options.
func NewRevertPolicy(noRevert, noRestore bool) RevertPolicy {
return revertPolicy{noRestore: noRestore, noRevert: noRevert}
}
// ResponseRule contains a rule to rewrite a response with.
type ResponseRule interface {
RewriteResponse(res *dns.Msg, rr dns.RR)
}
// ResponseRules describes an ordered list of response rules to apply
// after a name rewrite
type ResponseRules = []ResponseRule
// ResponseReverter reverses the operations done on the question section of a packet.
// This is need because the client will otherwise disregards the response, i.e.
// dig will complain with ';; Question section mismatch: got example.org/HINFO/IN'
type ResponseReverter struct {
dns.ResponseWriter
originalQuestion dns.Question
ResponseRules ResponseRules
revertPolicy RevertPolicy
}
// NewResponseReverter returns a pointer to a new ResponseReverter.
func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg, policy RevertPolicy) *ResponseReverter {
return &ResponseReverter{
ResponseWriter: w,
originalQuestion: r.Question[0],
revertPolicy: policy,
}
}
// WriteMsg records the status code and calls the underlying ResponseWriter's WriteMsg method.
func (r *ResponseReverter) WriteMsg(res1 *dns.Msg) error {
// Deep copy 'res' as to not (e.g). rewrite a message that's also stored in the cache.
res := res1.Copy()
if r.revertPolicy.DoQuestionRestore() {
res.Question[0] = r.originalQuestion
}
if len(r.ResponseRules) > 0 {
for _, rr := range res.Ns {
r.rewriteResourceRecord(res, rr)
}
for _, rr := range res.Answer {
r.rewriteResourceRecord(res, rr)
}
for _, rr := range res.Extra {
r.rewriteResourceRecord(res, rr)
}
}
return r.ResponseWriter.WriteMsg(res)
}
func (r *ResponseReverter) rewriteResourceRecord(res *dns.Msg, rr dns.RR) {
// The reverting rules need to be done in reversed order.
for i := len(r.ResponseRules) - 1; i >= 0; i-- {
r.ResponseRules[i].RewriteResponse(res, rr)
}
}
// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseReverter) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
return n, err
}
func getRecordValueForRewrite(rr dns.RR) (name string) {
switch rr.Header().Rrtype {
case dns.TypeSRV:
return rr.(*dns.SRV).Target
case dns.TypeMX:
return rr.(*dns.MX).Mx
case dns.TypeCNAME:
return rr.(*dns.CNAME).Target
case dns.TypeNS:
return rr.(*dns.NS).Ns
case dns.TypeDNAME:
return rr.(*dns.DNAME).Target
case dns.TypeNAPTR:
return rr.(*dns.NAPTR).Replacement
case dns.TypeSOA:
return rr.(*dns.SOA).Ns
case dns.TypePTR:
return rr.(*dns.PTR).Ptr
default:
return ""
}
}
func setRewrittenRecordValue(rr dns.RR, value string) {
switch rr.Header().Rrtype {
case dns.TypeSRV:
rr.(*dns.SRV).Target = value
case dns.TypeMX:
rr.(*dns.MX).Mx = value
case dns.TypeCNAME:
rr.(*dns.CNAME).Target = value
case dns.TypeNS:
rr.(*dns.NS).Ns = value
case dns.TypeDNAME:
rr.(*dns.DNAME).Target = value
case dns.TypeNAPTR:
rr.(*dns.NAPTR).Replacement = value
case dns.TypeSOA:
rr.(*dns.SOA).Ns = value
case dns.TypePTR:
rr.(*dns.PTR).Ptr = value
}
}
package rewrite
import (
"context"
"fmt"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Result is the result of a rewrite
type Result int
const (
// RewriteIgnored is returned when rewrite is not done on request.
RewriteIgnored Result = iota
// RewriteDone is returned when rewrite is done on request.
RewriteDone
)
// These are defined processing mode.
const (
// Processing should stop after completing this rule
Stop = "stop"
// Processing should continue to next rule
Continue = "continue"
)
// Rewrite is a plugin to rewrite requests internally before being handled.
type Rewrite struct {
Next plugin.Handler
Rules []Rule
RevertPolicy
}
// ServeDNS implements the plugin.Handler interface.
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
if rw.RevertPolicy == nil {
rw.RevertPolicy = NewRevertPolicy(false, false)
}
wr := NewResponseReverter(w, r, rw.RevertPolicy)
state := request.Request{W: w, Req: r}
for _, rule := range rw.Rules {
respRules, result := rule.Rewrite(ctx, state)
if result == RewriteDone {
if _, ok := dns.IsDomainName(state.Req.Question[0].Name); !ok {
err := fmt.Errorf("invalid name after rewrite: %s", state.Req.Question[0].Name)
state.Req.Question[0] = wr.originalQuestion
return dns.RcodeServerFailure, err
}
wr.ResponseRules = append(wr.ResponseRules, respRules...)
if rule.Mode() == Stop {
if !rw.DoRevert() {
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
}
rcode, err := plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
if plugin.ClientWrite(rcode) {
return rcode, err
}
// The next plugins didn't write a response, so write one now with the ResponseReverter.
// If server.ServeDNS does this then it will create an answer mismatch.
res := new(dns.Msg).SetRcode(r, rcode)
state.SizeAndDo(res)
wr.WriteMsg(res)
// return success, so server does not write a second error response to client
return dns.RcodeSuccess, err
}
}
}
if !rw.DoRevert() || len(wr.ResponseRules) == 0 {
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
}
return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
}
// Name implements the Handler interface.
func (rw Rewrite) Name() string { return "rewrite" }
// Rule describes a rewrite rule.
type Rule interface {
// Rewrite rewrites the current request.
Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result)
// Mode returns the processing mode stop or continue.
Mode() string
}
func newRule(args ...string) (Rule, error) {
if len(args) == 0 {
return nil, fmt.Errorf("no rule type specified for rewrite")
}
arg0 := strings.ToLower(args[0])
var ruleType string
var expectNumArgs, startArg int
mode := Stop
switch arg0 {
case Continue:
mode = Continue
if len(args) < 2 {
return nil, fmt.Errorf("continue rule must begin with a rule type")
}
ruleType = strings.ToLower(args[1])
expectNumArgs = len(args) - 1
startArg = 2
case Stop:
if len(args) < 2 {
return nil, fmt.Errorf("stop rule must begin with a rule type")
}
ruleType = strings.ToLower(args[1])
expectNumArgs = len(args) - 1
startArg = 2
default:
// for backward compatibility
ruleType = arg0
expectNumArgs = len(args)
startArg = 1
}
switch ruleType {
case "answer":
return nil, fmt.Errorf("response rewrites must begin with a name rule")
case "name":
return newNameRule(mode, args[startArg:]...)
case "class":
if expectNumArgs != 3 {
return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
}
return newClassRule(mode, args[startArg:]...)
case "type":
if expectNumArgs != 3 {
return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
}
return newTypeRule(mode, args[startArg:]...)
case "edns0":
return newEdns0Rule(mode, args[startArg:]...)
case "ttl":
return newTTLRule(mode, args[startArg:]...)
case "cname":
return newCNAMERule(mode, args[startArg:]...)
case "rcode":
return newRCodeRule(mode, args[startArg:]...)
default:
return nil, fmt.Errorf("invalid rule type %q", args[0])
}
}
package rewrite
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("rewrite", setup) }
func setup(c *caddy.Controller) error {
rewrites, err := rewriteParse(c)
if err != nil {
return plugin.Error("rewrite", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Rewrite{Next: next, Rules: rewrites}
})
return nil
}
func rewriteParse(c *caddy.Controller) ([]Rule, error) {
var rules []Rule
for c.Next() {
args := c.RemainingArgs()
if len(args) < 2 {
// Handles rules out of nested instructions, i.e. the ones enclosed in curly brackets
for c.NextBlock() {
args = append(args, c.Val())
}
}
rule, err := newRule(args...)
if err != nil {
return nil, err
}
rules = append(rules, rule)
}
return rules, nil
}
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
type ttlResponseRule struct {
minTTL uint32
maxTTL uint32
}
func (r *ttlResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
if rr.Header().Ttl < r.minTTL {
rr.Header().Ttl = r.minTTL
} else if rr.Header().Ttl > r.maxTTL {
rr.Header().Ttl = r.maxTTL
}
}
type ttlRuleBase struct {
nextAction string
response ttlResponseRule
}
func newTTLRuleBase(nextAction string, minTtl, maxTtl uint32) ttlRuleBase {
return ttlRuleBase{
nextAction: nextAction,
response: ttlResponseRule{minTTL: minTtl, maxTTL: maxTtl},
}
}
func (rule *ttlRuleBase) responseRule(match bool) (ResponseRules, Result) {
if match {
return ResponseRules{&rule.response}, RewriteDone
}
return nil, RewriteIgnored
}
// Mode returns the processing nextAction
func (rule *ttlRuleBase) Mode() string { return rule.nextAction }
type exactTTLRule struct {
ttlRuleBase
From string
}
type prefixTTLRule struct {
ttlRuleBase
Prefix string
}
type suffixTTLRule struct {
ttlRuleBase
Suffix string
}
type substringTTLRule struct {
ttlRuleBase
Substring string
}
type regexTTLRule struct {
ttlRuleBase
Pattern *regexp.Regexp
}
// Rewrite rewrites the current request based upon exact match of the name
// in the question section of the request.
func (rule *exactTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(rule.From == state.Name())
}
// Rewrite rewrites the current request when the name begins with the matching string.
func (rule *prefixTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.HasPrefix(state.Name(), rule.Prefix))
}
// Rewrite rewrites the current request when the name ends with the matching string.
func (rule *suffixTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.HasSuffix(state.Name(), rule.Suffix))
}
// Rewrite rewrites the current request based upon partial match of the
// name in the question section of the request.
func (rule *substringTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(strings.Contains(state.Name(), rule.Substring))
}
// Rewrite rewrites the current request when the name in the question
// section of the request matches a regular expression.
func (rule *regexTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
return rule.responseRule(len(rule.Pattern.FindStringSubmatch(state.Name())) != 0)
}
// newTTLRule creates a name matching rule based on exact, partial, or regex match
func newTTLRule(nextAction string, args ...string) (Rule, error) {
if len(args) < 2 {
return nil, fmt.Errorf("too few (%d) arguments for a ttl rule", len(args))
}
var s string
if len(args) == 2 {
s = args[1]
}
if len(args) == 3 {
s = args[2]
}
minTtl, maxTtl, valid := isValidTTL(s)
if !valid {
return nil, fmt.Errorf("invalid TTL '%s' for a ttl rule", s)
}
if len(args) == 3 {
switch strings.ToLower(args[0]) {
case ExactMatch:
return &exactTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case PrefixMatch:
return &prefixTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case SuffixMatch:
return &suffixTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case SubstringMatch:
return &substringTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[1]).Normalize(),
}, nil
case RegexMatch:
regexPattern, err := regexp.Compile(args[1])
if err != nil {
return nil, fmt.Errorf("invalid regex pattern in a ttl rule: %s", args[1])
}
return ®exTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
regexPattern,
}, nil
default:
return nil, fmt.Errorf("ttl rule supports only exact, prefix, suffix, substring, and regex name matching")
}
}
if len(args) > 3 {
return nil, fmt.Errorf("many few arguments for a ttl rule")
}
return &exactTTLRule{
newTTLRuleBase(nextAction, minTtl, maxTtl),
plugin.Name(args[0]).Normalize(),
}, nil
}
// validTTL returns true if v is valid TTL value.
func isValidTTL(v string) (uint32, uint32, bool) {
s := strings.Split(v, "-")
if len(s) == 1 {
i, err := strconv.ParseUint(s[0], 10, 32)
if err != nil {
return 0, 0, false
}
return uint32(i), uint32(i), true
}
if len(s) == 2 {
var min, max uint64
var err error
if s[0] == "" {
min = 0
} else {
min, err = strconv.ParseUint(s[0], 10, 32)
if err != nil {
return 0, 0, false
}
}
if s[1] == "" {
if s[0] == "" {
// explicitly reject ttl directive "-" that would otherwise be interpreted
// as 0-2147483647 which is pretty useless
return 0, 0, false
}
max = 2147483647
} else {
max, err = strconv.ParseUint(s[1], 10, 32)
if err != nil {
return 0, 0, false
}
}
if min > max {
// reject invalid range
return 0, 0, false
}
return uint32(min), uint32(max), true
}
return 0, 0, false
}
// Package rewrite is a plugin for rewriting requests internally to something different.
package rewrite
import (
"context"
"fmt"
"strings"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// typeRule is a type rewrite rule.
type typeRule struct {
fromType uint16
toType uint16
nextAction string
}
func newTypeRule(nextAction string, args ...string) (Rule, error) {
var from, to uint16
var ok bool
if from, ok = dns.StringToType[strings.ToUpper(args[0])]; !ok {
return nil, fmt.Errorf("invalid type %q", strings.ToUpper(args[0]))
}
if to, ok = dns.StringToType[strings.ToUpper(args[1])]; !ok {
return nil, fmt.Errorf("invalid type %q", strings.ToUpper(args[1]))
}
return &typeRule{from, to, nextAction}, nil
}
// Rewrite rewrites the current request.
func (rule *typeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
if rule.fromType > 0 && rule.toType > 0 {
if state.QType() == rule.fromType {
state.Req.Question[0].Qtype = rule.toType
return nil, RewriteDone
}
}
return nil, RewriteIgnored
}
// Mode returns the processing mode.
func (rule *typeRule) Mode() string { return rule.nextAction }
package rewrite
import (
"encoding/binary"
"fmt"
"net"
"strconv"
)
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
func ipToWire(family int, ipAddr string) ([]byte, error) {
switch family {
case 1:
return net.ParseIP(ipAddr).To4(), nil
case 2:
return net.ParseIP(ipAddr).To16(), nil
}
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
}
// uint16ToWire writes unit16 to wire/binary format
func uint16ToWire(data uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, data)
return buf
}
// portToWire writes port to wire/binary format, 2 bytes
func portToWire(portStr string) ([]byte, error) {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
return uint16ToWire(uint16(port)), nil
}
package test
import (
"os"
"path/filepath"
"testing"
)
// TempFile will create a temporary file on disk and returns the name and a cleanup function to remove it later.
func TempFile(dir, content string) (string, func(), error) {
f, err := os.CreateTemp(dir, "go-test-tmpfile")
if err != nil {
return "", nil, err
}
if err := os.WriteFile(f.Name(), []byte(content), 0644); err != nil {
return "", nil, err
}
rmFunc := func() { os.Remove(f.Name()) }
return f.Name(), rmFunc, nil
}
// WritePEMFiles creates a tmp dir with ca.pem, cert.pem, and key.pem
func WritePEMFiles(t *testing.T) (string, error) {
t.Helper()
tempDir := t.TempDir()
data := `-----BEGIN CERTIFICATE-----
MIIC9zCCAd+gAwIBAgIJALGtqdMzpDemMA0GCSqGSIb3DQEBCwUAMBIxEDAOBgNV
BAMMB2t1YmUtY2EwHhcNMTYxMDE5MTU1NDI0WhcNNDQwMzA2MTU1NDI0WjASMRAw
DgYDVQQDDAdrdWJlLWNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA
pa4Wu/WkpJNRr8pMVE6jjwzNUOx5mIyoDr8WILSxVQcEeyVPPmAqbmYXtVZO11p9
jTzoEqF7Kgts3HVYGCk5abqbE14a8Ru/DmV5avU2hJ/NvSjtNi/O+V6SzCbg5yR9
lBR53uADDlzuJEQT9RHq7A5KitFkx4vUcXnjOQCbDogWFoYuOgNEwJPy0Raz3NJc
ViVfDqSJ0QHg02kCOMxcGFNRQ9F5aoW7QXZXZXD0tn3wLRlu4+GYyqt8fw5iNdLJ
t79yKp8I+vMTmMPz4YKUO+eCl5EY10Qs7wvoG/8QNbjH01BRN3L8iDT2WfxdvjTu
1RjPxFL92i+B7HZO7jGLfQIDAQABo1AwTjAdBgNVHQ4EFgQUZTrg+Xt87tkxDhlB
gKk9FdTOW3IwHwYDVR0jBBgwFoAUZTrg+Xt87tkxDhlBgKk9FdTOW3IwDAYDVR0T
BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEApB7JFVrZpGSOXNO3W7SlN6OCPXv9
C7rIBc8rwOrzi2mZWcBmWheQrqBo8xHif2rlFNVQxtq3JcQ8kfg/m1fHeQ/Ygzel
Z+U1OqozynDySBZdNn9i+kXXgAUCqDPp3hEQWe0os/RRpIwo9yOloBxdiX6S0NIf
VB8n8kAynFPkH7pYrGrL1HQgDFCSfa4tUJ3+9sppnCu0pNtq5AdhYx9xFb2sn+8G
xGbtCkhVk2VQ+BiCWnjYXJ6ZMzabP7wiOFDP9Pvr2ik22PRItsW/TLfHFXM1jDmc
I1rs/VUGKzcJGVIWbHrgjP68CTStGAvKgbsTqw7aLXTSqtPw88N9XVSyRg==
-----END CERTIFICATE-----`
path := filepath.Join(tempDir, "ca.pem")
if err := os.WriteFile(path, []byte(data), 0644); err != nil {
return "", err
}
data = `-----BEGIN CERTIFICATE-----
MIICozCCAYsCCQCRlf5BrvPuqjANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDDAdr
dWJlLWNhMB4XDTE2MTAxOTE2MDUxOFoXDTE3MTAxOTE2MDUxOFowFTETMBEGA1UE
AwwKa3ViZS1hZG1pbjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMTw
a7wCFoiCad/N53aURfjrme+KR7FS0yf5Ur9OR/oM3BoS9stYu5Flzr35oL5T6t5G
c2ey78mUs/Cs07psnjUdKH55bDpJSdG7zW9mXNyeLwIefFcj/38SS5NBSotmLo8u
scJMGXeQpCQtfVuVJSP2bfU5u5d0KTLSg/Cor6UYonqrRB82HbOuuk8Wjaww4VHo
nCq7X8o948V6HN5ZibQOgMMo+nf0wORREHBjvwc4W7ewbaTcfoe1VNAo/QnkqxTF
ueMb2HxgghArqQSK8b44O05V0zrde25dVnmnte6sPjcV0plqMJ37jViISxsOPUFh
/ZW7zbIM/7CMcDekCiECAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAYZE8OxwRR7GR
kdd5aIriDwWfcl56cq5ICyx87U8hAZhBxk46a6a901LZPzt3xKyWIFQSRj/NYiQ+
/thjGLZI2lhkVgYtyAD4BNxDiuppQSCbkjY9tLVDdExGttEVN7+UYDWJBHy6X16Y
xSG9FE3Dvp9LI89Nq8E3dRh+Q8wu52q9HaQXjS5YtzQOtDFKPBkihXu/c6gEHj4Y
bZVk8rFiH8/CvcQxAuvNI3VVCFUKd2LeQtqwYQQ//qoiuA15krTq5Ut9eXJ8zxAw
zhDEPP4FhY+Sz+y1yWirphl7A1aZwhXVPcfWIGqpQ3jzNwUeocbH27kuLh+U4hQo
qeg10RdFnw==
-----END CERTIFICATE-----`
path = filepath.Join(tempDir, "cert.pem")
if err := os.WriteFile(path, []byte(data), 0644); err != nil {
return "", err
}
data = `-----BEGIN RSA PRIVATE KEY-----
MIIEpgIBAAKCAQEAxPBrvAIWiIJp383ndpRF+OuZ74pHsVLTJ/lSv05H+gzcGhL2
y1i7kWXOvfmgvlPq3kZzZ7LvyZSz8KzTumyeNR0ofnlsOklJ0bvNb2Zc3J4vAh58
VyP/fxJLk0FKi2Yujy6xwkwZd5CkJC19W5UlI/Zt9Tm7l3QpMtKD8KivpRiieqtE
HzYds666TxaNrDDhUeicKrtfyj3jxXoc3lmJtA6Awyj6d/TA5FEQcGO/Bzhbt7Bt
pNx+h7VU0Cj9CeSrFMW54xvYfGCCECupBIrxvjg7TlXTOt17bl1Weae17qw+NxXS
mWownfuNWIhLGw49QWH9lbvNsgz/sIxwN6QKIQIDAQABAoIBAQDCXq9V7ZGjxWMN
OkFaLVkqJg3V91puztoMt+xNV8t+JTcOnOzrIXZuOFbl9PwLHPPP0SSRkm9LOvKl
dU26zv0OWureeKSymia7U2mcqyC3tX+bzc7WinbeSYZBnc0e7AjD1EgpBcaU1TLL
agIxY3A2oD9CKmrVPhZzTIZf/XztqTYjhvs5I2kBeT0imdYGpXkdndRyGX4I5/JQ
fnp3Czj+AW3zX7RvVnXOh4OtIAcfoG9xoNyD5LOSlJkkX0MwTS8pEBeZA+A4nb+C
ivjnOSgXWD+liisI+LpBgBbwYZ/E49x5ghZYrJt8QXSk7Bl/+UOyv6XZAm2mev6j
RLAZtoABAoGBAP2P+1PoKOwsk+d/AmHqyTCUQm0UG18LOLB/5PyWfXs/6caDmdIe
DZWeZWng1jUQLEadmoEw/CBY5+tPfHlzwzMNhT7KwUfIDQCIBoS7dzHYnwrJ3VZh
qYA05cuGHAAHqwb6UWz3y6Pa4AEVSHX6CM83CAi9jdWZ1rdZybWG+qYBAoGBAMbV
FsR/Ft+tK5ALgXGoG83TlmxzZYuZ1SnNje1OSdCQdMFCJB10gwoaRrw1ICzi40Xk
ydJwV1upGz1om9ReDAD1zQM9artmQx6+TVLiVPALuARdZE70+NrA6w3ZvxUgJjdN
ngvXUr+8SdvaYUAwFu7BulfJlwXjUS711hHW/KQhAoGBALY41QuV2mLwHlLNie7I
hlGtGpe9TXZeYB0nrG6B0CfU5LJPPSotguG1dXhDpm138/nDpZeWlnrAqdsHwpKd
yPhVjR51I7XsZLuvBdA50Q03egSM0c4UXXXPjh1XgaPb3uMi3YWMBwL4ducQXoS6
bb5M9C8j2lxZNF+L3VPhbxwBAoGBAIEWDvX7XKpTDxkxnxRfA84ZNGusb5y2fsHp
Bd+vGBUj8+kUO8Yzwm9op8vA4ebCVrMl2jGZZd3IaDryE1lIxZpJ+pPD5+tKdQEc
o67P6jz+HrYWu+zW9klvPit71qasfKMi7Rza6oo4f+sQWFsH3ZucgpJD+pyD/Ez0
pcpnPRaBAoGBANT/xgHBfIWt4U2rtmRLIIiZxKr+3mGnQdpA1J2BCh+/6AvrEx//
E/WObVJXDnBdViu0L9abE9iaTToBVri4cmlDlZagLuKVR+TFTCN/DSlVZTDkqkLI
8chzqtkH6b2b2R73hyRysWjsomys34ma3mEEPTX/aXeAF2MSZ/EWT9yL
-----END RSA PRIVATE KEY-----`
path = filepath.Join(tempDir, "key.pem")
if err := os.WriteFile(path, []byte(data), 0644); err != nil {
return "", err
}
return tempDir, nil
}
package test
import (
"context"
"fmt"
"sort"
"strings"
"github.com/miekg/dns"
)
type sect int
const (
// Answer is the answer section in an Msg.
Answer sect = iota
// Ns is the authoritative section in an Msg.
Ns
// Extra is the additional section in an Msg.
Extra
)
// RRSet represents a list of RRs.
type RRSet []dns.RR
func (p RRSet) Len() int { return len(p) }
func (p RRSet) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p RRSet) Less(i, j int) bool { return p[i].String() < p[j].String() }
// Case represents a test case that encapsulates various data from a query and response.
// Note that is the TTL of a record is 303 we don't compare it with the TTL.
type Case struct {
Qname string
Qtype uint16
Rcode int
Do bool
CheckingDisabled bool
RecursionAvailable bool
AuthenticatedData bool
Authoritative bool
Truncated bool
Answer []dns.RR
Ns []dns.RR
Extra []dns.RR
Error error
}
// Msg returns a *dns.Msg embedded in c.
func (c Case) Msg() *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(c.Qname), c.Qtype)
if c.Do {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetDo()
o.SetUDPSize(4096)
m.Extra = []dns.RR{o}
}
return m
}
// A returns an A record from rr. It panics on errors.
func A(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
// AAAA returns an AAAA record from rr. It panics on errors.
func AAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
// CNAME returns a CNAME record from rr. It panics on errors.
func CNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
// DNAME returns a DNAME record from rr. It panics on errors.
func DNAME(rr string) *dns.DNAME { r, _ := dns.NewRR(rr); return r.(*dns.DNAME) }
// SRV returns a SRV record from rr. It panics on errors.
func SRV(rr string) *dns.SRV { r, _ := dns.NewRR(rr); return r.(*dns.SRV) }
// SOA returns a SOA record from rr. It panics on errors.
func SOA(rr string) *dns.SOA { r, _ := dns.NewRR(rr); return r.(*dns.SOA) }
// NS returns an NS record from rr. It panics on errors.
func NS(rr string) *dns.NS { r, _ := dns.NewRR(rr); return r.(*dns.NS) }
// PTR returns a PTR record from rr. It panics on errors.
func PTR(rr string) *dns.PTR { r, _ := dns.NewRR(rr); return r.(*dns.PTR) }
// TXT returns a TXT record from rr. It panics on errors.
func TXT(rr string) *dns.TXT { r, _ := dns.NewRR(rr); return r.(*dns.TXT) }
// CAA returns a CAA record from rr. It panics on errors.
func CAA(rr string) *dns.CAA { r, _ := dns.NewRR(rr); return r.(*dns.CAA) }
// HINFO returns a HINFO record from rr. It panics on errors.
func HINFO(rr string) *dns.HINFO { r, _ := dns.NewRR(rr); return r.(*dns.HINFO) }
// MX returns an MX record from rr. It panics on errors.
func MX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) }
// RRSIG returns an RRSIG record from rr. It panics on errors.
func RRSIG(rr string) *dns.RRSIG { r, _ := dns.NewRR(rr); return r.(*dns.RRSIG) }
// NSEC returns an NSEC record from rr. It panics on errors.
func NSEC(rr string) *dns.NSEC { r, _ := dns.NewRR(rr); return r.(*dns.NSEC) }
// DNSKEY returns a DNSKEY record from rr. It panics on errors.
func DNSKEY(rr string) *dns.DNSKEY { r, _ := dns.NewRR(rr); return r.(*dns.DNSKEY) }
// DS returns a DS record from rr. It panics on errors.
func DS(rr string) *dns.DS { r, _ := dns.NewRR(rr); return r.(*dns.DS) }
// NAPTR returns a NAPTR record from rr. It panics on errors.
func NAPTR(rr string) *dns.NAPTR { r, _ := dns.NewRR(rr); return r.(*dns.NAPTR) }
// OPT returns an OPT record with UDP buffer size set to bufsize and the DO bit set to do.
func OPT(bufsize int, do bool) *dns.OPT {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetVersion(0)
o.SetUDPSize(uint16(bufsize))
if do {
o.SetDo()
}
return o
}
// Header tests if the header in resp matches the header as defined in tc.
func Header(tc Case, resp *dns.Msg) error {
if resp.Rcode != tc.Rcode {
return fmt.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode])
}
if len(resp.Answer) != len(tc.Answer) {
return fmt.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer))
}
if len(resp.Ns) != len(tc.Ns) {
return fmt.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns))
}
if len(resp.Extra) != len(tc.Extra) {
return fmt.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra))
}
return nil
}
// Section tests if the section in tc matches rr.
func Section(tc Case, sec sect, rr []dns.RR) error {
section := []dns.RR{}
switch sec {
case 0:
section = tc.Answer
case 1:
section = tc.Ns
case 2:
section = tc.Extra
}
for i, a := range rr {
if a.Header().Name != section[i].Header().Name {
return fmt.Errorf("RR %d should have a Header Name of %q, but has %q", i, section[i].Header().Name, a.Header().Name)
}
// 303 signals: don't care what the ttl is.
if section[i].Header().Ttl != 303 && a.Header().Ttl != section[i].Header().Ttl {
if _, ok := section[i].(*dns.OPT); !ok {
// we check edns0 bufize on this one
return fmt.Errorf("RR %d should have a Header TTL of %d, but has %d", i, section[i].Header().Ttl, a.Header().Ttl)
}
}
if a.Header().Rrtype != section[i].Header().Rrtype {
return fmt.Errorf("RR %d should have a header rr type of %d, but has %d", i, section[i].Header().Rrtype, a.Header().Rrtype)
}
switch x := a.(type) {
case *dns.SRV:
if x.Priority != section[i].(*dns.SRV).Priority {
return fmt.Errorf("RR %d should have a Priority of %d, but has %d", i, section[i].(*dns.SRV).Priority, x.Priority)
}
if x.Weight != section[i].(*dns.SRV).Weight {
return fmt.Errorf("RR %d should have a Weight of %d, but has %d", i, section[i].(*dns.SRV).Weight, x.Weight)
}
if x.Port != section[i].(*dns.SRV).Port {
return fmt.Errorf("RR %d should have a Port of %d, but has %d", i, section[i].(*dns.SRV).Port, x.Port)
}
if x.Target != section[i].(*dns.SRV).Target {
return fmt.Errorf("RR %d should have a Target of %q, but has %q", i, section[i].(*dns.SRV).Target, x.Target)
}
case *dns.RRSIG:
if x.TypeCovered != section[i].(*dns.RRSIG).TypeCovered {
return fmt.Errorf("RR %d should have a TypeCovered of %d, but has %d", i, section[i].(*dns.RRSIG).TypeCovered, x.TypeCovered)
}
if x.Labels != section[i].(*dns.RRSIG).Labels {
return fmt.Errorf("RR %d should have a Labels of %d, but has %d", i, section[i].(*dns.RRSIG).Labels, x.Labels)
}
if x.SignerName != section[i].(*dns.RRSIG).SignerName {
return fmt.Errorf("RR %d should have a SignerName of %s, but has %s", i, section[i].(*dns.RRSIG).SignerName, x.SignerName)
}
case *dns.NSEC:
if x.NextDomain != section[i].(*dns.NSEC).NextDomain {
return fmt.Errorf("RR %d should have a NextDomain of %s, but has %s", i, section[i].(*dns.NSEC).NextDomain, x.NextDomain)
}
// TypeBitMap
case *dns.A:
if x.A.String() != section[i].(*dns.A).A.String() {
return fmt.Errorf("RR %d should have a Address of %q, but has %q", i, section[i].(*dns.A).A.String(), x.A.String())
}
case *dns.AAAA:
if x.AAAA.String() != section[i].(*dns.AAAA).AAAA.String() {
return fmt.Errorf("RR %d should have a Address of %q, but has %q", i, section[i].(*dns.AAAA).AAAA.String(), x.AAAA.String())
}
case *dns.TXT:
actualTxt := strings.Join(x.Txt, "")
expectedTxt := strings.Join(section[i].(*dns.TXT).Txt, "")
if actualTxt != expectedTxt {
return fmt.Errorf("RR %d should have a TXT value of %q, but has %q", i, expectedTxt, actualTxt)
}
case *dns.HINFO:
if x.Cpu != section[i].(*dns.HINFO).Cpu {
return fmt.Errorf("RR %d should have a Cpu of %s, but has %s", i, section[i].(*dns.HINFO).Cpu, x.Cpu)
}
if x.Os != section[i].(*dns.HINFO).Os {
return fmt.Errorf("RR %d should have a Os of %s, but has %s", i, section[i].(*dns.HINFO).Os, x.Os)
}
case *dns.SOA:
tt := section[i].(*dns.SOA)
if x.Ns != tt.Ns {
return fmt.Errorf("SOA nameserver should be %q, but is %q", tt.Ns, x.Ns)
}
case *dns.PTR:
tt := section[i].(*dns.PTR)
if x.Ptr != tt.Ptr {
return fmt.Errorf("PTR ptr should be %q, but is %q", tt.Ptr, x.Ptr)
}
case *dns.CNAME:
tt := section[i].(*dns.CNAME)
if x.Target != tt.Target {
return fmt.Errorf("CNAME target should be %q, but is %q", tt.Target, x.Target)
}
case *dns.MX:
tt := section[i].(*dns.MX)
if x.Mx != tt.Mx {
return fmt.Errorf("MX Mx should be %q, but is %q", tt.Mx, x.Mx)
}
if x.Preference != tt.Preference {
return fmt.Errorf("MX Preference should be %q, but is %q", tt.Preference, x.Preference)
}
case *dns.NS:
tt := section[i].(*dns.NS)
if x.Ns != tt.Ns {
return fmt.Errorf("NS nameserver should be %q, but is %q", tt.Ns, x.Ns)
}
case *dns.OPT:
tt := section[i].(*dns.OPT)
if x.UDPSize() != tt.UDPSize() {
return fmt.Errorf("OPT UDPSize should be %d, but is %d", tt.UDPSize(), x.UDPSize())
}
if x.Do() != tt.Do() {
return fmt.Errorf("OPT DO should be %t, but is %t", tt.Do(), x.Do())
}
}
}
return nil
}
// CNAMEOrder makes sure that CNAMES do not appear after their target records.
func CNAMEOrder(res *dns.Msg) error {
for i, c := range res.Answer {
if c.Header().Rrtype != dns.TypeCNAME {
continue
}
for _, a := range res.Answer[:i] {
if a.Header().Name != c.(*dns.CNAME).Target {
continue
}
return fmt.Errorf("CNAME found after target record")
}
}
return nil
}
// SortAndCheck sorts resp and the checks the header and three sections against the testcase in tc.
func SortAndCheck(resp *dns.Msg, tc Case) error {
sort.Sort(RRSet(resp.Answer))
sort.Sort(RRSet(resp.Ns))
sort.Sort(RRSet(resp.Extra))
if err := Header(tc, resp); err != nil {
return err
}
if err := Section(tc, Answer, resp.Answer); err != nil {
return err
}
if err := Section(tc, Ns, resp.Ns); err != nil {
return err
}
return Section(tc, Extra, resp.Extra)
}
// ErrorHandler returns a Handler that returns ServerFailure error when called.
func ErrorHandler() Handler {
return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(m)
return dns.RcodeServerFailure, nil
})
}
// NextHandler returns a Handler that returns rcode and err.
func NextHandler(rcode int, err error) Handler {
return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return rcode, err
})
}
// Copied here to prevent an import cycle, so that we can define to above handlers.
type (
// HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error.
HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
// Handler interface defines a plugin.
Handler interface {
ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
Name() string
}
)
// ServeDNS implements the Handler interface.
func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(ctx, w, r)
}
// Name implements the Handler interface.
func (f HandlerFunc) Name() string { return "handlerfunc" }
package test
import (
"net"
"github.com/miekg/dns"
)
// ResponseWriter is useful for writing tests. It uses some fixed values for the client. The
// remote will always be 10.240.0.1 and port 40212. The local address is always 127.0.0.1 and
// port 53.
type ResponseWriter struct {
TCP bool // if TCP is true we return an TCP connection instead of an UDP one.
RemoteIP string
Zone string
}
// LocalAddr returns the local address, 127.0.0.1:53 (UDP, TCP if t.TCP is true).
func (t *ResponseWriter) LocalAddr() net.Addr {
ip := net.ParseIP("127.0.0.1")
port := 53
if t.TCP {
return &net.TCPAddr{IP: ip, Port: port, Zone: ""}
}
return &net.UDPAddr{IP: ip, Port: port, Zone: ""}
}
// RemoteAddr returns the remote address, defaults to 10.240.0.1:40212 (UDP, TCP is t.TCP is true).
func (t *ResponseWriter) RemoteAddr() net.Addr {
remoteIP := "10.240.0.1"
if t.RemoteIP != "" {
remoteIP = t.RemoteIP
}
ip := net.ParseIP(remoteIP)
port := 40212
if t.TCP {
return &net.TCPAddr{IP: ip, Port: port, Zone: t.Zone}
}
return &net.UDPAddr{IP: ip, Port: port, Zone: t.Zone}
}
// Network implements dns.ResponseWriter interface.
func (t *ResponseWriter) Network() string { return "" }
// WriteMsg implements dns.ResponseWriter interface.
func (t *ResponseWriter) WriteMsg(m *dns.Msg) error { return nil }
// Write implements dns.ResponseWriter interface.
func (t *ResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil }
// Close implements dns.ResponseWriter interface.
func (t *ResponseWriter) Close() error { return nil }
// TsigStatus implements dns.ResponseWriter interface.
func (t *ResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly implements dns.ResponseWriter interface.
func (t *ResponseWriter) TsigTimersOnly(bool) {}
// Hijack implements dns.ResponseWriter interface.
func (t *ResponseWriter) Hijack() {}
// ResponseWriter6 returns fixed client and remote address in IPv6. The remote
// address is always fe80::42:ff:feca:4c65 and port 40212. The local address is always ::1 and port 53.
type ResponseWriter6 struct {
ResponseWriter
}
// LocalAddr returns the local address, always ::1, port 53 (UDP, TCP is t.TCP is true).
func (t *ResponseWriter6) LocalAddr() net.Addr {
if t.TCP {
return &net.TCPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""}
}
return &net.UDPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""}
}
// RemoteAddr returns the remote address, always fe80::42:ff:feca:4c65 port 40212 (UDP, TCP is t.TCP is true).
func (t *ResponseWriter6) RemoteAddr() net.Addr {
if t.TCP {
return &net.TCPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""}
}
return &net.UDPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""}
}
// Adapted by Miek Gieben for CoreDNS testing.
//
// License from prom2json
// Copyright 2014 Prometheus Team
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package test will scrape a target and you can inspect the variables.
// Basic usage:
//
// result := Scrape("http://localhost:9153/metrics")
// v := MetricValue("coredns_cache_capacity", result)
package test
import (
"fmt"
"io"
"mime"
"net/http"
"strconv"
"github.com/matttproud/golang_protobuf_extensions/pbutil"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
)
type (
// MetricFamily holds a prometheus metric.
MetricFamily struct {
Name string `json:"name"`
Help string `json:"help"`
Type string `json:"type"`
Metrics []interface{} `json:"metrics,omitempty"` // Either metric or summary.
}
// metric is for all "single value" metrics.
metric struct {
Labels map[string]string `json:"labels,omitempty"`
Value string `json:"value"`
}
summary struct {
Labels map[string]string `json:"labels,omitempty"`
Quantiles map[string]string `json:"quantiles,omitempty"`
Count string `json:"count"`
Sum string `json:"sum"`
}
histogram struct {
Labels map[string]string `json:"labels,omitempty"`
Buckets map[string]string `json:"buckets,omitempty"`
Count string `json:"count"`
Sum string `json:"sum"`
}
)
// Scrape returns the all the vars a []*metricFamily.
func Scrape(url string) []*MetricFamily {
mfChan := make(chan *dto.MetricFamily, 1024)
go fetchMetricFamilies(url, mfChan)
result := []*MetricFamily{}
for mf := range mfChan {
result = append(result, newMetricFamily(mf))
}
return result
}
// ScrapeMetricAsInt provides a sum of all metrics collected for the name and label provided.
// if the metric is not a numeric value, it will be counted a 0.
func ScrapeMetricAsInt(addr string, name string, label string, nometricvalue int) int {
valueToInt := func(m metric) int {
v := m.Value
r, err := strconv.Atoi(v)
if err != nil {
return 0
}
return r
}
met := Scrape(fmt.Sprintf("http://%s/metrics", addr))
found := false
tot := 0
for _, mf := range met {
if mf.Name == name {
// Sum all metrics available
for _, m := range mf.Metrics {
if label == "" {
tot += valueToInt(m.(metric))
found = true
continue
}
for _, v := range m.(metric).Labels {
if v == label {
tot += valueToInt(m.(metric))
found = true
}
}
}
}
}
if !found {
return nometricvalue
}
return tot
}
// MetricValue returns the value associated with name as a string as well as the labels.
// It only returns the first metrics of the slice.
func MetricValue(name string, mfs []*MetricFamily) (string, map[string]string) {
for _, mf := range mfs {
if mf.Name == name {
// Only works with Gauge and Counter...
return mf.Metrics[0].(metric).Value, mf.Metrics[0].(metric).Labels
}
}
return "", nil
}
// MetricValueLabel returns the value for name *and* label *value*.
func MetricValueLabel(name, label string, mfs []*MetricFamily) (string, map[string]string) {
// bit hacky is this really handy...?
for _, mf := range mfs {
if mf.Name == name {
for _, m := range mf.Metrics {
for _, v := range m.(metric).Labels {
if v == label {
return m.(metric).Value, m.(metric).Labels
}
}
}
}
}
return "", nil
}
func newMetricFamily(dtoMF *dto.MetricFamily) *MetricFamily {
mf := &MetricFamily{
Name: dtoMF.GetName(),
Help: dtoMF.GetHelp(),
Type: dtoMF.GetType().String(),
Metrics: make([]interface{}, len(dtoMF.GetMetric())),
}
for i, m := range dtoMF.GetMetric() {
if dtoMF.GetType() == dto.MetricType_SUMMARY {
mf.Metrics[i] = summary{
Labels: makeLabels(m),
Quantiles: makeQuantiles(m),
Count: fmt.Sprint(m.GetSummary().GetSampleCount()),
Sum: fmt.Sprint(m.GetSummary().GetSampleSum()),
}
} else if dtoMF.GetType() == dto.MetricType_HISTOGRAM {
mf.Metrics[i] = histogram{
Labels: makeLabels(m),
Buckets: makeBuckets(m),
Count: fmt.Sprint(m.GetHistogram().GetSampleCount()),
Sum: fmt.Sprint(m.GetSummary().GetSampleSum()),
}
} else {
mf.Metrics[i] = metric{
Labels: makeLabels(m),
Value: fmt.Sprint(value(m)),
}
}
}
return mf
}
func value(m *dto.Metric) float64 {
if m.GetGauge() != nil {
return m.GetGauge().GetValue()
}
if m.GetCounter() != nil {
return m.GetCounter().GetValue()
}
if m.GetUntyped() != nil {
return m.GetUntyped().GetValue()
}
return 0.
}
func makeLabels(m *dto.Metric) map[string]string {
result := map[string]string{}
for _, lp := range m.GetLabel() {
result[lp.GetName()] = lp.GetValue()
}
return result
}
func makeQuantiles(m *dto.Metric) map[string]string {
result := map[string]string{}
for _, q := range m.GetSummary().GetQuantile() {
result[fmt.Sprint(q.GetQuantile())] = fmt.Sprint(q.GetValue())
}
return result
}
func makeBuckets(m *dto.Metric) map[string]string {
result := map[string]string{}
for _, b := range m.GetHistogram().GetBucket() {
result[fmt.Sprint(b.GetUpperBound())] = fmt.Sprint(b.GetCumulativeCount())
}
return result
}
func fetchMetricFamilies(url string, ch chan<- *dto.MetricFamily) {
defer close(ch)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return
}
req.Header.Add("Accept", acceptHeader)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return
}
mediatype, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err == nil && mediatype == "application/vnd.google.protobuf" &&
params["encoding"] == "delimited" &&
params["proto"] == "io.prometheus.client.MetricFamily" {
for {
mf := &dto.MetricFamily{}
if _, err = pbutil.ReadDelimited(resp.Body, mf); err != nil {
if err == io.EOF {
break
}
return
}
ch <- mf
}
} else {
// We could do further content-type checks here, but the
// fallback for now will anyway be the text format
// version 0.0.4, so just go for it and see if it works.
var parser expfmt.TextParser
metricFamilies, err := parser.TextToMetricFamilies(resp.Body)
if err != nil {
return
}
for _, mf := range metricFamilies {
ch <- mf
}
}
}
const acceptHeader = `application/vnd.google.protobuf;proto=io.prometheus.client.MetricFamily;encoding=delimited;q=0.7,text/plain;version=0.0.4;q=0.3`
package transfer
import (
"fmt"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/miekg/dns"
)
// Notify will send notifies to all configured to hosts IP addresses. The string zone must be lowercased.
func (t *Transfer) Notify(zone string) error {
if t == nil { // t might be nil, mostly expected in tests, so intercept and to a noop in that case
return nil
}
m := new(dns.Msg)
m.SetNotify(zone)
c := new(dns.Client)
x := longestMatch(t.xfrs, zone)
if x == nil {
// return without error if there is no matching zone
return nil
}
var err1 error
for _, t := range x.to {
if t == "*" {
continue
}
if err := sendNotify(c, m, t); err != nil {
err1 = err
}
}
log.Debugf("Sent notifies for zone %q to %v", zone, x.to)
return err1 // this only captures the last error
}
func sendNotify(c *dns.Client, m *dns.Msg, s string) error {
var err error
code := dns.RcodeServerFailure
for range 3 {
ret, _, err := c.Exchange(m, s)
if err != nil {
continue
}
code = ret.Rcode
if code == dns.RcodeSuccess {
return nil
}
}
if err != nil {
return fmt.Errorf("notify for zone %q was not accepted by %q: %q", m.Question[0].Name, s, err)
}
return fmt.Errorf("notify for zone %q was not accepted by %q: rcode was %q", m.Question[0].Name, s, rcode.ToString(code))
}
package transfer
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/parse"
"github.com/coredns/coredns/plugin/pkg/transport"
)
func init() {
caddy.RegisterPlugin("transfer", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
t, err := parseTransfer(c)
if err != nil {
return plugin.Error("transfer", err)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
t.Next = next
return t
})
c.OnStartup(func() error {
config := dnsserver.GetConfig(c)
t.tsigSecret = config.TsigSecret
// find all plugins that implement Transferer and add them to Transferers
plugins := config.Handlers()
for _, pl := range plugins {
tr, ok := pl.(Transferer)
if !ok {
continue
}
t.Transferers = append(t.Transferers, tr)
}
return nil
})
return nil
}
func parseTransfer(c *caddy.Controller) (*Transfer, error) {
t := &Transfer{}
for c.Next() {
x := &xfr{}
x.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys)
for c.NextBlock() {
switch c.Val() {
case "to":
args := c.RemainingArgs()
if len(args) == 0 {
return nil, c.ArgErr()
}
for _, host := range args {
if host == "*" {
x.to = append(x.to, host)
continue
}
normalized, err := parse.HostPort(host, transport.Port)
if err != nil {
return nil, err
}
x.to = append(x.to, normalized)
}
default:
return nil, plugin.Error("transfer", c.Errf("unknown property %q", c.Val()))
}
}
if len(x.to) == 0 {
return nil, plugin.Error("transfer", c.Err("'to' is required"))
}
t.xfrs = append(t.xfrs, x)
}
return t, nil
}
package transfer
import (
"context"
"errors"
"net"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("transfer")
// Transfer is a plugin that handles zone transfers.
type Transfer struct {
Transferers []Transferer // List of plugins that implement Transferer
xfrs []*xfr
tsigSecret map[string]string
Next plugin.Handler
}
type xfr struct {
Zones []string
to []string
}
// Transferer may be implemented by plugins to enable zone transfers
type Transferer interface {
// Transfer returns a channel to which it writes responses to the transfer request.
// If the plugin is not authoritative for the zone, it should immediately return the
// transfer.ErrNotAuthoritative error. This is important otherwise the transfer plugin can
// use plugin X while it should transfer the data from plugin Y.
//
// If serial is 0, handle as an AXFR request. Transfer should send all records
// in the zone to the channel. The SOA should be written to the channel first, followed
// by all other records, including all NS + glue records. The implementation is also responsible
// for sending the last SOA record (to signal end of the transfer). This plugin will just grab
// these records and send them back to the requester, there is little validation done.
//
// If serial is not 0, it will be handled as an IXFR request. If the serial is equal to or greater (newer) than
// the current serial for the zone, send a single SOA record to the channel and then close it.
// If the serial is less (older) than the current serial for the zone, perform an AXFR fallback
// by proceeding as if an AXFR was requested (as above).
Transfer(zone string, serial uint32) (<-chan []dns.RR, error)
}
var (
// ErrNotAuthoritative is returned by Transfer() when the plugin is not authoritative for the zone.
ErrNotAuthoritative = errors.New("not authoritative for zone")
)
// ServeDNS implements the plugin.Handler interface.
func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if state.QType() != dns.TypeAXFR && state.QType() != dns.TypeIXFR {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
if state.Proto() != "tcp" {
return dns.RcodeRefused, nil
}
x := longestMatch(t.xfrs, state.QName())
if x == nil {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
if !x.allowed(state) {
// write msg here, so logging will pick it up
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(m)
return 0, nil
}
// Get serial from request if this is an IXFR.
var serial uint32
if state.QType() == dns.TypeIXFR {
if len(r.Ns) != 1 {
return dns.RcodeServerFailure, nil
}
soa, ok := r.Ns[0].(*dns.SOA)
if !ok {
return dns.RcodeServerFailure, nil
}
serial = soa.Serial
}
// Get a receiving channel from the first Transferer plugin that returns one.
var pchan <-chan []dns.RR
var err error
for _, p := range t.Transferers {
pchan, err = p.Transfer(state.QName(), serial)
if err == ErrNotAuthoritative {
// plugin was not authoritative for the zone, try next plugin
continue
}
if err != nil {
return dns.RcodeServerFailure, err
}
break
}
if pchan == nil {
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
}
// Send response to client
ch := make(chan *dns.Envelope)
tr := new(dns.Transfer)
if r.IsTsig() != nil {
tr.TsigSecret = t.tsigSecret
}
errCh := make(chan error)
go func() {
if err := tr.Out(w, r, ch); err != nil {
errCh <- err
}
close(errCh)
}()
rrs := []dns.RR{}
l := 0
var soa *dns.SOA
for records := range pchan {
if x, ok := records[0].(*dns.SOA); ok && soa == nil {
soa = x
}
rrs = append(rrs, records...)
if len(rrs) > 500 {
select {
case ch <- &dns.Envelope{RR: rrs}:
case err := <-errCh:
return dns.RcodeServerFailure, err
}
l += len(rrs)
rrs = []dns.RR{}
}
}
// if we are here and we only hold 1 soa (len(rrs) == 1) and soa != nil, and IXFR fallback should
// be performed. We haven't send anything on ch yet, so that can be closed (and waited for), and we only
// need to return the SOA back to the client and return.
if len(rrs) == 1 && soa != nil { // soa should never be nil...
close(ch)
err := <-errCh
if err != nil {
return dns.RcodeServerFailure, err
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = []dns.RR{soa}
w.WriteMsg(m)
log.Infof("Outgoing noop, incremental transfer for up to date zone %q to %s for %d SOA serial", state.QName(), state.IP(), soa.Serial)
return 0, nil
}
if len(rrs) > 0 {
select {
case ch <- &dns.Envelope{RR: rrs}:
case err := <-errCh:
return dns.RcodeServerFailure, err
}
l += len(rrs)
}
close(ch) // Even though we close the channel here, we still have
err = <-errCh // to wait before we can return and close the connection.
if err != nil {
return dns.RcodeServerFailure, err
}
logserial := uint32(0)
if soa != nil {
logserial = soa.Serial
}
log.Infof("Outgoing transfer of %d records of zone %q to %s for %d SOA serial", l, state.QName(), state.IP(), logserial)
return 0, nil
}
func (x xfr) allowed(state request.Request) bool {
for _, h := range x.to {
if h == "*" {
return true
}
to, _, err := net.SplitHostPort(h)
if err != nil {
return false
}
// If remote IP matches we accept. TODO(): make this works with ranges
if to == state.IP() {
return true
}
}
return false
}
// Find the first transfer instance for which the queried zone is the longest match. When nothing
// is found nil is returned.
func longestMatch(xfrs []*xfr, name string) *xfr {
// TODO(xxx): optimize and make it a map (or maps)
var x *xfr
zone := "" // longest zone match wins
for _, xfr := range xfrs {
if z := plugin.Zones(xfr.Zones).Matches(name); z != "" {
if z > zone {
zone = z
x = xfr
}
}
}
return x
}
// Name implements the Handler interface.
func (Transfer) Name() string { return "transfer" }
//go:build gofuzz
package whoami
import (
"github.com/coredns/coredns/plugin/pkg/fuzz"
)
// Fuzz fuzzes cache.
func Fuzz(data []byte) int {
w := Whoami{}
return fuzz.Do(w, data)
}
package whoami
import (
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() { plugin.Register("whoami", setup) }
func setup(c *caddy.Controller) error {
c.Next() // 'whoami'
if c.NextArg() {
return plugin.Error("whoami", c.ArgErr())
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return Whoami{}
})
return nil
}
// Package whoami implements a plugin that returns details about the resolving
// querying it.
package whoami
import (
"context"
"net"
"strconv"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
const name = "whoami"
// Whoami is a plugin that returns your IP address, port and the protocol used for connecting
// to CoreDNS.
type Whoami struct{}
// ServeDNS implements the plugin.Handler interface.
func (wh Whoami) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
a := new(dns.Msg)
a.SetReply(r)
a.Authoritative = true
ip := state.IP()
var rr dns.RR
switch state.Family() {
case 1:
rr = new(dns.A)
rr.(*dns.A).Hdr = dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeA, Class: state.QClass()}
rr.(*dns.A).A = net.ParseIP(ip).To4()
case 2:
rr = new(dns.AAAA)
rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeAAAA, Class: state.QClass()}
rr.(*dns.AAAA).AAAA = net.ParseIP(ip)
}
srv := new(dns.SRV)
srv.Hdr = dns.RR_Header{Name: "_" + state.Proto() + "." + state.QName(), Rrtype: dns.TypeSRV, Class: state.QClass()}
if state.QName() == "." {
srv.Hdr.Name = "_" + state.Proto() + state.QName()
}
port, _ := strconv.ParseUint(state.Port(), 10, 16)
srv.Port = uint16(port)
srv.Target = "."
a.Extra = []dns.RR{rr, srv}
w.WriteMsg(a)
return 0, nil
}
// Name implements the Handler interface.
func (wh Whoami) Name() string { return name }
package request
import (
"github.com/coredns/coredns/plugin/pkg/edns"
"github.com/miekg/dns"
)
func supportedOptions(o []dns.EDNS0) []dns.EDNS0 {
var supported = make([]dns.EDNS0, 0, 3)
// For as long as possible try avoid looking up in the map, because that need an Rlock.
for _, opt := range o {
switch code := opt.Option(); code {
case dns.EDNS0NSID:
fallthrough
case dns.EDNS0EXPIRE:
fallthrough
case dns.EDNS0COOKIE:
fallthrough
case dns.EDNS0TCPKEEPALIVE:
fallthrough
case dns.EDNS0PADDING:
supported = append(supported, opt)
default:
if edns.SupportedOption(code) {
supported = append(supported, opt)
}
}
}
return supported
}
// Package request abstracts a client's request so that all plugins will handle them in an unified way.
package request
import (
"net"
"strings"
"github.com/coredns/coredns/plugin/pkg/edns"
"github.com/miekg/dns"
)
// Request contains some connection state and is useful in plugin.
type Request struct {
Req *dns.Msg
W dns.ResponseWriter
// Optional lowercased zone of this query.
Zone string
// Cache size after first call to Size or Do. If size is zero nothing has been cached yet.
// Both Size and Do set these values (and cache them).
size uint16 // UDP buffer size, or 64K in case of TCP.
do bool // DNSSEC OK value
// Caches
family int8 // transport's family.
name string // lowercase qname.
ip string // client's ip.
port string // client's port.
localPort string // server's port.
localIP string // server's ip.
}
// NewWithQuestion returns a new request based on the old, but with a new question
// section in the request.
func (r *Request) NewWithQuestion(name string, typ uint16) Request {
req1 := Request{W: r.W, Req: r.Req.Copy()}
req1.Req.Question[0] = dns.Question{Name: dns.Fqdn(name), Qclass: dns.ClassINET, Qtype: typ}
return req1
}
// IP gets the (remote) IP address of the client making the request.
func (r *Request) IP() string {
if r.ip != "" {
return r.ip
}
ip, _, err := net.SplitHostPort(r.W.RemoteAddr().String())
if err != nil {
r.ip = r.W.RemoteAddr().String()
return r.ip
}
r.ip = ip
return r.ip
}
// LocalIP gets the (local) IP address of server handling the request.
func (r *Request) LocalIP() string {
if r.localIP != "" {
return r.localIP
}
ip, _, err := net.SplitHostPort(r.W.LocalAddr().String())
if err != nil {
r.localIP = r.W.LocalAddr().String()
return r.localIP
}
r.localIP = ip
return r.localIP
}
// Port gets the (remote) port of the client making the request.
func (r *Request) Port() string {
if r.port != "" {
return r.port
}
_, port, err := net.SplitHostPort(r.W.RemoteAddr().String())
if err != nil {
r.port = "0"
return r.port
}
r.port = port
return r.port
}
// LocalPort gets the local port of the server handling the request.
func (r *Request) LocalPort() string {
if r.localPort != "" {
return r.localPort
}
_, port, err := net.SplitHostPort(r.W.LocalAddr().String())
if err != nil {
r.localPort = "0"
return r.localPort
}
r.localPort = port
return r.localPort
}
// RemoteAddr returns the net.Addr of the client that sent the current request.
func (r *Request) RemoteAddr() string { return r.W.RemoteAddr().String() }
// LocalAddr returns the net.Addr of the server handling the current request.
func (r *Request) LocalAddr() string { return r.W.LocalAddr().String() }
// Proto gets the protocol used as the transport. This will be udp or tcp.
func (r *Request) Proto() string {
if _, ok := r.W.RemoteAddr().(*net.UDPAddr); ok {
return "udp"
}
if _, ok := r.W.RemoteAddr().(*net.TCPAddr); ok {
return "tcp"
}
return "udp"
}
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
func (r *Request) Family() int {
if r.family != 0 {
return int(r.family)
}
var a net.IP
ip := r.W.RemoteAddr()
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
if i, ok := ip.(*net.TCPAddr); ok {
a = i.IP
}
if a.To4() != nil {
r.family = 1
return 1
}
r.family = 2
return 2
}
// Do returns true if the request has the DO (DNSSEC OK) bit set.
func (r *Request) Do() bool {
if r.size != 0 {
return r.do
}
r.Size()
return r.do
}
// Len returns the length in bytes in the request.
func (r *Request) Len() int { return r.Req.Len() }
// Size returns if buffer size *advertised* in the requests OPT record.
// Or when the request was over TCP, we return the maximum allowed size of 64K.
func (r *Request) Size() int {
if r.size != 0 {
return int(r.size)
}
size := uint16(0)
if o := r.Req.IsEdns0(); o != nil {
r.do = o.Do()
size = o.UDPSize()
}
// normalize size
size = edns.Size(r.Proto(), size)
r.size = size
return int(size)
}
// SizeAndDo adds an OPT record that the reflects the intent from request.
// The returned bool indicates if an record was found and normalised.
func (r *Request) SizeAndDo(m *dns.Msg) bool {
o := r.Req.IsEdns0()
if o == nil {
return false
}
if mo := m.IsEdns0(); mo != nil {
mo.Hdr.Name = "."
mo.Hdr.Rrtype = dns.TypeOPT
mo.SetVersion(0)
mo.SetUDPSize(o.UDPSize())
mo.Hdr.Ttl &= 0xff00 // clear flags
// Assume if the message m has options set, they are OK and represent what an upstream can do.
if o.Do() {
mo.SetDo()
}
return true
}
// Reuse the request's OPT record and tack it to m.
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetVersion(0)
o.Hdr.Ttl &= 0xff00 // clear flags
if len(o.Option) > 0 {
o.Option = supportedOptions(o.Option)
}
m.Extra = append(m.Extra, o)
return true
}
// Scrub scrubs the reply message so that it will fit the client's buffer. It will first
// check if the reply fits without compression and then *with* compression.
// Note, the TC bit will be set regardless of protocol, even TCP message will
// get the bit, the client should then retry with pigeons.
func (r *Request) Scrub(reply *dns.Msg) *dns.Msg {
reply.Truncate(r.Size())
if reply.Compress {
return reply
}
if r.Proto() == "udp" {
rl := reply.Len()
// Last ditch attempt to avoid fragmentation, if the size is bigger than the v4/v6 UDP fragmentation
// limit and sent via UDP compress it (in the hope we go under that limit). Limits taken from NSD:
//
// .., 1480 (EDNS/IPv4), 1220 (EDNS/IPv6), or the advertised EDNS buffer size if that is
// smaller than the EDNS default.
// See: https://open.nlnetlabs.nl/pipermail/nsd-users/2011-November/001278.html
if rl > 1480 && r.Family() == 1 {
reply.Compress = true
}
if rl > 1220 && r.Family() == 2 {
reply.Compress = true
}
}
return reply
}
// Type returns the type of the question as a string. If the request is malformed the empty string is returned.
func (r *Request) Type() string {
if r.Req == nil {
return ""
}
if len(r.Req.Question) == 0 {
return ""
}
return dns.Type(r.Req.Question[0].Qtype).String()
}
// QType returns the type of the question as an uint16. If the request is malformed
// 0 is returned.
func (r *Request) QType() uint16 {
if r.Req == nil {
return 0
}
if len(r.Req.Question) == 0 {
return 0
}
return r.Req.Question[0].Qtype
}
// Name returns the name of the question in the request. Note
// this name will always have a closing dot and will be lower cased. After a call Name
// the value will be cached. To clear this caching call Clear.
// If the request is malformed the root zone is returned.
func (r *Request) Name() string {
if r.name != "" {
return r.name
}
if r.Req == nil {
r.name = "."
return "."
}
if len(r.Req.Question) == 0 {
r.name = "."
return "."
}
r.name = strings.ToLower(dns.Name(r.Req.Question[0].Name).String())
return r.name
}
// QName returns the name of the question in the request.
// If the request is malformed the root zone is returned.
func (r *Request) QName() string {
if r.Req == nil {
return "."
}
if len(r.Req.Question) == 0 {
return "."
}
return dns.Name(r.Req.Question[0].Name).String()
}
// Class returns the class of the question in the request.
// If the request is malformed the empty string is returned.
func (r *Request) Class() string {
if r.Req == nil {
return ""
}
if len(r.Req.Question) == 0 {
return ""
}
return dns.Class(r.Req.Question[0].Qclass).String()
}
// QClass returns the class of the question in the request.
// If the request is malformed 0 returned.
func (r *Request) QClass() uint16 {
if r.Req == nil {
return 0
}
if len(r.Req.Question) == 0 {
return 0
}
return r.Req.Question[0].Qclass
}
// Clear clears all caching from Request s.
func (r *Request) Clear() {
r.name = ""
r.ip = ""
r.localIP = ""
r.port = ""
r.localPort = ""
r.family = 0
r.size = 0
r.do = false
}
// Match checks if the reply matches the qname and qtype from the request, it returns
// false when they don't match.
func (r *Request) Match(reply *dns.Msg) bool {
if len(reply.Question) != 1 {
return false
}
if !reply.Response {
return false
}
if strings.ToLower(reply.Question[0].Name) != r.Name() {
return false
}
if reply.Question[0].Qtype != r.QType() {
return false
}
return true
}
package request
import "github.com/miekg/dns"
// ScrubWriter will, when writing the message, call scrub to make it fit the client's buffer.
type ScrubWriter struct {
dns.ResponseWriter
req *dns.Msg // original request
}
// NewScrubWriter returns a new and initialized ScrubWriter.
func NewScrubWriter(req *dns.Msg, w dns.ResponseWriter) *ScrubWriter { return &ScrubWriter{w, req} }
// WriteMsg overrides the default implementation of the underlying dns.ResponseWriter and calls
// scrub on the message m and will then write it to the client.
func (s *ScrubWriter) WriteMsg(m *dns.Msg) error {
state := Request{Req: s.req, W: s.ResponseWriter}
state.SizeAndDo(m)
state.Scrub(m)
return s.ResponseWriter.WriteMsg(m)
}