package circuit
// binarysampler contains a series of events as 0 or 1 values, e.g. errors or successes,
// within a limited, sliding window.
// count contains the actual number of events with the value of 1 within the window.
// it compresses the event storage by 64.
type binarySampler struct {
size int
filled int
frames []uint64
pad uint64
count int
}
func newBinarySampler(size int) *binarySampler {
if size <= 0 {
size = 1
}
return &binarySampler{
size: size,
pad: 64 - uint64(size)%64,
}
}
func highestSet(frame, pad uint64) bool {
return frame&(1<<(63-pad)) != 0
}
func shift(frames []uint64) {
highestFrame := len(frames) - 1
for i := highestFrame; i >= 0; i-- {
h := highestSet(frames[i], 0)
frames[i] = frames[i] << 1
if h && i < highestFrame {
frames[i+1] |= 1
}
}
}
func (s *binarySampler) tick(set bool) {
filled := s.filled == s.size
if filled && highestSet(s.frames[len(s.frames)-1], s.pad) {
s.count--
}
if !filled {
if len(s.frames) <= s.filled/64 {
s.frames = append(s.frames, 0)
}
s.filled++
}
shift(s.frames)
if set {
s.count++
s.frames[0] |= 1
}
}
package circuit
import (
"fmt"
"strconv"
"strings"
"time"
)
// BreakerType defines the type of the used breaker: consecutive, rate or disabled.
type BreakerType int
func (b *BreakerType) UnmarshalYAML(unmarshal func(interface{}) error) error {
var value string
if err := unmarshal(&value); err != nil {
return err
}
switch value {
case "consecutive":
*b = ConsecutiveFailures
case "rate":
*b = FailureRate
case "disabled":
*b = BreakerDisabled
default:
return fmt.Errorf("invalid breaker type %v (allowed values are: consecutive, rate or disabled)", value)
}
return nil
}
const (
BreakerNone BreakerType = iota
ConsecutiveFailures
FailureRate
BreakerDisabled
)
// BreakerSettings contains the settings for individual circuit breakers.
//
// See the package overview for the detailed merging/overriding rules of the settings and for the meaning of the
// individual fields.
type BreakerSettings struct {
Type BreakerType `yaml:"type"`
Host string `yaml:"host"`
Window int `yaml:"window"`
Failures int `yaml:"failures"`
Timeout time.Duration `yaml:"timeout"`
HalfOpenRequests int `yaml:"half-open-requests"`
IdleTTL time.Duration `yaml:"idle-ttl"`
}
type breakerImplementation interface {
Allow() (func(bool), bool)
}
type voidBreaker struct{}
// Breaker represents a single circuit breaker for a particular set of settings.
//
// Use the Get() method of the Registry to request fully initialized breakers.
type Breaker struct {
settings BreakerSettings
ts time.Time
impl breakerImplementation
}
func (to BreakerSettings) mergeSettings(from BreakerSettings) BreakerSettings {
if to.Type == BreakerNone {
to.Type = from.Type
if from.Type == ConsecutiveFailures {
to.Failures = from.Failures
}
if from.Type == FailureRate {
to.Window = from.Window
to.Failures = from.Failures
}
}
if to.Timeout == 0 {
to.Timeout = from.Timeout
}
if to.HalfOpenRequests == 0 {
to.HalfOpenRequests = from.HalfOpenRequests
}
if to.IdleTTL == 0 {
to.IdleTTL = from.IdleTTL
}
return to
}
// String returns the string representation of a particular set of settings.
//
//lint:ignore ST1016 "s" makes sense here and mergeSettings has "to"
func (s BreakerSettings) String() string {
var ss []string
switch s.Type {
case ConsecutiveFailures:
ss = append(ss, "type=consecutive")
case FailureRate:
ss = append(ss, "type=rate")
case BreakerDisabled:
return "disabled"
default:
return "none"
}
if s.Host != "" {
ss = append(ss, "host="+s.Host)
}
if s.Type == FailureRate && s.Window > 0 {
ss = append(ss, "window="+strconv.Itoa(s.Window))
}
if s.Failures > 0 {
ss = append(ss, "failures="+strconv.Itoa(s.Failures))
}
if s.Timeout > 0 {
ss = append(ss, "timeout="+s.Timeout.String())
}
if s.HalfOpenRequests > 0 {
ss = append(ss, "half-open-requests="+strconv.Itoa(s.HalfOpenRequests))
}
if s.IdleTTL > 0 {
ss = append(ss, "idle-ttl="+s.IdleTTL.String())
}
return strings.Join(ss, ",")
}
func (b voidBreaker) Allow() (func(bool), bool) {
return func(bool) {}, true
}
func newBreaker(s BreakerSettings) *Breaker {
var impl breakerImplementation
switch s.Type {
case ConsecutiveFailures:
impl = newConsecutive(s)
case FailureRate:
impl = newRate(s)
default:
impl = voidBreaker{}
}
return &Breaker{
settings: s,
impl: impl,
}
}
// Allow returns true if the breaker is in the closed state and a callback function for reporting the outcome of
// the operation. The callback expects true values if the outcome of the request was successful. Allow may not
// return a callback function when the state is open.
func (b *Breaker) Allow() (func(bool), bool) {
return b.impl.Allow()
}
func (b *Breaker) idle(now time.Time) bool {
return now.Sub(b.ts) > b.settings.IdleTTL
}
package circuit
import (
log "github.com/sirupsen/logrus"
"github.com/sony/gobreaker"
)
type consecutiveBreaker struct {
settings BreakerSettings
gb *gobreaker.TwoStepCircuitBreaker
}
func newConsecutive(s BreakerSettings) *consecutiveBreaker {
b := &consecutiveBreaker{
settings: s,
}
b.gb = gobreaker.NewTwoStepCircuitBreaker(gobreaker.Settings{
Name: s.Host,
MaxRequests: uint32(s.HalfOpenRequests),
Timeout: s.Timeout,
ReadyToTrip: b.readyToTrip,
OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) {
log.Infof("circuit breaker %v went from %v to %v", name, from.String(), to.String())
},
})
return b
}
func (b *consecutiveBreaker) readyToTrip(c gobreaker.Counts) bool {
return int(c.ConsecutiveFailures) >= b.settings.Failures
}
func (b *consecutiveBreaker) Allow() (func(bool), bool) {
done, err := b.gb.Allow()
// this error can only indicate that the breaker is not closed
closed := err == nil
if !closed {
return nil, false
}
return done, true
}
package circuit
import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/sony/gobreaker"
)
// TODO:
// in case of the rate breaker, there are unnecessary synchronization steps due to the 3rd party gobreaker. If
// the sliding window was part of the implementation of the individual breakers, this additional syncrhonization
// would not be required.
type rateBreaker struct {
settings BreakerSettings
mu sync.Mutex
sampler *binarySampler
gb *gobreaker.TwoStepCircuitBreaker
}
func newRate(s BreakerSettings) *rateBreaker {
b := &rateBreaker{
settings: s,
}
b.gb = gobreaker.NewTwoStepCircuitBreaker(gobreaker.Settings{
Name: s.Host,
MaxRequests: uint32(s.HalfOpenRequests),
Timeout: s.Timeout,
ReadyToTrip: func(gobreaker.Counts) bool { return b.readyToTrip() },
OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) {
log.Infof("circuit breaker %v went from %v to %v", name, from.String(), to.String())
},
})
return b
}
func (b *rateBreaker) readyToTrip() bool {
b.mu.Lock()
defer b.mu.Unlock()
if b.sampler == nil {
return false
}
return b.sampler.count >= b.settings.Failures
}
// count the failures in closed and half-open state
func (b *rateBreaker) countRate(success bool) {
b.mu.Lock()
defer b.mu.Unlock()
if b.sampler == nil {
b.sampler = newBinarySampler(b.settings.Window)
}
b.sampler.tick(!success)
}
func (b *rateBreaker) Allow() (func(bool), bool) {
done, err := b.gb.Allow()
// this error can only indicate that the breaker is not closed
closed := err == nil
if !closed {
return nil, false
}
return func(success bool) {
b.countRate(success)
done(success)
}, true
}
package circuit
import (
"sync"
"time"
)
const DefaultIdleTTL = time.Hour
// Registry objects hold the active circuit breakers, ensure synchronized access to them, apply default settings
// and recycle the idle breakers.
type Registry struct {
defaults BreakerSettings
hostSettings map[string]BreakerSettings
mu sync.Mutex
lookup map[BreakerSettings]*Breaker
}
// NewRegistry initializes a registry with the provided default settings. Settings with an empty Host field are
// considered as defaults. Settings with the same Host field are merged together.
func NewRegistry(settings ...BreakerSettings) *Registry {
var (
defaults BreakerSettings
hostSettings []BreakerSettings
)
for _, s := range settings {
if s.Host == "" {
defaults = defaults.mergeSettings(s)
continue
}
hostSettings = append(hostSettings, s)
}
if defaults.IdleTTL <= 0 {
defaults.IdleTTL = DefaultIdleTTL
}
hs := make(map[string]BreakerSettings)
for _, s := range hostSettings {
if sh, ok := hs[s.Host]; ok {
hs[s.Host] = s.mergeSettings(sh)
} else {
hs[s.Host] = s.mergeSettings(defaults)
}
}
return &Registry{
defaults: defaults,
hostSettings: hs,
lookup: make(map[BreakerSettings]*Breaker),
}
}
func (r *Registry) mergeDefaults(s BreakerSettings) BreakerSettings {
defaults, ok := r.hostSettings[s.Host]
if !ok {
defaults = r.defaults
}
return s.mergeSettings(defaults)
}
func (r *Registry) dropIdle(now time.Time) {
for h, b := range r.lookup {
if b.idle(now) {
delete(r.lookup, h)
}
}
}
func (r *Registry) get(s BreakerSettings) *Breaker {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
b, ok := r.lookup[s]
if !ok || b.idle(now) {
// check if there is any other to evict, evict if yes
r.dropIdle(now)
// create a new one
b = newBreaker(s)
r.lookup[s] = b
}
// set the access timestamp
b.ts = now
return b
}
// Get returns a circuit breaker for the provided settings. The BreakerSettings object is used here as a key,
// but typically it is enough to just set its Host field:
//
// r.Get(BreakerSettings{Host: backendHost})
//
// The key will be filled up with the defaults and the matching circuit breaker will be returned if it exists,
// or a new one will be created if not.
func (r *Registry) Get(s BreakerSettings) *Breaker {
// we check for host, because we don't want to use shared global breakers
if s.Type == BreakerDisabled || s.Host == "" {
return nil
}
s = r.mergeDefaults(s)
if s.Type == BreakerNone {
return nil
}
return r.get(s)
}
package config
import (
"errors"
"strconv"
"strings"
"time"
"github.com/zalando/skipper/circuit"
)
const breakerUsage = `set global or host specific circuit breakers, e.g. -breaker type=rate,host=www.example.org,window=300s,failures=30
possible breaker properties:
type: consecutive/rate/disabled (defaults to consecutive)
host: a host name that overrides the global for a host
failures: the number of failures for consecutive or rate breakers
window: the size of the sliding window for the rate breaker
timeout: duration string or milliseconds while the breaker stays open
half-open-requests: the number of requests in half-open state to succeed before getting closed again
idle-ttl: duration string or milliseconds after the breaker is considered idle and reset
(see also: https://godoc.org/github.com/zalando/skipper/circuit)`
const enableBreakersUsage = `enable breakers to be set from filters without providing global or host settings (equivalent to: -breaker type=disabled)`
type breakerFlags []circuit.BreakerSettings
var errInvalidBreakerConfig = errors.New("invalid breaker config (allowed values are: consecutive, rate or disabled)")
func (b breakerFlags) String() string {
s := make([]string, len(b))
for i, bi := range b {
s[i] = bi.String()
}
return strings.Join(s, "\n")
}
func (b *breakerFlags) Set(value string) error {
var s circuit.BreakerSettings
vs := strings.Split(value, ",")
for _, vi := range vs {
k, v, found := strings.Cut(vi, "=")
if !found {
return errInvalidBreakerConfig
}
switch k {
case "type":
switch v {
case "consecutive":
s.Type = circuit.ConsecutiveFailures
case "rate":
s.Type = circuit.FailureRate
case "disabled":
s.Type = circuit.BreakerDisabled
default:
return errInvalidBreakerConfig
}
case "host":
s.Host = v
case "window":
i, err := strconv.Atoi(v)
if err != nil {
return err
}
s.Window = i
case "failures":
i, err := strconv.Atoi(v)
if err != nil {
return err
}
s.Failures = i
case "timeout":
d, err := time.ParseDuration(v)
if err != nil {
return err
}
s.Timeout = d
case "half-open-requests":
i, err := strconv.Atoi(v)
if err != nil {
return err
}
s.HalfOpenRequests = i
case "idle-ttl":
d, err := time.ParseDuration(v)
if err != nil {
return err
}
s.IdleTTL = d
default:
return errInvalidBreakerConfig
}
}
if s.Type == circuit.BreakerNone {
s.Type = circuit.ConsecutiveFailures
}
*b = append(*b, s)
return nil
}
func (b *breakerFlags) UnmarshalYAML(unmarshal func(interface{}) error) error {
var breakerSettings circuit.BreakerSettings
if err := unmarshal(&breakerSettings); err != nil {
return err
}
*b = append(*b, breakerSettings)
return nil
}
package config
import (
"crypto/tls"
"flag"
"fmt"
"net/http"
"os"
"sort"
"strconv"
"strings"
"time"
"gopkg.in/yaml.v2"
log "github.com/sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus"
"github.com/zalando/skipper"
"github.com/zalando/skipper/dataclients/kubernetes"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters/openpolicyagent"
"github.com/zalando/skipper/net"
"github.com/zalando/skipper/proxy"
"github.com/zalando/skipper/swarm"
)
type Config struct {
ConfigFile string
Flags *flag.FlagSet
// generic:
Address string `yaml:"address"`
InsecureAddress string `yaml:"insecure-address"`
EnableTCPQueue bool `yaml:"enable-tcp-queue"`
ExpectedBytesPerRequest int `yaml:"expected-bytes-per-request"`
MaxTCPListenerConcurrency int `yaml:"max-tcp-listener-concurrency"`
MaxTCPListenerQueue int `yaml:"max-tcp-listener-queue"`
IgnoreTrailingSlash bool `yaml:"ignore-trailing-slash"`
Insecure bool `yaml:"insecure"`
ProxyPreserveHost bool `yaml:"proxy-preserve-host"`
DevMode bool `yaml:"dev-mode"`
SupportListener string `yaml:"support-listener"`
DebugListener string `yaml:"debug-listener"`
CertPathTLS string `yaml:"tls-cert"`
KeyPathTLS string `yaml:"tls-key"`
StatusChecks *listFlag `yaml:"status-checks"`
PrintVersion bool `yaml:"version"`
MaxLoopbacks int `yaml:"max-loopbacks"`
DefaultHTTPStatus int `yaml:"default-http-status"`
PluginDir string `yaml:"plugindir"`
LoadBalancerHealthCheckInterval time.Duration `yaml:"lb-healthcheck-interval"`
ReverseSourcePredicate bool `yaml:"reverse-source-predicate"`
RemoveHopHeaders bool `yaml:"remove-hop-headers"`
RfcPatchPath bool `yaml:"rfc-patch-path"`
MaxAuditBody int `yaml:"max-audit-body"`
MaxMatcherBufferSize uint64 `yaml:"max-matcher-buffer-size"`
EnableBreakers bool `yaml:"enable-breakers"`
Breakers breakerFlags `yaml:"breaker"`
EnableRatelimiters bool `yaml:"enable-ratelimits"`
Ratelimits ratelimitFlags `yaml:"ratelimits"`
EnableRouteFIFOMetrics bool `yaml:"enable-route-fifo-metrics"`
EnableRouteLIFOMetrics bool `yaml:"enable-route-lifo-metrics"`
MetricsFlavour *listFlag `yaml:"metrics-flavour"`
FilterPlugins *pluginFlag `yaml:"filter-plugin"`
PredicatePlugins *pluginFlag `yaml:"predicate-plugin"`
DataclientPlugins *pluginFlag `yaml:"dataclient-plugin"`
MultiPlugins *pluginFlag `yaml:"multi-plugin"`
CompressEncodings *listFlag `yaml:"compress-encodings"`
// logging, metrics, profiling, tracing:
EnablePrometheusMetrics bool `yaml:"enable-prometheus-metrics"`
EnablePrometheusStartLabel bool `yaml:"enable-prometheus-start-label"`
OpenTracing string `yaml:"opentracing"`
OpenTracingInitialSpan string `yaml:"opentracing-initial-span"`
OpenTracingExcludedProxyTags string `yaml:"opentracing-excluded-proxy-tags"`
OpenTracingDisableFilterSpans bool `yaml:"opentracing-disable-filter-spans"`
OpentracingLogFilterLifecycleEvents bool `yaml:"opentracing-log-filter-lifecycle-events"`
OpentracingLogStreamEvents bool `yaml:"opentracing-log-stream-events"`
OpentracingBackendNameTag bool `yaml:"opentracing-backend-name-tag"`
MetricsListener string `yaml:"metrics-listener"`
MetricsPrefix string `yaml:"metrics-prefix"`
EnableProfile bool `yaml:"enable-profile"`
BlockProfileRate int `yaml:"block-profile-rate"`
MutexProfileFraction int `yaml:"mutex-profile-fraction"`
MemProfileRate int `yaml:"memory-profile-rate"`
DebugGcMetrics bool `yaml:"debug-gc-metrics"`
RuntimeMetrics bool `yaml:"runtime-metrics"`
ServeRouteMetrics bool `yaml:"serve-route-metrics"`
ServeRouteCounter bool `yaml:"serve-route-counter"`
ServeHostMetrics bool `yaml:"serve-host-metrics"`
ServeHostCounter bool `yaml:"serve-host-counter"`
ServeMethodMetric bool `yaml:"serve-method-metric"`
ServeStatusCodeMetric bool `yaml:"serve-status-code-metric"`
BackendHostMetrics bool `yaml:"backend-host-metrics"`
ProxyRequestMetrics bool `yaml:"proxy-request-metrics"`
ProxyResponseMetrics bool `yaml:"proxy-response-metrics"`
AllFiltersMetrics bool `yaml:"all-filters-metrics"`
CombinedResponseMetrics bool `yaml:"combined-response-metrics"`
RouteResponseMetrics bool `yaml:"route-response-metrics"`
RouteBackendErrorCounters bool `yaml:"route-backend-error-counters"`
RouteStreamErrorCounters bool `yaml:"route-stream-error-counters"`
RouteBackendMetrics bool `yaml:"route-backend-metrics"`
RouteCreationMetrics bool `yaml:"route-creation-metrics"`
MetricsUseExpDecaySample bool `yaml:"metrics-exp-decay-sample"`
HistogramMetricBucketsString string `yaml:"histogram-metric-buckets"`
HistogramMetricBuckets []float64 `yaml:"-"`
DisableMetricsCompat bool `yaml:"disable-metrics-compat"`
ApplicationLog string `yaml:"application-log"`
ApplicationLogLevel log.Level `yaml:"-"`
ApplicationLogLevelString string `yaml:"application-log-level"`
ApplicationLogPrefix string `yaml:"application-log-prefix"`
ApplicationLogJSONEnabled bool `yaml:"application-log-json-enabled"`
AccessLog string `yaml:"access-log"`
AccessLogDisabled bool `yaml:"access-log-disabled"`
AccessLogJSONEnabled bool `yaml:"access-log-json-enabled"`
AccessLogStripQuery bool `yaml:"access-log-strip-query"`
SuppressRouteUpdateLogs bool `yaml:"suppress-route-update-logs"`
// route sources:
EtcdUrls string `yaml:"etcd-urls"`
EtcdPrefix string `yaml:"etcd-prefix"`
EtcdTimeout time.Duration `yaml:"etcd-timeout"`
EtcdInsecure bool `yaml:"etcd-insecure"`
EtcdOAuthToken string `yaml:"etcd-oauth-token"`
EtcdUsername string `yaml:"etcd-username"`
EtcdPassword string `yaml:"etcd-password"`
RoutesFile string `yaml:"routes-file"`
RoutesURLs *listFlag `yaml:"routes-urls"`
InlineRoutes string `yaml:"inline-routes"`
AppendFilters *defaultFiltersFlags `yaml:"default-filters-append"`
PrependFilters *defaultFiltersFlags `yaml:"default-filters-prepend"`
DisabledFilters *listFlag `yaml:"disabled-filters"`
EditRoute routeChangerConfig `yaml:"edit-route"`
CloneRoute routeChangerConfig `yaml:"clone-route"`
SourcePollTimeout int64 `yaml:"source-poll-timeout"`
WaitFirstRouteLoad bool `yaml:"wait-first-route-load"`
// Forwarded headers
ForwardedHeadersList *listFlag `yaml:"forwarded-headers"`
ForwardedHeaders net.ForwardedHeaders `yaml:"-"`
ForwardedHeadersExcludeCIDRList *listFlag `yaml:"forwarded-headers-exclude-cidrs"`
ForwardedHeadersExcludeCIDRs net.IPNets `yaml:"-"`
// host patch:
NormalizeHost bool `yaml:"normalize-host"`
HostPatch net.HostPatch `yaml:"-"`
ValidateQuery bool `yaml:"validate-query"`
ValidateQueryLog bool `yaml:"validate-query-log"`
RefusePayload multiFlag `yaml:"refuse-payload"`
// Kubernetes:
KubernetesIngress bool `yaml:"kubernetes"`
KubernetesInCluster bool `yaml:"kubernetes-in-cluster"`
KubernetesURL string `yaml:"kubernetes-url"`
KubernetesTokenFile string `yaml:"kubernetes-token-file"`
KubernetesHealthcheck bool `yaml:"kubernetes-healthcheck"`
KubernetesHTTPSRedirect bool `yaml:"kubernetes-https-redirect"`
KubernetesHTTPSRedirectCode int `yaml:"kubernetes-https-redirect-code"`
KubernetesDisableCatchAllRoutes bool `yaml:"kubernetes-disable-catchall-routes"`
KubernetesIngressClass string `yaml:"kubernetes-ingress-class"`
KubernetesRouteGroupClass string `yaml:"kubernetes-routegroup-class"`
WhitelistedHealthCheckCIDR string `yaml:"whitelisted-healthcheck-cidr"`
KubernetesPathModeString string `yaml:"kubernetes-path-mode"`
KubernetesPathMode kubernetes.PathMode `yaml:"-"`
KubernetesNamespace string `yaml:"kubernetes-namespace"`
KubernetesEnableEndpointSlices bool `yaml:"enable-kubernetes-endpointslices"`
KubernetesEnableEastWest bool `yaml:"enable-kubernetes-east-west"`
KubernetesEastWestDomain string `yaml:"kubernetes-east-west-domain"`
KubernetesEastWestRangeDomains *listFlag `yaml:"kubernetes-east-west-range-domains"`
KubernetesEastWestRangePredicatesString string `yaml:"kubernetes-east-west-range-predicates"`
KubernetesEastWestRangeAnnotationPredicatesString multiFlag `yaml:"kubernetes-east-west-range-annotation-predicates"`
KubernetesEastWestRangeAnnotationFiltersAppendString multiFlag `yaml:"kubernetes-east-west-range-annotation-filters-append"`
KubernetesAnnotationPredicatesString multiFlag `yaml:"kubernetes-annotation-predicates"`
KubernetesAnnotationFiltersAppendString multiFlag `yaml:"kubernetes-annotation-filters-append"`
KubernetesEastWestRangeAnnotationPredicates []kubernetes.AnnotationPredicates `yaml:"-"`
KubernetesEastWestRangeAnnotationFiltersAppend []kubernetes.AnnotationFilters `yaml:"-"`
KubernetesAnnotationPredicates []kubernetes.AnnotationPredicates `yaml:"-"`
KubernetesAnnotationFiltersAppend []kubernetes.AnnotationFilters `yaml:"-"`
KubernetesEastWestRangePredicates []*eskip.Predicate `yaml:"-"`
KubernetesOnlyAllowedExternalNames bool `yaml:"kubernetes-only-allowed-external-names"`
KubernetesAllowedExternalNames regexpListFlag `yaml:"kubernetes-allowed-external-names"`
KubernetesRedisServiceNamespace string `yaml:"kubernetes-redis-service-namespace"`
KubernetesRedisServiceName string `yaml:"kubernetes-redis-service-name"`
KubernetesRedisServicePort int `yaml:"kubernetes-redis-service-port"`
KubernetesBackendTrafficAlgorithmString string `yaml:"kubernetes-backend-traffic-algorithm"`
KubernetesBackendTrafficAlgorithm kubernetes.BackendTrafficAlgorithm `yaml:"-"`
KubernetesDefaultLoadBalancerAlgorithm string `yaml:"kubernetes-default-lb-algorithm"`
KubernetesForceService bool `yaml:"kubernetes-force-service"`
// Default filters
DefaultFiltersDir string `yaml:"default-filters-dir"`
// Auth:
EnableOAuth2GrantFlow bool `yaml:"enable-oauth2-grant-flow"`
Oauth2AuthURL string `yaml:"oauth2-auth-url"`
Oauth2TokenURL string `yaml:"oauth2-token-url"`
Oauth2RevokeTokenURL string `yaml:"oauth2-revoke-token-url"`
Oauth2TokeninfoURL string `yaml:"oauth2-tokeninfo-url"`
Oauth2TokeninfoTimeout time.Duration `yaml:"oauth2-tokeninfo-timeout"`
Oauth2TokeninfoCacheSize int `yaml:"oauth2-tokeninfo-cache-size"`
Oauth2TokeninfoCacheTTL time.Duration `yaml:"oauth2-tokeninfo-cache-ttl"`
Oauth2SecretFile string `yaml:"oauth2-secret-file"`
Oauth2ClientID string `yaml:"oauth2-client-id"`
Oauth2ClientSecret string `yaml:"oauth2-client-secret"`
Oauth2ClientIDFile string `yaml:"oauth2-client-id-file"`
Oauth2ClientSecretFile string `yaml:"oauth2-client-secret-file"`
Oauth2AuthURLParameters mapFlags `yaml:"oauth2-auth-url-parameters"`
Oauth2CallbackPath string `yaml:"oauth2-callback-path"`
Oauth2TokenintrospectionTimeout time.Duration `yaml:"oauth2-tokenintrospect-timeout"`
Oauth2AccessTokenHeaderName string `yaml:"oauth2-access-token-header-name"`
Oauth2TokeninfoSubjectKey string `yaml:"oauth2-tokeninfo-subject-key"`
Oauth2GrantTokeninfoKeys *listFlag `yaml:"oauth2-grant-tokeninfo-keys"`
Oauth2TokenCookieName string `yaml:"oauth2-token-cookie-name"`
Oauth2TokenCookieRemoveSubdomains int `yaml:"oauth2-token-cookie-remove-subdomains"`
Oauth2GrantInsecure bool `yaml:"oauth2-grant-insecure"`
WebhookTimeout time.Duration `yaml:"webhook-timeout"`
OidcSecretsFile string `yaml:"oidc-secrets-file"`
OIDCCookieValidity time.Duration `yaml:"oidc-cookie-validity"`
OidcDistributedClaimsTimeout time.Duration `yaml:"oidc-distributed-claims-timeout"`
OIDCCookieRemoveSubdomains int `yaml:"oidc-cookie-remove-subdomains"`
CredentialPaths *listFlag `yaml:"credentials-paths"`
CredentialsUpdateInterval time.Duration `yaml:"credentials-update-interval"`
// TLS client certs
ClientKeyFile string `yaml:"client-tls-key"`
ClientCertFile string `yaml:"client-tls-cert"`
Certificates []tls.Certificate `yaml:"-"`
// TLS version
TLSMinVersion string `yaml:"tls-min-version"`
TLSClientAuth tls.ClientAuthType `yaml:"tls-client-auth"`
// Exclude insecure cipher suites
ExcludeInsecureCipherSuites bool `yaml:"exclude-insecure-cipher-suites"`
// TLS Config
KubernetesEnableTLS bool `yaml:"kubernetes-enable-tls"`
// API Monitoring
ApiUsageMonitoringEnable bool `yaml:"enable-api-usage-monitoring"`
ApiUsageMonitoringRealmKeys string `yaml:"api-usage-monitoring-realm-keys"`
ApiUsageMonitoringClientKeys string `yaml:"api-usage-monitoring-client-keys"`
ApiUsageMonitoringDefaultClientTrackingPattern string `yaml:"api-usage-monitoring-default-client-tracking-pattern"`
ApiUsageMonitoringRealmsTrackingPattern string `yaml:"api-usage-monitoring-realms-tracking-pattern"`
// connections, timeouts:
WaitForHealthcheckInterval time.Duration `yaml:"wait-for-healthcheck-interval"`
IdleConnsPerHost int `yaml:"idle-conns-num"`
CloseIdleConnsPeriod time.Duration `yaml:"close-idle-conns-period"`
BackendFlushInterval time.Duration `yaml:"backend-flush-interval"`
ExperimentalUpgrade bool `yaml:"experimental-upgrade"`
ExperimentalUpgradeAudit bool `yaml:"experimental-upgrade-audit"`
ReadTimeoutServer time.Duration `yaml:"read-timeout-server"`
ReadHeaderTimeoutServer time.Duration `yaml:"read-header-timeout-server"`
WriteTimeoutServer time.Duration `yaml:"write-timeout-server"`
IdleTimeoutServer time.Duration `yaml:"idle-timeout-server"`
KeepaliveServer time.Duration `yaml:"keepalive-server"`
KeepaliveRequestsServer int `yaml:"keepalive-requests-server"`
MaxHeaderBytes int `yaml:"max-header-bytes"`
EnableConnMetricsServer bool `yaml:"enable-connection-metrics"`
TimeoutBackend time.Duration `yaml:"timeout-backend"`
KeepaliveBackend time.Duration `yaml:"keepalive-backend"`
EnableDualstackBackend bool `yaml:"enable-dualstack-backend"`
TlsHandshakeTimeoutBackend time.Duration `yaml:"tls-timeout-backend"`
ResponseHeaderTimeoutBackend time.Duration `yaml:"response-header-timeout-backend"`
ExpectContinueTimeoutBackend time.Duration `yaml:"expect-continue-timeout-backend"`
MaxIdleConnsBackend int `yaml:"max-idle-connection-backend"`
DisableHTTPKeepalives bool `yaml:"disable-http-keepalives"`
// swarm:
EnableSwarm bool `yaml:"enable-swarm"`
// redis based
SwarmRedisURLs *listFlag `yaml:"swarm-redis-urls"`
SwarmRedisPassword string `yaml:"swarm-redis-password"`
SwarmRedisHashAlgorithm string `yaml:"swarm-redis-hash-algorithm"`
SwarmRedisDialTimeout time.Duration `yaml:"swarm-redis-dial-timeout"`
SwarmRedisReadTimeout time.Duration `yaml:"swarm-redis-read-timeout"`
SwarmRedisWriteTimeout time.Duration `yaml:"swarm-redis-write-timeout"`
SwarmRedisPoolTimeout time.Duration `yaml:"swarm-redis-pool-timeout"`
SwarmRedisMinConns int `yaml:"swarm-redis-min-conns"`
SwarmRedisMaxConns int `yaml:"swarm-redis-max-conns"`
SwarmRedisEndpointsRemoteURL string `yaml:"swarm-redis-remote"`
// swim based
SwarmKubernetesNamespace string `yaml:"swarm-namespace"`
SwarmKubernetesLabelSelectorKey string `yaml:"swarm-label-selector-key"`
SwarmKubernetesLabelSelectorValue string `yaml:"swarm-label-selector-value"`
SwarmPort int `yaml:"swarm-port"`
SwarmMaxMessageBuffer int `yaml:"swarm-max-msg-buffer"`
SwarmLeaveTimeout time.Duration `yaml:"swarm-leave-timeout"`
SwarmStaticSelf string `yaml:"swarm-static-self"`
SwarmStaticOther string `yaml:"swarm-static-other"`
ClusterRatelimitMaxGroupShards int `yaml:"cluster-ratelimit-max-group-shards"`
LuaModules *listFlag `yaml:"lua-modules"`
LuaSources *listFlag `yaml:"lua-sources"`
EnableOpenPolicyAgent bool `yaml:"enable-open-policy-agent"`
EnableOpenPolicyAgentCustomControlLoop bool `yaml:"enable-open-policy-agent-custom-control-loop"`
OpenPolicyAgentControlLoopInterval time.Duration `yaml:"open-policy-agent-control-loop-interval"`
OpenPolicyAgentControlLoopMaxJitter time.Duration `yaml:"open-policy-agent-control-loop-max-jitter"`
EnableOpenPolicyAgentDataPreProcessingOptimization bool `yaml:"enable-open-policy-agent-data-preprocessing-optimization"`
OpenPolicyAgentConfigTemplate string `yaml:"open-policy-agent-config-template"`
OpenPolicyAgentEnvoyMetadata string `yaml:"open-policy-agent-envoy-metadata"`
OpenPolicyAgentCleanerInterval time.Duration `yaml:"open-policy-agent-cleaner-interval"`
OpenPolicyAgentStartupTimeout time.Duration `yaml:"open-policy-agent-startup-timeout"`
OpenPolicyAgentRequestBodyBufferSize int64 `yaml:"open-policy-agent-request-body-buffer-size"`
OpenPolicyAgentMaxRequestBodySize int64 `yaml:"open-policy-agent-max-request-body-size"`
OpenPolicyAgentMaxMemoryBodyParsing int64 `yaml:"open-policy-agent-max-memory-body-parsing"`
PassiveHealthCheck mapFlags `yaml:"passive-health-check"`
}
const (
// TLS
defaultMinTLSVersion = "1.2"
// environment keys:
redisPasswordEnv = "SWARM_REDIS_PASSWORD"
)
func NewConfig() *Config {
cfg := new(Config)
cfg.MetricsFlavour = commaListFlag("codahale", "prometheus")
cfg.StatusChecks = commaListFlag()
cfg.FilterPlugins = newPluginFlag()
cfg.PredicatePlugins = newPluginFlag()
cfg.DataclientPlugins = newPluginFlag()
cfg.MultiPlugins = newPluginFlag()
cfg.CredentialPaths = commaListFlag()
cfg.SwarmRedisURLs = commaListFlag()
cfg.AppendFilters = &defaultFiltersFlags{}
cfg.PrependFilters = &defaultFiltersFlags{}
cfg.DisabledFilters = commaListFlag()
cfg.CloneRoute = routeChangerConfig{}
cfg.EditRoute = routeChangerConfig{}
cfg.KubernetesEastWestRangeDomains = commaListFlag()
cfg.RoutesURLs = commaListFlag()
cfg.ForwardedHeadersList = commaListFlag()
cfg.ForwardedHeadersExcludeCIDRList = commaListFlag()
cfg.CompressEncodings = commaListFlag("gzip", "deflate", "br")
cfg.LuaModules = commaListFlag()
cfg.LuaSources = commaListFlag()
cfg.Oauth2GrantTokeninfoKeys = commaListFlag()
flag := flag.NewFlagSet("", flag.ExitOnError)
flag.StringVar(&cfg.ConfigFile, "config-file", "", "if provided the flags will be loaded/overwritten by the values on the file (yaml)")
// generic:
flag.StringVar(&cfg.Address, "address", ":9090", "network address that skipper should listen on")
flag.StringVar(&cfg.InsecureAddress, "insecure-address", "", "insecure network address that skipper should listen on when TLS is enabled")
flag.BoolVar(&cfg.EnableTCPQueue, "enable-tcp-queue", false, "enable the TCP listener queue")
flag.IntVar(&cfg.ExpectedBytesPerRequest, "expected-bytes-per-request", 50*1024, "bytes per request, that is used to calculate concurrency limits to buffer connection spikes")
flag.IntVar(&cfg.MaxTCPListenerConcurrency, "max-tcp-listener-concurrency", 0, "sets hardcoded max for TCP listener concurrency, normally calculated based on available memory cgroups with max TODO")
flag.IntVar(&cfg.MaxTCPListenerQueue, "max-tcp-listener-queue", 0, "sets hardcoded max queue size for TCP listener, normally calculated 10x concurrency with max TODO:50k")
flag.BoolVar(&cfg.IgnoreTrailingSlash, "ignore-trailing-slash", false, "flag indicating to ignore trailing slashes in paths when routing")
flag.BoolVar(&cfg.Insecure, "insecure", false, "flag indicating to ignore the verification of the TLS certificates of the backend services")
flag.BoolVar(&cfg.ProxyPreserveHost, "proxy-preserve-host", false, "flag indicating to preserve the incoming request 'Host' header in the outgoing requests")
flag.BoolVar(&cfg.DevMode, "dev-mode", false, "enables developer time behavior, like ubuffered routing updates")
flag.StringVar(&cfg.SupportListener, "support-listener", ":9911", "network address used for exposing the /metrics endpoint. An empty value disables support endpoint.")
flag.StringVar(&cfg.DebugListener, "debug-listener", "", "when this address is set, skipper starts an additional listener returning the original and transformed requests")
flag.StringVar(&cfg.CertPathTLS, "tls-cert", "", "the path on the local filesystem to the certificate file(s) (including any intermediates), multiple may be given comma separated")
flag.StringVar(&cfg.KeyPathTLS, "tls-key", "", "the path on the local filesystem to the certificate's private key file(s), multiple keys may be given comma separated - the order must match the certs")
flag.Var(cfg.StatusChecks, "status-checks", "experimental URLs to check before reporting healthy on startup")
flag.BoolVar(&cfg.PrintVersion, "version", false, "print Skipper version")
flag.IntVar(&cfg.MaxLoopbacks, "max-loopbacks", proxy.DefaultMaxLoopbacks, "maximum number of loopbacks for an incoming request, set to -1 to disable loopbacks")
flag.IntVar(&cfg.DefaultHTTPStatus, "default-http-status", http.StatusNotFound, "default HTTP status used when no route is found for a request")
flag.StringVar(&cfg.PluginDir, "plugindir", "", "set the directory to load plugins from, default is ./")
flag.DurationVar(&cfg.LoadBalancerHealthCheckInterval, "lb-healthcheck-interval", 0, "This is *deprecated* and not in use anymore")
flag.BoolVar(&cfg.ReverseSourcePredicate, "reverse-source-predicate", false, "reverse the order of finding the client IP from X-Forwarded-For header")
flag.BoolVar(&cfg.RemoveHopHeaders, "remove-hop-headers", false, "enables removal of Hop-Headers according to RFC-2616")
flag.BoolVar(&cfg.RfcPatchPath, "rfc-patch-path", false, "patches the incoming request path to preserve uncoded reserved characters according to RFC 2616 and RFC 3986")
flag.IntVar(&cfg.MaxAuditBody, "max-audit-body", 1024, "sets the max body to read to log in the audit log body")
flag.Uint64Var(&cfg.MaxMatcherBufferSize, "max-matcher-buffer-size", 2097152, "sets the maximum read size of the body read by the block filter, default is 2MiB")
flag.BoolVar(&cfg.EnableBreakers, "enable-breakers", false, enableBreakersUsage)
flag.Var(&cfg.Breakers, "breaker", breakerUsage)
flag.BoolVar(&cfg.EnableRatelimiters, "enable-ratelimits", false, enableRatelimitsUsage)
flag.Var(&cfg.Ratelimits, "ratelimits", ratelimitsUsage)
flag.BoolVar(&cfg.EnableRouteFIFOMetrics, "enable-route-fifo-metrics", false, "enable metrics for the individual route FIFO queues")
flag.BoolVar(&cfg.EnableRouteLIFOMetrics, "enable-route-lifo-metrics", false, "enable metrics for the individual route LIFO queues")
flag.Var(cfg.MetricsFlavour, "metrics-flavour", "Metrics flavour is used to change the exposed metrics format. Supported metric formats: 'codahale' and 'prometheus', you can select both of them by using one option with ',' separated values")
flag.Var(cfg.FilterPlugins, "filter-plugin", "set a custom filter plugins to load, a comma separated list of name and arguments")
flag.Var(cfg.PredicatePlugins, "predicate-plugin", "set a custom predicate plugins to load, a comma separated list of name and arguments")
flag.Var(cfg.DataclientPlugins, "dataclient-plugin", "set a custom dataclient plugins to load, a comma separated list of name and arguments")
flag.Var(cfg.MultiPlugins, "multi-plugin", "set a custom multitype plugins to load, a comma separated list of name and arguments")
flag.Var(cfg.CompressEncodings, "compress-encodings", "set encodings supported for compression, the order defines priority when Accept-Header has equal quality values, see RFC 7231 section 5.3.1")
// logging, metrics, tracing:
flag.BoolVar(&cfg.EnablePrometheusMetrics, "enable-prometheus-metrics", false, "*Deprecated*: use metrics-flavour. Switch to Prometheus metrics format to expose metrics")
flag.StringVar(&cfg.OpenTracing, "opentracing", "noop", "list of arguments for opentracing (space separated), first argument is the tracer implementation")
flag.StringVar(&cfg.OpenTracingInitialSpan, "opentracing-initial-span", "ingress", "set the name of the initial, pre-routing, tracing span")
flag.StringVar(&cfg.OpenTracingExcludedProxyTags, "opentracing-excluded-proxy-tags", "", "set tags that should be excluded from spans created for proxy operation. must be a comma-separated list of strings.")
flag.BoolVar(&cfg.OpenTracingDisableFilterSpans, "opentracing-disable-filter-spans", false, "disable creation of spans representing request and response filters")
flag.BoolVar(&cfg.OpentracingLogFilterLifecycleEvents, "opentracing-log-filter-lifecycle-events", true, "enables the logs for request & response filters' lifecycle events that are marking start & end times.")
flag.BoolVar(&cfg.OpentracingLogStreamEvents, "opentracing-log-stream-events", true, "enables the logs for events marking the times response headers & payload are streamed to the client")
flag.BoolVar(&cfg.OpentracingBackendNameTag, "opentracing-backend-name-tag", false, "enables an additional tracing tag that contains a backend name for a route when it's available (e.g. for RouteGroups) (default false)")
flag.StringVar(&cfg.MetricsListener, "metrics-listener", ":9911", "network address used for exposing the /metrics endpoint. An empty value disables metrics iff support listener is also empty.")
flag.StringVar(&cfg.MetricsPrefix, "metrics-prefix", "skipper.", "allows setting a custom path prefix for metrics export")
flag.BoolVar(&cfg.EnableProfile, "enable-profile", false, "enable profile information on the metrics endpoint with path /pprof")
flag.IntVar(&cfg.BlockProfileRate, "block-profile-rate", 0, "block profile sample rate, see runtime.SetBlockProfileRate")
flag.IntVar(&cfg.MutexProfileFraction, "mutex-profile-fraction", 0, "mutex profile fraction rate, see runtime.SetMutexProfileFraction")
flag.IntVar(&cfg.MemProfileRate, "memory-profile-rate", 0, "memory profile rate, see runtime.SetMemProfileRate, keeps default 512 kB")
flag.BoolVar(&cfg.EnablePrometheusStartLabel, "enable-prometheus-start-label", false, "adds start label to each prometheus counter with the value of counter creation timestamp as unix nanoseconds")
flag.BoolVar(&cfg.DebugGcMetrics, "debug-gc-metrics", false, "enables reporting of the Go garbage collector statistics exported in debug.GCStats")
flag.BoolVar(&cfg.RuntimeMetrics, "runtime-metrics", true, "enables reporting of the Go runtime statistics exported in runtime and specifically runtime.MemStats")
flag.BoolVar(&cfg.ServeRouteMetrics, "serve-route-metrics", false, "enables reporting total serve time metrics for each route")
flag.BoolVar(&cfg.ServeRouteCounter, "serve-route-counter", false, "enables reporting counting metrics for each route. Has the route, HTTP method and status code as labels. Currently just implemented for the Prometheus metrics flavour")
flag.BoolVar(&cfg.ServeHostMetrics, "serve-host-metrics", false, "enables reporting total serve time metrics for each host")
flag.BoolVar(&cfg.ServeHostCounter, "serve-host-counter", false, "enables reporting counting metrics for each host. Has the route, HTTP method and status code as labels. Currently just implemented for the Prometheus metrics flavour")
flag.BoolVar(&cfg.ServeMethodMetric, "serve-method-metric", true, "enables the HTTP method as a domain of the total serve time metric. It affects both route and host split metrics")
flag.BoolVar(&cfg.ServeStatusCodeMetric, "serve-status-code-metric", true, "enables the HTTP response status code as a domain of the total serve time metric. It affects both route and host split metrics")
flag.BoolVar(&cfg.BackendHostMetrics, "backend-host-metrics", false, "enables reporting total serve time metrics for each backend")
flag.BoolVar(&cfg.ProxyRequestMetrics, "proxy-request-metrics", false, "enables reporting latency / time spent in handling the request part of the proxy operation i.e., the duration from entry till before the backend round trip")
flag.BoolVar(&cfg.ProxyResponseMetrics, "proxy-response-metrics", false, "enables reporting latency / time spent in handling the reponse part of the proxy operation i.e., the duration from after the backend round trip till the response is served")
flag.BoolVar(&cfg.AllFiltersMetrics, "all-filters-metrics", false, "enables reporting combined filter metrics for each route")
flag.BoolVar(&cfg.CombinedResponseMetrics, "combined-response-metrics", false, "enables reporting combined response time metrics")
flag.BoolVar(&cfg.RouteResponseMetrics, "route-response-metrics", false, "enables reporting response time metrics for each route")
flag.BoolVar(&cfg.RouteBackendErrorCounters, "route-backend-error-counters", false, "enables counting backend errors for each route")
flag.BoolVar(&cfg.RouteStreamErrorCounters, "route-stream-error-counters", false, "enables counting streaming errors for each route")
flag.BoolVar(&cfg.RouteBackendMetrics, "route-backend-metrics", false, "enables reporting backend response time metrics for each route")
flag.BoolVar(&cfg.RouteCreationMetrics, "route-creation-metrics", false, "enables reporting for route creation times")
flag.BoolVar(&cfg.MetricsUseExpDecaySample, "metrics-exp-decay-sample", false, "use exponentially decaying sample in metrics")
flag.StringVar(&cfg.HistogramMetricBucketsString, "histogram-metric-buckets", "", "use custom buckets for prometheus histograms, must be a comma-separated list of numbers")
flag.BoolVar(&cfg.DisableMetricsCompat, "disable-metrics-compat", false, "disables the default true value for all-filters-metrics, route-response-metrics, route-backend-errorCounters and route-stream-error-counters")
flag.StringVar(&cfg.ApplicationLog, "application-log", "", "output file for the application log. When not set, /dev/stderr is used")
flag.StringVar(&cfg.ApplicationLogLevelString, "application-log-level", "INFO", "log level for application logs, possible values: PANIC, FATAL, ERROR, WARN, INFO, DEBUG")
flag.StringVar(&cfg.ApplicationLogPrefix, "application-log-prefix", "[APP]", "prefix for each log entry")
flag.BoolVar(&cfg.ApplicationLogJSONEnabled, "application-log-json-enabled", false, "when this flag is set, log in JSON format is used")
flag.StringVar(&cfg.AccessLog, "access-log", "", "output file for the access log, When not set, /dev/stderr is used")
flag.BoolVar(&cfg.AccessLogDisabled, "access-log-disabled", false, "when this flag is set, no access log is printed")
flag.BoolVar(&cfg.AccessLogJSONEnabled, "access-log-json-enabled", false, "when this flag is set, log in JSON format is used")
flag.BoolVar(&cfg.AccessLogStripQuery, "access-log-strip-query", false, "when this flag is set, the access log strips the query strings from the access log")
flag.BoolVar(&cfg.SuppressRouteUpdateLogs, "suppress-route-update-logs", false, "print only summaries on route updates/deletes")
// route sources:
flag.StringVar(&cfg.EtcdUrls, "etcd-urls", "", "urls of nodes in an etcd cluster, storing route definitions")
flag.StringVar(&cfg.EtcdPrefix, "etcd-prefix", "/skipper", "path prefix for skipper related data in etcd")
flag.DurationVar(&cfg.EtcdTimeout, "etcd-timeout", time.Second, "http client timeout duration for etcd")
flag.BoolVar(&cfg.EtcdInsecure, "etcd-insecure", false, "ignore the verification of TLS certificates for etcd")
flag.StringVar(&cfg.EtcdOAuthToken, "etcd-oauth-token", "", "optional token for OAuth authentication with etcd")
flag.StringVar(&cfg.EtcdUsername, "etcd-username", "", "optional username for basic authentication with etcd")
flag.StringVar(&cfg.EtcdPassword, "etcd-password", "", "optional password for basic authentication with etcd")
flag.StringVar(&cfg.RoutesFile, "routes-file", "", "file containing route definitions")
flag.Var(cfg.RoutesURLs, "routes-urls", "comma separated URLs to route definitions in eskip format")
flag.StringVar(&cfg.InlineRoutes, "inline-routes", "", "inline routes in eskip format")
flag.Int64Var(&cfg.SourcePollTimeout, "source-poll-timeout", int64(3000), "polling timeout of the routing data sources, in milliseconds")
flag.Var(cfg.AppendFilters, "default-filters-append", "set of default filters to apply to append to all filters of all routes")
flag.Var(cfg.PrependFilters, "default-filters-prepend", "set of default filters to apply to prepend to all filters of all routes")
flag.Var(cfg.DisabledFilters, "disabled-filters", "comma separated list of filters unavailable for use")
flag.Var(&cfg.EditRoute, "edit-route", "match and edit filters and predicates of all routes")
flag.Var(&cfg.CloneRoute, "clone-route", "clone all matching routes and replace filters and predicates of all matched routes")
flag.BoolVar(&cfg.WaitFirstRouteLoad, "wait-first-route-load", false, "prevent starting the listener before the first batch of routes were loaded")
// Forwarded headers
flag.Var(cfg.ForwardedHeadersList, "forwarded-headers", "comma separated list of headers to add to the incoming request before routing\n"+
"X-Forwarded-For sets or appends with comma the remote IP of the request to the X-Forwarded-For header value\n"+
"X-Forwarded-Host sets X-Forwarded-Host value to the request host\n"+
"X-Forwarded-Method sets X-Forwarded-Method value to the request method\n"+
"X-Forwarded-Uri sets X-Forwarded-Uri value to the requestURI\n"+
"X-Forwarded-Port=<port> sets X-Forwarded-Port value\n"+
"X-Forwarded-Proto=<http|https> sets X-Forwarded-Proto value")
flag.Var(cfg.ForwardedHeadersExcludeCIDRList, "forwarded-headers-exclude-cidrs", "disables addition of forwarded headers for the remote host IPs from the comma separated list of CIDRs")
flag.BoolVar(&cfg.NormalizeHost, "normalize-host", false, "converts request host to lowercase and removes port and trailing dot if any")
flag.BoolVar(&cfg.ValidateQuery, "validate-query", true, "Validates the HTTP Query of a request and if invalid responds with status code 400")
flag.BoolVar(&cfg.ValidateQueryLog, "validate-query-log", true, "Enable looging for validate query logs")
flag.Var(&cfg.RefusePayload, "refuse-payload", "refuse requests that match configured value. Can be set multiple times")
// Kubernetes:
flag.BoolVar(&cfg.KubernetesIngress, "kubernetes", false, "enables skipper to generate routes for ingress resources in kubernetes cluster. Enables -normalize-host")
flag.BoolVar(&cfg.KubernetesInCluster, "kubernetes-in-cluster", false, "specify if skipper is running inside kubernetes cluster. It will automatically discover API server URL and service account token")
flag.StringVar(&cfg.KubernetesURL, "kubernetes-url", "", "kubernetes API server URL, ignored if kubernetes-in-cluster is set to true")
flag.StringVar(&cfg.KubernetesTokenFile, "kubernetes-token-file", "", "kubernetes token file path, ignored if kubernetes-in-cluster is set to true")
flag.BoolVar(&cfg.KubernetesHealthcheck, "kubernetes-healthcheck", true, "automatic healthcheck route for internal IPs with path /kube-system/healthz; valid only with kubernetes")
flag.BoolVar(&cfg.KubernetesHTTPSRedirect, "kubernetes-https-redirect", true, "automatic HTTP->HTTPS redirect route; valid only with kubernetes")
flag.IntVar(&cfg.KubernetesHTTPSRedirectCode, "kubernetes-https-redirect-code", 308, "overrides the default redirect code (308) when used together with -kubernetes-https-redirect")
flag.BoolVar(&cfg.KubernetesDisableCatchAllRoutes, "kubernetes-disable-catchall-routes", false, "disables creation of catchall routes")
flag.StringVar(&cfg.KubernetesIngressClass, "kubernetes-ingress-class", "", "ingress class regular expression used to filter ingress resources for kubernetes")
flag.StringVar(&cfg.KubernetesRouteGroupClass, "kubernetes-routegroup-class", "", "route group class regular expression used to filter route group resources for kubernetes")
flag.StringVar(&cfg.WhitelistedHealthCheckCIDR, "whitelisted-healthcheck-cidr", "", "sets the iprange/CIDRS to be whitelisted during healthcheck")
flag.StringVar(&cfg.KubernetesPathModeString, "kubernetes-path-mode", "kubernetes-ingress", "controls the default interpretation of Kubernetes ingress paths: <kubernetes-ingress|path-regexp|path-prefix>")
flag.StringVar(&cfg.KubernetesNamespace, "kubernetes-namespace", "", "watch only this namespace for ingresses")
flag.BoolVar(&cfg.KubernetesEnableEndpointSlices, "enable-kubernetes-endpointslices", false, "Enables that skipper fetches Kubernetes endpointslices instead of endpoints to scale more than 1000 pods within a service")
flag.BoolVar(&cfg.KubernetesEnableEastWest, "enable-kubernetes-east-west", false, "*Deprecated*: use kubernetes-east-west-range feature. Enables east-west communication, which automatically adds routes for Ingress objects with hostname <name>.<namespace>.skipper.cluster.local")
flag.StringVar(&cfg.KubernetesEastWestDomain, "kubernetes-east-west-domain", "", "*Deprecated*: use kubernetes-east-west-range feature. Sets the east-west domain, defaults to .skipper.cluster.local")
flag.Var(cfg.KubernetesEastWestRangeDomains, "kubernetes-east-west-range-domains", "set the the cluster internal domains for east west traffic. Identified routes to such domains will include the -kubernetes-east-west-range-predicates")
flag.StringVar(&cfg.KubernetesEastWestRangePredicatesString, "kubernetes-east-west-range-predicates", "", "set the predicates that will be appended to routes identified as to -kubernetes-east-west-range-domains")
flag.Var(&cfg.KubernetesAnnotationPredicatesString, "kubernetes-annotation-predicates", "configures predicates appended to non east-west routes of annotated resources. E.g. -kubernetes-annotation-predicates='zone-a=true=Foo() && Bar()' will add 'Foo() && Bar()' predicates to all non east-west routes of ingress or routegroup annotated with 'zone-a: true'. For east-west routes use -kubernetes-east-west-range-annotation-predicates.")
flag.Var(&cfg.KubernetesAnnotationFiltersAppendString, "kubernetes-annotation-filters-append", "configures filters appended to non east-west routes of annotated resources. E.g. -kubernetes-annotation-filters-append='zone-a=true=foo() -> bar()' will add 'foo() -> bar()' filters to all non east-west routes of ingress or routegroup annotated with 'zone-a: true'. For east-west routes use -kubernetes-east-west-range-annotation-filters-append.")
flag.Var(&cfg.KubernetesEastWestRangeAnnotationPredicatesString, "kubernetes-east-west-range-annotation-predicates", "similar to -kubernetes-annotation-predicates configures predicates appended to east-west routes of annotated resources. See also -kubernetes-east-west-range-domains.")
flag.Var(&cfg.KubernetesEastWestRangeAnnotationFiltersAppendString, "kubernetes-east-west-range-annotation-filters-append", "similar to -kubernetes-annotation-filters-append configures filters appended to east-west routes of annotated resources. See also -kubernetes-east-west-range-domains.")
flag.BoolVar(&cfg.KubernetesOnlyAllowedExternalNames, "kubernetes-only-allowed-external-names", false, "only accept external name services, route group network backends and route group explicit LB endpoints from an allow list defined by zero or more -kubernetes-allowed-external-name flags")
flag.Var(&cfg.KubernetesAllowedExternalNames, "kubernetes-allowed-external-name", "set zero or more regular expressions from which at least one should be matched by the external name services, route group network addresses and explicit endpoints domain names")
flag.StringVar(&cfg.KubernetesRedisServiceNamespace, "kubernetes-redis-service-namespace", "", "Sets namespace for redis to be used to lookup endpoints")
flag.StringVar(&cfg.KubernetesRedisServiceName, "kubernetes-redis-service-name", "", "Sets name for redis to be used to lookup endpoints")
flag.IntVar(&cfg.KubernetesRedisServicePort, "kubernetes-redis-service-port", 6379, "Sets the port for redis to be used to lookup endpoints")
flag.StringVar(&cfg.KubernetesBackendTrafficAlgorithmString, "kubernetes-backend-traffic-algorithm", kubernetes.TrafficPredicateAlgorithm.String(), "sets the algorithm to be used for traffic splitting between backends: traffic-predicate or traffic-segment-predicate")
flag.StringVar(&cfg.KubernetesDefaultLoadBalancerAlgorithm, "kubernetes-default-lb-algorithm", kubernetes.DefaultLoadBalancerAlgorithm, "sets the default algorithm to be used for load balancing between backend endpoints, available options: roundRobin, consistentHash, random, powerOfRandomNChoices")
flag.BoolVar(&cfg.KubernetesForceService, "kubernetes-force-service", false, "overrides default Skipper functionality and routes traffic using Kubernetes Services instead of Endpoints")
// Auth:
flag.BoolVar(&cfg.EnableOAuth2GrantFlow, "enable-oauth2-grant-flow", false, "enables OAuth2 Grant Flow filter")
flag.StringVar(&cfg.Oauth2AuthURL, "oauth2-auth-url", "", "sets the OAuth2 Auth URL to redirect the requests to when login is required")
flag.StringVar(&cfg.Oauth2TokenURL, "oauth2-token-url", "", "the url where the access code should be exchanged for the access token")
flag.StringVar(&cfg.Oauth2RevokeTokenURL, "oauth2-revoke-token-url", "", "the url where the access and refresh tokens can be revoked when logging out")
flag.StringVar(&cfg.Oauth2TokeninfoURL, "oauth2-tokeninfo-url", "", "sets the default tokeninfo URL to query information about an incoming OAuth2 token in oauth2Tokeninfo filters")
flag.StringVar(&cfg.Oauth2SecretFile, "oauth2-secret-file", "", "sets the filename with the encryption key for the authentication cookie and grant flow state stored in secrets registry")
flag.StringVar(&cfg.Oauth2ClientID, "oauth2-client-id", "", "sets the OAuth2 client id of the current service, used to exchange the access code. Falls back to env variable OAUTH2_CLIENT_ID if value is empty.")
flag.StringVar(&cfg.Oauth2ClientSecret, "oauth2-client-secret", "", "sets the OAuth2 client secret associated with the oauth2-client-id, used to exchange the access code. Falls back to env variable OAUTH2_CLIENT_SECRET if value is empty.")
flag.StringVar(&cfg.Oauth2ClientIDFile, "oauth2-client-id-file", "", "sets the path of the file containing the OAuth2 client id of the current service, used to exchange the access code. "+
"File name may contain {host} placeholder which will be replaced by the request host")
flag.StringVar(&cfg.Oauth2ClientSecretFile, "oauth2-client-secret-file", "", "sets the path of the file containing the OAuth2 client secret associated with the oauth2-client-id, used to exchange the access code. "+
"File name may contain {host} placeholder which will be replaced by the request host")
flag.StringVar(&cfg.Oauth2CallbackPath, "oauth2-callback-path", "", "sets the path where the OAuth2 callback requests with the authorization code should be redirected to")
flag.DurationVar(&cfg.Oauth2TokeninfoTimeout, "oauth2-tokeninfo-timeout", 2*time.Second, "sets the default tokeninfo request timeout duration to 2000ms")
flag.IntVar(&cfg.Oauth2TokeninfoCacheSize, "oauth2-tokeninfo-cache-size", 0, "non-zero value enables tokeninfo cache and sets the maximum number of cached tokens")
flag.DurationVar(&cfg.Oauth2TokeninfoCacheTTL, "oauth2-tokeninfo-cache-ttl", 0, "non-zero value limits the lifetime of a cached tokeninfo which otherwise equals to the tokeninfo 'expires_in' field value")
flag.DurationVar(&cfg.Oauth2TokenintrospectionTimeout, "oauth2-tokenintrospect-timeout", 2*time.Second, "sets the default tokenintrospection request timeout duration to 2000ms")
flag.Var(&cfg.Oauth2AuthURLParameters, "oauth2-auth-url-parameters", "sets additional parameters to send when calling the OAuth2 authorize or token endpoints as key-value pairs")
flag.StringVar(&cfg.Oauth2AccessTokenHeaderName, "oauth2-access-token-header-name", "", "sets the access token to a header on the request with this name")
flag.StringVar(&cfg.Oauth2TokeninfoSubjectKey, "oauth2-tokeninfo-subject-key", "uid", "sets the tokeninfo subject key")
flag.Var(cfg.Oauth2GrantTokeninfoKeys, "oauth2-grant-tokeninfo-keys", "non-empty comma separated list configures keys to preserve in OAuth2 Grant Flow tokeninfo")
flag.StringVar(&cfg.Oauth2TokenCookieName, "oauth2-token-cookie-name", "oauth2-grant", "sets the name of the cookie where the encrypted token is stored")
flag.IntVar(&cfg.Oauth2TokenCookieRemoveSubdomains, "oauth2-token-cookie-remove-subdomains", 1, "sets the number of subdomains to remove from the callback request hostname to obtain token cookie domain")
flag.BoolVar(&cfg.Oauth2GrantInsecure, "oauth2-grant-insecure", false, "omits Secure attribute of the token cookie and uses http scheme for callback url")
flag.DurationVar(&cfg.WebhookTimeout, "webhook-timeout", 2*time.Second, "sets the webhook request timeout duration")
flag.StringVar(&cfg.OidcSecretsFile, "oidc-secrets-file", "", "file storing the encryption key of the OID Connect token. Enables OIDC filters")
flag.DurationVar(&cfg.OIDCCookieValidity, "oidc-cookie-validity", time.Hour, "sets the cookie expiry time to +1h for OIDC filters, in case no 'exp' claim is found in the JWT token")
flag.DurationVar(&cfg.OidcDistributedClaimsTimeout, "oidc-distributed-claims-timeout", 2*time.Second, "sets the default OIDC distributed claims request timeout duration to 2000ms")
flag.IntVar(&cfg.OIDCCookieRemoveSubdomains, "oidc-cookie-remove-subdomains", 1, "sets the number of subdomains to remove from the callback request hostname to obtain token cookie domain")
flag.Var(cfg.CredentialPaths, "credentials-paths", "directories or files to watch for credentials to use by bearerinjector filter")
flag.DurationVar(&cfg.CredentialsUpdateInterval, "credentials-update-interval", 10*time.Minute, "sets the interval to update secrets")
flag.BoolVar(&cfg.EnableOpenPolicyAgent, "enable-open-policy-agent", false, "enables Open Policy Agent filters")
flag.BoolVar(&cfg.EnableOpenPolicyAgentCustomControlLoop, "enable-open-policy-agent-custom-control-loop", false, "when enabled skipper will use a custom control loop to orchestrate certain opa behaviour (like the download of new bundles) instead of relying on periodic plugin triggers")
flag.DurationVar(&cfg.OpenPolicyAgentControlLoopInterval, "open-policy-agent-control-loop-interval", openpolicyagent.DefaultControlLoopInterval, "Interval between the execution of the control loop. Only applies if the custom control loop is enabled")
flag.DurationVar(&cfg.OpenPolicyAgentControlLoopMaxJitter, "open-policy-agent-control-loop-max-jitter", openpolicyagent.DefaultControlLoopMaxJitter, "Maximum jitter to add to the control loop interval. Only applies if the custom control loop is enabled")
flag.BoolVar(&cfg.EnableOpenPolicyAgentDataPreProcessingOptimization, "enable-open-policy-agent-data-preprocessing-optimization", false, "As a latency optimization, open policy agent will read values from in-memory storage as pre converted ASTs, removing conversion overhead at evaluation time. Currently experimental and if successful will be enabled by default")
flag.StringVar(&cfg.OpenPolicyAgentConfigTemplate, "open-policy-agent-config-template", "", "file containing a template for an Open Policy Agent configuration file that is interpolated for each OPA filter instance")
flag.StringVar(&cfg.OpenPolicyAgentEnvoyMetadata, "open-policy-agent-envoy-metadata", "", "JSON file containing meta-data passed as input for compatibility with Envoy policies in the format")
flag.DurationVar(&cfg.OpenPolicyAgentCleanerInterval, "open-policy-agent-cleaner-interval", openpolicyagent.DefaultCleanIdlePeriod, "Duration in seconds to wait before cleaning up unused opa instances")
flag.DurationVar(&cfg.OpenPolicyAgentStartupTimeout, "open-policy-agent-startup-timeout", openpolicyagent.DefaultOpaStartupTimeout, "Maximum duration in seconds to wait for the open policy agent to start up and if the custom control loop is enabled, how long to wait for the processing of each instance to finish (to f.ex. download updated bundles)")
flag.Int64Var(&cfg.OpenPolicyAgentMaxRequestBodySize, "open-policy-agent-max-request-body-size", openpolicyagent.DefaultMaxRequestBodySize, "Maximum number of bytes from a http request body that are passed as input to the policy")
flag.Int64Var(&cfg.OpenPolicyAgentRequestBodyBufferSize, "open-policy-agent-request-body-buffer-size", openpolicyagent.DefaultRequestBodyBufferSize, "Read buffer size for the request body")
flag.Int64Var(&cfg.OpenPolicyAgentMaxMemoryBodyParsing, "open-policy-agent-max-memory-body-parsing", openpolicyagent.DefaultMaxMemoryBodyParsing, "Total number of bytes used to parse http request bodies across all requests. Once the limit is met, requests will be rejected.")
// TLS client certs
flag.StringVar(&cfg.ClientKeyFile, "client-tls-key", "", "TLS Key file for backend connections, multiple keys may be given comma separated - the order must match the certs")
flag.StringVar(&cfg.ClientCertFile, "client-tls-cert", "", "TLS certificate files for backend connections, multiple keys may be given comma separated - the order must match the keys")
// TLS version
flag.StringVar(&cfg.TLSMinVersion, "tls-min-version", defaultMinTLSVersion, "minimal TLS Version to be used in server, proxy and client connections")
flag.Func("tls-client-auth", "TLS client authentication policy for server, one of: "+
"NoClientCert, RequestClientCert, RequireAnyClientCert, VerifyClientCertIfGiven or RequireAndVerifyClientCert. "+
"See https://pkg.go.dev/crypto/tls#ClientAuthType for details.", cfg.setTLSClientAuth)
// Exclude insecure cipher suites
flag.BoolVar(&cfg.ExcludeInsecureCipherSuites, "exclude-insecure-cipher-suites", false, "excludes insecure cipher suites")
// API Monitoring:
flag.BoolVar(&cfg.ApiUsageMonitoringEnable, "enable-api-usage-monitoring", false, "enables the apiUsageMonitoring filter")
flag.StringVar(&cfg.ApiUsageMonitoringRealmKeys, "api-usage-monitoring-realm-keys", "", "name of the property in the JWT payload that contains the authority realm")
flag.StringVar(&cfg.ApiUsageMonitoringClientKeys, "api-usage-monitoring-client-keys", "sub", "comma separated list of names of the properties in the JWT body that contains the client ID")
flag.StringVar(&cfg.ApiUsageMonitoringDefaultClientTrackingPattern, "api-usage-monitoring-default-client-tracking-pattern", "", "*Deprecated*: set `client_tracking_pattern` directly on filter")
flag.StringVar(&cfg.ApiUsageMonitoringRealmsTrackingPattern, "api-usage-monitoring-realms-tracking-pattern", "services", "regular expression used for matching monitored realms (defaults is 'services')")
// Default filters:
flag.StringVar(&cfg.DefaultFiltersDir, "default-filters-dir", "", "path to directory which contains default filter configurations per service and namespace (disabled if not set)")
// Connections, timeouts:
flag.DurationVar(&cfg.WaitForHealthcheckInterval, "wait-for-healthcheck-interval", (10+5)*3*time.Second, "period waiting to become unhealthy in the loadbalancer pool in front of this instance, before shutdown triggered by SIGINT or SIGTERM") // kube-ingress-aws-controller default
flag.IntVar(&cfg.IdleConnsPerHost, "idle-conns-num", proxy.DefaultIdleConnsPerHost, "maximum idle connections per backend host")
flag.DurationVar(&cfg.CloseIdleConnsPeriod, "close-idle-conns-period", proxy.DefaultCloseIdleConnsPeriod, "sets the time interval of closing all idle connections. Not closing when 0")
flag.DurationVar(&cfg.BackendFlushInterval, "backend-flush-interval", 20*time.Millisecond, "flush interval for upgraded proxy connections")
flag.BoolVar(&cfg.ExperimentalUpgrade, "experimental-upgrade", false, "enable experimental feature to handle upgrade protocol requests")
flag.BoolVar(&cfg.ExperimentalUpgradeAudit, "experimental-upgrade-audit", false, "enable audit logging of the request line and the messages during the experimental web socket upgrades")
flag.DurationVar(&cfg.ReadTimeoutServer, "read-timeout-server", 5*time.Minute, "set ReadTimeout for http server connections")
flag.DurationVar(&cfg.ReadHeaderTimeoutServer, "read-header-timeout-server", 60*time.Second, "set ReadHeaderTimeout for http server connections")
flag.DurationVar(&cfg.WriteTimeoutServer, "write-timeout-server", 60*time.Second, "set WriteTimeout for http server connections")
flag.DurationVar(&cfg.IdleTimeoutServer, "idle-timeout-server", 60*time.Second, "set IdleTimeout for http server connections")
flag.DurationVar(&cfg.KeepaliveServer, "keepalive-server", 0*time.Second, "sets maximum age for http server connections. The connection is closed after it existed for this duration. Default is 0 for unlimited.")
flag.IntVar(&cfg.KeepaliveRequestsServer, "keepalive-requests-server", 0, "sets maximum number of requests for http server connections. The connection is closed after serving this number of requests. Default is 0 for unlimited.")
flag.IntVar(&cfg.MaxHeaderBytes, "max-header-bytes", http.DefaultMaxHeaderBytes, "set MaxHeaderBytes for http server connections")
flag.BoolVar(&cfg.EnableConnMetricsServer, "enable-connection-metrics", false, "enables connection metrics for http server connections")
flag.DurationVar(&cfg.TimeoutBackend, "timeout-backend", 60*time.Second, "sets the TCP client connection timeout for backend connections")
flag.DurationVar(&cfg.KeepaliveBackend, "keepalive-backend", 30*time.Second, "sets the keepalive for backend connections")
flag.BoolVar(&cfg.EnableDualstackBackend, "enable-dualstack-backend", true, "enables DualStack for backend connections")
flag.DurationVar(&cfg.TlsHandshakeTimeoutBackend, "tls-timeout-backend", 60*time.Second, "sets the TLS handshake timeout for backend connections")
flag.DurationVar(&cfg.ResponseHeaderTimeoutBackend, "response-header-timeout-backend", 60*time.Second, "sets the HTTP response header timeout for backend connections")
flag.DurationVar(&cfg.ExpectContinueTimeoutBackend, "expect-continue-timeout-backend", 30*time.Second, "sets the HTTP expect continue timeout for backend connections")
flag.IntVar(&cfg.MaxIdleConnsBackend, "max-idle-connection-backend", 0, "sets the maximum idle connections for all backend connections")
flag.BoolVar(&cfg.DisableHTTPKeepalives, "disable-http-keepalives", false, "forces backend to always create a new connection")
flag.BoolVar(&cfg.KubernetesEnableTLS, "kubernetes-enable-tls", false, "enable using kubnernetes resources to terminate tls")
// Swarm:
flag.BoolVar(&cfg.EnableSwarm, "enable-swarm", false, "enable swarm communication between nodes in a skipper fleet")
flag.Var(cfg.SwarmRedisURLs, "swarm-redis-urls", "Redis URLs as comma separated list, used for building a swarm, for example in redis based cluster ratelimits.\nUse "+redisPasswordEnv+" environment variable or 'swarm-redis-password' key in config file to set redis password")
flag.StringVar(&cfg.SwarmRedisHashAlgorithm, "swarm-redis-hash-algorithm", "", "sets hash algorithm to be used in redis ring client to find the shard <jump|mpchash|rendezvous|rendezvousVnodes>, defaults to github.com/redis/go-redis default")
flag.DurationVar(&cfg.SwarmRedisDialTimeout, "swarm-redis-dial-timeout", net.DefaultDialTimeout, "set redis client dial timeout")
flag.DurationVar(&cfg.SwarmRedisReadTimeout, "swarm-redis-read-timeout", net.DefaultReadTimeout, "set redis socket read timeout")
flag.DurationVar(&cfg.SwarmRedisWriteTimeout, "swarm-redis-write-timeout", net.DefaultWriteTimeout, "set redis socket write timeout")
flag.DurationVar(&cfg.SwarmRedisPoolTimeout, "swarm-redis-pool-timeout", net.DefaultPoolTimeout, "set redis get connection from pool timeout")
flag.IntVar(&cfg.SwarmRedisMinConns, "swarm-redis-min-conns", net.DefaultMinConns, "set min number of connections to redis")
flag.IntVar(&cfg.SwarmRedisMaxConns, "swarm-redis-max-conns", net.DefaultMaxConns, "set max number of connections to redis")
flag.StringVar(&cfg.SwarmRedisEndpointsRemoteURL, "swarm-redis-remote", "", "Remote URL to pull redis endpoints from.")
flag.StringVar(&cfg.SwarmKubernetesNamespace, "swarm-namespace", swarm.DefaultNamespace, "Kubernetes namespace to find swarm peer instances")
flag.StringVar(&cfg.SwarmKubernetesLabelSelectorKey, "swarm-label-selector-key", swarm.DefaultLabelSelectorKey, "Kubernetes labelselector key to find swarm peer instances")
flag.StringVar(&cfg.SwarmKubernetesLabelSelectorValue, "swarm-label-selector-value", swarm.DefaultLabelSelectorValue, "Kubernetes labelselector value to find swarm peer instances")
flag.IntVar(&cfg.SwarmPort, "swarm-port", swarm.DefaultPort, "swarm port to use to communicate with our peers")
flag.IntVar(&cfg.SwarmMaxMessageBuffer, "swarm-max-msg-buffer", swarm.DefaultMaxMessageBuffer, "swarm max message buffer size to use for member list messages")
flag.DurationVar(&cfg.SwarmLeaveTimeout, "swarm-leave-timeout", swarm.DefaultLeaveTimeout, "swarm leave timeout to use for leaving the memberlist on timeout")
flag.StringVar(&cfg.SwarmStaticSelf, "swarm-static-self", "", "set static swarm self node, for example 127.0.0.1:9001")
flag.StringVar(&cfg.SwarmStaticOther, "swarm-static-other", "", "set static swarm all nodes, for example 127.0.0.1:9002,127.0.0.1:9003")
flag.IntVar(&cfg.ClusterRatelimitMaxGroupShards, "cluster-ratelimit-max-group-shards", 1, "sets the maximum number of group shards for the clusterRatelimit filter")
flag.Var(cfg.LuaModules, "lua-modules", "comma separated list of lua filter modules. Use <module>.<symbol> to selectively enable module symbols, for example: package,base._G,base.print,json")
flag.Var(cfg.LuaSources, "lua-sources", `comma separated list of lua input types for the lua() filter. Valid sources "", "file", "inline", "file,inline" and "none". Use "file" to only allow lua file references in lua filter. Default "" is the same as "file","inline". Use "none" to disable lua filters.`)
// Passive Health Checks
flag.Var(&cfg.PassiveHealthCheck, "passive-health-check", "sets the parameters for passive health check feature")
cfg.Flags = flag
return cfg
}
func validate(c *Config) error {
_, err := log.ParseLevel(c.ApplicationLogLevelString)
if err != nil {
return err
}
_, err = kubernetes.ParsePathMode(c.KubernetesPathModeString)
if err != nil {
return err
}
_, err = eskip.ParsePredicates(c.KubernetesEastWestRangePredicatesString)
if err != nil {
return fmt.Errorf("invalid east-west-range-predicates: %w", err)
}
_, err = parseAnnotationPredicates(c.KubernetesAnnotationPredicatesString)
if err != nil {
return fmt.Errorf("invalid annotation predicates: %q, %w", c.KubernetesAnnotationPredicatesString, err)
}
_, err = parseAnnotationFilters(c.KubernetesAnnotationFiltersAppendString)
if err != nil {
return fmt.Errorf("invalid annotation filters: %q, %w", c.KubernetesAnnotationFiltersAppendString, err)
}
_, err = parseAnnotationPredicates(c.KubernetesEastWestRangeAnnotationPredicatesString)
if err != nil {
return fmt.Errorf("invalid east-west annotation predicates: %q, %w", c.KubernetesEastWestRangeAnnotationPredicatesString, err)
}
_, err = parseAnnotationFilters(c.KubernetesEastWestRangeAnnotationFiltersAppendString)
if err != nil {
return fmt.Errorf("invalid east-west annotation filters: %q, %w", c.KubernetesEastWestRangeAnnotationFiltersAppendString, err)
}
_, err = kubernetes.ParseBackendTrafficAlgorithm(c.KubernetesBackendTrafficAlgorithmString)
if err != nil {
return err
}
_, err = c.parseHistogramBuckets()
if err != nil {
return err
}
return c.parseForwardedHeaders()
}
func (c *Config) Parse() error {
return c.ParseArgs(os.Args[0], os.Args[1:])
}
func (c *Config) ParseArgs(progname string, args []string) error {
c.Flags.Init(progname, flag.ExitOnError)
err := c.Flags.Parse(args)
if err != nil {
return err
}
// check if arguments were correctly parsed.
if len(c.Flags.Args()) != 0 {
return fmt.Errorf("invalid arguments: %s", c.Flags.Args())
}
configKeys := make(map[string]interface{})
if c.ConfigFile != "" {
yamlFile, err := os.ReadFile(c.ConfigFile)
if err != nil {
return fmt.Errorf("invalid config file: %w", err)
}
err = yaml.Unmarshal(yamlFile, c)
if err != nil {
return fmt.Errorf("unmarshalling config file error: %w", err)
}
_ = yaml.Unmarshal(yamlFile, configKeys)
err = c.Flags.Parse(args)
if err != nil {
return err
}
}
c.checkDeprecated(configKeys,
"enable-prometheus-metrics",
"api-usage-monitoring-default-client-tracking-pattern",
"enable-kubernetes-east-west",
"kubernetes-east-west-domain",
"lb-healthcheck-interval",
)
if err := validate(c); err != nil {
return err
}
c.ApplicationLogLevel, _ = log.ParseLevel(c.ApplicationLogLevelString)
c.KubernetesPathMode, _ = kubernetes.ParsePathMode(c.KubernetesPathModeString)
c.KubernetesEastWestRangePredicates, _ = eskip.ParsePredicates(c.KubernetesEastWestRangePredicatesString)
c.KubernetesAnnotationPredicates, _ = parseAnnotationPredicates(c.KubernetesAnnotationPredicatesString)
c.KubernetesAnnotationFiltersAppend, _ = parseAnnotationFilters(c.KubernetesAnnotationFiltersAppendString)
c.KubernetesEastWestRangeAnnotationPredicates, _ = parseAnnotationPredicates(c.KubernetesEastWestRangeAnnotationPredicatesString)
c.KubernetesEastWestRangeAnnotationFiltersAppend, _ = parseAnnotationFilters(c.KubernetesEastWestRangeAnnotationFiltersAppendString)
c.KubernetesBackendTrafficAlgorithm, _ = kubernetes.ParseBackendTrafficAlgorithm(c.KubernetesBackendTrafficAlgorithmString)
c.HistogramMetricBuckets, _ = c.parseHistogramBuckets()
if c.ClientKeyFile != "" && c.ClientCertFile != "" {
certsFiles := strings.Split(c.ClientCertFile, ",")
keyFiles := strings.Split(c.ClientKeyFile, ",")
var certificates []tls.Certificate
for i := range keyFiles {
certificate, err := tls.LoadX509KeyPair(certsFiles[i], keyFiles[i])
if err != nil {
return fmt.Errorf("invalid key/cert pair: %w", err)
}
certificates = append(certificates, certificate)
}
c.Certificates = certificates
}
if c.NormalizeHost || c.KubernetesIngress {
c.HostPatch = net.HostPatch{
ToLower: true,
RemovePort: true,
RemoteTrailingDot: true,
}
}
c.parseEnv()
return nil
}
func (c *Config) ToOptions() skipper.Options {
var eus []string
if len(c.EtcdUrls) > 0 {
eus = strings.Split(c.EtcdUrls, ",")
}
var whitelistCIDRS []string
if len(c.WhitelistedHealthCheckCIDR) > 0 {
whitelistCIDRS = strings.Split(c.WhitelistedHealthCheckCIDR, ",")
}
options := skipper.Options{
// generic:
Address: c.Address,
InsecureAddress: c.InsecureAddress,
StatusChecks: c.StatusChecks.values,
EnableTCPQueue: c.EnableTCPQueue,
ExpectedBytesPerRequest: c.ExpectedBytesPerRequest,
MaxTCPListenerConcurrency: c.MaxTCPListenerConcurrency,
MaxTCPListenerQueue: c.MaxTCPListenerQueue,
IgnoreTrailingSlash: c.IgnoreTrailingSlash,
DevMode: c.DevMode,
SupportListener: c.SupportListener,
DebugListener: c.DebugListener,
CertPathTLS: c.CertPathTLS,
KeyPathTLS: c.KeyPathTLS,
TLSClientAuth: c.TLSClientAuth,
TLSMinVersion: c.getMinTLSVersion(),
CipherSuites: c.filterCipherSuites(),
MaxLoopbacks: c.MaxLoopbacks,
DefaultHTTPStatus: c.DefaultHTTPStatus,
ReverseSourcePredicate: c.ReverseSourcePredicate,
MaxAuditBody: c.MaxAuditBody,
MaxMatcherBufferSize: c.MaxMatcherBufferSize,
EnableBreakers: c.EnableBreakers,
BreakerSettings: c.Breakers,
EnableRatelimiters: c.EnableRatelimiters,
RatelimitSettings: c.Ratelimits,
EnableRouteFIFOMetrics: c.EnableRouteFIFOMetrics,
EnableRouteLIFOMetrics: c.EnableRouteLIFOMetrics,
MetricsFlavours: c.MetricsFlavour.values,
FilterPlugins: c.FilterPlugins.values,
PredicatePlugins: c.PredicatePlugins.values,
DataClientPlugins: c.DataclientPlugins.values,
Plugins: c.MultiPlugins.values,
PluginDirs: []string{skipper.DefaultPluginDir},
CompressEncodings: c.CompressEncodings.values,
// logging, metrics, profiling, tracing:
EnablePrometheusMetrics: c.EnablePrometheusMetrics,
EnablePrometheusStartLabel: c.EnablePrometheusStartLabel,
OpenTracing: strings.Split(c.OpenTracing, " "),
OpenTracingInitialSpan: c.OpenTracingInitialSpan,
OpenTracingExcludedProxyTags: strings.Split(c.OpenTracingExcludedProxyTags, ","),
OpenTracingDisableFilterSpans: c.OpenTracingDisableFilterSpans,
OpenTracingLogStreamEvents: c.OpentracingLogStreamEvents,
OpenTracingLogFilterLifecycleEvents: c.OpentracingLogFilterLifecycleEvents,
MetricsListener: c.MetricsListener,
MetricsPrefix: c.MetricsPrefix,
EnableProfile: c.EnableProfile,
BlockProfileRate: c.BlockProfileRate,
MutexProfileFraction: c.MutexProfileFraction,
EnableDebugGcMetrics: c.DebugGcMetrics,
EnableRuntimeMetrics: c.RuntimeMetrics,
EnableServeRouteMetrics: c.ServeRouteMetrics,
EnableServeRouteCounter: c.ServeRouteCounter,
EnableServeHostMetrics: c.ServeHostMetrics,
EnableServeHostCounter: c.ServeHostCounter,
EnableServeMethodMetric: c.ServeMethodMetric,
EnableServeStatusCodeMetric: c.ServeStatusCodeMetric,
EnableProxyRequestMetrics: c.ProxyRequestMetrics,
EnableProxyResponseMetrics: c.ProxyResponseMetrics,
EnableBackendHostMetrics: c.BackendHostMetrics,
EnableAllFiltersMetrics: c.AllFiltersMetrics,
EnableCombinedResponseMetrics: c.CombinedResponseMetrics,
EnableRouteResponseMetrics: c.RouteResponseMetrics,
EnableRouteBackendErrorsCounters: c.RouteBackendErrorCounters,
EnableRouteStreamingErrorsCounters: c.RouteStreamErrorCounters,
EnableRouteBackendMetrics: c.RouteBackendMetrics,
EnableRouteCreationMetrics: c.RouteCreationMetrics,
MetricsUseExpDecaySample: c.MetricsUseExpDecaySample,
HistogramMetricBuckets: c.HistogramMetricBuckets,
DisableMetricsCompatibilityDefaults: c.DisableMetricsCompat,
ApplicationLogOutput: c.ApplicationLog,
ApplicationLogPrefix: c.ApplicationLogPrefix,
ApplicationLogJSONEnabled: c.ApplicationLogJSONEnabled,
AccessLogOutput: c.AccessLog,
AccessLogDisabled: c.AccessLogDisabled,
AccessLogJSONEnabled: c.AccessLogJSONEnabled,
AccessLogStripQuery: c.AccessLogStripQuery,
SuppressRouteUpdateLogs: c.SuppressRouteUpdateLogs,
// route sources:
EtcdUrls: eus,
EtcdPrefix: c.EtcdPrefix,
EtcdWaitTimeout: c.EtcdTimeout,
EtcdInsecure: c.EtcdInsecure,
EtcdOAuthToken: c.EtcdOAuthToken,
EtcdUsername: c.EtcdUsername,
EtcdPassword: c.EtcdPassword,
WatchRoutesFile: c.RoutesFile,
RoutesURLs: c.RoutesURLs.values,
InlineRoutes: c.InlineRoutes,
DefaultFilters: &eskip.DefaultFilters{
Prepend: c.PrependFilters.filters,
Append: c.AppendFilters.filters,
},
DisabledFilters: c.DisabledFilters.values,
SourcePollTimeout: time.Duration(c.SourcePollTimeout) * time.Millisecond,
WaitFirstRouteLoad: c.WaitFirstRouteLoad,
// Kubernetes:
Kubernetes: c.KubernetesIngress,
KubernetesInCluster: c.KubernetesInCluster,
KubernetesURL: c.KubernetesURL,
KubernetesTokenFile: c.KubernetesTokenFile,
KubernetesHealthcheck: c.KubernetesHealthcheck,
KubernetesHTTPSRedirect: c.KubernetesHTTPSRedirect,
KubernetesHTTPSRedirectCode: c.KubernetesHTTPSRedirectCode,
KubernetesDisableCatchAllRoutes: c.KubernetesDisableCatchAllRoutes,
KubernetesIngressClass: c.KubernetesIngressClass,
KubernetesRouteGroupClass: c.KubernetesRouteGroupClass,
WhitelistedHealthCheckCIDR: whitelistCIDRS,
KubernetesPathMode: c.KubernetesPathMode,
KubernetesNamespace: c.KubernetesNamespace,
KubernetesEnableEndpointslices: c.KubernetesEnableEndpointSlices,
KubernetesEnableEastWest: c.KubernetesEnableEastWest,
KubernetesEastWestDomain: c.KubernetesEastWestDomain,
KubernetesEastWestRangeDomains: c.KubernetesEastWestRangeDomains.values,
KubernetesEastWestRangePredicates: c.KubernetesEastWestRangePredicates,
KubernetesEastWestRangeAnnotationPredicates: c.KubernetesEastWestRangeAnnotationPredicates,
KubernetesEastWestRangeAnnotationFiltersAppend: c.KubernetesEastWestRangeAnnotationFiltersAppend,
KubernetesAnnotationPredicates: c.KubernetesAnnotationPredicates,
KubernetesAnnotationFiltersAppend: c.KubernetesAnnotationFiltersAppend,
KubernetesOnlyAllowedExternalNames: c.KubernetesOnlyAllowedExternalNames,
KubernetesAllowedExternalNames: c.KubernetesAllowedExternalNames,
KubernetesRedisServiceNamespace: c.KubernetesRedisServiceNamespace,
KubernetesRedisServiceName: c.KubernetesRedisServiceName,
KubernetesRedisServicePort: c.KubernetesRedisServicePort,
KubernetesBackendTrafficAlgorithm: c.KubernetesBackendTrafficAlgorithm,
KubernetesDefaultLoadBalancerAlgorithm: c.KubernetesDefaultLoadBalancerAlgorithm,
KubernetesForceService: c.KubernetesForceService,
// API Monitoring:
ApiUsageMonitoringEnable: c.ApiUsageMonitoringEnable,
ApiUsageMonitoringRealmKeys: c.ApiUsageMonitoringRealmKeys,
ApiUsageMonitoringClientKeys: c.ApiUsageMonitoringClientKeys,
ApiUsageMonitoringRealmsTrackingPattern: c.ApiUsageMonitoringRealmsTrackingPattern,
// Default filters:
DefaultFiltersDir: c.DefaultFiltersDir,
// Auth:
EnableOAuth2GrantFlow: c.EnableOAuth2GrantFlow,
OAuth2AuthURL: c.Oauth2AuthURL,
OAuth2TokenURL: c.Oauth2TokenURL,
OAuth2RevokeTokenURL: c.Oauth2RevokeTokenURL,
OAuthTokeninfoURL: c.Oauth2TokeninfoURL,
OAuthTokeninfoTimeout: c.Oauth2TokeninfoTimeout,
OAuthTokeninfoCacheSize: c.Oauth2TokeninfoCacheSize,
OAuthTokeninfoCacheTTL: c.Oauth2TokeninfoCacheTTL,
OAuth2SecretFile: c.Oauth2SecretFile,
OAuth2ClientID: c.Oauth2ClientID,
OAuth2ClientSecret: c.Oauth2ClientSecret,
OAuth2ClientIDFile: c.Oauth2ClientIDFile,
OAuth2ClientSecretFile: c.Oauth2ClientSecretFile,
OAuth2CallbackPath: c.Oauth2CallbackPath,
OAuthTokenintrospectionTimeout: c.Oauth2TokenintrospectionTimeout,
OAuth2AuthURLParameters: c.Oauth2AuthURLParameters.values,
OAuth2AccessTokenHeaderName: c.Oauth2AccessTokenHeaderName,
OAuth2TokeninfoSubjectKey: c.Oauth2TokeninfoSubjectKey,
OAuth2GrantTokeninfoKeys: c.Oauth2GrantTokeninfoKeys.values,
OAuth2TokenCookieName: c.Oauth2TokenCookieName,
OAuth2TokenCookieRemoveSubdomains: c.Oauth2TokenCookieRemoveSubdomains,
OAuth2GrantInsecure: c.Oauth2GrantInsecure,
WebhookTimeout: c.WebhookTimeout,
OIDCSecretsFile: c.OidcSecretsFile,
OIDCCookieValidity: c.OIDCCookieValidity,
OIDCDistributedClaimsTimeout: c.OidcDistributedClaimsTimeout,
OIDCCookieRemoveSubdomains: c.OIDCCookieRemoveSubdomains,
CredentialsPaths: c.CredentialPaths.values,
CredentialsUpdateInterval: c.CredentialsUpdateInterval,
// connections, timeouts:
WaitForHealthcheckInterval: c.WaitForHealthcheckInterval,
IdleConnectionsPerHost: c.IdleConnsPerHost,
CloseIdleConnsPeriod: c.CloseIdleConnsPeriod,
BackendFlushInterval: c.BackendFlushInterval,
ExperimentalUpgrade: c.ExperimentalUpgrade,
ExperimentalUpgradeAudit: c.ExperimentalUpgradeAudit,
ReadTimeoutServer: c.ReadTimeoutServer,
ReadHeaderTimeoutServer: c.ReadHeaderTimeoutServer,
WriteTimeoutServer: c.WriteTimeoutServer,
IdleTimeoutServer: c.IdleTimeoutServer,
KeepaliveServer: c.KeepaliveServer,
KeepaliveRequestsServer: c.KeepaliveRequestsServer,
MaxHeaderBytes: c.MaxHeaderBytes,
EnableConnMetricsServer: c.EnableConnMetricsServer,
TimeoutBackend: c.TimeoutBackend,
KeepAliveBackend: c.KeepaliveBackend,
DualStackBackend: c.EnableDualstackBackend,
TLSHandshakeTimeoutBackend: c.TlsHandshakeTimeoutBackend,
ResponseHeaderTimeoutBackend: c.ResponseHeaderTimeoutBackend,
ExpectContinueTimeoutBackend: c.ExpectContinueTimeoutBackend,
MaxIdleConnsBackend: c.MaxIdleConnsBackend,
DisableHTTPKeepalives: c.DisableHTTPKeepalives,
KubernetesEnableTLS: c.KubernetesEnableTLS,
// swarm:
EnableSwarm: c.EnableSwarm,
// redis based
SwarmRedisURLs: c.SwarmRedisURLs.values,
SwarmRedisPassword: c.SwarmRedisPassword,
SwarmRedisHashAlgorithm: c.SwarmRedisHashAlgorithm,
SwarmRedisDialTimeout: c.SwarmRedisDialTimeout,
SwarmRedisReadTimeout: c.SwarmRedisReadTimeout,
SwarmRedisWriteTimeout: c.SwarmRedisWriteTimeout,
SwarmRedisPoolTimeout: c.SwarmRedisPoolTimeout,
SwarmRedisMinIdleConns: c.SwarmRedisMinConns,
SwarmRedisMaxIdleConns: c.SwarmRedisMaxConns,
SwarmRedisEndpointsRemoteURL: c.SwarmRedisEndpointsRemoteURL,
// swim based
SwarmKubernetesNamespace: c.SwarmKubernetesNamespace,
SwarmKubernetesLabelSelectorKey: c.SwarmKubernetesLabelSelectorKey,
SwarmKubernetesLabelSelectorValue: c.SwarmKubernetesLabelSelectorValue,
SwarmPort: c.SwarmPort,
SwarmMaxMessageBuffer: c.SwarmMaxMessageBuffer,
SwarmLeaveTimeout: c.SwarmLeaveTimeout,
// swim on localhost for testing
SwarmStaticSelf: c.SwarmStaticSelf,
SwarmStaticOther: c.SwarmStaticOther,
ClusterRatelimitMaxGroupShards: c.ClusterRatelimitMaxGroupShards,
LuaModules: c.LuaModules.values,
LuaSources: c.LuaSources.values,
EnableOpenPolicyAgent: c.EnableOpenPolicyAgent,
EnableOpenPolicyAgentCustomControlLoop: c.EnableOpenPolicyAgentCustomControlLoop,
OpenPolicyAgentControlLoopInterval: c.OpenPolicyAgentControlLoopInterval,
OpenPolicyAgentControlLoopMaxJitter: c.OpenPolicyAgentControlLoopMaxJitter,
EnableOpenPolicyAgentDataPreProcessingOptimization: c.EnableOpenPolicyAgentDataPreProcessingOptimization,
OpenPolicyAgentConfigTemplate: c.OpenPolicyAgentConfigTemplate,
OpenPolicyAgentEnvoyMetadata: c.OpenPolicyAgentEnvoyMetadata,
OpenPolicyAgentCleanerInterval: c.OpenPolicyAgentCleanerInterval,
OpenPolicyAgentStartupTimeout: c.OpenPolicyAgentStartupTimeout,
OpenPolicyAgentMaxRequestBodySize: c.OpenPolicyAgentMaxRequestBodySize,
OpenPolicyAgentRequestBodyBufferSize: c.OpenPolicyAgentRequestBodyBufferSize,
OpenPolicyAgentMaxMemoryBodyParsing: c.OpenPolicyAgentMaxMemoryBodyParsing,
PassiveHealthCheck: c.PassiveHealthCheck.values,
}
for _, rcci := range c.CloneRoute {
eskipClone := eskip.NewClone(rcci.Reg, rcci.Repl)
options.CloneRoute = append(options.CloneRoute, eskipClone)
}
for _, rcci := range c.EditRoute {
eskipEdit := eskip.NewEditor(rcci.Reg, rcci.Repl)
options.EditRoute = append(options.EditRoute, eskipEdit)
}
if c.PluginDir != "" {
options.PluginDirs = append(options.PluginDirs, c.PluginDir)
}
if c.Insecure {
options.ProxyFlags |= proxy.Insecure
}
if c.ProxyPreserveHost {
options.ProxyFlags |= proxy.PreserveHost
}
if c.RemoveHopHeaders {
options.ProxyFlags |= proxy.HopHeadersRemoval
}
if c.RfcPatchPath {
options.ProxyFlags |= proxy.PatchPath
}
if len(c.Certificates) > 0 {
options.ClientTLS = &tls.Config{
Certificates: c.Certificates,
MinVersion: c.getMinTLSVersion(),
}
}
var wrappers []func(handler http.Handler) http.Handler
options.CustomHttpHandlerWrap = func(handler http.Handler) http.Handler {
for _, wrapper := range wrappers {
handler = wrapper(handler)
}
return handler
}
if c.ForwardedHeaders != (net.ForwardedHeaders{}) {
wrappers = append(wrappers, func(handler http.Handler) http.Handler {
return &net.ForwardedHeadersHandler{
Headers: c.ForwardedHeaders,
Exclude: c.ForwardedHeadersExcludeCIDRs,
Handler: handler,
}
})
}
if c.HostPatch != (net.HostPatch{}) {
wrappers = append(wrappers, func(handler http.Handler) http.Handler {
return &net.HostPatchHandler{
Patch: c.HostPatch,
Handler: handler,
}
})
}
if len(c.RefusePayload) > 0 {
wrappers = append(wrappers, func(handler http.Handler) http.Handler {
return &net.RequestMatchHandler{
Match: c.RefusePayload,
Handler: handler,
}
})
}
if c.ValidateQuery {
wrappers = append(wrappers, func(handler http.Handler) http.Handler {
return &net.ValidateQueryHandler{
Handler: handler,
}
})
}
if c.ValidateQueryLog {
wrappers = append(wrappers, func(handler http.Handler) http.Handler {
return &net.ValidateQueryLogHandler{
Handler: handler,
}
})
}
return options
}
func (c *Config) getMinTLSVersion() uint16 {
tlsVersionTable := map[string]uint16{
"1.3": tls.VersionTLS13,
"13": tls.VersionTLS13,
"1.2": tls.VersionTLS12,
"12": tls.VersionTLS12,
"1.1": tls.VersionTLS11,
"11": tls.VersionTLS11,
"1.0": tls.VersionTLS10,
"10": tls.VersionTLS10,
}
if v, ok := tlsVersionTable[c.TLSMinVersion]; ok {
return v
}
log.Infof("No valid minimal TLS version confiured (set to '%s'), fallback to default: %s", c.TLSMinVersion, defaultMinTLSVersion)
return tlsVersionTable[defaultMinTLSVersion]
}
func (c *Config) setTLSClientAuth(s string) error {
var ok bool
c.TLSClientAuth, ok = map[string]tls.ClientAuthType{
"NoClientCert": tls.NoClientCert,
"RequestClientCert": tls.RequestClientCert,
"RequireAnyClientCert": tls.RequireAnyClientCert,
"VerifyClientCertIfGiven": tls.VerifyClientCertIfGiven,
"RequireAndVerifyClientCert": tls.RequireAndVerifyClientCert,
}[s]
if !ok {
return fmt.Errorf("unsupported TLS client authentication type")
}
return nil
}
func (c *Config) filterCipherSuites() []uint16 {
if !c.ExcludeInsecureCipherSuites {
return nil
}
cipherSuites := make([]uint16, 0)
for _, suite := range tls.CipherSuites() {
cipherSuites = append(cipherSuites, suite.ID)
}
return cipherSuites
}
func (c *Config) parseHistogramBuckets() ([]float64, error) {
if c.HistogramMetricBucketsString == "" {
return prometheus.DefBuckets, nil
}
var result []float64
thresholds := strings.Split(c.HistogramMetricBucketsString, ",")
for _, v := range thresholds {
bucket, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
if err != nil {
return nil, fmt.Errorf("unable to parse histogram-metric-buckets: %w", err)
}
result = append(result, bucket)
}
sort.Float64s(result)
return result, nil
}
func (c *Config) parseForwardedHeaders() error {
for _, header := range c.ForwardedHeadersList.values {
switch {
case header == "X-Forwarded-For":
c.ForwardedHeaders.For = true
case header == "X-Forwarded-Host":
c.ForwardedHeaders.Host = true
case header == "X-Forwarded-Method":
c.ForwardedHeaders.Method = true
case header == "X-Forwarded-Uri":
c.ForwardedHeaders.Uri = true
case strings.HasPrefix(header, "X-Forwarded-Port="):
c.ForwardedHeaders.Port = strings.TrimPrefix(header, "X-Forwarded-Port=")
case header == "X-Forwarded-Proto=http":
c.ForwardedHeaders.Proto = "http"
case header == "X-Forwarded-Proto=https":
c.ForwardedHeaders.Proto = "https"
default:
return fmt.Errorf("invalid forwarded header: %s", header)
}
}
cidrs, err := net.ParseCIDRs(c.ForwardedHeadersExcludeCIDRList.values)
if err != nil {
return fmt.Errorf("invalid forwarded headers exclude CIDRs: %w", err)
}
c.ForwardedHeadersExcludeCIDRs = cidrs
return nil
}
func (c *Config) parseEnv() {
// Set Redis password from environment variable if not set earlier (configuration file)
if c.SwarmRedisPassword == "" {
c.SwarmRedisPassword = os.Getenv(redisPasswordEnv)
}
}
func (c *Config) checkDeprecated(configKeys map[string]interface{}, options ...string) {
flagKeys := make(map[string]bool)
c.Flags.Visit(func(f *flag.Flag) { flagKeys[f.Name] = true })
for _, name := range options {
_, ck := configKeys[name]
_, fk := flagKeys[name]
if ck || fk {
f := c.Flags.Lookup(name)
log.Warnf("%s: %s", f.Name, f.Usage)
}
}
}
func parseAnnotationPredicates(s []string) ([]kubernetes.AnnotationPredicates, error) {
return parseAnnotationConfig(s, func(annotationKey, annotationValue, value string) (kubernetes.AnnotationPredicates, error) {
predicates, err := eskip.ParsePredicates(value)
if err != nil {
var zero kubernetes.AnnotationPredicates
return zero, err
}
return kubernetes.AnnotationPredicates{
Key: annotationKey,
Value: annotationValue,
Predicates: predicates,
}, nil
})
}
func parseAnnotationFilters(s []string) ([]kubernetes.AnnotationFilters, error) {
return parseAnnotationConfig(s, func(annotationKey, annotationValue, value string) (kubernetes.AnnotationFilters, error) {
filters, err := eskip.ParseFilters(value)
if err != nil {
var zero kubernetes.AnnotationFilters
return zero, err
}
return kubernetes.AnnotationFilters{
Key: annotationKey,
Value: annotationValue,
Filters: filters,
}, nil
})
}
// parseAnnotationConfig parses a slice of strings in the "annotationKey=annotationValue=value" format
// by calling parseValue function to convert (annotationKey, annotationValue, value) tuple into T.
// Empty input strings are skipped and duplicate annotationKey-annotationValue pairs are rejected with error.
func parseAnnotationConfig[T any](kvvs []string, parseValue func(annotationKey, annotationValue, value string) (T, error)) ([]T, error) {
var result []T
seenKVs := make(map[string]struct{})
for _, kvv := range kvvs {
if kvv == "" {
continue
}
annotationKey, rest, found := strings.Cut(kvv, "=")
if !found {
return nil, fmt.Errorf("invalid annotation flag: %q, failed to get annotation key", kvv)
}
annotationValue, value, found := strings.Cut(rest, "=")
if !found {
return nil, fmt.Errorf("invalid annotation flag: %q, failed to get annotation value", kvv)
}
v, err := parseValue(annotationKey, annotationValue, value)
if err != nil {
return nil, fmt.Errorf("invalid annotation flag value: %q, %w", kvv, err)
}
// Reject duplicate annotation key-value pairs
kv := annotationKey + "=" + annotationValue
if _, ok := seenKVs[kv]; ok {
return nil, fmt.Errorf("invalid annotation flag: %q, duplicate annotation key-value %q", kvv, kv)
} else {
seenKVs[kv] = struct{}{}
}
result = append(result, v)
}
return result, nil
}
package config
import (
"fmt"
"strings"
"github.com/zalando/skipper/eskip"
)
type defaultFiltersFlags struct {
values []string
filters []*eskip.Filter
}
func (dpf defaultFiltersFlags) String() string {
return strings.Join(dpf.values, " -> ")
}
func (dpf *defaultFiltersFlags) Set(value string) error {
fs, err := eskip.ParseFilters(value)
if err != nil {
return fmt.Errorf("failed to parse default filters: %w", err)
}
if len(fs) > 0 {
dpf.values = append(dpf.values, value)
dpf.filters = append(dpf.filters, fs...)
}
return nil
}
func (dpf *defaultFiltersFlags) UnmarshalYAML(unmarshal func(interface{}) error) error {
values := make([]string, 1)
if err := unmarshal(&values); err != nil {
// Try to unmarshal as string for backwards compatibility.
// UnmarshalYAML allows calling unmarshal more than once.
if err := unmarshal(&values[0]); err != nil {
return err
}
}
dpf.values = nil
dpf.filters = nil
for _, v := range values {
if err := dpf.Set(v); err != nil {
return err
}
}
return nil
}
package config
import (
"fmt"
"strings"
)
type listFlag struct {
sep string
allowed map[string]bool
value string
values []string
}
func newListFlag(sep string, allowed ...string) *listFlag {
lf := &listFlag{
sep: sep,
allowed: make(map[string]bool),
}
for _, a := range allowed {
lf.allowed[a] = true
}
return lf
}
func commaListFlag(allowed ...string) *listFlag {
return newListFlag(",", allowed...)
}
func (lf *listFlag) Set(value string) error {
if lf == nil {
return nil
}
if value == "" {
lf.value = ""
lf.values = nil
} else {
lf.value = value
lf.values = strings.Split(value, lf.sep)
}
if err := lf.validate(); err != nil {
return err
}
return nil
}
func (lf *listFlag) UnmarshalYAML(unmarshal func(interface{}) error) error {
var values []string
if err := unmarshal(&values); err != nil {
return err
}
lf.value = strings.Join(values, lf.sep)
lf.values = values
if err := lf.validate(); err != nil {
return err
}
return nil
}
func (lf *listFlag) validate() error {
if len(lf.allowed) == 0 {
return nil
}
for _, v := range lf.values {
if !lf.allowed[v] {
return fmt.Errorf("value not allowed: %s", v)
}
}
return nil
}
func (lf listFlag) String() string { return lf.value }
package config
import (
"fmt"
"strings"
)
// mapFlags are generic string key-value pair flags.
// Use when option keys are not predetermined.
type mapFlags struct {
values map[string]string
}
const formatErrorString = "invalid map key-value pair, expected format key=value but got: '%v'"
func newMapFlags() *mapFlags {
return &mapFlags{
values: make(map[string]string),
}
}
func (m *mapFlags) String() string {
var pairs []string
for k, v := range m.values {
pairs = append(pairs, fmt.Sprint(k, "=", v))
}
return strings.Join(pairs, "'")
}
func (m *mapFlags) Set(value string) error {
if m == nil {
return nil
}
m.values = make(map[string]string)
vs := strings.Split(value, ",")
for _, vi := range vs {
k, v, found := strings.Cut(vi, "=")
if !found {
return fmt.Errorf(formatErrorString, vi)
}
k = strings.TrimSpace(k)
v = strings.TrimSpace(v)
if k == "" || v == "" {
return fmt.Errorf(formatErrorString, vi)
}
m.values[k] = v
}
return nil
}
func (m *mapFlags) UnmarshalYAML(unmarshal func(interface{}) error) error {
var values = make(map[string]string)
if err := unmarshal(&values); err != nil {
return err
}
m.values = values
return nil
}
package config
import (
"strings"
)
type multiFlag []string
func (f *multiFlag) String() string {
return strings.Join(*f, " ")
}
func (f *multiFlag) Set(value string) error {
*f = append(*f, value)
return nil
}
func (f *multiFlag) UnmarshalYAML(unmarshal func(interface{}) error) error {
var values []string
if err := unmarshal(&values); err != nil {
return err
}
*f = values
return nil
}
package config
import (
"strings"
)
type pluginFlag struct {
listFlag *listFlag
values [][]string
}
func newPluginFlag() *pluginFlag {
return &pluginFlag{listFlag: newListFlag(" ")}
}
func (f pluginFlag) String() string {
if f.listFlag == nil {
return ""
}
return f.listFlag.String()
}
func (f *pluginFlag) Set(value string) error {
if err := f.listFlag.Set(value); err != nil {
return err
}
for _, v := range f.listFlag.values {
f.values = append(f.values, strings.Split(v, ","))
}
return nil
}
func (f *pluginFlag) UnmarshalYAML(unmarshal func(interface{}) error) error {
var value map[string][]string
if err := unmarshal(&value); err != nil {
return err
}
for k, values := range value {
plugin := append([]string{k}, values...)
f.values = append(f.values, plugin)
}
return nil
}
package config
import (
"errors"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/ratelimit"
)
const ratelimitsUsage = `set global rate limit settings, e.g. -ratelimits type=client,max-hits=20,time-window=60s
possible ratelimit properties:
type: client/service/clusterClient/clusterService/disabled (defaults to disabled)
max-hits: the number of hits a ratelimiter can get
time-window: the duration of the sliding window for the rate limiter
group: defines the ratelimit group, which can be the same for different routes.
(see also: https://godoc.org/github.com/zalando/skipper/ratelimit)`
const enableRatelimitsUsage = `enable ratelimits`
type ratelimitFlags []ratelimit.Settings
var errInvalidRatelimitConfig = errors.New("invalid ratelimit config (allowed values are: client, service or disabled)")
func (r ratelimitFlags) String() string {
s := make([]string, len(r))
for i, ri := range r {
s[i] = ri.String()
}
return strings.Join(s, "\n")
}
func (r *ratelimitFlags) Set(value string) error {
var s ratelimit.Settings
vs := strings.Split(value, ",")
for _, vi := range vs {
k, v, found := strings.Cut(vi, "=")
if !found {
return errInvalidRatelimitConfig
}
switch k {
case "type":
switch v {
case "local":
log.Warning("LocalRatelimit is deprecated, please use ClientRatelimit instead")
fallthrough
case "client":
s.Type = ratelimit.ClientRatelimit
case "service":
s.Type = ratelimit.ServiceRatelimit
case "clusterClient":
s.Type = ratelimit.ClusterClientRatelimit
case "clusterService":
s.Type = ratelimit.ClusterServiceRatelimit
case "disabled":
s.Type = ratelimit.DisableRatelimit
default:
return errInvalidRatelimitConfig
}
case "max-hits":
i, err := strconv.Atoi(v)
if err != nil {
return err
}
s.MaxHits = i
case "time-window":
d, err := time.ParseDuration(v)
if err != nil {
return err
}
s.TimeWindow = d
s.CleanInterval = d * 10
case "group":
s.Group = v
default:
return errInvalidRatelimitConfig
}
}
if s.Type == ratelimit.NoRatelimit {
s.Type = ratelimit.DisableRatelimit
}
*r = append(*r, s)
return nil
}
func (r *ratelimitFlags) UnmarshalYAML(unmarshal func(interface{}) error) error {
var rateLimitSettings ratelimit.Settings
if err := unmarshal(&rateLimitSettings); err != nil {
return err
}
rateLimitSettings.CleanInterval = rateLimitSettings.TimeWindow * 10
*r = append(*r, rateLimitSettings)
return nil
}
package config
import (
"regexp"
"strings"
)
type regexpListFlag []*regexp.Regexp
func (r regexpListFlag) String() string {
s := make([]string, len(r))
for i, ri := range r {
s[i] = ri.String()
}
return strings.Join(s, "\n")
}
func (r *regexpListFlag) Set(value string) error {
rx, err := regexp.Compile(value)
if err != nil {
return err
}
*r = append(*r, rx)
return nil
}
func (r *regexpListFlag) UnmarshalYAML(unmarshal func(interface{}) error) error {
var m map[string][]string
if err := unmarshal(&m); err != nil {
return err
}
for _, value := range m {
for _, item := range value {
rx, err := regexp.Compile(item)
if err != nil {
return err
}
*r = append(*r, rx)
}
}
return nil
}
package config
import (
"fmt"
"regexp"
"strings"
)
type routeChangerConfigItem struct {
Reg *regexp.Regexp `yaml:"reg"`
Repl string `yaml:"repl"`
Sep string `yaml:"sep"`
}
func (rcci routeChangerConfigItem) String() string {
return rcci.Sep + rcci.Reg.String() + rcci.Sep + rcci.Repl + rcci.Sep
}
type routeChangerConfig []routeChangerConfigItem
func (rcc routeChangerConfig) String() string {
var b strings.Builder
for i, rcci := range rcc {
if i > 0 {
b.WriteString("\n")
}
b.WriteString(rcci.String())
}
return b.String()
}
func (rcc *routeChangerConfig) Set(value string) error {
if len(value) == 0 {
return fmt.Errorf("empty string as an argument is not allowed")
}
firstSym := value[:1]
a := strings.Split(value, firstSym)
if len(a) != 4 {
return fmt.Errorf("unexpected size of string split: %d", len(a))
}
var err error
reg, repl := a[1], a[2]
regex, err := regexp.Compile(reg)
if err != nil {
return err
}
rcci := routeChangerConfigItem{
Reg: regex,
Repl: repl,
Sep: firstSym,
}
*rcc = append(*rcc, rcci)
return err
}
func (rcc *routeChangerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
var value string
if err := unmarshal(&value); err != nil {
return err
}
return rcc.Set(value)
}
package config
import (
"fmt"
"gopkg.in/yaml.v2"
)
type yamlFlag[T any] struct {
Ptr **T
value string // only for Set
}
func newYamlFlag[T any](ptr **T) *yamlFlag[T] {
return &yamlFlag[T]{Ptr: ptr}
}
func (yf *yamlFlag[T]) Set(value string) error {
var opts T
if err := yaml.Unmarshal([]byte(value), &opts); err != nil {
return fmt.Errorf("failed to parse yaml: %w", err)
}
*yf.Ptr = &opts
yf.value = value
return nil
}
func (yf *yamlFlag[T]) UnmarshalYAML(unmarshal func(interface{}) error) error {
var opts T
if err := unmarshal(&opts); err != nil {
return err
}
*yf.Ptr = &opts
return nil
}
func (yf *yamlFlag[T]) String() string {
if yf == nil {
return ""
}
return yf.value
}
package kubernetes
import (
"github.com/zalando/skipper/eskip"
)
type AnnotationPredicates struct {
Key string
Value string
Predicates []*eskip.Predicate
}
type AnnotationFilters struct {
Key string
Value string
Filters []*eskip.Filter
}
func appendAnnotationPredicates(annotationPredicates []AnnotationPredicates, annotations map[string]string, r *eskip.Route) {
for _, ap := range annotationPredicates {
if objAnnotationVal, ok := annotations[ap.Key]; ok && ap.Value == objAnnotationVal {
// since this annotation is managed by skipper operator, we can safely assume that the predicate is valid
// and we can append it to the route
r.Predicates = append(r.Predicates, ap.Predicates...)
}
}
}
func appendAnnotationFilters(annotationFilters []AnnotationFilters, annotations map[string]string, r *eskip.Route) {
for _, af := range annotationFilters {
if objAnnotationVal, ok := annotations[af.Key]; ok && af.Value == objAnnotationVal {
r.Filters = append(r.Filters, af.Filters...)
}
}
}
package kubernetes
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"regexp"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
"github.com/zalando/skipper/secrets"
"github.com/zalando/skipper/secrets/certregistry"
)
const (
ingressClassKey = "kubernetes.io/ingress.class"
IngressesV1ClusterURI = "/apis/networking.k8s.io/v1/ingresses"
ZalandoResourcesClusterURI = "/apis/zalando.org/v1"
RouteGroupsName = "routegroups"
RouteGroupsClusterURI = "/apis/zalando.org/v1/routegroups"
routeGroupClassKey = "zalando.org/routegroup.class"
ServicesClusterURI = "/api/v1/services"
EndpointsClusterURI = "/api/v1/endpoints"
EndpointSlicesClusterURI = "/apis/discovery.k8s.io/v1/endpointslices"
SecretsClusterURI = "/api/v1/secrets"
defaultKubernetesURL = "http://localhost:8001"
IngressesV1NamespaceFmt = "/apis/networking.k8s.io/v1/namespaces/%s/ingresses"
RouteGroupsNamespaceFmt = "/apis/zalando.org/v1/namespaces/%s/routegroups"
ServicesNamespaceFmt = "/api/v1/namespaces/%s/services"
EndpointsNamespaceFmt = "/api/v1/namespaces/%s/endpoints"
EndpointSlicesNamespaceFmt = "/apis/discovery.k8s.io/v1/namespaces/%s/endpointslices"
SecretsNamespaceFmt = "/api/v1/namespaces/%s/secrets"
serviceAccountDir = "/var/run/secrets/kubernetes.io/serviceaccount/"
serviceAccountTokenKey = "token"
serviceAccountRootCAKey = "ca.crt"
labelSelectorFmt = "%s=%s"
labelSelectorQueryFmt = "?labelSelector=%s"
)
const RouteGroupsNotInstalledMessage = `RouteGroups CRD is not installed in the cluster.
See: https://opensource.zalando.com/skipper/kubernetes/routegroups/#installation`
type clusterClient struct {
ingressesURI string
routeGroupsURI string
servicesURI string
endpointsURI string
endpointSlicesURI string
secretsURI string
tokenProvider secrets.SecretsProvider
tokenFile string
apiURL string
certificateRegistry *certregistry.CertRegistry
routeGroupClass *regexp.Regexp
ingressClass *regexp.Regexp
httpClient *http.Client
ingressLabelSelectors string
servicesLabelSelectors string
endpointsLabelSelectors string
endpointSlicesLabelSelectors string
secretsLabelSelectors string
routeGroupsLabelSelectors string
enableEndpointSlices bool
loggedMissingRouteGroups bool
routeGroupValidator *definitions.RouteGroupValidator
ingressValidator *definitions.IngressV1Validator
}
var (
errResourceNotFound = errors.New("resource not found")
errServiceNotFound = errors.New("service not found")
errAPIServerURLNotFound = errors.New("kubernetes API server URL could not be constructed from env vars")
errInvalidCertificate = errors.New("invalid CA")
)
func buildHTTPClient(certFilePath string, inCluster bool, quit <-chan struct{}) (*http.Client, error) {
if !inCluster {
return http.DefaultClient, nil
}
rootCA, err := os.ReadFile(certFilePath)
if err != nil {
return nil, err
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(rootCA) {
return nil, errInvalidCertificate
}
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
ExpectContinueTimeout: 30 * time.Second,
MaxIdleConns: 5,
MaxIdleConnsPerHost: 5,
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: certPool,
},
}
// regularly force closing idle connections
go func() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
transport.CloseIdleConnections()
case <-quit:
return
}
}
}()
return &http.Client{
Transport: transport,
}, nil
}
func newClusterClient(o Options, apiURL, ingCls, rgCls string, quit <-chan struct{}) (*clusterClient, error) {
httpClient, err := buildHTTPClient(serviceAccountDir+serviceAccountRootCAKey, o.KubernetesInCluster, quit)
if err != nil {
return nil, err
}
ingClsRx, err := regexp.Compile(ingCls)
if err != nil {
return nil, err
}
rgClsRx, err := regexp.Compile(rgCls)
if err != nil {
return nil, err
}
c := &clusterClient{
ingressesURI: IngressesV1ClusterURI,
routeGroupsURI: RouteGroupsClusterURI,
servicesURI: ServicesClusterURI,
endpointsURI: EndpointsClusterURI,
endpointSlicesURI: EndpointSlicesClusterURI,
secretsURI: SecretsClusterURI,
ingressClass: ingClsRx,
ingressLabelSelectors: toLabelSelectorQuery(o.IngressLabelSelectors),
servicesLabelSelectors: toLabelSelectorQuery(o.ServicesLabelSelectors),
endpointsLabelSelectors: toLabelSelectorQuery(o.EndpointsLabelSelectors),
endpointSlicesLabelSelectors: toLabelSelectorQuery(o.EndpointSlicesLabelSelectors),
secretsLabelSelectors: toLabelSelectorQuery(o.SecretsLabelSelectors),
routeGroupsLabelSelectors: toLabelSelectorQuery(o.RouteGroupsLabelSelectors),
routeGroupClass: rgClsRx,
httpClient: httpClient,
apiURL: apiURL,
certificateRegistry: o.CertificateRegistry,
routeGroupValidator: &definitions.RouteGroupValidator{},
ingressValidator: &definitions.IngressV1Validator{},
enableEndpointSlices: o.KubernetesEnableEndpointslices,
}
if o.KubernetesInCluster {
c.tokenProvider = secrets.NewSecretPaths(time.Minute)
c.tokenFile = serviceAccountDir + serviceAccountTokenKey
} else if o.TokenFile != "" {
c.tokenProvider = secrets.NewSecretPaths(time.Minute)
c.tokenFile = o.TokenFile
}
if c.tokenProvider != nil {
if err := c.tokenProvider.Add(c.tokenFile); err != nil {
return nil, fmt.Errorf("failed to add secret %s: %w", c.tokenFile, err)
}
if b, ok := c.tokenProvider.GetSecret(c.tokenFile); ok {
log.Debugf("Got secret %d bytes from %s", len(b), c.tokenFile)
} else {
return nil, fmt.Errorf("failed to get secret %s", c.tokenFile)
}
}
if o.KubernetesNamespace != "" {
c.setNamespace(o.KubernetesNamespace)
}
return c, nil
}
// serializes a given map of label selectors to a string that can be appended to a request URI to kubernetes
// Examples (note that the resulting value in the query is URL escaped, for readability this is not done in examples):
//
// [] becomes ``
// ["label": ""] becomes `?labelSelector=label`
// ["label": "value"] becomes `?labelSelector=label=value`
// ["label": "value", "label2": "value2"] becomes `?labelSelector=label=value&label2=value2`
func toLabelSelectorQuery(selectors map[string]string) string {
if len(selectors) == 0 {
return ""
}
var strs []string
for k, v := range selectors {
if v == "" {
strs = append(strs, k)
} else {
strs = append(strs, fmt.Sprintf(labelSelectorFmt, k, v))
}
}
return fmt.Sprintf(labelSelectorQueryFmt, url.QueryEscape(strings.Join(strs, ",")))
}
func (c *clusterClient) setNamespace(namespace string) {
c.ingressesURI = fmt.Sprintf(IngressesV1NamespaceFmt, namespace)
c.routeGroupsURI = fmt.Sprintf(RouteGroupsNamespaceFmt, namespace)
c.servicesURI = fmt.Sprintf(ServicesNamespaceFmt, namespace)
c.endpointsURI = fmt.Sprintf(EndpointsNamespaceFmt, namespace)
c.endpointSlicesURI = fmt.Sprintf(EndpointSlicesNamespaceFmt, namespace)
c.secretsURI = fmt.Sprintf(SecretsNamespaceFmt, namespace)
}
func (c *clusterClient) createRequest(uri string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest("GET", c.apiURL+uri, body)
if err != nil {
return nil, err
}
if c.tokenProvider != nil {
token, ok := c.tokenProvider.GetSecret(c.tokenFile)
if !ok {
return nil, fmt.Errorf("secret not found: %v", c.tokenFile)
}
req.Header.Set("Authorization", "Bearer "+string(token))
}
return req, nil
}
func (c *clusterClient) getJSON(uri string, a interface{}) error {
log.Tracef("making request to: %s", uri)
req, err := c.createRequest(uri, nil)
if err != nil {
return err
}
rsp, err := c.httpClient.Do(req)
if err != nil {
log.Tracef("request to %s failed: %v", uri, err)
return err
}
log.Tracef("request to %s succeeded", uri)
defer rsp.Body.Close()
if rsp.StatusCode == http.StatusNotFound {
return errResourceNotFound
}
if rsp.StatusCode != http.StatusOK {
log.Tracef("request failed, status: %d, %s", rsp.StatusCode, rsp.Status)
return fmt.Errorf("request to %s failed, status: %d, %s", uri, rsp.StatusCode, rsp.Status)
}
b := bytes.NewBuffer(nil)
if _, err = io.Copy(b, rsp.Body); err != nil {
log.Tracef("reading response body failed: %v", err)
return err
}
err = json.Unmarshal(b.Bytes(), a)
if err != nil {
log.Tracef("invalid response format: %v", err)
}
return err
}
func (c *clusterClient) clusterHasRouteGroups() (bool, error) {
var crl ClusterResourceList
if err := c.getJSON(ZalandoResourcesClusterURI, &crl); err != nil { // it probably should bounce once
return false, err
}
for _, cr := range crl.Items {
if cr.Name == RouteGroupsName {
return true, nil
}
}
return false, nil
}
func (c *clusterClient) ingressClassMissmatch(m *definitions.Metadata) bool {
// No Metadata is the same as no annotations for us
if m != nil {
cls, ok := m.Annotations[ingressClassKey]
// Skip loop iteration if not valid ingress (non defined, empty or non defined one)
return ok && cls != "" && !c.ingressClass.MatchString(cls)
}
return false
}
// filterIngressesV1ByClass will filter only the ingresses that have the valid class, these are
// the defined one, empty string class or not class at all
func (c *clusterClient) filterIngressesV1ByClass(items []*definitions.IngressV1Item) []*definitions.IngressV1Item {
validIngs := []*definitions.IngressV1Item{}
for _, ing := range items {
// v1beta1 style
if c.ingressClassMissmatch(ing.Metadata) {
continue
}
// v1 style, TODO(sszuecs) we need also to fetch ingressclass object and check what should be done
if ing.Spec == nil || ing.Spec.IngressClassName == "" || c.ingressClass.MatchString(ing.Spec.IngressClassName) {
validIngs = append(validIngs, ing)
}
}
return validIngs
}
func sortByMetadata(slice interface{}, getMetadata func(int) *definitions.Metadata) {
sort.Slice(slice, func(i, j int) bool {
mI := getMetadata(i)
mJ := getMetadata(j)
if mI == nil && mJ != nil {
return true
} else if mJ == nil {
return false
}
nsI := mI.Namespace
nsJ := mJ.Namespace
if nsI != nsJ {
return nsI < nsJ
}
return mI.Name < mJ.Name
})
}
func (c *clusterClient) loadIngressesV1() ([]*definitions.IngressV1Item, error) {
var il definitions.IngressV1List
if err := c.getJSON(c.ingressesURI+c.ingressLabelSelectors, &il); err != nil {
log.Debugf("requesting all ingresses failed: %v", err)
return nil, err
}
log.Debugf("all ingresses received: %d", len(il.Items))
fItems := c.filterIngressesV1ByClass(il.Items)
log.Debugf("filtered ingresses by ingress class: %d", len(fItems))
sortByMetadata(fItems, func(i int) *definitions.Metadata { return fItems[i].Metadata })
validatedItems := make([]*definitions.IngressV1Item, 0, len(fItems))
for _, i := range fItems {
if err := c.ingressValidator.Validate(i); err != nil {
log.Errorf("[ingress] %v", err)
continue
}
validatedItems = append(validatedItems, i)
}
return validatedItems, nil
}
func (c *clusterClient) LoadRouteGroups() ([]*definitions.RouteGroupItem, error) {
var rgl definitions.RouteGroupList
if err := c.getJSON(c.routeGroupsURI+c.routeGroupsLabelSelectors, &rgl); err != nil {
return nil, err
}
log.Debugf("all routegroups received: %d", len(rgl.Items))
rgs := make([]*definitions.RouteGroupItem, 0, len(rgl.Items))
for _, i := range rgl.Items {
// Validate RouteGroup item.
if err := c.routeGroupValidator.Validate(i); err != nil {
log.Errorf("[routegroup] %v", err)
continue
}
// Check the RouteGroup has a valid class annotation.
// Not defined, or empty are ok too.
if i.Metadata != nil {
cls, ok := i.Metadata.Annotations[routeGroupClassKey]
if ok && cls != "" && !c.routeGroupClass.MatchString(cls) {
continue
}
}
rgs = append(rgs, i)
}
log.Debugf("filtered valid routegroups by routegroups class: %d", len(rgs))
sortByMetadata(rgs, func(i int) *definitions.Metadata { return rgs[i].Metadata })
return rgs, nil
}
func (c *clusterClient) loadServices() (map[definitions.ResourceID]*service, error) {
var services serviceList
if err := c.getJSON(c.servicesURI+c.servicesLabelSelectors, &services); err != nil {
log.Debugf("requesting all services failed: %v", err)
return nil, err
}
log.Debugf("all services received: %d", len(services.Items))
result := make(map[definitions.ResourceID]*service)
var hasInvalidService bool
for _, service := range services.Items {
if service == nil || service.Meta == nil || service.Spec == nil {
hasInvalidService = true
continue
}
result[service.Meta.ToResourceID()] = service
}
if hasInvalidService {
log.Errorf("Invalid service resource detected.")
}
return result, nil
}
func (c *clusterClient) loadSecrets() (map[definitions.ResourceID]*secret, error) {
var secrets secretList
if err := c.getJSON(c.secretsURI+c.secretsLabelSelectors, &secrets); err != nil {
log.Debugf("requesting all secrets failed: %v", err)
return nil, err
}
log.Debugf("all secrets received: %d", len(secrets.Items))
result := make(map[definitions.ResourceID]*secret)
for _, secret := range secrets.Items {
if secret == nil || secret.Metadata == nil {
continue
}
result[secret.Metadata.ToResourceID()] = secret
}
return result, nil
}
func (c *clusterClient) loadEndpoints() (map[definitions.ResourceID]*endpoint, error) {
var endpoints endpointList
if err := c.getJSON(c.endpointsURI+c.endpointsLabelSelectors, &endpoints); err != nil {
log.Debugf("requesting all endpoints failed: %v", err)
return nil, err
}
log.Debugf("all endpoints received: %d", len(endpoints.Items))
result := make(map[definitions.ResourceID]*endpoint)
for _, endpoint := range endpoints.Items {
resID := endpoint.Meta.ToResourceID()
result[resID] = endpoint
}
return result, nil
}
// loadEndpointSlices is different from the other load$Kind()
// functions because there are 1..N endpointslices created for a given
// service. Endpointslices need to be deduplicated and state needs to
// be checked. We read all endpointslices and create de-duplicated
// business objects [skipperEndpointSlice] instead of raw Kubernetes
// objects, because we need just a clean list of load balancer
// members. The returned map will return the full list of ready
// non-terminating endpoints that should be in the load balancer of a
// given service, check [endpointSlice.ToResourceID].
func (c *clusterClient) loadEndpointSlices() (map[definitions.ResourceID]*skipperEndpointSlice, error) {
var endpointSlices endpointSliceList
if err := c.getJSON(c.endpointSlicesURI+c.endpointSlicesLabelSelectors, &endpointSlices); err != nil {
log.Debugf("requesting all endpointslices failed: %v", err)
return nil, err
}
log.Debugf("all endpointslices received: %d", len(endpointSlices.Items))
return collectReadyEndpoints(&endpointSlices), nil
}
func collectReadyEndpoints(endpointSlices *endpointSliceList) map[definitions.ResourceID]*skipperEndpointSlice {
mapSlices := make(map[definitions.ResourceID][]*endpointSlice)
for _, endpointSlice := range endpointSlices.Items {
// https://github.com/zalando/skipper/issues/3151
// endpointslices can have nil ports
if endpointSlice.Ports != nil {
resID := endpointSlice.ToResourceID() // service resource ID
mapSlices[resID] = append(mapSlices[resID], endpointSlice)
}
}
result := make(map[definitions.ResourceID]*skipperEndpointSlice)
for resID, epSlices := range mapSlices {
if len(epSlices) == 0 {
continue
}
result[resID] = &skipperEndpointSlice{
Meta: epSlices[0].Meta,
}
terminatingEps := make(map[string]struct{})
resEps := make(map[string]*skipperEndpoint)
for i := range epSlices {
for _, ep := range epSlices[i].Endpoints {
// Addresses [1..100] of the same AddressType, as kube-proxy we use the first
// see also https://github.com/kubernetes/kubernetes/issues/106267
address := ep.Addresses[0]
if _, ok := terminatingEps[address]; ok {
// already known terminating
} else if ep.isTerminating() {
terminatingEps[address] = struct{}{}
// if we had this one with a non terminating condition,
// we should delete it, because of eventual consistency
// it is actually terminating
delete(resEps, address)
} else if ep.Conditions == nil {
// if conditions are nil then we need to treat is as ready
resEps[address] = &skipperEndpoint{
Address: address,
Zone: ep.Zone,
}
} else if ep.isReady() {
resEps[address] = &skipperEndpoint{
Address: address,
Zone: ep.Zone,
}
}
}
result[resID].Ports = epSlices[i].Ports
}
for _, o := range resEps {
result[resID].Endpoints = append(result[resID].Endpoints, o)
}
}
return result
}
// loadEndpointAddresses returns the list of all addresses for the given service using endpoints or endpointslices API.
func (c *clusterClient) loadEndpointAddresses(namespace, name string) ([]string, error) {
var result []string
if c.enableEndpointSlices {
url := fmt.Sprintf(EndpointSlicesNamespaceFmt, namespace) +
toLabelSelectorQuery(map[string]string{endpointSliceServiceNameLabel: name})
var endpointSlices endpointSliceList
if err := c.getJSON(url, &endpointSlices); err != nil {
return nil, fmt.Errorf("requesting endpointslices for %s/%s failed: %w", namespace, name, err)
}
ready := collectReadyEndpoints(&endpointSlices)
if len(ready) != 1 {
return nil, fmt.Errorf("unexpected number of endpoint slices for %s/%s: %d", namespace, name, len(ready))
}
for _, eps := range ready {
result = eps.addresses()
break
}
} else {
url := fmt.Sprintf(EndpointsNamespaceFmt, namespace) + "/" + name
var ep endpoint
if err := c.getJSON(url, &ep); err != nil {
return nil, fmt.Errorf("requesting endpoints for %s/%s failed: %w", namespace, name, err)
}
result = ep.addresses()
}
sort.Strings(result)
return result, nil
}
func (c *clusterClient) logMissingRouteGroupsOnce() {
if c.loggedMissingRouteGroups {
return
}
c.loggedMissingRouteGroups = true
log.Warn(RouteGroupsNotInstalledMessage)
}
func (c *clusterClient) fetchClusterState() (*clusterState, error) {
var (
err error
ingressesV1 []*definitions.IngressV1Item
)
ingressesV1, err = c.loadIngressesV1()
if err != nil {
return nil, err
}
var routeGroups []*definitions.RouteGroupItem
if hasRouteGroups, err := c.clusterHasRouteGroups(); errors.Is(err, errResourceNotFound) {
c.logMissingRouteGroupsOnce()
} else if err != nil {
log.Errorf("Error while checking known resource types: %v.", err)
} else if hasRouteGroups {
c.loggedMissingRouteGroups = false
if routeGroups, err = c.LoadRouteGroups(); err != nil {
return nil, err
}
}
services, err := c.loadServices()
if err != nil {
return nil, err
}
state := &clusterState{
ingressesV1: ingressesV1,
routeGroups: routeGroups,
services: services,
cachedEndpoints: make(map[endpointID][]string),
enableEndpointSlices: c.enableEndpointSlices,
}
if c.enableEndpointSlices {
state.endpointSlices, err = c.loadEndpointSlices()
if err != nil {
return nil, err
}
} else {
state.endpoints, err = c.loadEndpoints()
if err != nil {
return nil, err
}
}
if c.certificateRegistry != nil {
state.secrets, err = c.loadSecrets()
if err != nil {
return nil, err
}
}
return state, nil
}
package kubernetes
import (
"fmt"
"sort"
"sync"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
)
type clusterState struct {
mu sync.Mutex
ingressesV1 []*definitions.IngressV1Item
routeGroups []*definitions.RouteGroupItem
services map[definitions.ResourceID]*service
endpoints map[definitions.ResourceID]*endpoint
endpointSlices map[definitions.ResourceID]*skipperEndpointSlice
secrets map[definitions.ResourceID]*secret
cachedEndpoints map[endpointID][]string
enableEndpointSlices bool
}
func (state *clusterState) getService(namespace, name string) (*service, error) {
state.mu.Lock()
defer state.mu.Unlock()
s, ok := state.services[newResourceID(namespace, name)]
if !ok {
return nil, errServiceNotFound
}
if s.Spec == nil {
log.Debug("invalid service datagram, missing spec")
return nil, errServiceNotFound
}
return s, nil
}
func (state *clusterState) getServiceRG(namespace, name string) (*service, error) {
state.mu.Lock()
defer state.mu.Unlock()
s, ok := state.services[newResourceID(namespace, name)]
if !ok {
return nil, fmt.Errorf("service not found: %s/%s", namespace, name)
}
return s, nil
}
// GetEndpointsByService returns the skipper endpoints for kubernetes endpoints or endpointslices.
func (state *clusterState) GetEndpointsByService(namespace, name, protocol string, servicePort *servicePort) []string {
epID := endpointID{
ResourceID: newResourceID(namespace, name),
Protocol: protocol,
TargetPort: servicePort.TargetPort.String(),
}
state.mu.Lock()
defer state.mu.Unlock()
if cached, ok := state.cachedEndpoints[epID]; ok {
return cached
}
var targets []string
if state.enableEndpointSlices {
if eps, ok := state.endpointSlices[epID.ResourceID]; ok {
targets = eps.targetsByServicePort("TCP", protocol, servicePort)
} else {
return nil
}
} else {
if ep, ok := state.endpoints[epID.ResourceID]; ok {
targets = ep.targetsByServicePort(protocol, servicePort)
} else {
return nil
}
}
sort.Strings(targets)
state.cachedEndpoints[epID] = targets
return targets
}
// getEndpointAddresses returns the list of all addresses for the given service using endpoints or endpointslices.
func (state *clusterState) getEndpointAddresses(namespace, name string) []string {
rID := newResourceID(namespace, name)
state.mu.Lock()
defer state.mu.Unlock()
var addresses []string
if state.enableEndpointSlices {
if eps, ok := state.endpointSlices[rID]; ok {
addresses = eps.addresses()
} else {
return nil
}
} else {
if ep, ok := state.endpoints[rID]; ok {
addresses = ep.addresses()
} else {
return nil
}
}
sort.Strings(addresses)
return addresses
}
// GetEndpointsByTarget returns the skipper endpoints for kubernetes endpoints or endpointslices.
func (state *clusterState) GetEndpointsByTarget(namespace, name, protocol, scheme string, target *definitions.BackendPort) []string {
epID := endpointID{
ResourceID: newResourceID(namespace, name),
Protocol: protocol,
TargetPort: target.String(),
}
state.mu.Lock()
defer state.mu.Unlock()
if cached, ok := state.cachedEndpoints[epID]; ok {
return cached
}
var targets []string
if state.enableEndpointSlices {
if eps, ok := state.endpointSlices[epID.ResourceID]; ok {
targets = eps.targetsByServiceTarget(protocol, scheme, target)
} else {
return nil
}
} else {
if ep, ok := state.endpoints[epID.ResourceID]; ok {
targets = ep.targetsByServiceTarget(scheme, target)
} else {
return nil
}
}
sort.Strings(targets)
state.cachedEndpoints[epID] = targets
return targets
}
package kubernetes
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
"github.com/zalando/skipper/eskip"
)
type filterSet struct {
text string
filters []*eskip.Filter
parsed bool
err error
}
type defaultFilters map[definitions.ResourceID]*filterSet
func readDefaultFilters(dir string) (defaultFilters, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
filters := make(defaultFilters)
for _, f := range files {
r := strings.Split(f.Name(), ".") // format: {service}.{namespace}
if len(r) != 2 {
log.WithField("file", f.Name()).Debug("malformed file name")
continue
}
info, err := f.Info()
if err != nil || !(f.Type().IsRegular() || f.Type()&os.ModeSymlink != 0) || info.Size() > maxFileSize {
log.WithError(err).WithField("file", f.Name()).Debug("incompatible file")
continue
}
file := filepath.Join(dir, f.Name())
config, err := os.ReadFile(file)
if err != nil {
log.WithError(err).WithField("file", file).Debug("could not read file")
continue
}
filters[definitions.ResourceID{Name: r[0], Namespace: r[1]}] = &filterSet{text: string(config)}
}
return filters, nil
}
func (fs *filterSet) parse() {
if fs.parsed {
return
}
fs.filters, fs.err = eskip.ParseFilters(fs.text)
if fs.err != nil {
fs.err = fmt.Errorf("[eskip] default filters: %w", fs.err)
}
fs.parsed = true
}
func (df defaultFilters) get(serviceID definitions.ResourceID) ([]*eskip.Filter, error) {
fs, ok := df[serviceID]
if !ok {
return nil, nil
}
fs.parse()
if fs.err != nil {
return nil, fs.err
}
f := make([]*eskip.Filter, len(fs.filters))
copy(f, fs.filters)
return f, nil
}
func (df defaultFilters) getNamed(namespace, serviceName string) ([]*eskip.Filter, error) {
return df.get(definitions.ResourceID{Namespace: namespace, Name: serviceName})
}
package definitions
import (
"time"
"errors"
)
var errInvalidMetadata = errors.New("invalid metadata")
type Metadata struct {
Namespace string `json:"namespace"`
Name string `json:"name"`
Created time.Time `json:"creationTimestamp"`
Uid string `json:"uid"`
Annotations map[string]string `json:"annotations"`
Labels map[string]string `json:"labels"`
}
func (meta *Metadata) ToResourceID() ResourceID {
return ResourceID{
Namespace: namespaceString(meta.Namespace),
Name: meta.Name,
}
}
func validate(meta *Metadata) error {
if meta == nil || meta.Name == "" {
return errInvalidMetadata
}
return nil
}
func namespaceString(ns string) string {
if ns == "" {
return "default"
}
return ns
}
type WeightedBackend interface {
GetName() string
GetWeight() float64
}
package definitions
import (
"encoding/json"
"errors"
"strconv"
)
const (
IngressFilterAnnotation = "zalando.org/skipper-filter"
IngressPredicateAnnotation = "zalando.org/skipper-predicate"
IngressRoutesAnnotation = "zalando.org/skipper-routes"
)
var errInvalidPortType = errors.New("invalid port type")
type IngressV1List struct {
Items []*IngressV1Item `json:"items"`
}
type IngressV1Item struct {
Metadata *Metadata `json:"metadata"`
Spec *IngressV1Spec `json:"spec"`
}
// IngressSpecV1 https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.22/#ingressspec-v1-networking-k8s-io
type IngressV1Spec struct {
DefaultBackend *BackendV1 `json:"defaultBackend,omitempty"`
IngressClassName string `json:"ingressClassName,omitempty"`
Rules []*RuleV1 `json:"rules"`
IngressTLS []*TLSV1 `json:"tls,omitempty"`
}
// BackendV1 https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.22/#ingressbackend-v1-networking-k8s-io
type BackendV1 struct {
Service Service `json:"service,omitempty"` // can be nil, because of TypedLocalObjectReference
// Resource TypedLocalObjectReference is not supported https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.22/#typedlocalobjectreference-v1-core
}
// Service https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.22/#ingressservicebackend-v1-networking-k8s-io
type Service struct {
Name string `json:"name"`
Port BackendPortV1 `json:"port"`
}
type BackendPortV1 struct {
Name string `json:"name"`
Number int `json:"number"`
}
func (p BackendPortV1) String() string {
if p.Number != 0 {
return strconv.Itoa(p.Number)
}
return p.Name
}
type RuleV1 struct {
Host string `json:"host"`
Http *HTTPRuleV1 `json:"http"`
}
type HTTPRuleV1 struct {
Paths []*PathRuleV1 `json:"paths"`
}
// PathRuleV1 https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.22/#httpingresspath-v1-networking-k8s-io
type PathRuleV1 struct {
Path string `json:"path"`
PathType string `json:"pathType"`
Backend *BackendV1 `json:"backend"`
}
type TLSV1 struct {
Hosts []string `json:"hosts"`
SecretName string `json:"secretName"`
}
// ResourceID is a stripped down version of Metadata used to identify resources in a cache map
type ResourceID struct {
Namespace string
Name string
}
// BackendPort is used for TargetPort similar to Kubernetes intOrString type
type BackendPort struct {
Value interface{}
}
// ParseIngressV1JSON parse JSON into an IngressV1List
func ParseIngressV1JSON(d []byte) (IngressV1List, error) {
var il IngressV1List
err := json.Unmarshal(d, &il)
return il, err
}
func GetHostsFromIngressRulesV1(ing *IngressV1Item) []string {
hostList := make([]string, 0)
for _, i := range ing.Spec.Rules {
hostList = append(hostList, i.Host)
}
return hostList
}
// String converts BackendPort to string
func (p BackendPort) String() string {
switch v := p.Value.(type) {
case string:
return v
case int:
return strconv.Itoa(v)
default:
return ""
}
}
// Number converts BackendPort to int
func (p BackendPort) Number() (int, bool) {
i, ok := p.Value.(int)
return i, ok
}
func (p *BackendPort) UnmarshalJSON(value []byte) error {
if value[0] == '"' {
var s string
if err := json.Unmarshal(value, &s); err != nil {
return err
}
p.Value = s
return nil
}
var i int
if err := json.Unmarshal(value, &i); err != nil {
return err
}
p.Value = i
return nil
}
func (p BackendPort) MarshalJSON() ([]byte, error) {
switch p.Value.(type) {
case string, int:
return json.Marshal(p.Value)
default:
return nil, errInvalidPortType
}
}
package definitions
import (
"errors"
"fmt"
"github.com/zalando/skipper/eskip"
)
type IngressV1Validator struct{}
func (igv *IngressV1Validator) Validate(item *IngressV1Item) error {
var errs []error
errs = append(errs, igv.validateFilterAnnotation(item.Metadata.Annotations))
errs = append(errs, igv.validatePredicateAnnotation(item.Metadata.Annotations))
errs = append(errs, igv.validateRoutesAnnotation(item.Metadata.Annotations))
return errors.Join(errs...)
}
func (igv *IngressV1Validator) validateFilterAnnotation(annotations map[string]string) error {
if filters, ok := annotations[IngressFilterAnnotation]; ok {
_, err := eskip.ParseFilters(filters)
if err != nil {
err = fmt.Errorf("invalid \"%s\" annotation: %w", IngressFilterAnnotation, err)
}
return err
}
return nil
}
func (igv *IngressV1Validator) validatePredicateAnnotation(annotations map[string]string) error {
if predicates, ok := annotations[IngressPredicateAnnotation]; ok {
_, err := eskip.ParsePredicates(predicates)
if err != nil {
err = fmt.Errorf("invalid \"%s\" annotation: %w", IngressPredicateAnnotation, err)
}
return err
}
return nil
}
func (igv *IngressV1Validator) validateRoutesAnnotation(annotations map[string]string) error {
if routes, ok := annotations[IngressRoutesAnnotation]; ok {
_, err := eskip.Parse(routes)
if err != nil {
err = fmt.Errorf("invalid \"%s\" annotation: %w", IngressRoutesAnnotation, err)
}
return err
}
return nil
}
package definitions
import (
"encoding/json"
"fmt"
"errors"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/loadbalancer"
)
// adding Kubernetes specific backend types here. To be discussed.
// The main reason is to differentiate between service and external, in a way
// where we can also use the current global option to decide whether the service
// should then be converted to LB. Or shall we expect the route group already
// contain the pod endpoints, and ignore the global option for skipper?
// --> As CRD we have to lookup endpoints ourselves, maybe via kube.go
const (
ServiceBackend = eskip.LBBackend + 1 + iota
)
var (
errRouteGroupWithoutBackend = errors.New("route group without backend")
errRouteGroupWithoutName = errors.New("route group without name")
errRouteGroupWithoutSpec = errors.New("route group without spec")
errInvalidRouteSpec = errors.New("invalid route spec")
errInvalidMethod = errors.New("invalid method")
errBothPathAndPathSubtree = errors.New("path and path subtree in the same route")
errMissingBackendReference = errors.New("missing backend reference")
errUnnamedBackend = errors.New("unnamed backend")
errUnnamedBackendReference = errors.New("unnamed backend reference")
)
type RouteGroupList struct {
Items []*RouteGroupItem `json:"items"`
}
type RouteGroupItem struct {
Metadata *Metadata `json:"metadata"`
Spec *RouteGroupSpec `json:"spec"`
}
type RouteGroupSpec struct {
// Hosts specifies the host headers, that will be matched for
// all routes created by this route group.
Hosts []string `json:"hosts,omitempty"`
// Backends specify the list of backends that can be
// referenced from routes or DefaultBackends.
Backends []*SkipperBackend `json:"backends"`
// DefaultBackends should be in most cases only one default
// backend which is applied to all routes, if no override was
// added to a route. A special case is Traffic Switching which
// will have more than one default backend definition.
DefaultBackends BackendReferences `json:"defaultBackends,omitempty"`
// Routes specifies the list of route based on path, method
// and predicates.
Routes []*RouteSpec `json:"routes,omitempty"`
// TLS specifies the list of Kubernetes TLS secrets to
// be used to terminate the TLS connection
TLS []*RouteTLSSpec `json:"tls,omitempty"`
}
// SkipperBackend is the type safe version of skipperBackendParser
type SkipperBackend struct {
// Name is the backendName that can be referenced as backendReference
Name string
// Type is the parsed backend type
Type eskip.BackendType
// Address is required for Type network. Address follows the
// URL spec without path, query and fragment. For example
// https://user:password@example.org
Address string
// ServiceName is required for Type service
ServiceName string
// ServicePort is required for Type service
ServicePort int
// Algorithm is required for Type lb
Algorithm loadbalancer.Algorithm
// Endpoints is required for Type lb
Endpoints []string
parseError error
}
// skipperBackendParser is an intermediate type required for parsing
// skipperBackend and adding type safety for Algorithm and Type with
// skipperBackend type.
type skipperBackendParser struct {
// Name is the backendName that can be referenced as backendReference
Name string `json:"name"`
// Type is one of "service|shunt|loopback|dynamic|lb|network"
Type string `json:"type"`
// Address is required for Type network
Address string `json:"address"`
// Algorithm is required for Type lb
Algorithm string `json:"algorithm"`
// Endpoints is required for Type lb
Endpoints []string `json:"endpoints"`
// ServiceName is required for Type service
ServiceName string `json:"serviceName"`
// ServicePort is required for Type service
ServicePort int `json:"servicePort"`
}
type BackendReference struct {
// BackendName references the skipperBackend by name
BackendName string `json:"backendName"`
// Weight defines the traffic weight, if there are 2 or more
// default backends
Weight int `json:"weight"`
}
type BackendReferences []*BackendReference
var _ WeightedBackend = &BackendReference{}
func (br *BackendReference) GetName() string { return br.BackendName }
func (br *BackendReference) GetWeight() float64 { return float64(br.Weight) }
type RouteSpec struct {
// Path specifies Path predicate, only one of Path or PathSubtree is allowed
Path string `json:"path,omitempty"`
// PathSubtree specifies PathSubtree predicate, only one of Path or PathSubtree is allowed
PathSubtree string `json:"pathSubtree,omitempty"`
// PathRegexp can be added additionally
PathRegexp string `json:"pathRegexp,omitempty"`
// Backends specifies the list of backendReference that should
// be applied to override the defaultBackends
Backends BackendReferences `json:"backends,omitempty"`
// Filters specifies the list of filters applied to the RouteSpec
Filters []string `json:"filters,omitempty"`
// Predicates specifies the list of predicates applied to the RouteSpec
Predicates []string `json:"predicates,omitempty"`
// Methods defines valid HTTP methods for the specified RouteSpec
Methods []string `json:"methods,omitempty"`
}
type RouteTLSSpec struct {
// Hosts specifies the list of hosts included in the
// TLS certificate
Hosts []string `json:"hosts,omitempty"`
// SecretName specifies the Kubernetes TLS secret to be
// used to terminate the TLS SNI connection
SecretName string `json:"secretName,omitempty"`
}
func backendsWithDuplicateName(name string) error {
return fmt.Errorf("backends with duplicate name: %s", name)
}
func invalidBackend(name string, err error) error {
return fmt.Errorf("invalid backend: %s, %w", name, err)
}
func invalidBackendReference(name string) error {
return fmt.Errorf("invalid backend reference: %s", name)
}
func duplicateBackendReference(name string) error {
return fmt.Errorf("duplicate backend reference: %s", name)
}
func invalidBackendWeight(name string, w int) error {
return fmt.Errorf("invalid weight in backend: %s, %d", name, w)
}
func invalidRoute(index int, err error) error {
return fmt.Errorf("invalid route at %d, %w", index, err)
}
func missingAddress(backendName string) error {
return fmt.Errorf("address missing in backend: %s", backendName)
}
func missingServiceName(backendName string) error {
return fmt.Errorf("service name missing in backend: %s", backendName)
}
func invalidServicePort(backendName string, p int) error {
return fmt.Errorf("invalid service port in backend: %s, %d", backendName, p)
}
func missingEndpoints(backendName string) error {
return fmt.Errorf("missing LB endpoints in backend: %s", backendName)
}
func routeGroupError(m *Metadata, err error) error {
return fmt.Errorf("error in route group %s/%s: %w", namespaceString(m.Namespace), m.Name, err)
}
// UnmarshalJSON creates a new skipperBackend, safe to be called on nil pointer
func (sb *SkipperBackend) UnmarshalJSON(value []byte) error {
if sb == nil {
return nil
}
var p skipperBackendParser
if err := json.Unmarshal(value, &p); err != nil {
return err
}
var perr error
bt, err := backendTypeFromString(p.Type)
if err != nil {
// we cannot return an error here, because then the parsing of
// all route groups would fail. We'll report the error in the
// validation phase, only for the containing route group
perr = err
}
a, err := loadbalancer.AlgorithmFromString(p.Algorithm)
if err != nil {
// we cannot return an error here, because then the parsing of
// all route groups would fail. We'll report the error in the
// validation phase, only for the containing route group
perr = err
}
var b SkipperBackend
b.Name = p.Name
b.Type = bt
b.Address = p.Address
b.ServiceName = p.ServiceName
b.ServicePort = p.ServicePort
b.Algorithm = a
b.Endpoints = p.Endpoints
b.parseError = perr
*sb = b
return nil
}
func (rg *RouteGroupSpec) UniqueHosts() []string {
return uniqueStrings(rg.Hosts)
}
func (r *RouteSpec) UniqueMethods() []string {
return uniqueStrings(r.Methods)
}
// ParseRouteGroupsJSON parses a json list of RouteGroups into RouteGroupList
func ParseRouteGroupsJSON(d []byte) (RouteGroupList, error) {
var rl RouteGroupList
err := json.Unmarshal(d, &rl)
return rl, err
}
func uniqueStrings(s []string) []string {
u := make([]string, 0, len(s))
m := make(map[string]bool)
for _, si := range s {
if m[si] {
continue
}
m[si] = true
u = append(u, si)
}
return u
}
func backendTypeFromString(s string) (eskip.BackendType, error) {
switch s {
case "", "service":
return ServiceBackend, nil
default:
return eskip.BackendTypeFromString(s)
}
}
func hasEmpty(s []string) bool {
for _, si := range s {
if si == "" {
return true
}
}
return false
}
package definitions
import (
"errors"
"fmt"
"net/url"
"github.com/zalando/skipper/eskip"
)
type RouteGroupValidator struct{}
var (
errSingleFilterExpected = errors.New("single filter expected")
errSinglePredicateExpected = errors.New("single predicate expected")
)
var defaultRouteGroupValidator = &RouteGroupValidator{}
// ValidateRouteGroup validates a RouteGroupItem
func ValidateRouteGroup(rg *RouteGroupItem) error {
return defaultRouteGroupValidator.Validate(rg)
}
func ValidateRouteGroups(rgl *RouteGroupList) error {
var errs []error
for _, rg := range rgl.Items {
errs = append(errs, defaultRouteGroupValidator.Validate(rg))
}
return errors.Join(errs...)
}
func (rgv *RouteGroupValidator) Validate(item *RouteGroupItem) error {
err := rgv.basicValidation(item)
if err != nil {
return err
}
var errs []error
errs = append(errs, rgv.validateFilters(item))
errs = append(errs, rgv.validatePredicates(item))
errs = append(errs, rgv.validateBackends(item))
errs = append(errs, rgv.validateHosts(item))
return errors.Join(errs...)
}
// TODO: we need to pass namespace/name in all errors
func (rgv *RouteGroupValidator) basicValidation(r *RouteGroupItem) error {
// has metadata and name:
if r == nil || validate(r.Metadata) != nil {
return errRouteGroupWithoutName
}
// has spec:
if r.Spec == nil {
return routeGroupError(r.Metadata, errRouteGroupWithoutSpec)
}
if err := r.Spec.validate(); err != nil {
return routeGroupError(r.Metadata, err)
}
return nil
}
func (rgv *RouteGroupValidator) validateFilters(item *RouteGroupItem) error {
var errs []error
for _, r := range item.Spec.Routes {
for _, f := range r.Filters {
filters, err := eskip.ParseFilters(f)
if err != nil {
errs = append(errs, err)
} else if len(filters) != 1 {
errs = append(errs, fmt.Errorf("%w at %q", errSingleFilterExpected, f))
}
}
}
return errors.Join(errs...)
}
func (rgv *RouteGroupValidator) validatePredicates(item *RouteGroupItem) error {
var errs []error
for _, r := range item.Spec.Routes {
for _, p := range r.Predicates {
predicates, err := eskip.ParsePredicates(p)
if err != nil {
errs = append(errs, err)
} else if len(predicates) != 1 {
errs = append(errs, fmt.Errorf("%w at %q", errSinglePredicateExpected, p))
}
}
}
return errors.Join(errs...)
}
func (rgv *RouteGroupValidator) validateBackends(item *RouteGroupItem) error {
var errs []error
for _, backend := range item.Spec.Backends {
if backend.Type == eskip.NetworkBackend {
address, err := url.Parse(backend.Address)
if err != nil {
errs = append(errs, fmt.Errorf("failed to parse backend address %q: %w", backend.Address, err))
} else {
if address.Path != "" || address.RawQuery != "" || address.Scheme == "" || address.Host == "" {
errs = append(errs, fmt.Errorf("backend address %q does not match scheme://host format", backend.Address))
}
}
}
}
return errors.Join(errs...)
}
func (rgv *RouteGroupValidator) validateHosts(item *RouteGroupItem) error {
var errs []error
uniqueHosts := make(map[string]struct{}, len(item.Spec.Hosts))
for _, host := range item.Spec.Hosts {
if _, ok := uniqueHosts[host]; ok {
errs = append(errs, fmt.Errorf("duplicate host %q", host))
}
uniqueHosts[host] = struct{}{}
}
return errors.Join(errs...)
}
// TODO: we need to pass namespace/name in all errors
func (rg *RouteGroupSpec) validate() error {
// has at least one backend:
if len(rg.Backends) == 0 {
return errRouteGroupWithoutBackend
}
// backends valid and have unique names:
backends := make(map[string]bool)
for _, b := range rg.Backends {
if backends[b.Name] {
return backendsWithDuplicateName(b.Name)
}
backends[b.Name] = true
if err := b.validate(); err != nil {
return invalidBackend(b.Name, err)
}
}
hasDefault := len(rg.DefaultBackends) > 0
if err := rg.DefaultBackends.validate(backends); err != nil {
return err
}
if !hasDefault && len(rg.Routes) == 0 {
return errMissingBackendReference
}
for i, r := range rg.Routes {
if err := r.validate(hasDefault, backends); err != nil {
return invalidRoute(i, err)
}
}
return nil
}
// TODO: we need to pass namespace/name in all errors
func (r *RouteSpec) validate(hasDefault bool, backends map[string]bool) error {
if r == nil {
return errInvalidRouteSpec
}
if !hasDefault && len(r.Backends) == 0 {
return errMissingBackendReference
}
if err := r.Backends.validate(backends); err != nil {
return err
}
if r.Path != "" && r.PathSubtree != "" {
return errBothPathAndPathSubtree
}
if hasEmpty(r.Methods) {
return errInvalidMethod
}
return nil
}
func (br *BackendReference) validate(backends map[string]bool) error {
if br == nil || br.BackendName == "" {
return errUnnamedBackendReference
}
if !backends[br.BackendName] {
return invalidBackendReference(br.BackendName)
}
if br.Weight < 0 {
return invalidBackendWeight(br.BackendName, br.Weight)
}
return nil
}
func (brs BackendReferences) validate(backends map[string]bool) error {
if brs == nil {
return nil
}
names := make(map[string]struct{}, len(brs))
for _, br := range brs {
if _, ok := names[br.BackendName]; ok {
return duplicateBackendReference(br.BackendName)
}
names[br.BackendName] = struct{}{}
if err := br.validate(backends); err != nil {
return err
}
}
return nil
}
func (sb *SkipperBackend) validate() error {
if sb.parseError != nil {
return sb.parseError
}
if sb == nil || sb.Name == "" {
return errUnnamedBackend
}
switch {
case sb.Type == eskip.NetworkBackend && sb.Address == "":
return missingAddress(sb.Name)
case sb.Type == ServiceBackend && sb.ServiceName == "":
return missingServiceName(sb.Name)
case sb.Type == ServiceBackend &&
(sb.ServicePort == 0 || sb.ServicePort != int(uint16(sb.ServicePort))):
return invalidServicePort(sb.Name, sb.ServicePort)
case sb.Type == eskip.LBBackend && len(sb.Endpoints) == 0:
return missingEndpoints(sb.Name)
}
return nil
}
package kubernetes
import (
"fmt"
"strings"
"github.com/zalando/skipper/eskip"
)
func eastWestRouteID(rid string) string {
return "kubeew" + rid[len(ingressRouteIDPrefix):]
}
func createEastWestRouteIng(eastWestDomain, name, ns string, r *eskip.Route) *eskip.Route {
if strings.HasPrefix(r.Id, "kubeew") || ns == "" || name == "" {
return nil
}
ewR := *r
ewR.HostRegexps = []string{createHostRx(name + "." + ns + "." + eastWestDomain)}
ewR.Id = eastWestRouteID(r.Id)
return &ewR
}
func createEastWestRouteRG(name, ns, postfix string, r *eskip.Route) *eskip.Route {
hostRx := createHostRx(fmt.Sprintf("%s.%s.%s", name, ns, postfix))
ewr := eskip.Copy(r)
ewr.Id = eastWestRouteID(ewr.Id)
ewr.HostRegexps = nil
p := make([]*eskip.Predicate, 0, len(ewr.Predicates))
for _, pi := range ewr.Predicates {
if pi.Name != "Host" {
p = append(p, pi)
}
}
p = append(p, &eskip.Predicate{
Name: "Host",
Args: []interface{}{hostRx},
})
ewr.Predicates = p
return ewr
}
func applyEastWestRange(domains []string, predicates []*eskip.Predicate, host string, routes []*eskip.Route) {
for _, d := range domains {
if strings.HasSuffix(host, d) {
applyEastWestRangePredicates(routes, predicates)
}
}
}
func applyEastWestRangePredicates(routes []*eskip.Route, predicates []*eskip.Predicate) {
for _, route := range routes {
route.Predicates = append(route.Predicates, predicates...)
}
}
package kubernetes
import (
"net"
"strconv"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
)
type endpointID struct {
definitions.ResourceID
TargetPort string
Protocol string
}
type endpoint struct {
Meta *definitions.Metadata `json:"metadata"`
Subsets []*subset `json:"subsets"`
}
type endpointList struct {
Items []*endpoint `json:"items"`
}
func formatEndpointString(ip, scheme string, port int) string {
return scheme + "://" + net.JoinHostPort(ip, strconv.Itoa(port))
}
func formatEndpoint(a *address, p *port, scheme string) string {
return formatEndpointString(a.IP, scheme, p.Port)
}
func formatEndpointsForSubsetAddresses(addresses []*address, port *port, scheme string) []string {
result := make([]string, 0, len(addresses))
for _, address := range addresses {
result = append(result, formatEndpoint(address, port, scheme))
}
return result
}
func (ep *endpoint) targetsByServicePort(protocol string, servicePort *servicePort) []string {
for _, s := range ep.Subsets {
// If only one port exists in the endpoint, use it
if len(s.Ports) == 1 {
return formatEndpointsForSubsetAddresses(s.Addresses, s.Ports[0], protocol)
}
// Otherwise match port by name
for _, p := range s.Ports {
if p.Name != servicePort.Name {
continue
}
return formatEndpointsForSubsetAddresses(s.Addresses, p, protocol)
}
}
return nil
}
func (ep *endpoint) targetsByServiceTarget(scheme string, serviceTarget *definitions.BackendPort) []string {
portName, named := serviceTarget.Value.(string)
portValue, byValue := serviceTarget.Value.(int)
for _, s := range ep.Subsets {
for _, p := range s.Ports {
if named && p.Name != portName || byValue && p.Port != portValue {
continue
}
var result []string
for _, a := range s.Addresses {
result = append(result, formatEndpoint(a, p, scheme))
}
return result
}
}
return nil
}
func (ep *endpoint) addresses() []string {
result := make([]string, 0)
for _, s := range ep.Subsets {
for _, a := range s.Addresses {
result = append(result, a.IP)
}
}
return result
}
type subset struct {
Addresses []*address `json:"addresses"`
Ports []*port `json:"ports"`
}
type address struct {
IP string `json:"ip"`
Node string `json:"nodeName"`
}
type port struct {
Name string `json:"name"`
Port int `json:"port"`
Protocol string `json:"protocol"`
}
package kubernetes
import (
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
)
const endpointSliceServiceNameLabel = "kubernetes.io/service-name"
// There are [1..N] Kubernetes endpointslices created for a single Kubernetes service.
// Kubernetes endpointslices of a given service can have duplicates with different states.
// Therefore Kubernetes endpointslices need to be de-duplicated before usage.
// The business object skipperEndpointSlice is a de-duplicated endpoint list that concats all endpointslices of a given service into one slice of skipperEndpointSlice.
type skipperEndpointSlice struct {
Meta *definitions.Metadata
Endpoints []*skipperEndpoint
Ports []*endpointSlicePort
}
// Conditions have to be evaluated before creation
type skipperEndpoint struct {
Address string
Zone string
}
func (eps *skipperEndpointSlice) getPort(protocol, pName string, pValue int) int {
var port int
for _, p := range eps.Ports {
if protocol != "" && p.Protocol != protocol {
continue
}
// https://pkg.go.dev/k8s.io/api/core/v1#ServicePort
// Optional if only one ServicePort is defined on this service.
// Therefore empty name match is fine.
if p.Name == pName {
port = p.Port
break
}
if pValue != 0 && p.Port == pValue {
port = pValue
break
}
}
return port
}
func (eps *skipperEndpointSlice) targetsByServicePort(protocol, scheme string, servicePort *servicePort) []string {
var port int
if servicePort.Name != "" {
port = eps.getPort(protocol, servicePort.Name, servicePort.Port)
} else if servicePort.TargetPort != nil {
var ok bool
port, ok = servicePort.TargetPort.Number()
if !ok {
port = eps.getPort(protocol, servicePort.Name, servicePort.Port)
}
} else {
port = eps.getPort(protocol, servicePort.Name, servicePort.Port)
}
result := make([]string, 0, len(eps.Endpoints))
for _, ep := range eps.Endpoints {
result = append(result, formatEndpointString(ep.Address, scheme, port))
}
return result
}
func (eps *skipperEndpointSlice) targetsByServiceTarget(protocol, scheme string, serviceTarget *definitions.BackendPort) []string {
pName, _ := serviceTarget.Value.(string)
pValue, _ := serviceTarget.Value.(int)
port := eps.getPort(protocol, pName, pValue)
result := make([]string, 0, len(eps.Endpoints))
for _, ep := range eps.Endpoints {
result = append(result, formatEndpointString(ep.Address, scheme, port))
}
return result
}
func (eps *skipperEndpointSlice) addresses() []string {
result := make([]string, 0, len(eps.Endpoints))
for _, ep := range eps.Endpoints {
result = append(result, ep.Address)
}
return result
}
type endpointSliceList struct {
Meta *definitions.Metadata
Items []*endpointSlice `json:"items"`
}
// see https://kubernetes.io/docs/reference/kubernetes-api/service-resources/endpoint-slice-v1/#EndpointSlice
type endpointSlice struct {
Meta *definitions.Metadata `json:"metadata"`
AddressType string `json:"addressType"` // "IPv4"
Endpoints []*EndpointSliceEndpoints `json:"endpoints"`
Ports []*endpointSlicePort `json:"ports"` // contains all ports like 9999/9911
}
// ToResourceID returns the same string for a group endpointlisces created for the same svc
func (eps *endpointSlice) ToResourceID() definitions.ResourceID {
svcName := eps.Meta.Labels[endpointSliceServiceNameLabel]
namespace := eps.Meta.Namespace
return newResourceID(namespace, svcName)
}
// EndpointSliceEndpoints is the single endpoint definition
type EndpointSliceEndpoints struct {
// Addresses [1..100] of the same AddressType, see also https://github.com/kubernetes/kubernetes/issues/106267
// Basically it always has only one in our case and likely makes no sense to use more than one.
// Pick first or one at random are possible, but skipper will pick the first.
// If you need something else please create an issue https://github.com/zalando/skipper/issues/new/choose
Addresses []string `json:"addresses"` // [ "10.2.13.9" ]
// Conditions are used for deciding to drop out of load balancer or fade into the load balancer.
Conditions *endpointsliceCondition `json:"conditions"`
// Zone is used for zone aware traffic routing, please see also
// https://kubernetes.io/docs/concepts/services-networking/topology-aware-routing/#constraints
// https://kubernetes.io/docs/concepts/services-networking/topology-aware-routing/#safeguards
// Zone aware routing will be available if https://github.com/zalando/skipper/issues/1446 is closed.
Zone string `json:"zone"` // "eu-central-1c"
}
type endpointsliceCondition struct {
Ready *bool `json:"ready"` // ready endpoint -> put into endpoints unless terminating
Serving *bool `json:"serving"` // serving endpoint
Terminating *bool `json:"terminating"` // termiating pod -> drop out of endpoints
}
type endpointSlicePort struct {
Name string `json:"name"` // "http"
Port int `json:"port"` // 8080
Protocol string `json:"protocol"` // "TCP"
// AppProtocol is not used, but would make it possible to optimize H2C and websocket connections
AppProtocol string `json:"appProtocol"` // "kubernetes.io/h2c", "kubernetes.io/ws", "kubernetes.io/wss"
}
func (ep *EndpointSliceEndpoints) isTerminating() bool {
// see also https://github.com/kubernetes/kubernetes/blob/91aca10d5984313c1c5858979d4946ff9446615f/pkg/proxy/endpointslicecache.go#L137C39-L139
return ep.Conditions != nil && ep.Conditions.Terminating != nil && *ep.Conditions.Terminating
}
func (ep *EndpointSliceEndpoints) isReady() bool {
if ep.isTerminating() {
return false
}
// defaults to ready, see also https://github.com/kubernetes/kubernetes/blob/91aca10d5984313c1c5858979d4946ff9446615f/pkg/proxy/endpointslicecache.go#L137C39-L139
// we ignore serving because of https://github.com/zalando/skipper/issues/2684
return ep.Conditions.Ready == nil || *ep.Conditions.Ready
}
package kubernetes
import (
"net/url"
"regexp"
"strings"
"github.com/zalando/skipper/eskip"
)
func createHostRx(hosts ...string) string {
if len(hosts) == 0 {
return ""
}
hrx := make([]string, len(hosts))
for i, host := range hosts {
// trailing dots and port are not allowed in kube
// ingress spec, so we can append optional setting
// without check
hrx[i] = strings.ReplaceAll(host, ".", "[.]") + "[.]?(:[0-9]+)?"
}
return "^(" + strings.Join(hrx, "|") + ")$"
}
// hostCatchAllRoutes creates catch-all routes for those hosts that only have routes with
// a Host predicate and at least one additional predicate.
//
// currently only used for RouteGroups
func hostCatchAllRoutes(hostRoutes map[string][]*eskip.Route, createID func(string) string) []*eskip.Route {
var catchAll []*eskip.Route
for h, r := range hostRoutes {
var hasHostOnlyRoute bool
for _, ri := range r {
ct := eskip.Canonical(ri)
var hasNonHostPredicate bool
for _, p := range ct.Predicates {
if p.Name != "Host" {
hasNonHostPredicate = true
break
}
}
if !hasNonHostPredicate {
hasHostOnlyRoute = true
break
}
}
if !hasHostOnlyRoute {
catchAll = append(catchAll, &eskip.Route{
Id: createID(h),
Predicates: []*eskip.Predicate{{
Name: "Host",
Args: []interface{}{createHostRx(h)},
}},
BackendType: eskip.ShuntBackend,
})
}
}
return catchAll
}
func isExternalDomainAllowed(allowedDomains []*regexp.Regexp, domain string) bool {
for _, a := range allowedDomains {
if a.MatchString(domain) {
return true
}
}
return false
}
func isExternalAddressAllowed(allowedDomains []*regexp.Regexp, address string) bool {
u, err := url.Parse(address)
if err != nil {
return false
}
return isExternalDomainAllowed(allowedDomains, u.Hostname())
}
func isEastWestHost(host string, eastWestRangeDomains []string) bool {
for _, domainSuffix := range eastWestRangeDomains {
if strings.HasSuffix(host, domainSuffix) {
return true
}
}
return false
}
package kubernetes
import (
"crypto/tls"
b64 "encoding/base64"
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/secrets/certregistry"
)
const (
ingressRouteIDPrefix = "kube"
backendWeightsAnnotationKey = "zalando.org/backend-weights"
ratelimitAnnotationKey = "zalando.org/ratelimit"
skipperLoadBalancerAnnotationKey = "zalando.org/skipper-loadbalancer"
skipperBackendProtocolAnnotationKey = "zalando.org/skipper-backend-protocol"
pathModeAnnotationKey = "zalando.org/skipper-ingress-path-mode"
ingressOriginName = "ingress"
tlsSecretType = "kubernetes.io/tls"
tlsSecretDataCrt = "tls.crt"
tlsSecretDataKey = "tls.key"
)
type ingressContext struct {
state *clusterState
ingressV1 *definitions.IngressV1Item
logger *logger
annotationFilters []*eskip.Filter
annotationPredicate string
extraRoutes []*eskip.Route
backendWeights map[string]float64
pathMode PathMode
redirect *redirectInfo
hostRoutes map[string][]*eskip.Route
defaultFilters defaultFilters
certificateRegistry *certregistry.CertRegistry
calculateTraffic func([]*weightedIngressBackend) map[string]backendTraffic
}
type ingress struct {
eastWestRangeDomains []string
eastWestRangePredicates []*eskip.Predicate
allowedExternalNames []*regexp.Regexp
kubernetesEastWestDomain string
pathMode PathMode
httpsRedirectCode int
kubernetesEnableEastWest bool
provideHTTPSRedirect bool
disableCatchAllRoutes bool
forceKubernetesService bool
backendTrafficAlgorithm BackendTrafficAlgorithm
defaultLoadBalancerAlgorithm string
kubernetesAnnotationPredicates []AnnotationPredicates
kubernetesAnnotationFiltersAppend []AnnotationFilters
kubernetesEastWestRangeAnnotationPredicates []AnnotationPredicates
kubernetesEastWestRangeAnnotationFiltersAppend []AnnotationFilters
}
var nonWord = regexp.MustCompile(`\W`)
var errNotAllowedExternalName = errors.New("ingress with not allowed external name service")
func (ic *ingressContext) addHostRoute(host string, route *eskip.Route) {
ic.hostRoutes[host] = append(ic.hostRoutes[host], route)
}
func newIngress(o Options) *ingress {
return &ingress{
provideHTTPSRedirect: o.ProvideHTTPSRedirect,
httpsRedirectCode: o.HTTPSRedirectCode,
disableCatchAllRoutes: o.DisableCatchAllRoutes,
pathMode: o.PathMode,
kubernetesEnableEastWest: o.KubernetesEnableEastWest,
kubernetesEastWestDomain: o.KubernetesEastWestDomain,
eastWestRangeDomains: o.KubernetesEastWestRangeDomains,
eastWestRangePredicates: o.KubernetesEastWestRangePredicates,
allowedExternalNames: o.AllowedExternalNames,
forceKubernetesService: o.ForceKubernetesService,
backendTrafficAlgorithm: o.BackendTrafficAlgorithm,
defaultLoadBalancerAlgorithm: o.DefaultLoadBalancerAlgorithm,
kubernetesAnnotationPredicates: o.KubernetesAnnotationPredicates,
kubernetesAnnotationFiltersAppend: o.KubernetesAnnotationFiltersAppend,
kubernetesEastWestRangeAnnotationPredicates: o.KubernetesEastWestRangeAnnotationPredicates,
kubernetesEastWestRangeAnnotationFiltersAppend: o.KubernetesEastWestRangeAnnotationFiltersAppend,
}
}
func getLoadBalancerAlgorithm(m *definitions.Metadata, defaultAlgorithm string) string {
algorithm := defaultAlgorithm
if algorithmAnnotationValue, ok := m.Annotations[skipperLoadBalancerAnnotationKey]; ok {
algorithm = algorithmAnnotationValue
}
return algorithm
}
// TODO: find a nicer way to autogenerate route IDs
func routeID(namespace, name, host, path, backend string) string {
namespace = nonWord.ReplaceAllString(namespace, "_")
name = nonWord.ReplaceAllString(name, "_")
host = nonWord.ReplaceAllString(host, "_")
path = nonWord.ReplaceAllString(path, "_")
backend = nonWord.ReplaceAllString(backend, "_")
return fmt.Sprintf("%s_%s__%s__%s__%s__%s", ingressRouteIDPrefix, namespace, name, host, path, backend)
}
// routeIDForCustom generates a route id for a custom route of an ingress
// resource.
func routeIDForCustom(namespace, name, id, host string, index int) string {
name = name + "_" + id + "_" + strconv.Itoa(index)
return routeID(namespace, name, host, "", "")
}
func externalNameRoute(
ns, name, idHost string,
hostRegexps []string,
svc *service,
servicePort *servicePort,
allowedNames []*regexp.Regexp,
) (*eskip.Route, error) {
if !isExternalDomainAllowed(allowedNames, svc.Spec.ExternalName) {
return nil, fmt.Errorf("%w: %s", errNotAllowedExternalName, svc.Spec.ExternalName)
}
scheme := "https"
if n, _ := servicePort.TargetPort.Number(); n != 443 {
scheme = "http"
}
u := fmt.Sprintf("%s://%s:%s", scheme, svc.Spec.ExternalName, servicePort.TargetPort)
f, err := eskip.ParseFilters(fmt.Sprintf(`setRequestHeader("Host", "%s")`, svc.Spec.ExternalName))
if err != nil {
return nil, err
}
return &eskip.Route{
Id: routeID(ns, name, idHost, "", svc.Spec.ExternalName),
BackendType: eskip.NetworkBackend,
Backend: u,
Filters: f,
HostRegexps: hostRegexps,
}, nil
}
func applyAnnotationPredicates(m PathMode, r *eskip.Route, annotation string) error {
if annotation == "" {
return nil
}
predicates, err := eskip.ParsePredicates(annotation)
if err != nil {
return err
}
// to avoid conflict, give precedence for those predicates that come
// from annotations
if m == PathPrefix {
for _, p := range predicates {
if p.Name != "Path" && p.Name != "PathSubtree" {
continue
}
r.Path = ""
for i, p := range r.Predicates {
if p.Name != "PathSubtree" && p.Name != "Path" {
continue
}
copy(r.Predicates[i:], r.Predicates[i+1:])
r.Predicates[len(r.Predicates)-1] = nil
r.Predicates = r.Predicates[:len(r.Predicates)-1]
break
}
}
}
r.Predicates = append(r.Predicates, predicates...)
return nil
}
func (ing *ingress) addExtraRoutes(ic *ingressContext, ruleHost, path, pathType string) {
hosts := []string{createHostRx(ruleHost)}
var ns, name string
name = ic.ingressV1.Metadata.Name
ns = ic.ingressV1.Metadata.Namespace
eastWestDomain := ing.kubernetesEastWestDomain
enableEastWest := ing.kubernetesEnableEastWest
ewHost := isEastWestHost(ruleHost, ing.eastWestRangeDomains)
// add extra routes from optional annotation
for extraIndex, r := range ic.extraRoutes {
route := *r
route.HostRegexps = hosts
route.Id = routeIDForCustom(
ns,
name,
route.Id,
ruleHost+strings.ReplaceAll(path, "/", "_"),
extraIndex)
setPathV1(ic.pathMode, &route, pathType, path)
if n := countPathPredicates(&route); n <= 1 {
if ewHost {
appendAnnotationPredicates(ing.kubernetesEastWestRangeAnnotationPredicates, ic.ingressV1.Metadata.Annotations, &route)
appendAnnotationFilters(ing.kubernetesEastWestRangeAnnotationFiltersAppend, ic.ingressV1.Metadata.Annotations, &route)
} else {
appendAnnotationPredicates(ing.kubernetesAnnotationPredicates, ic.ingressV1.Metadata.Annotations, &route)
appendAnnotationFilters(ing.kubernetesAnnotationFiltersAppend, ic.ingressV1.Metadata.Annotations, &route)
}
ic.addHostRoute(ruleHost, &route)
ic.redirect.updateHost(ruleHost)
} else {
ic.logger.Errorf("Ignoring route due to multiple path predicates: %d path predicates, route: %v", n, route)
}
if enableEastWest {
ewRoute := createEastWestRouteIng(eastWestDomain, name, ns, &route)
ewHost := fmt.Sprintf("%s.%s.%s", name, ns, eastWestDomain)
ic.addHostRoute(ewHost, ewRoute)
}
}
}
func countPathPredicates(r *eskip.Route) int {
i := 0
for _, p := range r.Predicates {
if p.Name == "PathSubtree" || p.Name == "Path" {
i++
}
}
if r.Path != "" {
i++
}
return i
}
// parse filter and ratelimit annotation
func annotationFilter(m *definitions.Metadata, logger *logger) []*eskip.Filter {
var annotationFilter string
if ratelimitAnnotationValue, ok := m.Annotations[ratelimitAnnotationKey]; ok {
annotationFilter = ratelimitAnnotationValue
}
if val, ok := m.Annotations[definitions.IngressFilterAnnotation]; ok {
if annotationFilter != "" {
annotationFilter += " -> "
}
annotationFilter += val
}
if annotationFilter != "" {
annotationFilters, err := eskip.ParseFilters(annotationFilter)
if err == nil {
return annotationFilters
}
logger.Errorf("Can not parse annotation filters: %v", err)
}
return nil
}
// parse predicate annotation
func annotationPredicate(m *definitions.Metadata) string {
var annotationPredicate string
if val, ok := m.Annotations[definitions.IngressPredicateAnnotation]; ok {
annotationPredicate = val
}
return annotationPredicate
}
// parse routes annotation
func extraRoutes(m *definitions.Metadata) []*eskip.Route {
var extraRoutes []*eskip.Route
annotationRoutes := m.Annotations[definitions.IngressRoutesAnnotation]
if annotationRoutes != "" {
extraRoutes, _ = eskip.Parse(annotationRoutes) // We ignore the error here because it should be handled by the validator object
}
return extraRoutes
}
// parse backend-weights annotation if it exists
func backendWeights(m *definitions.Metadata, logger *logger) map[string]float64 {
var backendWeights map[string]float64
if backends, ok := m.Annotations[backendWeightsAnnotationKey]; ok {
err := json.Unmarshal([]byte(backends), &backendWeights)
if err != nil {
logger.Errorf("Error while parsing backend-weights annotation: %v", err)
}
}
return backendWeights
}
// parse pathmode from annotation or fallback to global default
func pathMode(m *definitions.Metadata, globalDefault PathMode, logger *logger) PathMode {
pathMode := globalDefault
if pathModeString, ok := m.Annotations[pathModeAnnotationKey]; ok {
if p, err := ParsePathMode(pathModeString); err != nil {
logger.Errorf("Failed to get path mode: %v", err)
} else {
logger.Debugf("Set pathMode to %s", p)
pathMode = p
}
}
return pathMode
}
func (ing *ingress) addCatchAllRoutes(host string, r *eskip.Route, redirect *redirectInfo) []*eskip.Route {
catchAll := &eskip.Route{
Id: routeID("", "catchall", host, "", ""),
HostRegexps: r.HostRegexps,
BackendType: eskip.ShuntBackend,
}
routes := []*eskip.Route{catchAll}
if ing.kubernetesEnableEastWest {
if ew := createEastWestRouteIng(ing.kubernetesEastWestDomain, r.Name, r.Namespace, catchAll); ew != nil {
routes = append(routes, ew)
}
}
if code, ok := redirect.setHostCode[host]; ok {
routes = append(routes, createIngressEnableHTTPSRedirect(catchAll, code))
}
if redirect.disableHost[host] {
routes = append(routes, createIngressDisableHTTPSRedirect(catchAll))
}
return routes
}
// hasCatchAllRoutes returns true if one of the routes in the list has a catchAll
// path expression.
//
// TODO: this should also consider path types exact and subtree
func hasCatchAllRoutes(routes []*eskip.Route) bool {
for _, route := range routes {
if len(route.PathRegexps) == 0 {
return true
}
for _, exp := range route.PathRegexps {
if exp == "^/" {
return true
}
}
}
return false
}
// convert logs if an invalid found, but proceeds with the valid ones.
// Reporting failures in Ingress status is not possible, because
// Ingress status field only supports IP and Hostname as string.
func (ing *ingress) convert(state *clusterState, df defaultFilters, r *certregistry.CertRegistry, loggingEnabled bool) ([]*eskip.Route, error) {
var ewIngInfo map[string][]string // r.Id -> {namespace, name}
if ing.kubernetesEnableEastWest {
ewIngInfo = make(map[string][]string)
}
routes := make([]*eskip.Route, 0, len(state.ingressesV1))
hostRoutes := make(map[string][]*eskip.Route)
redirect := createRedirectInfo(ing.provideHTTPSRedirect, ing.httpsRedirectCode)
for _, i := range state.ingressesV1 {
r, err := ing.ingressV1Route(i, redirect, state, hostRoutes, df, r, loggingEnabled)
if err != nil {
return nil, err
}
if r != nil {
routes = append(routes, r)
if ing.kubernetesEnableEastWest {
ewIngInfo[r.Id] = []string{i.Metadata.Namespace, i.Metadata.Name}
}
}
}
for host, rs := range hostRoutes {
if len(rs) == 0 {
continue
}
applyEastWestRange(ing.eastWestRangeDomains, ing.eastWestRangePredicates, host, rs)
routes = append(routes, rs...)
if !ing.disableCatchAllRoutes {
// if routes were configured, but there is no catchall route
// defined for the host name, create a route which returns 404
if !hasCatchAllRoutes(rs) {
routes = append(routes, ing.addCatchAllRoutes(host, rs[0], redirect)...)
}
}
}
if ing.kubernetesEnableEastWest && len(routes) > 0 && len(ewIngInfo) > 0 {
ewroutes := make([]*eskip.Route, 0, len(routes))
for _, r := range routes {
if v, ok := ewIngInfo[r.Id]; ok {
ewroutes = append(ewroutes, createEastWestRouteIng(ing.kubernetesEastWestDomain, v[0], v[1], r))
}
}
l := len(routes)
routes = append(routes, ewroutes...)
log.Infof("Enabled east west routes: %d %d %d %d", l, len(routes), len(ewroutes), len(hostRoutes))
}
return routes, nil
}
func generateTLSCertFromSecret(secret *secret) (*tls.Certificate, error) {
if secret.Data[tlsSecretDataCrt] == "" || secret.Data[tlsSecretDataKey] == "" {
return nil, fmt.Errorf("secret must contain %s and %s in data field", tlsSecretDataCrt, tlsSecretDataKey)
}
crt, err := b64.StdEncoding.DecodeString(secret.Data[tlsSecretDataCrt])
if err != nil {
return nil, fmt.Errorf("failed to decode %s from secret %s", tlsSecretDataCrt, secret.Metadata.Name)
}
key, err := b64.StdEncoding.DecodeString(secret.Data[tlsSecretDataKey])
if err != nil {
return nil, fmt.Errorf("failed to decode %s from secret %s", tlsSecretDataKey, secret.Metadata.Name)
}
cert, err := tls.X509KeyPair([]byte(crt), []byte(key))
if err != nil {
return nil, fmt.Errorf("failed to create tls certificate from secret %s", secret.Metadata.Name)
}
if secret.Type != tlsSecretType {
return nil, fmt.Errorf("secret %s is not of type %s", secret.Metadata.Name, tlsSecretType)
}
return &cert, nil
}
package kubernetes
import (
"errors"
"fmt"
"regexp"
"strings"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/secrets/certregistry"
)
type weightedIngressBackend struct {
name string
weight float64
}
var _ definitions.WeightedBackend = &weightedIngressBackend{}
func (b *weightedIngressBackend) GetName() string { return b.name }
func (b *weightedIngressBackend) GetWeight() float64 { return b.weight }
func setPathOld(m PathMode, r *eskip.Route, p string) {
switch m {
case PathPrefix:
r.Predicates = append(r.Predicates, &eskip.Predicate{
Name: "PathSubtree",
Args: []interface{}{p},
})
case PathRegexp:
r.PathRegexps = []string{p}
default:
if p == "/" {
r.PathRegexps = []string{"^/"}
} else {
r.PathRegexps = []string{"^(" + p + ")"}
}
}
}
func setPathV1(m PathMode, r *eskip.Route, pathType, path string) {
if path == "" {
return
}
// see https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.22/#httpingresspath-v1-networking-k8s-io
switch pathType {
case "Exact":
r.Predicates = append(r.Predicates, &eskip.Predicate{
Name: "Path",
Args: []interface{}{path},
})
case "Prefix":
r.Predicates = append(r.Predicates, &eskip.Predicate{
Name: "PathSubtree",
Args: []interface{}{path},
})
default:
setPathOld(m, r, path)
}
}
func convertPathRuleV1(
ic *ingressContext,
host string,
prule *definitions.PathRuleV1,
traffic backendTraffic,
allowedExternalNames []*regexp.Regexp,
forceKubernetesService bool,
defaultLoadBalancerAlgorithm string,
) (*eskip.Route, error) {
state := ic.state
metadata := ic.ingressV1.Metadata
pathMode := ic.pathMode
ns := metadata.Namespace
name := metadata.Name
if prule.Backend == nil {
return nil, fmt.Errorf("invalid path rule, missing backend in: %s/%s/%s", ns, name, host)
}
var (
eps []string
err error
svc *service
)
var hostRegexp []string
if host != "" {
hostRegexp = []string{createHostRx(host)}
}
svcPort := prule.Backend.Service.Port
svcName := prule.Backend.Service.Name
svc, err = state.getService(ns, svcName)
if err != nil {
ic.logger.Errorf("Failed to get service %s, %s", svcName, svcPort)
return nil, err
}
servicePort, err := svc.getServicePortV1(svcPort)
if err != nil {
// service definition is wrong or no pods
err = nil
if len(eps) > 0 {
// should never happen
ic.logger.Errorf("Failed to find target port for service %s, but %d endpoints exist", svcName, len(eps))
}
} else if svc.Spec.Type == "ExternalName" {
return externalNameRoute(ns, name, host, hostRegexp, svc, servicePort, allowedExternalNames)
} else if forceKubernetesService {
eps = []string{serviceNameBackend(svcName, ns, servicePort)}
} else {
protocol := "http"
if p, ok := metadata.Annotations[skipperBackendProtocolAnnotationKey]; ok {
protocol = p
}
eps = state.GetEndpointsByService(ns, svcName, protocol, servicePort)
}
if len(eps) == 0 {
ic.logger.Tracef("Target endpoints not found, shuntroute for %s:%s", svcName, svcPort)
r := &eskip.Route{
Id: routeID(ns, name, host, prule.Path, svcName),
HostRegexps: hostRegexp,
}
setPathV1(pathMode, r, prule.PathType, prule.Path)
traffic.apply(r)
shuntRoute(r)
return r, nil
}
ic.logger.Tracef("Found %d endpoints for %s:%s", len(eps), svcName, svcPort)
if len(eps) == 1 {
r := &eskip.Route{
Id: routeID(ns, name, host, prule.Path, svcName),
Backend: eps[0],
BackendType: eskip.NetworkBackend,
HostRegexps: hostRegexp,
}
setPathV1(pathMode, r, prule.PathType, prule.Path)
traffic.apply(r)
return r, nil
}
r := &eskip.Route{
Id: routeID(ns, name, host, prule.Path, svcName),
BackendType: eskip.LBBackend,
LBEndpoints: eps,
LBAlgorithm: getLoadBalancerAlgorithm(metadata, defaultLoadBalancerAlgorithm),
HostRegexps: hostRegexp,
}
setPathV1(pathMode, r, prule.PathType, prule.Path)
traffic.apply(r)
return r, nil
}
func (ing *ingress) addEndpointsRuleV1(ic *ingressContext, host string, prule *definitions.PathRuleV1, traffic backendTraffic) error {
meta := ic.ingressV1.Metadata
endpointsRoute, err := convertPathRuleV1(
ic,
host,
prule,
traffic,
ing.allowedExternalNames,
ing.forceKubernetesService,
ing.defaultLoadBalancerAlgorithm,
)
if err != nil {
// if the service is not found the route should be removed
if err == errServiceNotFound || err == errResourceNotFound {
return nil
}
// TODO: this error checking should not really be used, and the error handling of the ingress
// problems should be refactored such that a single ingress's error doesn't block the
// processing of the independent ingresses.
if errors.Is(err, errNotAllowedExternalName) {
ic.logger.Infof("Not allowed external name: %v", err)
return nil
}
// Ingress status field does not support errors
return fmt.Errorf("error while getting service: %w", err)
}
if endpointsRoute.BackendType != eskip.ShuntBackend {
// safe prepend, see: https://play.golang.org/p/zg5aGKJpRyK
filters := make([]*eskip.Filter, len(endpointsRoute.Filters)+len(ic.annotationFilters))
copy(filters, ic.annotationFilters)
copy(filters[len(ic.annotationFilters):], endpointsRoute.Filters)
endpointsRoute.Filters = filters
}
// add pre-configured default filters
df, err := ic.defaultFilters.getNamed(meta.Namespace, prule.Backend.Service.Name)
if err != nil {
ic.logger.Errorf("Failed to retrieve default filters: %v", err)
} else {
// it's safe to prepend, because type defaultFilters copies the slice during get()
endpointsRoute.Filters = append(df, endpointsRoute.Filters...)
}
err = applyAnnotationPredicates(ic.pathMode, endpointsRoute, ic.annotationPredicate)
if err != nil {
ic.logger.Errorf("Failed to apply annotation predicates: %v", err)
}
ic.addHostRoute(host, endpointsRoute)
redirect := ic.redirect
ewRangeMatch := isEastWestHost(host, ing.eastWestRangeDomains)
if !(ewRangeMatch || strings.HasSuffix(host, ing.kubernetesEastWestDomain) && ing.kubernetesEastWestDomain != "") {
switch {
case redirect.ignore:
// no redirect
case redirect.enable:
ic.addHostRoute(host, createIngressEnableHTTPSRedirect(endpointsRoute, redirect.code))
redirect.setHost(host)
case redirect.disable:
ic.addHostRoute(host, createIngressDisableHTTPSRedirect(endpointsRoute))
redirect.setHostDisabled(host)
case redirect.defaultEnabled:
ic.addHostRoute(host, createIngressEnableHTTPSRedirect(endpointsRoute, redirect.code))
redirect.setHost(host)
}
appendAnnotationPredicates(ing.kubernetesAnnotationPredicates, meta.Annotations, endpointsRoute)
appendAnnotationFilters(ing.kubernetesAnnotationFiltersAppend, meta.Annotations, endpointsRoute)
} else {
appendAnnotationPredicates(ing.kubernetesEastWestRangeAnnotationPredicates, meta.Annotations, endpointsRoute)
appendAnnotationFilters(ing.kubernetesEastWestRangeAnnotationFiltersAppend, meta.Annotations, endpointsRoute)
}
if ing.kubernetesEnableEastWest {
ewRoute := createEastWestRouteIng(ing.kubernetesEastWestDomain, meta.Name, meta.Namespace, endpointsRoute)
ewHost := fmt.Sprintf("%s.%s.%s", meta.Name, meta.Namespace, ing.kubernetesEastWestDomain)
ic.addHostRoute(ewHost, ewRoute)
}
return nil
}
// computeBackendWeightsV1 computes backend traffic weights for the rule backends grouped by path rule.
func computeBackendWeightsV1(calculateTraffic func([]*weightedIngressBackend) map[string]backendTraffic, backendWeights map[string]float64, rule *definitions.RuleV1) map[*definitions.PathRuleV1]backendTraffic {
backendsPerPath := make(map[string][]*weightedIngressBackend)
for _, prule := range rule.Http.Paths {
b := &weightedIngressBackend{
name: prule.Backend.Service.Name,
weight: backendWeights[prule.Backend.Service.Name],
}
backendsPerPath[prule.Path] = append(backendsPerPath[prule.Path], b)
}
trafficPerPath := make(map[string]map[string]backendTraffic, len(backendsPerPath))
for path, b := range backendsPerPath {
trafficPerPath[path] = calculateTraffic(b)
}
trafficPerPathRule := make(map[*definitions.PathRuleV1]backendTraffic)
for _, prule := range rule.Http.Paths {
trafficPerPathRule[prule] = trafficPerPath[prule.Path][prule.Backend.Service.Name]
}
return trafficPerPathRule
}
// TODO: default filters not applied to 'extra' routes from the custom route annotations. Is it on purpose?
// https://github.com/zalando/skipper/issues/1287
func (ing *ingress) addSpecRuleV1(ic *ingressContext, ru *definitions.RuleV1) error {
if ru.Http == nil {
ic.logger.Infof("Skipping rule without http definition")
return nil
}
trafficPerPathRule := computeBackendWeightsV1(ic.calculateTraffic, ic.backendWeights, ru)
for _, prule := range ru.Http.Paths {
ing.addExtraRoutes(ic, ru.Host, prule.Path, prule.PathType)
if trafficPerPathRule[prule].allowed() {
err := ing.addEndpointsRuleV1(ic, ru.Host, prule, trafficPerPathRule[prule])
if err != nil {
return err
}
}
}
return nil
}
// addSpecIngressTLSV1 is used to add TLS Certificates from Ingress resources. Certificates will be added
// only if the Ingress rule host matches a host in TLS config
func (ing *ingress) addSpecIngressTLSV1(ic *ingressContext, ingtls *definitions.TLSV1) {
ingressHosts := definitions.GetHostsFromIngressRulesV1(ic.ingressV1)
// Hosts in the tls section need to explicitly match the host in the rules section.
hostlist := compareStringList(ingtls.Hosts, ingressHosts)
if len(hostlist) == 0 {
ic.logger.Errorf("No matching tls hosts found - tls hosts: %s, ingress hosts: %s", ingtls.Hosts, ingressHosts)
return
} else if len(hostlist) != len(ingtls.Hosts) {
ic.logger.Infof("Hosts in TLS and Ingress don't match: tls hosts: %s, ingress hosts: %s", ingtls.Hosts, definitions.GetHostsFromIngressRulesV1(ic.ingressV1))
}
// Skip adding certs to registry since no certs defined
if ingtls.SecretName == "" {
ic.logger.Debugf("No tls secret defined for hosts - %s", ingtls.Hosts)
return
}
// Secrets should always reside in same namespace as the Ingress
secretID := definitions.ResourceID{Name: ingtls.SecretName, Namespace: ic.ingressV1.Metadata.Namespace}
secret, ok := ic.state.secrets[secretID]
if !ok {
ic.logger.Errorf("Failed to find secret %s in namespace %s", secretID.Name, secretID.Namespace)
return
}
addTLSCertToRegistry(ic.certificateRegistry, ic.logger, hostlist, secret)
}
// converts the default backend if any
func (ing *ingress) convertDefaultBackendV1(
ic *ingressContext,
forceKubernetesService bool,
) (*eskip.Route, bool, error) {
state := ic.state
i := ic.ingressV1
// the usage of the default backend depends on what we want
// we can generate a hostname out of it based on shared rules
// and instructions in annotations, if there are no rules defined
// this is a flaw in the ingress API design, because it is not on the hosts' level, but the spec
// tells to match if no rule matches. This means that there is no matching rule on this ingress
// and if there are multiple ingress items, then there is a race between them.
if i.Spec.DefaultBackend == nil {
return nil, false, nil
}
var (
eps []string
err error
ns = i.Metadata.Namespace
name = i.Metadata.Name
svcName = i.Spec.DefaultBackend.Service.Name
svcPort = i.Spec.DefaultBackend.Service.Port
)
svc, err := state.getService(ns, svcName)
if err != nil {
ic.logger.Errorf("Failed to get service %s, %s", svcName, svcPort)
return nil, false, err
}
servicePort, err := svc.getServicePortV1(svcPort)
if err != nil {
ic.logger.Errorf("Failed to find target port %v, %s, for service %s add shuntroute: %v", svc.Spec.Ports, svcPort, svcName, err)
err = nil
} else if svc.Spec.Type == "ExternalName" {
r, err := externalNameRoute(ns, name, "default", nil, svc, servicePort, ing.allowedExternalNames)
return r, err == nil, err
} else if forceKubernetesService {
eps = []string{serviceNameBackend(svcName, ns, servicePort)}
} else {
ic.logger.Debugf("Found target port %v, for service %s", servicePort.TargetPort, svcName)
protocol := "http"
if p, ok := i.Metadata.Annotations[skipperBackendProtocolAnnotationKey]; ok {
protocol = p
}
eps = state.GetEndpointsByService(
ns,
svcName,
protocol,
servicePort,
)
ic.logger.Debugf("Found %d endpoints for %s: %v", len(eps), svcName, err)
}
if len(eps) == 0 {
ic.logger.Tracef("Target endpoints not found, shuntroute for %s:%s", svcName, svcPort)
r := &eskip.Route{
Id: routeID(ns, name, "", "", ""),
}
shuntRoute(r)
return r, true, nil
} else if len(eps) == 1 {
return &eskip.Route{
Id: routeID(ns, name, "", "", ""),
Backend: eps[0],
BackendType: eskip.NetworkBackend,
}, true, nil
}
return &eskip.Route{
Id: routeID(ns, name, "", "", ""),
BackendType: eskip.LBBackend,
LBEndpoints: eps,
LBAlgorithm: getLoadBalancerAlgorithm(i.Metadata, ing.defaultLoadBalancerAlgorithm),
}, true, nil
}
func serviceNameBackend(svcName, svcNamespace string, servicePort *servicePort) string {
scheme := "https"
if n, _ := servicePort.TargetPort.Number(); n != 443 {
scheme = "http"
}
return fmt.Sprintf("%s://%s.%s.svc.cluster.local:%s", scheme, svcName, svcNamespace, servicePort.TargetPort)
}
func (ing *ingress) ingressV1Route(
i *definitions.IngressV1Item,
redirect *redirectInfo,
state *clusterState,
hostRoutes map[string][]*eskip.Route,
df defaultFilters,
r *certregistry.CertRegistry,
loggingEnabled bool,
) (*eskip.Route, error) {
if i.Metadata == nil || i.Metadata.Namespace == "" || i.Metadata.Name == "" || i.Spec == nil {
log.Error("invalid ingress item: missing Metadata or Spec")
return nil, nil
}
logger := newLogger("Ingress", i.Metadata.Namespace, i.Metadata.Name, loggingEnabled)
redirect.initCurrent(i.Metadata)
ic := &ingressContext{
state: state,
ingressV1: i,
logger: logger,
annotationFilters: annotationFilter(i.Metadata, logger),
annotationPredicate: annotationPredicate(i.Metadata),
extraRoutes: extraRoutes(i.Metadata),
backendWeights: backendWeights(i.Metadata, logger),
pathMode: pathMode(i.Metadata, ing.pathMode, logger),
redirect: redirect,
hostRoutes: hostRoutes,
defaultFilters: df,
certificateRegistry: r,
calculateTraffic: getBackendTrafficCalculator[*weightedIngressBackend](ing.backendTrafficAlgorithm),
}
var route *eskip.Route
if r, ok, err := ing.convertDefaultBackendV1(ic, ing.forceKubernetesService); ok {
route = r
} else if err != nil {
ic.logger.Errorf("Failed to convert default backend: %v", err)
}
for _, rule := range i.Spec.Rules {
err := ing.addSpecRuleV1(ic, rule)
if err != nil {
return nil, err
}
}
if ic.certificateRegistry != nil {
for _, ingtls := range i.Spec.IngressTLS {
ing.addSpecIngressTLSV1(ic, ingtls)
}
}
return route, nil
}
package kubernetes
import (
"fmt"
"net"
"net/http"
"os"
"regexp"
"strings"
"sync"
"text/template"
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/loadbalancer"
"github.com/zalando/skipper/secrets/certregistry"
)
const DefaultLoadBalancerAlgorithm = "roundRobin"
const (
defaultIngressClass = "skipper"
defaultRouteGroupClass = "skipper"
serviceHostEnvVar = "KUBERNETES_SERVICE_HOST"
servicePortEnvVar = "KUBERNETES_SERVICE_PORT"
httpRedirectRouteID = "kube__redirect"
defaultEastWestDomain = "skipper.cluster.local"
)
// PathMode values are used to control the ingress path interpretation. The path mode can
// be set globally for all ingress paths, and it can be overruled by the individual ingress
// rules using the zalando.org/skipper-ingress-path-mode annotation. When path mode is not
// set, the Kubernetes ingress specification is used, accepting regular expressions with a
// mandatory leading "/", automatically prepended by the "^" control character.
//
// When PathPrefix is used, the path matching becomes deterministic when
// a request could match more than one ingress routes otherwise.
type PathMode int
const (
// KubernetesIngressMode is the default path mode. Expects regular expressions
// with a mandatory leading "/". The expressions are automatically prepended by
// the "^" control character.
KubernetesIngressMode PathMode = iota
// PathRegexp is like KubernetesIngressMode but is not prepended by the "^"
// control character.
PathRegexp
// PathPrefix is like the PathSubtree predicate. E.g. "/foo/bar" will match
// "/foo/bar" or "/foo/bar/baz", but won't match "/foo/barooz".
//
// In this mode, when a Path or a PathSubtree predicate is set in an annotation,
// the value from the annotation has precedence over the standard ingress path.
PathPrefix
)
const (
kubernetesIngressModeString = "kubernetes-ingress"
pathRegexpString = "path-regexp"
pathPrefixString = "path-prefix"
)
const maxFileSize = 1024 * 1024 // 1MB
var internalIPs = []interface{}{
"10.0.0.0/8",
"192.168.0.0/16",
"172.16.0.0/12",
"127.0.0.1/8",
"fd00::/8",
"::1/128",
}
// Options is used to initialize the Kubernetes DataClient.
type Options struct {
// KubernetesInCluster defines if skipper is deployed and running in the kubernetes cluster
// this would make authentication with API server happen through the service account, rather than
// running along side kubectl proxy
KubernetesInCluster bool
// KubernetesURL is used as the base URL for Kubernetes API requests. Defaults to http://localhost:8001.
// (TBD: support in-cluster operation by taking the address and certificate from the standard Kubernetes
// environment variables.)
KubernetesURL string
// TokenFile configures path to the token file.
// Defaults to /var/run/secrets/kubernetes.io/serviceaccount/token when running in-cluster.
TokenFile string
// KubernetesNamespace is used to switch between finding ingresses in the cluster-scope or limit
// the ingresses to only those in the specified namespace. Defaults to "" which means monitor ingresses
// in the cluster-scope.
KubernetesNamespace string
// KubernetesEnableEndpointslices if set skipper will fetch
// endpointslices instead of endpoints to scale more than 1000 pods within a service
KubernetesEnableEndpointslices bool
// *DEPRECATED* KubernetesEnableEastWest if set adds automatically routes
// with "%s.%s.skipper.cluster.local" domain pattern
KubernetesEnableEastWest bool
// ProvideHealthcheck, when set, tells the data client to append a healthcheck route to the ingress
// routes in case of successfully receiving the ingress items from the API (even if individual ingress
// items may be invalid), or a failing healthcheck route when the API communication fails. The
// healthcheck endpoint can be accessed from internal IPs on any hostname, with the path
// /kube-system/healthz.
//
// When used in a custom configuration, the current filter registry needs to include the status()
// filter, and the available predicates need to include the Source() predicate.
ProvideHealthcheck bool
// ProvideHTTPSRedirect, when set, tells the data client to append an HTTPS redirect route to the
// ingress routes. This route will detect the X-Forwarded-Proto=http and respond with a 301 message
// to the HTTPS equivalent of the same request (using the redirectTo(301, "https:") filter). The
// X-Forwarded-Proto and X-Forwarded-Port is expected to be set by the load balancer.
//
// (See also https://github.com/zalando-incubator/kube-ingress-aws-controller as part of the
// https://github.com/zalando-incubator/kubernetes-on-aws project.)
ProvideHTTPSRedirect bool
// HTTPSRedirectCode, when set defines which redirect code to use for redirecting from HTTP to HTTPS.
// By default, 308 StatusPermanentRedirect is used.
HTTPSRedirectCode int
// DisableCatchAllRoutes, when set, tells the data client to not create catchall routes.
DisableCatchAllRoutes bool
// IngressClass is a regular expression to filter only those ingresses that match. If an ingress does
// not have a class annotation or the annotation is an empty string, skipper will load it. The default
// value for the ingress class is 'skipper'.
//
// For further information see:
// https://github.com/nginxinc/kubernetes-ingress/tree/master/examples/multiple-ingress-controllers
IngressClass string
// RouteGroupClass is a regular expression to filter only those RouteGroups that match. If a RouteGroup
// does not have the required annotation (zalando.org/routegroup.class) or the annotation is an empty string,
// skipper will load it. The default value for the RouteGroup class is 'skipper'.
RouteGroupClass string
// IngressLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. A label and its value on an Ingress must be match exactly to be loaded by Skipper.
// If the value is irrelevant for a given configuration, it can be left empty. The default
// value is no labels required.
// Examples:
// Config [] will load all Ingresses.
// Config ["skipper-enabled": ""] will load only Ingresses with a label "skipper-enabled", no matter the value.
// Config ["skipper-enabled": "true"] will load only Ingresses with a label "skipper-enabled: true"
// Config ["skipper-enabled": "", "foo": "bar"] will load only Ingresses with both labels while label "foo" must have a value "bar".
IngressLabelSelectors map[string]string
// ServicesLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
ServicesLabelSelectors map[string]string
// EndpointsLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
EndpointsLabelSelectors map[string]string
// EndpointSlicesLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
EndpointSlicesLabelSelectors map[string]string
// SecretsLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
SecretsLabelSelectors map[string]string
// RouteGroupsLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
RouteGroupsLabelSelectors map[string]string
// ReverseSourcePredicate set to true will do the Source IP
// whitelisting for the heartbeat endpoint correctly in AWS.
// Amazon's ALB writes the client IP to the last item of the
// string list of the X-Forwarded-For header, in this case you
// want to set this to true.
ReverseSourcePredicate bool
// Noop, WIP.
ForceFullUpdatePeriod time.Duration
// WhitelistedHealthcheckCIDR to be appended to the default iprange
WhitelistedHealthCheckCIDR []string
// PathMode controls the default interpretation of ingress paths in cases when the ingress doesn't
// specify it with an annotation.
PathMode PathMode
// *DEPRECATED *KubernetesEastWestDomain sets the DNS domain to be
// used for east west traffic, defaults to "skipper.cluster.local"
KubernetesEastWestDomain string
// KubernetesEastWestRangeDomains set the the cluster internal domains for
// east west traffic. Identified routes to such domains will include
// the KubernetesEastWestRangePredicates.
KubernetesEastWestRangeDomains []string
// KubernetesEastWestRangePredicates set the Predicates that will be
// appended to routes identified as to KubernetesEastWestRangeDomains.
KubernetesEastWestRangePredicates []*eskip.Predicate
// KubernetesEastWestRangeAnnotationPredicates same as KubernetesAnnotationPredicates but will append to
// routes that has KubernetesEastWestRangeDomains suffix.
KubernetesEastWestRangeAnnotationPredicates []AnnotationPredicates
// KubernetesEastWestRangeAnnotationFiltersAppend same as KubernetesAnnotationFiltersAppend but will append to
// routes that has KubernetesEastWestRangeDomains suffix.
KubernetesEastWestRangeAnnotationFiltersAppend []AnnotationFilters
// KubernetesAnnotationPredicates sets predicates to append for each annotation key and value
KubernetesAnnotationPredicates []AnnotationPredicates
// KubernetesAnnotationFiltersAppend sets filters to append for each annotation key and value
KubernetesAnnotationFiltersAppend []AnnotationFilters
// DefaultFiltersDir enables default filters mechanism and sets the location of the default filters.
// The provided filters are then applied to all routes.
DefaultFiltersDir string
// OriginMarker is *deprecated* and not used anymore. It will be deleted in v1.
OriginMarker bool
// If the OpenTracing tag containing RouteGroup backend name
// (using tracingTag filter) should be added to all routes
BackendNameTracingTag bool
// OnlyAllowedExternalNames will enable validation of ingress external names and route groups network
// backend addresses, explicit LB endpoints validation against the list of patterns in
// AllowedExternalNames.
OnlyAllowedExternalNames bool
// AllowedExternalNames contains regexp patterns of those domain names that are allowed to be
// used with external name services (type=ExternalName).
AllowedExternalNames []*regexp.Regexp
CertificateRegistry *certregistry.CertRegistry
// ForceKubernetesService overrides the default Skipper functionality to route traffic using
// Kubernetes Endpoint, instead using Kubernetes Services.
ForceKubernetesService bool
// BackendTrafficAlgorithm specifies the algorithm to calculate the backend traffic.
BackendTrafficAlgorithm BackendTrafficAlgorithm
// DefaultLoadBalancerAlgorithm sets the default algorithm to be used for load balancing between backend endpoints,
// available options: roundRobin, consistentHash, random, powerOfRandomNChoices
DefaultLoadBalancerAlgorithm string
}
// Client is a Skipper DataClient implementation used to create routes based on Kubernetes Ingress settings.
type Client struct {
mu sync.Mutex
ClusterClient *clusterClient
ingress *ingress
routeGroups *routeGroups
provideHealthcheck bool
provideHTTPSRedirect bool
reverseSourcePredicate bool
httpsRedirectCode int
current map[string]*eskip.Route
quit chan struct{}
defaultFiltersDir string
state *clusterState
loggingInterval time.Duration
loggingLastEnabled time.Time
}
// New creates and initializes a Kubernetes DataClient.
func New(o Options) (*Client, error) {
if o.OriginMarker {
log.Warning("OriginMarker is deprecated")
}
quit := make(chan struct{})
apiURL, err := buildAPIURL(o)
if err != nil {
return nil, err
}
ingCls := defaultIngressClass
if o.IngressClass != "" {
ingCls = o.IngressClass
}
rgCls := defaultRouteGroupClass
if o.RouteGroupClass != "" {
rgCls = o.RouteGroupClass
}
log.Debugf(
"running in-cluster: %t. api server url: %s. provide health check: %t. ingress.class filter: %s. routegroup.class filter: %s. namespace: %s",
o.KubernetesInCluster, apiURL, o.ProvideHealthcheck, ingCls, rgCls, o.KubernetesNamespace,
)
if len(o.WhitelistedHealthCheckCIDR) > 0 {
whitelistCIDRS := make([]interface{}, len(o.WhitelistedHealthCheckCIDR))
for i, v := range o.WhitelistedHealthCheckCIDR {
whitelistCIDRS[i] = v
}
internalIPs = append(internalIPs, whitelistCIDRS...)
log.Debugf("new internal ips are: %s", internalIPs)
}
if o.HTTPSRedirectCode <= 0 {
o.HTTPSRedirectCode = http.StatusPermanentRedirect
}
if o.KubernetesEnableEastWest {
if o.KubernetesEastWestDomain == "" {
o.KubernetesEastWestDomain = defaultEastWestDomain
} else {
o.KubernetesEastWestDomain = strings.Trim(o.KubernetesEastWestDomain, ".")
}
}
clusterClient, err := newClusterClient(o, apiURL, ingCls, rgCls, quit)
if err != nil {
return nil, err
}
if !o.OnlyAllowedExternalNames {
o.AllowedExternalNames = []*regexp.Regexp{regexp.MustCompile(".*")}
}
if algo, err := loadbalancer.AlgorithmFromString(o.DefaultLoadBalancerAlgorithm); err != nil || algo == loadbalancer.None {
o.DefaultLoadBalancerAlgorithm = DefaultLoadBalancerAlgorithm
}
ing := newIngress(o)
rg := newRouteGroups(o)
return &Client{
ClusterClient: clusterClient,
ingress: ing,
routeGroups: rg,
provideHealthcheck: o.ProvideHealthcheck,
provideHTTPSRedirect: o.ProvideHTTPSRedirect,
httpsRedirectCode: o.HTTPSRedirectCode,
current: make(map[string]*eskip.Route),
reverseSourcePredicate: o.ReverseSourcePredicate,
quit: quit,
defaultFiltersDir: o.DefaultFiltersDir,
loggingInterval: 1 * time.Minute,
}, nil
}
func buildAPIURL(o Options) (string, error) {
if !o.KubernetesInCluster {
if o.KubernetesURL == "" {
return defaultKubernetesURL, nil
}
return o.KubernetesURL, nil
}
host, port := os.Getenv(serviceHostEnvVar), os.Getenv(servicePortEnvVar)
if host == "" || port == "" {
return "", errAPIServerURLNotFound
}
return "https://" + net.JoinHostPort(host, port), nil
}
// String returns the string representation of the path mode, the same
// values that are used in the path mode annotation.
func (m PathMode) String() string {
switch m {
case PathRegexp:
return pathRegexpString
case PathPrefix:
return pathPrefixString
default:
return kubernetesIngressModeString
}
}
// ParsePathMode parses the string representations of the different
// path modes.
func ParsePathMode(s string) (PathMode, error) {
switch s {
case kubernetesIngressModeString:
return KubernetesIngressMode, nil
case pathRegexpString:
return PathRegexp, nil
case pathPrefixString:
return PathPrefix, nil
default:
return 0, fmt.Errorf("invalid path mode string: %s", s)
}
}
func mapRoutes(routes []*eskip.Route) (map[string]*eskip.Route, []*eskip.Route) {
var uniqueRoutes []*eskip.Route
routesById := make(map[string]*eskip.Route)
for _, route := range routes {
if existing, ok := routesById[route.Id]; ok {
existingEskip, routeEskip := existing.String(), route.String()
if existingEskip != routeEskip {
log.Errorf("Ignoring route with the same id %s, existing: %s, ignored: %s", route.Id, existingEskip, routeEskip)
}
} else {
routesById[route.Id] = route
uniqueRoutes = append(uniqueRoutes, route)
}
}
return routesById, uniqueRoutes
}
func (c *Client) loadAndConvert() ([]*eskip.Route, error) {
c.mu.Lock()
state, err := c.ClusterClient.fetchClusterState()
if err != nil {
c.mu.Unlock()
return nil, err
}
c.state = state
loggingEnabled := log.GetLevel() >= log.DebugLevel || time.Since(c.loggingLastEnabled) >= c.loggingInterval
if loggingEnabled {
c.loggingLastEnabled = time.Now()
}
c.mu.Unlock()
defaultFilters := c.fetchDefaultFilterConfigs()
ri, err := c.ingress.convert(state, defaultFilters, c.ClusterClient.certificateRegistry, loggingEnabled)
if err != nil {
return nil, err
}
rg, err := c.routeGroups.convert(state, defaultFilters, loggingEnabled, c.ClusterClient.certificateRegistry)
if err != nil {
return nil, err
}
r := append(ri, rg...)
if c.provideHealthcheck {
r = append(r, healthcheckRoutes(c.reverseSourcePredicate)...)
}
if c.provideHTTPSRedirect {
r = append(r, globalRedirectRoute(c.httpsRedirectCode))
}
return r, nil
}
// shuntRoute creates a route that returns a 502 status code when there are no endpoints found,
// see https://github.com/zalando/skipper/issues/1525
func shuntRoute(r *eskip.Route) {
r.Filters = []*eskip.Filter{
{
Name: filters.StatusName,
Args: []interface{}{502.0},
},
{
Name: filters.InlineContentName,
Args: []interface{}{"no endpoints"},
},
}
r.BackendType = eskip.ShuntBackend
r.Backend = ""
}
func healthcheckRoutes(reverseSourcePredicate bool) []*eskip.Route {
template := template.Must(template.New("healthcheck").Parse(`
kube__healthz_up: Path("/kube-system/healthz") && {{.Source}}({{.SourceCIDRs}}) -> {{.DisableAccessLog}} status(200) -> <shunt>;
kube__healthz_down: Path("/kube-system/healthz") && {{.Source}}({{.SourceCIDRs}}) && Shutdown() -> status(503) -> <shunt>;
`))
params := struct {
Source string
SourceCIDRs string
DisableAccessLog string
}{}
if reverseSourcePredicate {
params.Source = "SourceFromLast"
} else {
params.Source = "Source"
}
if !log.IsLevelEnabled(log.DebugLevel) {
params.DisableAccessLog = "disableAccessLog(200) ->"
}
cidrs := new(strings.Builder)
for i, ip := range internalIPs {
if i > 0 {
cidrs.WriteString(", ")
}
cidrs.WriteString(fmt.Sprintf("%q", ip))
}
params.SourceCIDRs = cidrs.String()
out := new(strings.Builder)
_ = template.Execute(out, params)
routes, _ := eskip.Parse(out.String())
return routes
}
func (c *Client) LoadAll() ([]*eskip.Route, error) {
log.Debug("loading all")
r, err := c.loadAndConvert()
if err != nil {
return nil, fmt.Errorf("failed to load cluster state: %w", err)
}
c.current, r = mapRoutes(r)
log.Debugf("all routes loaded and mapped: %d", len(r))
return r, nil
}
// LoadUpdate returns all known eskip.Route, a list of route IDs
// scheduled for delete and an error.
//
// TODO: implement a force reset after some time.
func (c *Client) LoadUpdate() ([]*eskip.Route, []string, error) {
log.Debugf("polling for updates")
r, err := c.loadAndConvert()
if err != nil {
log.Errorf("polling for updates failed: %v", err)
return nil, nil, err
}
next, _ := mapRoutes(r)
log.Debugf("next version of routes loaded and mapped")
var (
updatedRoutes []*eskip.Route
deletedIDs []string
)
for id := range c.current {
// TODO: use eskip.Eq()
if r, ok := next[id]; ok && r.String() != c.current[id].String() {
updatedRoutes = append(updatedRoutes, r)
} else if !ok {
deletedIDs = append(deletedIDs, id)
}
}
for id, r := range next {
if _, ok := c.current[id]; !ok {
updatedRoutes = append(updatedRoutes, r)
}
}
if len(updatedRoutes) > 0 || len(deletedIDs) > 0 {
log.Infof("diff taken, inserts/updates: %d, deletes: %d", len(updatedRoutes), len(deletedIDs))
}
c.current = next
return updatedRoutes, deletedIDs, nil
}
func (c *Client) Close() {
if c != nil && c.quit != nil {
close(c.quit)
}
}
func (c *Client) fetchDefaultFilterConfigs() defaultFilters {
if c.defaultFiltersDir == "" {
log.Debug("default filters are disabled")
return nil
}
filters, err := readDefaultFilters(c.defaultFiltersDir)
if err != nil {
log.WithError(err).Error("could not fetch default filter configurations")
return nil
}
log.WithField("#configs", len(filters)).Debug("default filter configurations loaded")
return filters
}
// GetEndpointAddresses returns the list of all addresses for the given service
// loaded by previous call to LoadAll or LoadUpdate.
func (c *Client) GetEndpointAddresses(ns, name string) []string {
c.mu.Lock()
defer c.mu.Unlock()
if c.state == nil {
return nil
}
return c.state.getEndpointAddresses(ns, name)
}
// LoadEndpointAddresses returns the list of all addresses for the given service.
func (c *Client) LoadEndpointAddresses(namespace, name string) ([]string, error) {
return c.ClusterClient.loadEndpointAddresses(namespace, name)
}
func compareStringList(a, b []string) []string {
c := make([]string, 0)
for i := len(a) - 1; i >= 0; i-- {
for _, vD := range b {
if a[i] == vD {
c = append(c, vD)
break
}
}
}
return c
}
// addTLSCertToRegistry adds a TLS certificate to the certificate registry per host using the provided
// Kubernetes TLS secret
func addTLSCertToRegistry(cr *certregistry.CertRegistry, logger *logger, hosts []string, secret *secret) {
cert, err := generateTLSCertFromSecret(secret)
if err != nil {
logger.Errorf("Failed to generate TLS certificate from secret: %v", err)
return
}
for _, host := range hosts {
err := cr.ConfigureCertificate(host, cert)
if err != nil {
logger.Errorf("Failed to configure certificate: %v", err)
}
}
}
package kubernetes
import (
"fmt"
"sync"
log "github.com/sirupsen/logrus"
)
type logger struct {
logger *log.Entry
mu sync.Mutex
history map[string]struct{}
}
// newLogger creates a logger that logs each unique message once
// for the resource identified by kind, namespace and name.
// It logs nothing when disabled.
func newLogger(kind, namespace, name string, enabled bool) *logger {
if !enabled {
return nil
}
return &logger{logger: log.WithFields(log.Fields{"kind": kind, "ns": namespace, "name": name})}
}
func (l *logger) Tracef(format string, args ...any) {
if l != nil {
l.once(log.TraceLevel, format, args...)
}
}
func (l *logger) Debugf(format string, args ...any) {
if l != nil {
l.once(log.DebugLevel, format, args...)
}
}
func (l *logger) Infof(format string, args ...any) {
if l != nil {
l.once(log.InfoLevel, format, args...)
}
}
func (l *logger) Errorf(format string, args ...any) {
if l != nil {
l.once(log.ErrorLevel, format, args...)
}
}
func (l *logger) once(level log.Level, format string, args ...any) {
if !l.logger.Logger.IsLevelEnabled(level) {
return
}
l.mu.Lock()
defer l.mu.Unlock()
if l.history == nil {
l.history = make(map[string]struct{})
}
msg := fmt.Sprintf(format, args...)
key := fmt.Sprintf("%s %s", level, msg)
if _, ok := l.history[key]; !ok {
l.logger.Log(level, msg)
l.history[key] = struct{}{}
}
}
package kubernetes
import (
"fmt"
"net/http"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/predicates"
)
const (
redirectAnnotationKey = "zalando.org/skipper-ingress-redirect"
redirectCodeAnnotationKey = "zalando.org/skipper-ingress-redirect-code"
forwardedProtoHeader = "X-Forwarded-Proto"
)
type redirectInfo struct {
defaultEnabled, enable, disable, ignore bool
defaultCode, code int
setHostCode map[string]int
disableHost map[string]bool
}
func createRedirectInfo(defaultEnabled bool, defaultCode int) *redirectInfo {
return &redirectInfo{
defaultEnabled: defaultEnabled,
defaultCode: defaultCode,
setHostCode: make(map[string]int),
disableHost: make(map[string]bool),
}
}
func (r *redirectInfo) initCurrent(m *definitions.Metadata) {
r.enable = m.Annotations[redirectAnnotationKey] == "true"
r.disable = m.Annotations[redirectAnnotationKey] == "false"
r.ignore = strings.Contains(m.Annotations[definitions.IngressPredicateAnnotation], `Header("X-Forwarded-Proto"`) || strings.Contains(m.Annotations[definitions.IngressRoutesAnnotation], `Header("X-Forwarded-Proto"`)
r.code = r.defaultCode
if annotationCode, ok := m.Annotations[redirectCodeAnnotationKey]; ok {
var err interface{}
if r.code, err = strconv.Atoi(annotationCode); err != nil ||
r.code < http.StatusMultipleChoices ||
r.code >= http.StatusBadRequest {
if err == nil {
err = annotationCode
}
log.Error("invalid redirect code annotation:", err)
r.code = r.defaultCode
}
}
}
func (r *redirectInfo) setHost(host string) {
r.setHostCode[host] = r.code
}
func (r *redirectInfo) setHostDisabled(host string) {
r.disableHost[host] = true
}
func (r *redirectInfo) updateHost(host string) {
switch {
case r.enable:
r.setHost(host)
case r.disable:
r.setHostDisabled(host)
case r.defaultEnabled:
r.setHost(host)
}
}
func routeIDForRedirectRoute(baseID string, enable bool) string {
f := "%s_https_redirect"
if !enable {
f = "%s_disable_https_redirect"
}
return fmt.Sprintf(f, baseID)
}
func initRedirectRoute(r *eskip.Route, code int) {
if r.Headers == nil {
r.Headers = make(map[string]string)
}
r.Headers[forwardedProtoHeader] = "http"
// Give this route a higher weight so that it will get precedence over existing routes
r.Predicates = append([]*eskip.Predicate{{
Name: predicates.WeightName,
Args: []interface{}{float64(1000)},
}}, r.Predicates...)
// remove all filters and just set redirect filter
r.Filters = []*eskip.Filter{
{
Name: "redirectTo",
Args: []interface{}{float64(code), "https:"},
},
}
r.BackendType = eskip.ShuntBackend
r.Backend = ""
}
func initDisableRedirectRoute(r *eskip.Route) {
if r.Headers == nil {
r.Headers = make(map[string]string)
}
r.Headers[forwardedProtoHeader] = "http"
// Give this route a higher weight so that it will get precedence over existing routes
r.Predicates = append([]*eskip.Predicate{{
Name: predicates.WeightName,
Args: []interface{}{float64(1000)},
}}, r.Predicates...)
}
func globalRedirectRoute(code int) *eskip.Route {
r := &eskip.Route{Id: httpRedirectRouteID}
initRedirectRoute(r, code)
return r
}
func createIngressEnableHTTPSRedirect(r *eskip.Route, code int) *eskip.Route {
rr := *r
rr.Id = routeIDForRedirectRoute(rr.Id, true)
initRedirectRoute(&rr, code)
return &rr
}
func createIngressDisableHTTPSRedirect(r *eskip.Route) *eskip.Route {
rr := *r
rr.Id = routeIDForRedirectRoute(rr.Id, false)
initDisableRedirectRoute(&rr)
return &rr
}
func hasProtoPredicate(r *eskip.Route) bool {
if r.Headers != nil {
for name := range r.Headers {
if http.CanonicalHeaderKey(name) == forwardedProtoHeader {
return true
}
}
}
if r.HeaderRegexps != nil {
for name := range r.HeaderRegexps {
if http.CanonicalHeaderKey(name) == forwardedProtoHeader {
return true
}
}
}
for _, p := range r.Predicates {
if p.Name != "Header" && p.Name != "HeaderRegexp" {
continue
}
if len(p.Args) > 0 && p.Args[0] == forwardedProtoHeader {
return true
}
}
return false
}
func createHTTPSRedirect(code int, r *eskip.Route) *eskip.Route {
// copy to avoid unexpected mutations
rr := eskip.Copy(r)
rr.Id = routeIDForRedirectRoute(rr.Id, true)
rr.BackendType = eskip.ShuntBackend
rr.Predicates = append(rr.Predicates, &eskip.Predicate{
Name: "Header",
Args: []interface{}{forwardedProtoHeader, "http"},
})
// remove all filters and just set redirect filter
rr.Filters = []*eskip.Filter{
{
Name: "redirectTo",
Args: []interface{}{float64(code), "https:"},
},
}
return rr
}
package kubernetes
import "github.com/zalando/skipper/dataclients/kubernetes/definitions"
func newResourceID(namespace, name string) definitions.ResourceID {
return definitions.ResourceID{Namespace: namespace, Name: name}
}
type ClusterResource struct {
Name string `json:"name"`
}
type ClusterResourceList struct {
// Items, aka "resources".
Items []*ClusterResource `json:"resources"`
}
type secret struct {
Metadata *definitions.Metadata `json:"metadata"`
Type string `json:"type"`
Data map[string]string `json:"data"`
}
type secretList struct {
Items []*secret `json:"items"`
}
package kubernetes
import (
"fmt"
"regexp"
"strings"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/loadbalancer"
"github.com/zalando/skipper/secrets/certregistry"
)
const backendNameTracingTagName = "skipper.backend_name"
// TODO:
// - consider catchall for east-west routes
type routeGroups struct {
options Options
}
type routeGroupContext struct {
state *clusterState
routeGroup *definitions.RouteGroupItem
logger *logger
hosts []string
allowedExternalNames []*regexp.Regexp
hostRx string
eastWestDomain string
hostRoutes map[string][]*eskip.Route
defaultBackendTraffic map[string]backendTraffic
defaultFilters defaultFilters
httpsRedirectCode int
backendsByName map[string]*definitions.SkipperBackend
eastWestEnabled bool
hasEastWestHost bool
backendNameTracingTag bool
internal bool
provideHTTPSRedirect bool
calculateTraffic func([]*definitions.BackendReference) map[string]backendTraffic
defaultLoadBalancerAlgorithm string
certificateRegistry *certregistry.CertRegistry
}
type routeContext struct {
group *routeGroupContext
groupRoute *definitions.RouteSpec
id string
method string
backend *definitions.SkipperBackend
}
func eskipError(typ, e string, err error) error {
if len(e) > 48 {
e = e[:48]
}
return fmt.Errorf("[eskip] %s, '%s'; %w", typ, e, err)
}
func targetPortNotFound(serviceName string, servicePort int) error {
return fmt.Errorf("target port not found: %s:%d", serviceName, servicePort)
}
func newRouteGroups(o Options) *routeGroups {
return &routeGroups{options: o}
}
func namespaceString(ns string) string {
if ns == "" {
return "default"
}
return ns
}
func notSupportedServiceType(s *service) error {
return fmt.Errorf(
"not supported service type in service/%s/%s: %s",
namespaceString(s.Meta.Namespace),
s.Meta.Name,
s.Spec.Type,
)
}
func defaultFiltersError(m *definitions.Metadata, service string, err error) error {
return fmt.Errorf(
"error while applying default filters for route group and service: %s/%s %s, %w",
namespaceString(m.Namespace),
m.Name,
service,
err,
)
}
func hasEastWestHost(eastWestPostfix string, hosts []string) bool {
for _, h := range hosts {
if strings.HasSuffix(h, eastWestPostfix) {
return true
}
}
return false
}
func toSymbol(p string) string {
b := []byte(p)
for i := range b {
if b[i] == '_' ||
b[i] >= '0' && b[i] <= '9' ||
b[i] >= 'a' && b[i] <= 'z' ||
b[i] >= 'A' && b[i] <= 'Z' {
continue
}
b[i] = '_'
}
return string(b)
}
func rgRouteID(namespace, name, subName string, index, subIndex int, internal bool) string {
if internal {
namespace = "internal_" + namespace
}
return fmt.Sprintf(
"kube_rg__%s__%s__%s__%d_%d",
namespace,
name,
subName,
index,
subIndex,
)
}
func crdRouteID(m *definitions.Metadata, method string, routeIndex, backendIndex int, internal bool) string {
return rgRouteID(
toSymbol(namespaceString(m.Namespace)),
toSymbol(m.Name),
toSymbol(method),
routeIndex,
backendIndex,
internal,
)
}
func mapBackends(backends []*definitions.SkipperBackend) map[string]*definitions.SkipperBackend {
m := make(map[string]*definitions.SkipperBackend)
for _, b := range backends {
m[b.Name] = b
}
return m
}
func getBackendService(ctx *routeGroupContext, backend *definitions.SkipperBackend) (*service, error) {
s, err := ctx.state.getServiceRG(
namespaceString(ctx.routeGroup.Metadata.Namespace),
backend.ServiceName,
)
if err != nil {
return nil, err
}
if strings.ToLower(s.Spec.Type) != "clusterip" {
return nil, notSupportedServiceType(s)
}
return s, nil
}
func applyServiceBackend(ctx *routeGroupContext, backend *definitions.SkipperBackend, r *eskip.Route) error {
protocol := "http"
if p, ok := ctx.routeGroup.Metadata.Annotations[skipperBackendProtocolAnnotationKey]; ok {
protocol = p
}
s, err := getBackendService(ctx, backend)
if err != nil {
return err
}
targetPort, ok := s.getTargetPortByValue(backend.ServicePort)
if !ok {
return targetPortNotFound(backend.ServiceName, backend.ServicePort)
}
eps := ctx.state.GetEndpointsByTarget(
namespaceString(ctx.routeGroup.Metadata.Namespace),
s.Meta.Name,
"TCP",
protocol,
targetPort,
)
if len(eps) == 0 {
ctx.logger.Tracef("Target endpoints not found, shuntroute for %s:%d", backend.ServiceName, backend.ServicePort)
shuntRoute(r)
return nil
}
if len(eps) == 1 {
r.BackendType = eskip.NetworkBackend
r.Backend = eps[0]
return nil
}
r.BackendType = eskip.LBBackend
r.LBEndpoints = eps
r.LBAlgorithm = ctx.defaultLoadBalancerAlgorithm
if backend.Algorithm != loadbalancer.None {
r.LBAlgorithm = backend.Algorithm.String()
}
return nil
}
func applyDefaultFilters(ctx *routeGroupContext, serviceName string, r *eskip.Route) error {
f, err := ctx.defaultFilters.getNamed(namespaceString(ctx.routeGroup.Metadata.Namespace), serviceName)
if err != nil {
return defaultFiltersError(ctx.routeGroup.Metadata, serviceName, err)
}
// safe to prepend as defaultFilters.get() copies the slice:
r.Filters = append(f, r.Filters...)
return nil
}
func appendFilter(f []*eskip.Filter, name string, args ...interface{}) []*eskip.Filter {
return append(f, &eskip.Filter{
Name: name,
Args: args,
})
}
func applyBackend(ctx *routeGroupContext, backend *definitions.SkipperBackend, r *eskip.Route) error {
r.BackendType = backend.Type
switch r.BackendType {
case definitions.ServiceBackend:
if err := applyServiceBackend(ctx, backend, r); err != nil {
return err
}
case eskip.NetworkBackend:
if !isExternalAddressAllowed(ctx.allowedExternalNames, backend.Address) {
return fmt.Errorf(
"routegroup with not allowed network backend: %s",
backend.Address,
)
}
r.Backend = backend.Address
case eskip.LBBackend:
for _, ep := range backend.Endpoints {
if !isExternalAddressAllowed(ctx.allowedExternalNames, ep) {
return fmt.Errorf(
"routegroup with not allowed explicit LB endpoint: %s",
ep,
)
}
}
r.LBEndpoints = backend.Endpoints
r.LBAlgorithm = ctx.defaultLoadBalancerAlgorithm
if backend.Algorithm != loadbalancer.None {
r.LBAlgorithm = backend.Algorithm.String()
}
}
if ctx.backendNameTracingTag {
r.Filters = appendFilter(r.Filters, "tracingTag", backendNameTracingTagName, backend.Name)
}
return nil
}
func appendPredicate(p []*eskip.Predicate, name string, args ...interface{}) []*eskip.Predicate {
return append(p, &eskip.Predicate{
Name: name,
Args: args,
})
}
func storeHostRoute(ctx *routeGroupContext, r *eskip.Route) {
for _, h := range ctx.hosts {
ctx.hostRoutes[h] = append(ctx.hostRoutes[h], r)
}
}
func appendEastWest(ctx *routeGroupContext, routes []*eskip.Route, current *eskip.Route) []*eskip.Route {
if !ctx.eastWestEnabled || ctx.hasEastWestHost {
return routes
}
ewr := createEastWestRouteRG(
ctx.routeGroup.Metadata.Name,
namespaceString(ctx.routeGroup.Metadata.Namespace),
ctx.eastWestDomain,
current,
)
return append(routes, ewr)
}
func appendHTTPSRedirect(ctx *routeGroupContext, routes []*eskip.Route, current *eskip.Route) []*eskip.Route {
// in case a route explicitly handles the forwarded proto header, we
// don't shadow it
if !ctx.internal && ctx.provideHTTPSRedirect && !hasProtoPredicate(current) {
hsr := createHTTPSRedirect(ctx.httpsRedirectCode, current)
routes = append(routes, hsr)
}
return routes
}
// implicitGroupRoutes creates routes for those route groups where the `route`
// field is not defined, and the routes are derived from the default backends.
func implicitGroupRoutes(ctx *routeGroupContext) ([]*eskip.Route, error) {
rg := ctx.routeGroup
var routes []*eskip.Route
for backendIndex, beref := range rg.Spec.DefaultBackends {
be := ctx.backendsByName[beref.BackendName]
rid := crdRouteID(rg.Metadata, "all", 0, backendIndex, ctx.internal)
ri := &eskip.Route{Id: rid}
if err := applyBackend(ctx, be, ri); err != nil {
return nil, err
}
if ctx.hostRx != "" {
ri.Predicates = appendPredicate(ri.Predicates, "Host", ctx.hostRx)
}
ctx.defaultBackendTraffic[beref.BackendName].apply(ri)
if be.Type == definitions.ServiceBackend {
if err := applyDefaultFilters(ctx, be.ServiceName, ri); err != nil {
ctx.logger.Errorf("Failed to retrieve default filters: %v", err)
}
}
storeHostRoute(ctx, ri)
routes = append(routes, ri)
routes = appendEastWest(ctx, routes, ri)
routes = appendHTTPSRedirect(ctx, routes, ri)
}
return routes, nil
}
func transformExplicitGroupRoute(ctx *routeContext) (*eskip.Route, error) {
gr := ctx.groupRoute
r := &eskip.Route{Id: ctx.id}
// Path or PathSubtree, prefer Path if we have, because it is more specific
if gr.Path != "" {
r.Predicates = appendPredicate(r.Predicates, "Path", gr.Path)
} else if gr.PathSubtree != "" {
r.Predicates = appendPredicate(r.Predicates, "PathSubtree", gr.PathSubtree)
}
if gr.PathRegexp != "" {
r.Predicates = appendPredicate(r.Predicates, "PathRegexp", gr.PathRegexp)
}
if ctx.group.hostRx != "" {
r.Predicates = appendPredicate(r.Predicates, "Host", ctx.group.hostRx)
}
if ctx.method != "" {
r.Predicates = appendPredicate(r.Predicates, "Method", strings.ToUpper(ctx.method))
}
for _, pi := range gr.Predicates {
ppi, err := eskip.ParsePredicates(pi)
if err != nil {
return nil, eskipError("predicate", pi, err)
}
r.Predicates = append(r.Predicates, ppi...)
}
var f []*eskip.Filter
for _, fi := range gr.Filters {
ffi, err := eskip.ParseFilters(fi)
if err != nil {
return nil, eskipError("filter", fi, err)
}
f = append(f, ffi...)
}
r.Filters = f
err := applyBackend(ctx.group, ctx.backend, r)
if err != nil {
return nil, err
}
if ctx.backend.Type == definitions.ServiceBackend {
if err := applyDefaultFilters(ctx.group, ctx.backend.ServiceName, r); err != nil {
ctx.group.logger.Errorf("Failed to retrieve default filters: %v", err)
}
}
return r, nil
}
// explicitGroupRoutes creates routes for those route groups that have the
// `route` field explicitly defined.
func explicitGroupRoutes(ctx *routeGroupContext) ([]*eskip.Route, error) {
var result []*eskip.Route
rg := ctx.routeGroup
nextRoute:
for routeIndex, rgr := range rg.Spec.Routes {
var routes []*eskip.Route
if len(rgr.Methods) == 0 {
rgr.Methods = []string{""}
}
backendRefs := rg.Spec.DefaultBackends
backendTraffic := ctx.defaultBackendTraffic
if len(rgr.Backends) != 0 {
backendRefs = rgr.Backends
backendTraffic = ctx.calculateTraffic(rgr.Backends)
}
for _, method := range rgr.UniqueMethods() {
for backendIndex, bref := range backendRefs {
be := ctx.backendsByName[bref.BackendName]
idMethod := strings.ToLower(method)
if idMethod == "" {
idMethod = "all"
}
r, err := transformExplicitGroupRoute(&routeContext{
group: ctx,
groupRoute: rgr,
id: crdRouteID(rg.Metadata, idMethod, routeIndex, backendIndex, ctx.internal),
method: strings.ToUpper(method),
backend: be,
})
if err != nil {
ctx.logger.Errorf("Ignoring route: %v", err)
continue nextRoute
}
backendTraffic[bref.BackendName].apply(r)
storeHostRoute(ctx, r)
routes = append(routes, r)
routes = appendEastWest(ctx, routes, r)
routes = appendHTTPSRedirect(ctx, routes, r)
}
}
result = append(result, routes...)
}
return result, nil
}
func transformRouteGroup(ctx *routeGroupContext) ([]*eskip.Route, error) {
ctx.defaultBackendTraffic = ctx.calculateTraffic(ctx.routeGroup.Spec.DefaultBackends)
if len(ctx.routeGroup.Spec.Routes) == 0 {
return implicitGroupRoutes(ctx)
}
return explicitGroupRoutes(ctx)
}
func splitHosts(hosts []string, domains []string) ([]string, []string) {
internalHosts := []string{}
externalHosts := []string{}
for _, host := range hosts {
for _, d := range domains {
if strings.HasSuffix(host, d) {
internalHosts = append(internalHosts, host)
} else {
externalHosts = append(externalHosts, host)
}
}
}
return internalHosts, externalHosts
}
// addRouteGroupTLS compares the RouteGroup host list and the RouteGroup TLS host list
// and adds the TLS secret to the registry if a match is found.
func (r *routeGroups) addRouteGroupTLS(ctx *routeGroupContext, tls *definitions.RouteTLSSpec) {
// Host in the tls section need to explicitly match the host in the RouteGroup
hostlist := compareStringList(tls.Hosts, ctx.routeGroup.Spec.UniqueHosts())
if len(hostlist) == 0 {
ctx.logger.Errorf("No matching tls hosts found - tls hosts: %s, routegroup hosts: %s", tls.Hosts, ctx.routeGroup.Spec.UniqueHosts())
return
} else if len(hostlist) != len(tls.Hosts) {
ctx.logger.Infof("Hosts in TLS and RouteGroup don't match: tls hosts: %s, routegroup hosts: %s", tls.Hosts, ctx.routeGroup.Spec.UniqueHosts())
}
// Skip adding certs to registry since no certs defined
if tls.SecretName == "" {
ctx.logger.Debugf("No tls secret defined for hosts - %s", tls.Hosts)
return
}
// Secrets should always reside in the same namespace as the RouteGroup
secretID := definitions.ResourceID{Name: tls.SecretName, Namespace: ctx.routeGroup.Metadata.Namespace}
secret, ok := ctx.state.secrets[secretID]
if !ok {
ctx.logger.Errorf("Failed to find secret %s in namespace %s", secretID.Name, secretID.Namespace)
return
}
addTLSCertToRegistry(ctx.certificateRegistry, ctx.logger, hostlist, secret)
}
func (r *routeGroups) convert(s *clusterState, df defaultFilters, loggingEnabled bool, cr *certregistry.CertRegistry) ([]*eskip.Route, error) {
var rs []*eskip.Route
redirect := createRedirectInfo(r.options.ProvideHTTPSRedirect, r.options.HTTPSRedirectCode)
for _, rg := range s.routeGroups {
logger := newLogger("RouteGroup", rg.Metadata.Namespace, rg.Metadata.Name, loggingEnabled)
redirect.initCurrent(rg.Metadata)
var internalHosts []string
var externalHosts []string
hosts := rg.Spec.UniqueHosts()
if len(r.options.KubernetesEastWestRangeDomains) == 0 {
externalHosts = hosts
} else {
internalHosts, externalHosts = splitHosts(hosts, r.options.KubernetesEastWestRangeDomains)
}
backends := mapBackends(rg.Spec.Backends)
// If there's no host at all, or if there's any external hosts
// create it.
if len(externalHosts) != 0 || len(hosts) == 0 {
var provideRedirect bool
switch {
case redirect.enable:
provideRedirect = true
case redirect.disable:
provideRedirect = false
case redirect.defaultEnabled:
provideRedirect = true
}
ctx := &routeGroupContext{
state: s,
routeGroup: rg,
logger: logger,
defaultFilters: df,
hosts: externalHosts,
hostRx: createHostRx(externalHosts...),
hostRoutes: make(map[string][]*eskip.Route),
hasEastWestHost: hasEastWestHost(r.options.KubernetesEastWestDomain, externalHosts),
eastWestEnabled: r.options.KubernetesEnableEastWest,
eastWestDomain: r.options.KubernetesEastWestDomain,
provideHTTPSRedirect: provideRedirect,
httpsRedirectCode: r.options.HTTPSRedirectCode,
backendsByName: backends,
backendNameTracingTag: r.options.BackendNameTracingTag,
internal: false,
allowedExternalNames: r.options.AllowedExternalNames,
calculateTraffic: getBackendTrafficCalculator[*definitions.BackendReference](r.options.BackendTrafficAlgorithm),
defaultLoadBalancerAlgorithm: r.options.DefaultLoadBalancerAlgorithm,
certificateRegistry: cr,
}
ri, err := transformRouteGroup(ctx)
if err != nil {
ctx.logger.Errorf("Error transforming external hosts: %v", err)
continue
}
if !r.options.DisableCatchAllRoutes {
catchAll := hostCatchAllRoutes(ctx.hostRoutes, func(host string) string {
// "catchall" won't conflict with any HTTP method
return rgRouteID("", toSymbol(host), "catchall", 0, 0, false)
})
ri = append(ri, catchAll...)
}
if ctx.certificateRegistry != nil {
for _, ctxTls := range rg.Spec.TLS {
r.addRouteGroupTLS(ctx, ctxTls)
}
}
for _, route := range ri {
appendAnnotationPredicates(r.options.KubernetesAnnotationPredicates, rg.Metadata.Annotations, route)
appendAnnotationFilters(r.options.KubernetesAnnotationFiltersAppend, rg.Metadata.Annotations, route)
}
rs = append(rs, ri...)
}
// Internal hosts
if len(internalHosts) > 0 {
internalCtx := &routeGroupContext{
state: s,
routeGroup: rg,
logger: logger,
defaultFilters: df,
hosts: internalHosts,
hostRx: createHostRx(internalHosts...),
hostRoutes: make(map[string][]*eskip.Route),
backendsByName: backends,
backendNameTracingTag: r.options.BackendNameTracingTag,
internal: true,
allowedExternalNames: r.options.AllowedExternalNames,
calculateTraffic: getBackendTrafficCalculator[*definitions.BackendReference](r.options.BackendTrafficAlgorithm),
defaultLoadBalancerAlgorithm: r.options.DefaultLoadBalancerAlgorithm,
certificateRegistry: cr,
}
internalRi, err := transformRouteGroup(internalCtx)
if err != nil {
internalCtx.logger.Errorf("Error transforming internal hosts: %v", err)
continue
}
if !r.options.DisableCatchAllRoutes {
catchAll := hostCatchAllRoutes(internalCtx.hostRoutes, func(host string) string {
// "catchall" won't conflict with any HTTP method
return rgRouteID("", toSymbol(host), "catchall", 0, 0, true)
})
internalRi = append(internalRi, catchAll...)
}
applyEastWestRangePredicates(internalRi, r.options.KubernetesEastWestRangePredicates)
for _, route := range internalRi {
appendAnnotationPredicates(r.options.KubernetesEastWestRangeAnnotationPredicates, rg.Metadata.Annotations, route)
appendAnnotationFilters(r.options.KubernetesEastWestRangeAnnotationFiltersAppend, rg.Metadata.Annotations, route)
}
if internalCtx.certificateRegistry != nil {
for _, ctxTls := range rg.Spec.TLS {
r.addRouteGroupTLS(internalCtx, ctxTls)
}
}
rs = append(rs, internalRi...)
}
}
return rs, nil
}
package kubernetes
import (
"fmt"
"strconv"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
)
type servicePort struct {
Name string `json:"name"`
Port int `json:"port"`
TargetPort *definitions.BackendPort `json:"targetPort"` // string or int
}
func (sp *servicePort) matchingPort(svcPort definitions.BackendPort) bool {
s := svcPort.String()
spt := strconv.Itoa(sp.Port)
return s != "" && (spt == s || sp.Name == s)
}
func (sp *servicePort) matchingPortV1(svcPort definitions.BackendPortV1) bool {
s := svcPort.String()
spt := strconv.Itoa(sp.Port)
return s != "" && (spt == s || sp.Name == s)
}
func (sp *servicePort) String() string {
return fmt.Sprintf("%s %d %s", sp.Name, sp.Port, sp.TargetPort)
}
type serviceSpec struct {
Type string `json:"type"`
ClusterIP string `json:"clusterIP"`
ExternalName string `json:"externalName"`
Ports []*servicePort `json:"ports"`
}
type service struct {
Meta *definitions.Metadata `json:"Metadata"`
Spec *serviceSpec `json:"spec"`
}
type serviceList struct {
Items []*service `json:"items"`
}
func (s *service) getServicePortV1(port definitions.BackendPortV1) (*servicePort, error) {
for _, sp := range s.Spec.Ports {
if sp.matchingPortV1(port) && sp.TargetPort != nil {
return sp, nil
}
}
return nil, fmt.Errorf("getServicePortV1: service port not found %v given %v", s.Spec.Ports, port)
}
func (s *service) getTargetPortByValue(p int) (*definitions.BackendPort, bool) {
for _, pi := range s.Spec.Ports {
if pi.Port == p {
return pi.TargetPort, true
}
}
return nil, false
}
package kubernetes
import (
"fmt"
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/predicates"
)
// BackendTrafficAlgorithm specifies the algorithm for backend traffic calculation
type BackendTrafficAlgorithm int
const (
// TrafficPredicateAlgorithm is the default algorithm for backend traffic calculation.
// It uses Traffic and True predicates to distribute traffic between backends.
TrafficPredicateAlgorithm BackendTrafficAlgorithm = iota
// TrafficSegmentPredicateAlgorithm uses TrafficSegment predicate to distribute traffic between backends
TrafficSegmentPredicateAlgorithm
)
func (a BackendTrafficAlgorithm) String() string {
switch a {
case TrafficPredicateAlgorithm:
return "traffic-predicate"
case TrafficSegmentPredicateAlgorithm:
return "traffic-segment-predicate"
default:
return "unknown" // should never happen
}
}
// ParseBackendTrafficAlgorithm parses a string into a BackendTrafficAlgorithm
func ParseBackendTrafficAlgorithm(name string) (BackendTrafficAlgorithm, error) {
switch name {
case "traffic-predicate":
return TrafficPredicateAlgorithm, nil
case "traffic-segment-predicate":
return TrafficSegmentPredicateAlgorithm, nil
default:
return -1, fmt.Errorf("invalid backend traffic algorithm: %s", name)
}
}
// backendTraffic specifies whether a given backend is allowed to receive any traffic and
// modifies route to receive the desired traffic portion
type backendTraffic interface {
allowed() bool
apply(*eskip.Route)
}
// getBackendTrafficCalculator returns a function that calculates backendTraffic for each backend using specified algorithm
func getBackendTrafficCalculator[T definitions.WeightedBackend](algorithm BackendTrafficAlgorithm) func(b []T) map[string]backendTraffic {
switch algorithm {
case TrafficSegmentPredicateAlgorithm:
return trafficSegmentPredicateCalculator[T]
case TrafficPredicateAlgorithm:
return trafficPredicateCalculator[T]
}
return nil // should never happen
}
// trafficPredicate implements backendTraffic using Traffic() and True() predicates
type trafficPredicate struct {
value float64
balance int
}
var _ backendTraffic = &trafficPredicate{}
// trafficPredicateCalculator calculates argument for the Traffic() predicate and
// the number of True() predicates to be added to the routes based on the weights of the backends.
//
// The Traffic() argument is calculated based on the following rules:
//
// - if no weight is defined for a backend it will get weight 0.
// - if no weights are specified for all backends of a path, then traffic will
// be distributed equally.
//
// Each Traffic() argument is relative to the number of remaining backends,
// e.g. if the weight is specified as:
//
// backend-1: 0.1
// backend-2: 0.2
// backend-3: 0.3
// backend-4: 0.4
//
// then Traffic() predicate arguments will be:
//
// backend-1: Traffic(0.1) == 0.1 / (0.1 + 0.2 + 0.3 + 0.4)
// backend-2: Traffic(0.222) == 0.2 / (0.2 + 0.3 + 0.4)
// backend-3: Traffic(0.428) == 0.3 / (0.3 + 0.4)
// backend-4: Traffic(1.0) == 0.4 / (0.4)
//
// The weight of the backend routes will be adjusted by a number of True() predicates
// equal to the number of remaining backends minus one, e.g. for the above example:
//
// backend-1: Traffic(0.1) && True() && True() -> ...
// backend-2: Traffic(0.222) && True() -> ...
// backend-3: Traffic(0.428) -> ...
// backend-4: Traffic(1.0) -> ...
//
// Traffic(1.0) is handled in a special way, see trafficPredicate.apply().
func trafficPredicateCalculator[T definitions.WeightedBackend](b []T) map[string]backendTraffic {
sum := 0.0
weights := make([]float64, len(b))
for i, bi := range b {
w := bi.GetWeight()
weights[i] = w
sum += w
}
if sum == 0 {
sum = float64(len(weights))
for i := range weights {
weights[i] = 1
}
}
var lastWithWeight int
for i, w := range weights {
if w > 0 {
lastWithWeight = i
}
}
bt := make(map[string]backendTraffic)
for i, bi := range b {
ct := &trafficPredicate{}
bt[bi.GetName()] = ct
switch {
case i == lastWithWeight:
ct.value = 1
case weights[i] == 0:
ct.value = 0
default:
ct.value = weights[i] / sum
}
sum -= weights[i]
ct.balance = len(b) - i - 2
}
return bt
}
func (tp *trafficPredicate) allowed() bool {
return tp.value > 0
}
// apply adds Traffic() and True() predicates to the route.
// For the value of 1.0 no predicates will be added.
func (tp *trafficPredicate) apply(r *eskip.Route) {
if tp.value == 1.0 {
return
}
r.Predicates = appendPredicate(r.Predicates, predicates.TrafficName, tp.value)
for i := 0; i < tp.balance; i++ {
r.Predicates = appendPredicate(r.Predicates, predicates.TrueName)
}
}
package kubernetes
import (
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/predicates"
)
type trafficSegmentPredicate struct {
min, max float64
}
func trafficSegmentPredicateCalculator[T definitions.WeightedBackend](b []T) map[string]backendTraffic {
bt := make(map[string]backendTraffic, len(b))
sum := 0.0
for _, bi := range b {
if _, ok := bt[bi.GetName()].(*trafficSegmentPredicate); ok {
// ignore duplicate backends
continue
}
p := &trafficSegmentPredicate{}
bt[bi.GetName()] = p
p.min = sum
sum += bi.GetWeight()
p.max = sum
}
if sum == 0 {
// evenly split traffic between backends
// range over b instead of bt for stable order
for _, bi := range b {
p := bt[bi.GetName()].(*trafficSegmentPredicate)
p.min = sum
sum += 1
p.max = sum
}
}
// normalize segments
for _, v := range bt {
p := v.(*trafficSegmentPredicate)
p.min /= sum
// last segment always ends up with p.max equal to one because
// dividing a finite non-zero value by itself always produces one,
// see https://stackoverflow.com/questions/63439390/does-ieee-754-float-division-or-subtraction-by-itself-always-result-in-the-same
p.max /= sum
}
return bt
}
func (ts *trafficSegmentPredicate) allowed() bool {
return ts.min != ts.max
}
func (ts *trafficSegmentPredicate) apply(r *eskip.Route) {
if ts.min == 0 && ts.max == 1 {
return
}
r.Predicates = appendPredicate(r.Predicates, predicates.TrafficSegmentName, ts.min, ts.max)
}
// Package routestring provides a DataClient implementation for
// setting route configuration in form of simple eskip string.
//
// Usage from the command line:
//
// skipper -inline-routes '* -> inlineContent("Hello, world!") -> <shunt>'
package routestring
import (
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/routing"
)
type routes struct {
parsed []*eskip.Route
}
// New creates a data client that parses a string of eskip routes and
// serves it for the routing package.
func New(r string) (routing.DataClient, error) {
parsed, err := eskip.Parse(r)
if err != nil {
return nil, err
}
return &routes{parsed: parsed}, nil
}
func (r *routes) LoadAll() ([]*eskip.Route, error) {
return r.parsed, nil
}
func (*routes) LoadUpdate() ([]*eskip.Route, []string, error) {
return nil, nil, nil
}
package eskip
func copyArgs(a []interface{}) []interface{} {
// we don't need deep copy of the items for the supported values
c := make([]interface{}, len(a))
copy(c, a)
return c
}
// CopyPredicate creates a copy of the input predicate.
func CopyPredicate(p *Predicate) *Predicate {
if p == nil {
return nil
}
c := &Predicate{}
c.Name = p.Name
c.Args = copyArgs(p.Args)
return c
}
// CopyPredicates creates a new slice with the copy of each predicate in the input slice.
func CopyPredicates(p []*Predicate) []*Predicate {
c := make([]*Predicate, len(p))
for i, pi := range p {
c[i] = CopyPredicate(pi)
}
return c
}
// CopyFilter creates a copy of the input filter.
func CopyFilter(f *Filter) *Filter {
if f == nil {
return nil
}
c := &Filter{}
c.Name = f.Name
c.Args = copyArgs(f.Args)
return c
}
// CopyFilters creates a new slice with the copy of each filter in the input slice.
func CopyFilters(f []*Filter) []*Filter {
c := make([]*Filter, len(f))
for i, fi := range f {
c[i] = CopyFilter(fi)
}
return c
}
// Copy creates a canonical copy of the input route. See also Canonical().
func Copy(r *Route) *Route {
if r == nil {
return nil
}
r = Canonical(r)
c := &Route{}
c.Id = r.Id
c.Predicates = CopyPredicates(r.Predicates)
c.Filters = CopyFilters(r.Filters)
c.BackendType = r.BackendType
c.Backend = r.Backend
c.LBAlgorithm = r.LBAlgorithm
c.LBEndpoints = make([]string, len(r.LBEndpoints))
copy(c.LBEndpoints, r.LBEndpoints)
return c
}
// CopyRoutes creates a new slice with the canonical copy of each route in the input slice.
func CopyRoutes(r []*Route) []*Route {
c := make([]*Route, len(r))
for i, ri := range r {
c[i] = Copy(ri)
}
return c
}
package eskip
import "sort"
// used for sorting:
func compareRouteID(r []*Route) func(int, int) bool {
return func(i, j int) bool {
return r[i].Id < r[j].Id
}
}
// used for sorting:
func comparePredicateName(p []*Predicate) func(int, int) bool {
return func(i, j int) bool {
return p[i].Name < p[j].Name
}
}
func hasDuplicateID(r []*Route) bool {
for i := 1; i < len(r); i++ {
if r[i-1].Id == r[i].Id {
return true
}
}
return false
}
func eqArgs(left, right []interface{}) bool {
if len(left) != len(right) {
return false
}
for i := range left {
if left[i] != right[i] {
return false
}
}
return true
}
func eqStrings(left, right []string) bool {
if len(left) != len(right) {
return false
}
for i := range left {
if left[i] != right[i] {
return false
}
}
return true
}
func eq2(left, right *Route) bool {
lc, rc := Canonical(left), Canonical(right)
if left == nil && right == nil {
return true
}
if left == nil || right == nil {
return false
}
if lc.Id != rc.Id {
return false
}
if len(lc.Predicates) != len(rc.Predicates) {
return false
}
for i := range lc.Predicates {
lp, rp := lc.Predicates[i], rc.Predicates[i]
if lp.Name != rp.Name || !eqArgs(lp.Args, rp.Args) {
return false
}
}
if len(lc.Filters) != len(rc.Filters) {
return false
}
for i := range lc.Filters {
lf, rf := lc.Filters[i], rc.Filters[i]
if lf.Name != rf.Name || !eqArgs(lf.Args, rf.Args) {
return false
}
}
if lc.BackendType != rc.BackendType {
return false
}
if lc.Backend != rc.Backend {
return false
}
if lc.LBAlgorithm != rc.LBAlgorithm {
return false
}
if !eqStrings(lc.LBEndpoints, rc.LBEndpoints) {
return false
}
return true
}
func eq2Lists(left, right []*Route) bool {
if len(left) != len(right) {
return false
}
for i := range left {
if !eq2(left[i], right[i]) {
return false
}
}
return true
}
// Eq implements canonical equivalence comparison of routes based on
// Skipper semantics.
//
// Duplicate IDs are considered invalid for Eq, and it returns false
// in this case.
//
// The Name and Namespace fields are ignored.
//
// If there are multiple methods, only the last one is considered, to
// reproduce the route matching (even if how it works, may not be the
// most expected in regard of the method predicates).
func Eq(r ...*Route) bool {
for i := 1; i < len(r); i++ {
if !eq2(r[i-1], r[i]) {
return false
}
}
return true
}
// EqLists compares lists of routes. It returns true if the routes contained
// by each list are equal by Eq(). Repeated route IDs are considered invalid
// and EqLists always returns false in this case. The order of the routes in
// the lists doesn't matter.
func EqLists(r ...[]*Route) bool {
rc := make([][]*Route, len(r))
for i := range rc {
rc[i] = make([]*Route, len(r[i]))
copy(rc[i], r[i])
sort.Slice(rc[i], compareRouteID(rc[i]))
if hasDuplicateID(rc[i]) {
return false
}
}
for i := 1; i < len(rc); i++ {
if !eq2Lists(rc[i-1], rc[i]) {
return false
}
}
return true
}
// Canonical returns the canonical representation of a route, that uses the
// standard, non-legacy representation of the predicates and the backends.
// Canonical creates a copy of the route, but doesn't necessarily creates a
// copy of every field. See also Copy().
func Canonical(r *Route) *Route {
if r == nil {
return nil
}
c := &Route{}
c.Id = r.Id
c.Predicates = make([]*Predicate, len(r.Predicates))
copy(c.Predicates, r.Predicates)
// legacy path:
var hasPath bool
for _, p := range c.Predicates {
if p.Name == "Path" {
hasPath = true
break
}
}
if r.Path != "" && !hasPath {
c.Predicates = append(c.Predicates, &Predicate{Name: "Path", Args: []interface{}{r.Path}})
}
// legacy host:
for _, h := range r.HostRegexps {
c.Predicates = append(c.Predicates, &Predicate{Name: "Host", Args: []interface{}{h}})
}
// legacy path regexp:
for _, p := range r.PathRegexps {
c.Predicates = append(c.Predicates, &Predicate{Name: "PathRegexp", Args: []interface{}{p}})
}
// legacy method:
if r.Method != "" {
// prepend the method, so that the canonical []Predicates will be prioritized in case of
// duplicates, and imitate how the routing evaluates multiple method predicates, even if
// weird
c.Predicates = append(
[]*Predicate{{Name: "Method", Args: []interface{}{r.Method}}},
c.Predicates...,
)
}
// legacy header:
for name, value := range r.Headers {
c.Predicates = append(
c.Predicates,
&Predicate{Name: "Header", Args: []interface{}{name, value}},
)
}
// legacy header regexp:
for name, values := range r.HeaderRegexps {
for _, value := range values {
c.Predicates = append(
c.Predicates,
&Predicate{Name: "HeaderRegexp", Args: []interface{}{name, value}},
)
}
}
if len(c.Predicates) == 0 {
c.Predicates = nil
}
sort.Slice(c.Predicates, comparePredicateName(c.Predicates))
c.Filters = r.Filters
c.BackendType = r.BackendType
switch c.BackendType {
case NetworkBackend:
// default overridden by legacy shunt:
if r.Shunt {
c.BackendType = ShuntBackend
} else {
c.Backend = r.Backend
}
case LBBackend:
// using the LB fields only when apply:
c.LBAlgorithm = r.LBAlgorithm
c.LBEndpoints = make([]string, len(r.LBEndpoints))
copy(c.LBEndpoints, r.LBEndpoints)
sort.Strings(c.LBEndpoints)
}
// Name and Namespace stripped
return c
}
// CanonicalList returns the canonical form of each route in the list,
// keeping the order. The returned slice is a new slice of the input
// slice but the routes in the slice and their fields are not necessarily
// all copied. See more at CopyRoutes() and Canonical().
func CanonicalList(l []*Route) []*Route {
if len(l) == 0 {
return nil
}
cl := make([]*Route, len(l))
for i := range l {
cl[i] = Canonical(l[i])
}
return cl
}
package eskip
//go:generate goyacc -l -v "" -o parser.go -p eskip parser.y
import (
"errors"
"fmt"
"math/rand"
"net/url"
"regexp"
"strings"
"sync"
log "github.com/sirupsen/logrus"
)
const duplicateHeaderPredicateErrorFmt = "duplicate header predicate: %s"
var (
errDuplicatePathTreePredicate = errors.New("duplicate path tree predicate")
errDuplicateMethodPredicate = errors.New("duplicate method predicate")
)
// NewEditor creates an Editor PreProcessor, that matches routes and
// replaces the content. For example to replace Source predicates with
// ClientIP predicates you can use
// --edit-route='/Source[(](.*)[)]/ClientIP($1)/', which will change
// routes as you can see:
//
// # input
// r0: Source("127.0.0.1/8", "10.0.0.0/8") -> inlineContent("OK") -> <shunt>;
// # actual route
// edit_r0: ClientIP("127.0.0.1/8", "10.0.0.0/8") -> inlineContent("OK") -> <shunt>;
func NewEditor(reg *regexp.Regexp, repl string) *Editor {
return &Editor{
reg: reg,
repl: repl,
}
}
type Editor struct {
reg *regexp.Regexp
repl string
}
// NewClone creates a Clone PreProcessor, that matches routes and
// replaces the content of the cloned routes. For example to migrate from Source to
// ClientIP predicates you can use
// --clone-route='/Source[(](.*)[)]/ClientIP($1)/', which will change
// routes as you can see:
//
// # input
// r0: Source("127.0.0.1/8", "10.0.0.0/8") -> inlineContent("OK") -> <shunt>;
// # actual route
// clone_r0: ClientIP("127.0.0.1/8", "10.0.0.0/8") -> inlineContent("OK") -> <shunt>;
// r0: Source("127.0.0.1/8", "10.0.0.0/8") -> inlineContent("OK") -> <shunt>;
func NewClone(reg *regexp.Regexp, repl string) *Clone {
return &Clone{
reg: reg,
repl: repl,
}
}
type Clone struct {
reg *regexp.Regexp
repl string
}
func (e *Editor) Do(routes []*Route) []*Route {
if e.reg == nil {
return routes
}
canonicalRoutes := CanonicalList(routes)
for i, r := range canonicalRoutes {
rr := new(Route)
*rr = *r
if doOneRoute(e.reg, e.repl, rr) {
routes[i] = rr
}
}
return routes
}
func (c *Clone) Do(routes []*Route) []*Route {
if c.reg == nil {
return routes
}
canonicalRoutes := CanonicalList(routes)
result := make([]*Route, len(routes), 2*len(routes))
copy(result, routes)
for _, r := range canonicalRoutes {
rr := new(Route)
*rr = *r
rr.Id = "clone_" + rr.Id
predicates := make([]*Predicate, len(r.Predicates))
for k, p := range r.Predicates {
q := *p
predicates[k] = &q
}
rr.Predicates = predicates
filters := make([]*Filter, len(r.Filters))
for k, f := range r.Filters {
ff := *f
filters[k] = &ff
}
rr.Filters = filters
if doOneRoute(c.reg, c.repl, rr) {
result = append(result, rr)
}
}
return result
}
func doOneRoute(rx *regexp.Regexp, repl string, r *Route) bool {
if rx == nil {
return false
}
var changed bool
for i, p := range r.Predicates {
ps := p.String()
pss := rx.ReplaceAllString(ps, repl)
sps := string(pss)
if ps == sps {
continue
}
pp, err := ParsePredicates(sps)
if err != nil {
log.Errorf("Failed to parse predicate: %v", err)
continue
}
r.Predicates[i] = pp[0]
changed = true
}
for i, f := range r.Filters {
fs := f.String()
fss := rx.ReplaceAllString(fs, repl)
sfs := string(fss)
if fs == sfs {
continue
}
ff, err := ParseFilters(sfs)
if err != nil {
log.Errorf("Failed to parse filter: %v", err)
continue
}
r.Filters[i] = ff[0]
changed = true
}
return changed
}
// DefaultFilters implements the routing.PreProcessor interface and
// should be used with the routing package.
type DefaultFilters struct {
Prepend []*Filter
Append []*Filter
}
// Do implements the interface routing.PreProcessor. It appends and
// prepends filters stored to incoming routes and returns the modified
// version of routes.
func (df *DefaultFilters) Do(routes []*Route) []*Route {
pn := len(df.Prepend)
an := len(df.Append)
if pn == 0 && an == 0 {
return routes
}
nextRoutes := make([]*Route, len(routes))
for i, r := range routes {
nextRoutes[i] = new(Route)
*nextRoutes[i] = *r
fn := len(r.Filters)
filters := make([]*Filter, fn+pn+an)
copy(filters[:pn], df.Prepend)
copy(filters[pn:pn+fn], r.Filters)
copy(filters[pn+fn:], df.Append)
nextRoutes[i].Filters = filters
}
return nextRoutes
}
// BackendType indicates whether a route is a network backend, a shunt or a loopback.
type BackendType int
const (
NetworkBackend = iota
ShuntBackend
LoopBackend
DynamicBackend
LBBackend
)
var errMixedProtocols = errors.New("loadbalancer endpoints cannot have mixed protocols")
// Route definition used during the parser processes the raw routing
// document.
type parsedRoute struct {
id string
predicates []*Predicate
filters []*Filter
shunt bool
loopback bool
dynamic bool
lbBackend bool
backend string
lbAlgorithm string
lbEndpoints []string
}
// A Predicate object represents a parsed, in-memory, route matching predicate
// that is defined by extensions.
type Predicate struct {
// The name of the custom predicate as referenced
// in the route definition. E.g. 'Foo'.
Name string `json:"name"`
// The arguments of the predicate as defined in the
// route definition. The arguments can be of type
// float64 or string (string for both strings and
// regular expressions).
Args []interface{} `json:"args"`
}
func (p *Predicate) String() string {
return fmt.Sprintf("%s(%s)", p.Name, argsString(p.Args))
}
// A Filter object represents a parsed, in-memory filter expression.
type Filter struct {
// name of the filter specification
Name string `json:"name"`
// filter args applied within a particular route
Args []interface{} `json:"args"`
}
func (f *Filter) String() string {
return fmt.Sprintf("%s(%s)", f.Name, argsString(f.Args))
}
// A Route object represents a parsed, in-memory route definition.
type Route struct {
// Id of the route definition.
// E.g. route1: ...
Id string
// Deprecated, use Predicate instances with the name "Path".
//
// Exact path to be matched.
// E.g. Path("/some/path")
Path string
// Host regular expressions to match.
// E.g. Host(/[.]example[.]org/)
HostRegexps []string
// Path regular expressions to match.
// E.g. PathRegexp(/\/api\//)
PathRegexps []string
// Method to match.
// E.g. Method("HEAD")
Method string
// Exact header definitions to match.
// E.g. Header("Accept", "application/json")
Headers map[string]string
// Header regular expressions to match.
// E.g. HeaderRegexp("Accept", /\Wapplication\/json\W/)
HeaderRegexps map[string][]string
// Custom predicates to match.
// E.g. Traffic(.3)
Predicates []*Predicate
// Set of filters in a particular route.
// E.g. redirect(302, "https://www.example.org/hello")
Filters []*Filter
// Indicates that the parsed route has a shunt backend.
// (<shunt>, no forwarding to a backend)
//
// Deprecated, use the BackendType field instead.
Shunt bool
// Indicates that the parsed route is a shunt, loopback or
// it is forwarding to a network backend.
BackendType BackendType
// The address of a backend for a parsed route.
// E.g. "https://www.example.org"
Backend string
// LBAlgorithm stores the name of the load balancing algorithm
// in case of load balancing backends.
LBAlgorithm string
// LBEndpoints stores one or more backend endpoint in case of
// load balancing backends.
LBEndpoints []string
// Name is deprecated and not used.
Name string
// Namespace is deprecated and not used.
Namespace string
}
type RoutePredicate func(*Route) bool
// RouteInfo contains a route id, plus the loaded and parsed route or
// the parse error in case of failure.
type RouteInfo struct {
// The route id plus the route data or if parsing was successful.
Route
// The parsing error if the parsing failed.
ParseError error
}
// Copy copies a filter to a new filter instance. The argument values are copied in a shallow way.
func (f *Filter) Copy() *Filter {
c := *f
c.Args = make([]interface{}, len(f.Args))
copy(c.Args, f.Args)
return &c
}
// Copy copies a predicate to a new filter instance. The argument values are copied in a shallow way.
func (p *Predicate) Copy() *Predicate {
c := *p
c.Args = make([]interface{}, len(p.Args))
copy(c.Args, p.Args)
return &c
}
// Copy copies a route to a new route instance with all the slice and map fields copied deep.
func (r *Route) Copy() *Route {
c := *r
if len(r.HostRegexps) > 0 {
c.HostRegexps = make([]string, len(r.HostRegexps))
copy(c.HostRegexps, r.HostRegexps)
}
if len(r.PathRegexps) > 0 {
c.PathRegexps = make([]string, len(r.PathRegexps))
copy(c.PathRegexps, r.PathRegexps)
}
if len(r.Headers) > 0 {
c.Headers = make(map[string]string)
for k, v := range r.Headers {
c.Headers[k] = v
}
}
if len(r.HeaderRegexps) > 0 {
c.HeaderRegexps = make(map[string][]string)
for k, vs := range r.HeaderRegexps {
c.HeaderRegexps[k] = make([]string, len(vs))
copy(c.HeaderRegexps[k], vs)
}
}
if len(r.Predicates) > 0 {
c.Predicates = make([]*Predicate, len(r.Predicates))
for i, p := range r.Predicates {
c.Predicates[i] = p.Copy()
}
}
if len(r.Filters) > 0 {
c.Filters = make([]*Filter, len(r.Filters))
for i, p := range r.Filters {
c.Filters[i] = p.Copy()
}
}
if len(r.LBEndpoints) > 0 {
c.LBEndpoints = make([]string, len(r.LBEndpoints))
copy(c.LBEndpoints, r.LBEndpoints)
}
return &c
}
// BackendTypeFromString parses the string representation of a backend type definition.
func BackendTypeFromString(s string) (BackendType, error) {
switch s {
case "", "network":
return NetworkBackend, nil
case "shunt":
return ShuntBackend, nil
case "loopback":
return LoopBackend, nil
case "dynamic":
return DynamicBackend, nil
case "lb":
return LBBackend, nil
default:
return -1, fmt.Errorf("unsupported backend type: %s", s)
}
}
// String returns the string representation of a backend type definition.
func (t BackendType) String() string {
switch t {
case NetworkBackend:
return "network"
case ShuntBackend:
return "shunt"
case LoopBackend:
return "loopback"
case DynamicBackend:
return "dynamic"
case LBBackend:
return "lb"
default:
return "unknown"
}
}
// Expects exactly n arguments of type string, or fails.
func getStringArgs(p *Predicate, n int) ([]string, error) {
failure := func() ([]string, error) {
if n == 1 {
return nil, fmt.Errorf("%s predicate expects 1 string argument", p.Name)
} else {
return nil, fmt.Errorf("%s predicate expects %d string arguments", p.Name, n)
}
}
if len(p.Args) != n {
return failure()
}
sargs := make([]string, n)
for i, a := range p.Args {
if sa, ok := a.(string); ok {
sargs[i] = sa
} else {
return failure()
}
}
return sargs, nil
}
// Checks and sets the different predicates taken from the yacc result.
// As the syntax is getting stabilized, this logic soon should be defined as
// yacc rules. (https://github.com/zalando/skipper/issues/89)
func applyPredicates(route *Route, proute *parsedRoute) error {
var (
err error
args []string
pathSet bool
methodSet bool
)
for _, p := range proute.predicates {
switch p.Name {
case "Path":
if pathSet {
return errDuplicatePathTreePredicate
}
if args, err = getStringArgs(p, 1); err == nil {
route.Path = args[0]
pathSet = true
}
case "Host":
if args, err = getStringArgs(p, 1); err == nil {
route.HostRegexps = append(route.HostRegexps, args[0])
}
case "PathRegexp":
if args, err = getStringArgs(p, 1); err == nil {
route.PathRegexps = append(route.PathRegexps, args[0])
}
case "Method":
if methodSet {
return errDuplicateMethodPredicate
}
if args, err = getStringArgs(p, 1); err == nil {
route.Method = args[0]
methodSet = true
}
case "HeaderRegexp":
if args, err = getStringArgs(p, 2); err == nil {
if route.HeaderRegexps == nil {
route.HeaderRegexps = make(map[string][]string)
}
route.HeaderRegexps[args[0]] = append(route.HeaderRegexps[args[0]], args[1])
}
case "Header":
if args, err = getStringArgs(p, 2); err == nil {
if route.Headers == nil {
route.Headers = make(map[string]string)
}
if _, ok := route.Headers[args[0]]; ok {
return fmt.Errorf(duplicateHeaderPredicateErrorFmt, args[0])
}
route.Headers[args[0]] = args[1]
}
case "*", "Any":
// void
default:
route.Predicates = append(route.Predicates, p)
}
if err != nil {
return fmt.Errorf("invalid route %q: %w", proute.id, err)
}
}
return nil
}
// Converts a parsing route objects to the exported route definition with
// pre-processed but not validated matchers.
func newRouteDefinition(r *parsedRoute) (*Route, error) {
if len(r.lbEndpoints) > 0 {
scheme := ""
for _, e := range r.lbEndpoints {
eu, err := url.ParseRequestURI(e)
if err != nil {
return nil, err
}
if scheme != "" && scheme != eu.Scheme {
return nil, errMixedProtocols
}
scheme = eu.Scheme
}
}
rd := &Route{}
rd.Id = r.id
rd.Filters = r.filters
rd.Shunt = r.shunt
rd.Backend = r.backend
rd.LBAlgorithm = r.lbAlgorithm
rd.LBEndpoints = r.lbEndpoints
switch {
case r.shunt:
rd.BackendType = ShuntBackend
case r.loopback:
rd.BackendType = LoopBackend
case r.dynamic:
rd.BackendType = DynamicBackend
case r.lbBackend:
rd.BackendType = LBBackend
default:
rd.BackendType = NetworkBackend
}
err := applyPredicates(rd, r)
return rd, err
}
type eskipLexParser struct {
lexer eskipLex
parser eskipParserImpl
}
var parserPool = sync.Pool{
New: func() interface{} {
return new(eskipLexParser)
},
}
func parseDocument(code string) ([]*parsedRoute, error) {
routes, _, _, err := parse(start_document, code)
return routes, err
}
func parsePredicates(code string) ([]*Predicate, error) {
_, predicates, _, err := parse(start_predicates, code)
return predicates, err
}
func parseFilters(code string) ([]*Filter, error) {
_, _, filters, err := parse(start_filters, code)
return filters, err
}
func parse(start int, code string) ([]*parsedRoute, []*Predicate, []*Filter, error) {
lp := parserPool.Get().(*eskipLexParser)
defer func() {
*lp = eskipLexParser{}
parserPool.Put(lp)
}()
lexer := &lp.lexer
lexer.init(start, code)
lp.parser.Parse(lexer)
// Do not return lexer to avoid reading lexer fields after returning eskipLexParser to the pool.
// Let the caller decide which of return values to use based on the start token.
return lexer.routes, lexer.predicates, lexer.filters, lexer.err
}
// Parses a route expression or a routing document to a set of route definitions.
func Parse(code string) ([]*Route, error) {
parsedRoutes, err := parseDocument(code)
if err != nil {
return nil, err
}
routeDefinitions := make([]*Route, len(parsedRoutes))
for i, r := range parsedRoutes {
rd, err := newRouteDefinition(r)
if err != nil {
return nil, err
}
routeDefinitions[i] = rd
}
return routeDefinitions, nil
}
// MustParse parses a route expression or a routing document to a set of route definitions and
// panics in case of error
func MustParse(code string) []*Route {
r, err := Parse(code)
if err != nil {
panic(err)
}
return r
}
// MustParsePredicates parses a set of predicates (combined by '&&') into
// a list of parsed predicate definitions and panics in case of error
func MustParsePredicates(s string) []*Predicate {
p, err := ParsePredicates(s)
if err != nil {
panic(err)
}
return p
}
// MustParseFilters parses a set of filters (combined by '->') into
// a list of parsed filter definitions and panics in case of error
func MustParseFilters(s string) []*Filter {
p, err := ParseFilters(s)
if err != nil {
panic(err)
}
return p
}
// Parses a filter chain into a list of parsed filter definitions.
func ParseFilters(f string) ([]*Filter, error) {
f = strings.TrimSpace(f)
if f == "" {
return nil, nil
}
return parseFilters(f)
}
// ParsePredicates parses a set of predicates (combined by '&&') into
// a list of parsed predicate definitions.
func ParsePredicates(p string) ([]*Predicate, error) {
p = strings.TrimSpace(p)
if p == "" {
return nil, nil
}
rs, err := parsePredicates(p)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
ps := make([]*Predicate, 0, len(rs))
for _, p := range rs {
if p.Name != "*" {
ps = append(ps, p)
}
}
if len(ps) == 0 {
ps = nil
}
return ps, nil
}
const (
randomIdLength = 16
// does not contain underscore to produce compatible output with previously used flow id generator
alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
// generate weak random id for a route if
// it doesn't have one.
//
// Deprecated: do not use, generate valid route id that matches [a-zA-Z_] yourself.
func GenerateIfNeeded(existingId string) string {
if existingId != "" {
return existingId
}
var sb strings.Builder
sb.WriteString("route")
for i := 0; i < randomIdLength; i++ {
ai := rand.Intn(len(alphabet))
sb.WriteByte(alphabet[ai])
}
return sb.String()
}
package eskip
import (
"bytes"
"encoding/json"
)
type jsonNameArgs struct {
Name string `json:"name"`
Args []interface{} `json:"args,omitempty"`
}
type jsonBackend struct {
Type string `json:"type"`
Address string `json:"address,omitempty"`
Algorithm string `json:"algorithm,omitempty"`
Endpoints []string `json:"endpoints,omitempty"`
}
type jsonRoute struct {
ID string `json:"id,omitempty"`
Backend *jsonBackend `json:"backend,omitempty"`
Predicates []*Predicate `json:"predicates,omitempty"`
Filters []*Filter `json:"filters,omitempty"`
}
func newJSONRoute(r *Route) *jsonRoute {
cr := Canonical(r)
jr := &jsonRoute{
ID: cr.Id,
Predicates: cr.Predicates,
Filters: cr.Filters,
}
if cr.BackendType != NetworkBackend || cr.Backend != "" {
jr.Backend = &jsonBackend{
Type: cr.BackendType.String(),
Address: cr.Backend,
Algorithm: cr.LBAlgorithm,
Endpoints: cr.LBEndpoints,
}
}
return jr
}
func marshalJSONNoEscape(v interface{}) ([]byte, error) {
var buf bytes.Buffer
e := json.NewEncoder(&buf)
e.SetEscapeHTML(false)
if err := e.Encode(v); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (f *Filter) MarshalJSON() ([]byte, error) {
return marshalJSONNoEscape(&jsonNameArgs{Name: f.Name, Args: f.Args})
}
func (p *Predicate) MarshalJSON() ([]byte, error) {
return marshalJSONNoEscape(&jsonNameArgs{Name: p.Name, Args: p.Args})
}
func (r *Route) MarshalJSON() ([]byte, error) {
return marshalJSONNoEscape(newJSONRoute(r))
}
func (r *Route) UnmarshalJSON(b []byte) error {
jr := &jsonRoute{}
if err := json.Unmarshal(b, jr); err != nil {
return err
}
r.Id = jr.ID
var bts string
if jr.Backend != nil {
bts = jr.Backend.Type
}
bt, err := BackendTypeFromString(bts)
if err != nil {
return err
}
r.BackendType = bt
switch bt {
case NetworkBackend:
if jr.Backend != nil {
r.Backend = jr.Backend.Address
}
case LBBackend:
r.LBAlgorithm = jr.Backend.Algorithm
r.LBEndpoints = jr.Backend.Endpoints
if len(r.LBEndpoints) == 0 {
r.LBEndpoints = nil
}
}
r.Filters = jr.Filters
if len(r.Filters) == 0 {
r.Filters = nil
}
r.Predicates = jr.Predicates
if len(r.Predicates) == 0 {
r.Predicates = nil
}
return nil
}
package eskip
import (
"errors"
"fmt"
"strings"
)
type token struct {
id int
val string
}
type charPredicate func(byte) bool
type eskipLex struct {
start int
code string
lastToken string
lastRouteID string
err error
initialLength int
routes []*parsedRoute
predicates []*Predicate
filters []*Filter
}
type fixedScanner token
const (
escapeChar = '\\'
decimalChar = '.'
newlineChar = '\n'
underscore = '_'
)
var (
errInvalidCharacter = errors.New("invalid character")
errIncompleteToken = errors.New("incomplete token")
errUnexpectedToken = errors.New("unexpected token")
errVoid = errors.New("void")
errEOF = errors.New("eof")
)
var (
andToken = &fixedScanner{and, "&&"}
anyToken = &fixedScanner{any, "*"}
arrowToken = &fixedScanner{arrow, "->"}
closeparenToken = &fixedScanner{closeparen, ")"}
colonToken = &fixedScanner{colon, ":"}
commaToken = &fixedScanner{comma, ","}
openparenToken = &fixedScanner{openparen, "("}
semicolonToken = &fixedScanner{semicolon, ";"}
openarrowToken = &fixedScanner{openarrow, "<"}
closearrowToken = &fixedScanner{closearrow, ">"}
)
var openarrowPrefixedTokens = []*fixedScanner{
{shunt, "<shunt>"},
{loopback, "<loopback>"},
{dynamic, "<dynamic>"},
}
func (fs *fixedScanner) scan(code string) (t token, rest string, err error) {
return token(*fs), code[len(fs.val):], nil
}
func (l *eskipLex) init(start int, code string) {
l.start = start
l.code = code
l.initialLength = len(code)
}
func isNewline(c byte) bool { return c == newlineChar }
func isUnderscore(c byte) bool { return c == underscore }
func isAlpha(c byte) bool { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') }
func isDigit(c byte) bool { return c >= '0' && c <= '9' }
func isSymbolChar(c byte) bool { return isAlpha(c) || isDigit(c) || isUnderscore(c) }
func isDecimalChar(c byte) bool { return c == decimalChar }
func isNumberChar(c byte) bool { return isDigit(c) || isDecimalChar(c) }
func scanWhile(code string, p charPredicate) (string, string) {
for i := 0; i < len(code); i++ {
if !p(code[i]) {
return code[0:i], code[i:]
}
}
return code, ""
}
func scanVoid(code string, p charPredicate) string {
_, rest := scanWhile(code, p)
return rest
}
func scanEscaped(delimiter byte, code string) (string, string) {
// fast path: check if there is a delimiter without preceding escape character
for i := 0; i < len(code); i++ {
c := code[i]
if c == delimiter {
// make a copy to avoid referencing the possibly large underlying data array
return strings.Clone(code[:i]), code[i:]
} else if c == escapeChar {
break
}
}
var sb strings.Builder
escaped := false
for len(code) > 0 {
c := code[0]
if escaped {
switch c {
case 'a':
c = '\a'
case 'b':
c = '\b'
case 'f':
c = '\f'
case 'n':
c = '\n'
case 'r':
c = '\r'
case 't':
c = '\t'
case 'v':
c = '\v'
case delimiter:
case escapeChar:
default:
sb.WriteByte(escapeChar)
}
sb.WriteByte(c)
escaped = false
} else {
if c == delimiter {
return sb.String(), code
}
if c == escapeChar {
escaped = true
} else {
sb.WriteByte(c)
}
}
code = code[1:]
}
return sb.String(), code
}
func scanRegexp(code string) (string, string) {
var sb strings.Builder
escaped := false
var insideGroup = false
for len(code) > 0 {
c := code[0]
isDelimiter := c == '/'
isEscapeChar := c == escapeChar
//Check if starting [... or ending ...]. Ignore if group character is escaped i.e. \[ or \]
if !escaped && !insideGroup && c == '[' {
insideGroup = true
} else if !escaped && insideGroup && c == ']' {
insideGroup = false
}
if escaped {
//delimiter / is escaped in PathRegexp so that it means no end PathRegexp(/\//)
if !isDelimiter && !isEscapeChar {
sb.WriteByte(escapeChar)
}
sb.WriteByte(c)
escaped = false
} else {
if isDelimiter && !insideGroup {
return sb.String(), code
}
if isEscapeChar {
escaped = true
} else {
sb.WriteByte(c)
}
}
code = code[1:]
}
return sb.String(), code
}
func scanRegexpLiteral(code string) (t token, rest string, err error) {
t.id = regexpliteral
t.val, rest = scanRegexp(code[1:])
if len(rest) == 0 {
err = errIncompleteToken
return
}
rest = rest[1:]
return
}
func scanRegexpOrComment(code string) (t token, rest string, err error) {
if len(code) < 2 {
rest = code
err = errInvalidCharacter
return
}
if code[1] == '/' {
rest = scanComment(code)
err = errVoid
return
}
return scanRegexpLiteral(code)
}
func scanStringLiteral(delimiter byte, code string) (t token, rest string, err error) {
t.id = stringliteral
t.val, rest = scanEscaped(delimiter, code[1:])
if len(rest) == 0 {
err = errIncompleteToken
return
}
rest = rest[1:]
return
}
func scanWhitespace(code string) string {
start := 0
for ; start < len(code); start++ {
c := code[start]
// check frequent values first
if c != ' ' && c != '\n' && c != '\t' && c != '\v' && c != '\f' && c != '\r' && c != 0x85 && c != 0xA0 {
break
}
}
return code[start:]
}
func scanComment(code string) string {
return scanVoid(code, func(c byte) bool { return !isNewline(c) })
}
func scanDoubleQuote(code string) (token, string, error) { return scanStringLiteral('"', code) }
func scanBacktick(code string) (token, string, error) { return scanStringLiteral('`', code) }
func scanNumber(code string) (t token, rest string, err error) {
t.id = number
decimal := false
t.val, rest = scanWhile(code, func(c byte) bool {
if isDecimalChar(c) {
if decimal {
return false
}
decimal = true
return true
}
return isDigit(c)
})
if isDecimalChar(t.val[len(t.val)-1]) {
err = errIncompleteToken
return
}
return
}
func scanSymbol(code string) (t token, rest string, err error) {
t.id = symbol
for i := 0; i < len(code); i++ {
if !isSymbolChar(code[i]) {
// make a copy to avoid referencing the possibly large underlying data array
t.val, rest = strings.Clone(code[:i]), code[i:]
return
}
}
t.val, rest = code, ""
return
}
func scan(code string) (token, string, error) {
switch code[0] {
case ',':
return commaToken.scan(code)
case ')':
return closeparenToken.scan(code)
case '(':
return openparenToken.scan(code)
case ':':
return colonToken.scan(code)
case ';':
return semicolonToken.scan(code)
case '>':
return closearrowToken.scan(code)
case '*':
return anyToken.scan(code)
case '&':
if len(code) >= 2 && code[1] == '&' {
return andToken.scan(code)
}
case '-':
if len(code) >= 2 && code[1] == '>' {
return arrowToken.scan(code)
}
case '/':
return scanRegexpOrComment(code)
case '"':
return scanDoubleQuote(code)
case '`':
return scanBacktick(code)
case '<':
for _, tok := range openarrowPrefixedTokens {
if strings.HasPrefix(code, tok.val) {
return tok.scan(code)
}
}
return openarrowToken.scan(code)
}
if isNumberChar(code[0]) {
return scanNumber(code)
}
if isAlpha(code[0]) || isUnderscore(code[0]) {
return scanSymbol(code)
}
return token{}, "", errUnexpectedToken
}
func (l *eskipLex) next() (token, error) {
l.code = scanWhitespace(l.code)
if len(l.code) == 0 {
return token{}, errEOF
}
t, rest, err := scan(l.code)
if err == errUnexpectedToken {
return token{}, err
}
l.code = rest
if err == errVoid {
return l.next()
}
if err == nil {
l.lastToken = t.val
}
return t, err
}
func (l *eskipLex) Lex(lval *eskipSymType) int {
// first emit the start token
if l.start != 0 {
start := l.start
l.start = 0
return start
}
t, err := l.next()
if err == errEOF {
return -1
}
if err != nil {
l.Error(err.Error())
return -1
}
lval.token = t.val
return t.id
}
func (l *eskipLex) Error(err string) {
lastRouteID := ""
if l.lastRouteID != "" {
lastRouteID = ", last route id: " + l.lastRouteID
}
l.err = fmt.Errorf(
"parse failed after token %s%s, position %d: %s",
l.lastToken, lastRouteID, l.initialLength-len(l.code), err)
}
// Code generated by goyacc -l -v -o parser.go -p eskip parser.y. DO NOT EDIT.
package eskip
import __yyfmt__ "fmt"
import "strconv"
// conversion error ignored, tokenizer expression already checked format
func convertNumber(s string) float64 {
n, _ := strconv.ParseFloat(s, 64)
return n
}
type eskipSymType struct {
yys int
token string
route *parsedRoute
routes []*parsedRoute
predicates []*Predicate
predicate *Predicate
filter *Filter
filters []*Filter
args []interface{}
arg interface{}
backend string
shunt bool
loopback bool
dynamic bool
lbBackend bool
numval float64
stringvals []string
lbAlgorithm string
lbEndpoints []string
}
const and = 57346
const any = 57347
const arrow = 57348
const closeparen = 57349
const colon = 57350
const comma = 57351
const number = 57352
const openparen = 57353
const regexpliteral = 57354
const semicolon = 57355
const shunt = 57356
const loopback = 57357
const dynamic = 57358
const stringliteral = 57359
const symbol = 57360
const openarrow = 57361
const closearrow = 57362
const start_document = 57363
const start_predicates = 57364
const start_filters = 57365
var eskipToknames = [...]string{
"$end",
"error",
"$unk",
"and",
"any",
"arrow",
"closeparen",
"colon",
"comma",
"number",
"openparen",
"regexpliteral",
"semicolon",
"shunt",
"loopback",
"dynamic",
"stringliteral",
"symbol",
"openarrow",
"closearrow",
"start_document",
"start_predicates",
"start_filters",
}
var eskipStatenames = [...]string{}
const eskipEofCode = 1
const eskipErrCode = 2
const eskipInitialStackSize = 16
var eskipExca = [...]int8{
-1, 1,
1, -1,
-2, 0,
}
const eskipPrivate = 57344
const eskipLast = 68
var eskipAct = [...]int8{
48, 39, 17, 29, 38, 2, 3, 4, 32, 33,
34, 31, 18, 36, 11, 55, 16, 43, 8, 42,
50, 49, 13, 18, 41, 50, 13, 28, 44, 59,
23, 45, 19, 24, 26, 15, 37, 30, 27, 12,
24, 7, 53, 51, 52, 52, 56, 57, 23, 44,
54, 21, 22, 20, 58, 46, 25, 21, 60, 9,
35, 47, 40, 14, 10, 6, 5, 1,
}
var eskipPact = [...]int16{
-16, -1000, 21, 17, 5, -1000, 19, -1000, -1000, 47,
17, -1000, 22, -1000, 53, 29, 50, -1000, 23, 9,
-6, 17, -1000, -1000, 7, 5, 7, -1000, 40, -1000,
49, -1000, -1000, -1000, -1000, -1000, 3, -1000, 36, -1000,
-1000, -1000, -1000, -1000, -1000, 35, -6, -5, 37, 38,
-1000, -1000, 7, -1000, -1000, -1000, 12, 8, -1000, -1000,
37,
}
var eskipPgo = [...]int8{
0, 67, 66, 59, 16, 65, 41, 18, 64, 3,
14, 4, 2, 1, 62, 0, 61, 60,
}
var eskipR1 = [...]int8{
0, 1, 1, 1, 1, 1, 2, 2, 5, 5,
5, 5, 7, 8, 6, 6, 3, 3, 10, 10,
4, 4, 12, 11, 11, 11, 13, 13, 13, 15,
15, 16, 16, 17, 9, 9, 9, 9, 9, 14,
}
var eskipR2 = [...]int8{
0, 2, 1, 2, 1, 2, 1, 1, 0, 1,
3, 2, 2, 2, 3, 5, 1, 3, 1, 4,
1, 3, 4, 0, 1, 3, 1, 1, 1, 1,
3, 1, 3, 3, 1, 1, 1, 1, 1, 1,
}
var eskipChk = [...]int16{
-1000, -1, 21, 22, 23, -2, -5, -6, -7, -3,
-8, -10, 18, 5, -3, 18, -4, -12, 18, 13,
6, 4, -6, 8, 11, 6, 11, -7, 18, -9,
-4, 17, 14, 15, 16, -17, 19, -10, -11, -13,
-14, 17, 12, 10, -12, -11, 6, -16, -15, 18,
17, 7, 9, 7, -9, 20, 9, 9, -13, 17,
-15,
}
var eskipDef = [...]int8{
0, -2, 8, 2, 4, 1, 6, 7, 9, 0,
0, 16, 0, 18, 3, 0, 5, 20, 0, 11,
0, 0, 12, 13, 23, 0, 23, 10, 0, 14,
0, 34, 35, 36, 37, 38, 0, 17, 0, 24,
26, 27, 28, 39, 21, 0, 0, 0, 31, 0,
29, 19, 0, 22, 15, 33, 0, 0, 25, 30,
32,
}
var eskipTok1 = [...]int8{
1,
}
var eskipTok2 = [...]int8{
2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23,
}
var eskipTok3 = [...]int8{
0,
}
var eskipErrorMessages = [...]struct {
state int
token int
msg string
}{}
/* parser for yacc output */
var (
eskipDebug = 0
eskipErrorVerbose = false
)
type eskipLexer interface {
Lex(lval *eskipSymType) int
Error(s string)
}
type eskipParser interface {
Parse(eskipLexer) int
Lookahead() int
}
type eskipParserImpl struct {
lval eskipSymType
stack [eskipInitialStackSize]eskipSymType
char int
}
func (p *eskipParserImpl) Lookahead() int {
return p.char
}
func eskipNewParser() eskipParser {
return &eskipParserImpl{}
}
const eskipFlag = -1000
func eskipTokname(c int) string {
if c >= 1 && c-1 < len(eskipToknames) {
if eskipToknames[c-1] != "" {
return eskipToknames[c-1]
}
}
return __yyfmt__.Sprintf("tok-%v", c)
}
func eskipStatname(s int) string {
if s >= 0 && s < len(eskipStatenames) {
if eskipStatenames[s] != "" {
return eskipStatenames[s]
}
}
return __yyfmt__.Sprintf("state-%v", s)
}
func eskipErrorMessage(state, lookAhead int) string {
const TOKSTART = 4
if !eskipErrorVerbose {
return "syntax error"
}
for _, e := range eskipErrorMessages {
if e.state == state && e.token == lookAhead {
return "syntax error: " + e.msg
}
}
res := "syntax error: unexpected " + eskipTokname(lookAhead)
// To match Bison, suggest at most four expected tokens.
expected := make([]int, 0, 4)
// Look for shiftable tokens.
base := int(eskipPact[state])
for tok := TOKSTART; tok-1 < len(eskipToknames); tok++ {
if n := base + tok; n >= 0 && n < eskipLast && int(eskipChk[int(eskipAct[n])]) == tok {
if len(expected) == cap(expected) {
return res
}
expected = append(expected, tok)
}
}
if eskipDef[state] == -2 {
i := 0
for eskipExca[i] != -1 || int(eskipExca[i+1]) != state {
i += 2
}
// Look for tokens that we accept or reduce.
for i += 2; eskipExca[i] >= 0; i += 2 {
tok := int(eskipExca[i])
if tok < TOKSTART || eskipExca[i+1] == 0 {
continue
}
if len(expected) == cap(expected) {
return res
}
expected = append(expected, tok)
}
// If the default action is to accept or reduce, give up.
if eskipExca[i+1] != 0 {
return res
}
}
for i, tok := range expected {
if i == 0 {
res += ", expecting "
} else {
res += " or "
}
res += eskipTokname(tok)
}
return res
}
func eskiplex1(lex eskipLexer, lval *eskipSymType) (char, token int) {
token = 0
char = lex.Lex(lval)
if char <= 0 {
token = int(eskipTok1[0])
goto out
}
if char < len(eskipTok1) {
token = int(eskipTok1[char])
goto out
}
if char >= eskipPrivate {
if char < eskipPrivate+len(eskipTok2) {
token = int(eskipTok2[char-eskipPrivate])
goto out
}
}
for i := 0; i < len(eskipTok3); i += 2 {
token = int(eskipTok3[i+0])
if token == char {
token = int(eskipTok3[i+1])
goto out
}
}
out:
if token == 0 {
token = int(eskipTok2[1]) /* unknown char */
}
if eskipDebug >= 3 {
__yyfmt__.Printf("lex %s(%d)\n", eskipTokname(token), uint(char))
}
return char, token
}
func eskipParse(eskiplex eskipLexer) int {
return eskipNewParser().Parse(eskiplex)
}
func (eskiprcvr *eskipParserImpl) Parse(eskiplex eskipLexer) int {
var eskipn int
var eskipVAL eskipSymType
var eskipDollar []eskipSymType
_ = eskipDollar // silence set and not used
eskipS := eskiprcvr.stack[:]
Nerrs := 0 /* number of errors */
Errflag := 0 /* error recovery flag */
eskipstate := 0
eskiprcvr.char = -1
eskiptoken := -1 // eskiprcvr.char translated into internal numbering
defer func() {
// Make sure we report no lookahead when not parsing.
eskipstate = -1
eskiprcvr.char = -1
eskiptoken = -1
}()
eskipp := -1
goto eskipstack
ret0:
return 0
ret1:
return 1
eskipstack:
/* put a state and value onto the stack */
if eskipDebug >= 4 {
__yyfmt__.Printf("char %v in %v\n", eskipTokname(eskiptoken), eskipStatname(eskipstate))
}
eskipp++
if eskipp >= len(eskipS) {
nyys := make([]eskipSymType, len(eskipS)*2)
copy(nyys, eskipS)
eskipS = nyys
}
eskipS[eskipp] = eskipVAL
eskipS[eskipp].yys = eskipstate
eskipnewstate:
eskipn = int(eskipPact[eskipstate])
if eskipn <= eskipFlag {
goto eskipdefault /* simple state */
}
if eskiprcvr.char < 0 {
eskiprcvr.char, eskiptoken = eskiplex1(eskiplex, &eskiprcvr.lval)
}
eskipn += eskiptoken
if eskipn < 0 || eskipn >= eskipLast {
goto eskipdefault
}
eskipn = int(eskipAct[eskipn])
if int(eskipChk[eskipn]) == eskiptoken { /* valid shift */
eskiprcvr.char = -1
eskiptoken = -1
eskipVAL = eskiprcvr.lval
eskipstate = eskipn
if Errflag > 0 {
Errflag--
}
goto eskipstack
}
eskipdefault:
/* default state action */
eskipn = int(eskipDef[eskipstate])
if eskipn == -2 {
if eskiprcvr.char < 0 {
eskiprcvr.char, eskiptoken = eskiplex1(eskiplex, &eskiprcvr.lval)
}
/* look through exception table */
xi := 0
for {
if eskipExca[xi+0] == -1 && int(eskipExca[xi+1]) == eskipstate {
break
}
xi += 2
}
for xi += 2; ; xi += 2 {
eskipn = int(eskipExca[xi+0])
if eskipn < 0 || eskipn == eskiptoken {
break
}
}
eskipn = int(eskipExca[xi+1])
if eskipn < 0 {
goto ret0
}
}
if eskipn == 0 {
/* error ... attempt to resume parsing */
switch Errflag {
case 0: /* brand new error */
eskiplex.Error(eskipErrorMessage(eskipstate, eskiptoken))
Nerrs++
if eskipDebug >= 1 {
__yyfmt__.Printf("%s", eskipStatname(eskipstate))
__yyfmt__.Printf(" saw %s\n", eskipTokname(eskiptoken))
}
fallthrough
case 1, 2: /* incompletely recovered error ... try again */
Errflag = 3
/* find a state where "error" is a legal shift action */
for eskipp >= 0 {
eskipn = int(eskipPact[eskipS[eskipp].yys]) + eskipErrCode
if eskipn >= 0 && eskipn < eskipLast {
eskipstate = int(eskipAct[eskipn]) /* simulate a shift of "error" */
if int(eskipChk[eskipstate]) == eskipErrCode {
goto eskipstack
}
}
/* the current p has no shift on "error", pop stack */
if eskipDebug >= 2 {
__yyfmt__.Printf("error recovery pops state %d\n", eskipS[eskipp].yys)
}
eskipp--
}
/* there is no state on the stack with an error shift ... abort */
goto ret1
case 3: /* no shift yet; clobber input char */
if eskipDebug >= 2 {
__yyfmt__.Printf("error recovery discards %s\n", eskipTokname(eskiptoken))
}
if eskiptoken == eskipEofCode {
goto ret1
}
eskiprcvr.char = -1
eskiptoken = -1
goto eskipnewstate /* try again in the same state */
}
}
/* reduction by production eskipn */
if eskipDebug >= 2 {
__yyfmt__.Printf("reduce %v in:\n\t%v\n", eskipn, eskipStatname(eskipstate))
}
eskipnt := eskipn
eskippt := eskipp
_ = eskippt // guard against "declared and not used"
eskipp -= int(eskipR2[eskipn])
// eskipp is now the index of $0. Perform the default action. Iff the
// reduced production is ε, $1 is possibly out of range.
if eskipp+1 >= len(eskipS) {
nyys := make([]eskipSymType, len(eskipS)*2)
copy(nyys, eskipS)
eskipS = nyys
}
eskipVAL = eskipS[eskipp+1]
/* consult goto table to find next state */
eskipn = int(eskipR1[eskipn])
eskipg := int(eskipPgo[eskipn])
eskipj := eskipg + eskipS[eskipp].yys + 1
if eskipj >= eskipLast {
eskipstate = int(eskipAct[eskipg])
} else {
eskipstate = int(eskipAct[eskipj])
if int(eskipChk[eskipstate]) != -eskipn {
eskipstate = int(eskipAct[eskipg])
}
}
// dummy call; replaced with literal code
switch eskipnt {
case 1:
eskipDollar = eskipS[eskippt-2 : eskippt+1]
{
eskiplex.(*eskipLex).routes = eskipDollar[2].routes
}
case 2:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
// allow empty or comments only
eskiplex.(*eskipLex).predicates = nil
}
case 3:
eskipDollar = eskipS[eskippt-2 : eskippt+1]
{
eskiplex.(*eskipLex).predicates = eskipDollar[2].predicates
}
case 4:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
// allow empty or comments only
eskiplex.(*eskipLex).filters = nil
}
case 5:
eskipDollar = eskipS[eskippt-2 : eskippt+1]
{
eskiplex.(*eskipLex).filters = eskipDollar[2].filters
}
case 6:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.routes = eskipDollar[1].routes
}
case 7:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.routes = []*parsedRoute{eskipDollar[1].route}
}
case 9:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.routes = []*parsedRoute{eskipDollar[1].route}
}
case 10:
eskipDollar = eskipS[eskippt-3 : eskippt+1]
{
eskipVAL.routes = eskipDollar[1].routes
eskipVAL.routes = append(eskipVAL.routes, eskipDollar[3].route)
}
case 11:
eskipDollar = eskipS[eskippt-2 : eskippt+1]
{
eskipVAL.routes = eskipDollar[1].routes
}
case 12:
eskipDollar = eskipS[eskippt-2 : eskippt+1]
{
eskipVAL.route = eskipDollar[2].route
eskipVAL.route.id = eskipDollar[1].token
}
case 13:
eskipDollar = eskipS[eskippt-2 : eskippt+1]
{
// match symbol and colon to get route id early even if route parsing fails later
eskipVAL.token = eskipDollar[1].token
eskiplex.(*eskipLex).lastRouteID = eskipDollar[1].token
}
case 14:
eskipDollar = eskipS[eskippt-3 : eskippt+1]
{
eskipVAL.route = &parsedRoute{
predicates: eskipDollar[1].predicates,
backend: eskipDollar[3].backend,
shunt: eskipDollar[3].shunt,
loopback: eskipDollar[3].loopback,
dynamic: eskipDollar[3].dynamic,
lbBackend: eskipDollar[3].lbBackend,
lbAlgorithm: eskipDollar[3].lbAlgorithm,
lbEndpoints: eskipDollar[3].lbEndpoints,
}
eskipDollar[1].predicates = nil
eskipDollar[3].lbEndpoints = nil
}
case 15:
eskipDollar = eskipS[eskippt-5 : eskippt+1]
{
eskipVAL.route = &parsedRoute{
predicates: eskipDollar[1].predicates,
filters: eskipDollar[3].filters,
backend: eskipDollar[5].backend,
shunt: eskipDollar[5].shunt,
loopback: eskipDollar[5].loopback,
dynamic: eskipDollar[5].dynamic,
lbBackend: eskipDollar[5].lbBackend,
lbAlgorithm: eskipDollar[5].lbAlgorithm,
lbEndpoints: eskipDollar[5].lbEndpoints,
}
eskipDollar[1].predicates = nil
eskipDollar[3].filters = nil
eskipDollar[5].lbEndpoints = nil
}
case 16:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.predicates = []*Predicate{eskipDollar[1].predicate}
}
case 17:
eskipDollar = eskipS[eskippt-3 : eskippt+1]
{
eskipVAL.predicates = eskipDollar[1].predicates
eskipVAL.predicates = append(eskipVAL.predicates, eskipDollar[3].predicate)
}
case 18:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.predicate = &Predicate{"*", nil}
}
case 19:
eskipDollar = eskipS[eskippt-4 : eskippt+1]
{
eskipVAL.predicate = &Predicate{eskipDollar[1].token, eskipDollar[3].args}
eskipDollar[3].args = nil
}
case 20:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.filters = []*Filter{eskipDollar[1].filter}
}
case 21:
eskipDollar = eskipS[eskippt-3 : eskippt+1]
{
eskipVAL.filters = eskipDollar[1].filters
eskipVAL.filters = append(eskipVAL.filters, eskipDollar[3].filter)
}
case 22:
eskipDollar = eskipS[eskippt-4 : eskippt+1]
{
eskipVAL.filter = &Filter{
Name: eskipDollar[1].token,
Args: eskipDollar[3].args}
eskipDollar[3].args = nil
}
case 24:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.args = []interface{}{eskipDollar[1].arg}
}
case 25:
eskipDollar = eskipS[eskippt-3 : eskippt+1]
{
eskipVAL.args = eskipDollar[1].args
eskipVAL.args = append(eskipVAL.args, eskipDollar[3].arg)
}
case 26:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.arg = eskipDollar[1].numval
}
case 27:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.arg = eskipDollar[1].token
}
case 28:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.arg = eskipDollar[1].token
}
case 29:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.stringvals = []string{eskipDollar[1].token}
}
case 30:
eskipDollar = eskipS[eskippt-3 : eskippt+1]
{
eskipVAL.stringvals = eskipDollar[1].stringvals
eskipVAL.stringvals = append(eskipVAL.stringvals, eskipDollar[3].token)
}
case 31:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.lbEndpoints = eskipDollar[1].stringvals
}
case 32:
eskipDollar = eskipS[eskippt-3 : eskippt+1]
{
eskipVAL.lbAlgorithm = eskipDollar[1].token
eskipVAL.lbEndpoints = eskipDollar[3].stringvals
}
case 33:
eskipDollar = eskipS[eskippt-3 : eskippt+1]
{
eskipVAL.lbAlgorithm = eskipDollar[2].lbAlgorithm
eskipVAL.lbEndpoints = eskipDollar[2].lbEndpoints
}
case 34:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.backend = eskipDollar[1].token
eskipVAL.shunt = false
eskipVAL.loopback = false
eskipVAL.dynamic = false
eskipVAL.lbBackend = false
}
case 35:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.shunt = true
eskipVAL.loopback = false
eskipVAL.dynamic = false
eskipVAL.lbBackend = false
}
case 36:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.shunt = false
eskipVAL.loopback = true
eskipVAL.dynamic = false
eskipVAL.lbBackend = false
}
case 37:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.shunt = false
eskipVAL.loopback = false
eskipVAL.dynamic = true
eskipVAL.lbBackend = false
}
case 38:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.shunt = false
eskipVAL.loopback = false
eskipVAL.dynamic = false
eskipVAL.lbBackend = true
eskipVAL.lbAlgorithm = eskipDollar[1].lbAlgorithm
eskipVAL.lbEndpoints = eskipDollar[1].lbEndpoints
}
case 39:
eskipDollar = eskipS[eskippt-1 : eskippt+1]
{
eskipVAL.numval = convertNumber(eskipDollar[1].token)
}
}
goto eskipstack /* stack new state and value */
}
package eskip
import (
"bytes"
"fmt"
"io"
"math"
"sort"
"strings"
)
type PrettyPrintInfo struct {
Pretty bool
IndentStr string
}
func escape(s string, chars string) string {
s = strings.ReplaceAll(s, `\`, `\\`) // escape backslash before others
s = strings.ReplaceAll(s, "\a", `\a`)
s = strings.ReplaceAll(s, "\b", `\b`)
s = strings.ReplaceAll(s, "\f", `\f`)
s = strings.ReplaceAll(s, "\n", `\n`)
s = strings.ReplaceAll(s, "\r", `\r`)
s = strings.ReplaceAll(s, "\t", `\t`)
s = strings.ReplaceAll(s, "\v", `\v`)
for i := 0; i < len(chars); i++ {
c := chars[i : i+1]
s = strings.ReplaceAll(s, c, `\`+c)
}
return s
}
func appendFmt(s []string, format string, args ...interface{}) []string {
return append(s, fmt.Sprintf(format, args...))
}
func appendFmtEscape(s []string, format string, escapeChars string, args ...interface{}) []string {
eargs := make([]interface{}, len(args))
for i, arg := range args {
eargs[i] = escape(fmt.Sprintf("%v", arg), escapeChars)
}
return appendFmt(s, format, eargs...)
}
func argsString(args []interface{}) string {
var sargs []string
for _, a := range args {
switch v := a.(type) {
case int:
sargs = appendFmt(sargs, "%d", a)
case float64:
f := "%g"
// imprecise elimination of 0 decimals
// TODO: better fix this issue on parsing side
if math.Floor(v) == v {
f = "%.0f"
}
sargs = appendFmt(sargs, f, a)
case string:
sargs = appendFmtEscape(sargs, `"%s"`, `"`, a)
default:
if m, ok := a.(interface{ MarshalText() ([]byte, error) }); ok {
t, err := m.MarshalText()
if err != nil {
sargs = append(sargs, `"[error]"`)
} else {
sargs = appendFmtEscape(sargs, `"%s"`, `"`, string(t))
}
} else {
sargs = appendFmtEscape(sargs, `"%s"`, `"`, a)
}
}
}
return strings.Join(sargs, ", ")
}
func sortTail(s []string, from int) {
if len(s)-from > 1 {
sort.Strings(s[from:])
}
}
func (r *Route) predicateString() string {
var predicates []string
if r.Path != "" {
predicates = appendFmtEscape(predicates, `Path("%s")`, `"`, r.Path)
}
for _, h := range r.HostRegexps {
predicates = appendFmtEscape(predicates, "Host(/%s/)", "/", h)
}
for _, p := range r.PathRegexps {
predicates = appendFmtEscape(predicates, "PathRegexp(/%s/)", "/", p)
}
if r.Method != "" {
predicates = appendFmtEscape(predicates, `Method("%s")`, `"`, r.Method)
}
from := len(predicates)
for k, v := range r.Headers {
predicates = appendFmtEscape(predicates, `Header("%s", "%s")`, `"`, k, v)
}
sortTail(predicates, from)
from = len(predicates)
for k, rxs := range r.HeaderRegexps {
for _, rx := range rxs {
predicates = appendFmt(predicates, `HeaderRegexp("%s", /%s/)`, escape(k, `"`), escape(rx, "/"))
}
}
sortTail(predicates, from)
for _, p := range r.Predicates {
if p.Name != "Any" {
predicates = appendFmt(predicates, "%s(%s)", p.Name, argsString(p.Args))
}
}
if len(predicates) == 0 {
predicates = append(predicates, "*")
}
return strings.Join(predicates, " && ")
}
func (r *Route) filterString(prettyPrintInfo PrettyPrintInfo) string {
var sfilters []string
for _, f := range r.Filters {
sfilters = appendFmt(sfilters, "%s(%s)", f.Name, argsString(f.Args))
}
if prettyPrintInfo.Pretty {
return strings.Join(sfilters, "\n"+prettyPrintInfo.IndentStr+"-> ")
}
return strings.Join(sfilters, " -> ")
}
func (r *Route) backendString() string {
switch {
case r.Shunt, r.BackendType == ShuntBackend:
return "<shunt>"
case r.BackendType == LoopBackend:
return "<loopback>"
case r.BackendType == DynamicBackend:
return "<dynamic>"
default:
return r.Backend
}
}
func lbBackendString(r *Route) string {
var b strings.Builder
b.WriteByte('<')
if r.LBAlgorithm != "" {
b.WriteString(r.LBAlgorithm)
b.WriteString(", ")
}
for i, ep := range r.LBEndpoints {
if i > 0 {
b.WriteString(", ")
}
b.WriteByte('"')
b.WriteString(ep)
b.WriteByte('"')
}
b.WriteByte('>')
return b.String()
}
func (r *Route) backendStringQuoted() string {
s := r.backendString()
switch {
case r.BackendType == NetworkBackend && !r.Shunt:
return fmt.Sprintf(`"%s"`, s)
case r.BackendType == LBBackend:
return lbBackendString(r)
default:
return s
}
}
// Serializes a route expression. Omits the route id if any.
func (r *Route) String() string {
return r.Print(PrettyPrintInfo{Pretty: false, IndentStr: ""})
}
func (r *Route) Print(prettyPrintInfo PrettyPrintInfo) string {
s := []string{r.predicateString()}
fs := r.filterString(prettyPrintInfo)
if fs != "" {
s = append(s, fs)
}
s = append(s, r.backendStringQuoted())
separator := " -> "
if prettyPrintInfo.Pretty {
separator = "\n" + prettyPrintInfo.IndentStr + "-> "
}
return strings.Join(s, separator)
}
// String is the same as Print but defaulting to pretty=false.
func String(routes ...*Route) string {
return Print(PrettyPrintInfo{Pretty: false, IndentStr: ""}, routes...)
}
// Print serializes a set of routes into a string. If there's only a
// single route, and its ID is not set, it prints only a route expression.
// If it has multiple routes with IDs, it prints full route definitions
// with the IDs and separated by ';'.
func Print(pretty PrettyPrintInfo, routes ...*Route) string {
var buf bytes.Buffer
Fprint(&buf, pretty, routes...)
return buf.String()
}
func isDefinition(route *Route) bool {
return route.Id != ""
}
func fprintExpression(w io.Writer, route *Route, prettyPrintInfo PrettyPrintInfo) {
fmt.Fprint(w, route.Print(prettyPrintInfo))
}
func fprintDefinition(w io.Writer, route *Route, prettyPrintInfo PrettyPrintInfo) {
fmt.Fprintf(w, "%s: %s", route.Id, route.Print(prettyPrintInfo))
}
func fprintDefinitions(w io.Writer, routes []*Route, prettyPrintInfo PrettyPrintInfo) {
for i, r := range routes {
if i > 0 {
fmt.Fprint(w, "\n")
if prettyPrintInfo.Pretty {
fmt.Fprint(w, "\n")
}
}
fprintDefinition(w, r, prettyPrintInfo)
fmt.Fprint(w, ";")
}
}
func Fprint(w io.Writer, prettyPrintInfo PrettyPrintInfo, routes ...*Route) {
if len(routes) == 0 {
return
}
if len(routes) == 1 && !isDefinition(routes[0]) {
fprintExpression(w, routes[0], prettyPrintInfo)
return
}
fprintDefinitions(w, routes, prettyPrintInfo)
}
// Package template provides a simple templating solution reusable in filters.
//
// (Note that the current template syntax is EXPERIMENTAL, and may change in
// the near future.)
package eskip
import (
"net"
"net/http"
"regexp"
"strings"
snet "github.com/zalando/skipper/net"
)
var placeholderRegexp = regexp.MustCompile(`\$\{([^{}]+)\}`)
// TemplateGetter functions return the value for a template parameter name.
type TemplateGetter func(string) string
// Template represents a string template with named placeholders.
type Template struct {
template string
placeholders []string
}
type TemplateContext interface {
PathParam(string) string
Request() *http.Request
Response() *http.Response
}
// New parses a template string and returns a reusable *Template object.
// The template string can contain named placeholders of the format:
//
// Hello, ${who}!
func NewTemplate(template string) *Template {
matches := placeholderRegexp.FindAllStringSubmatch(template, -1)
placeholders := make([]string, len(matches))
for index, placeholder := range matches {
placeholders[index] = placeholder[1]
}
return &Template{template: template, placeholders: placeholders}
}
// Apply evaluates the template using a TemplateGetter function to resolve the
// placeholders.
func (t *Template) Apply(get TemplateGetter) string {
if get == nil {
return t.template
}
result, _ := t.apply(get)
return result
}
// ApplyContext evaluates the template using template context to resolve the
// placeholders. Returns true if all placeholders resolved to non-empty values.
func (t *Template) ApplyContext(ctx TemplateContext) (string, bool) {
return t.apply(func(key string) string {
if h := strings.TrimPrefix(key, "request.header."); h != key {
return ctx.Request().Header.Get(h)
}
if q := strings.TrimPrefix(key, "request.query."); q != key {
return ctx.Request().URL.Query().Get(q)
}
if c := strings.TrimPrefix(key, "request.cookie."); c != key {
if cookie, err := ctx.Request().Cookie(c); err == nil {
return cookie.Value
}
return ""
}
switch key {
case "request.method":
return ctx.Request().Method
case "request.host":
return ctx.Request().Host
case "request.path":
return ctx.Request().URL.Path
case "request.rawQuery":
return ctx.Request().URL.RawQuery
case "request.source":
return snet.RemoteHost(ctx.Request()).String()
case "request.sourceFromLast":
return snet.RemoteHostFromLast(ctx.Request()).String()
case "request.clientIP":
if host, _, err := net.SplitHostPort(ctx.Request().RemoteAddr); err == nil {
return host
}
}
if ctx.Response() != nil {
if h := strings.TrimPrefix(key, "response.header."); h != key {
return ctx.Response().Header.Get(h)
}
}
return ctx.PathParam(key)
})
}
// apply evaluates the template using a TemplateGetter function to resolve the
// placeholders. Returns true if all placeholders resolved to non-empty values.
func (t *Template) apply(get TemplateGetter) (string, bool) {
result := t.template
missing := false
for _, placeholder := range t.placeholders {
value := get(placeholder)
if value == "" {
missing = true
}
result = strings.ReplaceAll(result, "${"+placeholder+"}", value)
}
return result, !missing
}
package eskipfile
import (
"os"
"github.com/zalando/skipper/eskip"
)
// Client contains the route definitions from an eskip file, not implementing file watch. Use the Open function
// to create instances of it.
type Client struct{ routes []*eskip.Route }
// Opens an eskip file and parses it, returning a DataClient implementation. If reading or parsing the file
// fails, returns an error. This implementation doesn't provide file watch.
func Open(path string) (*Client, error) {
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}
routes, err := eskip.Parse(string(content))
if err != nil {
return nil, err
}
return &Client{routes}, nil
}
func (c Client) LoadAndParseAll() (routeInfos []*eskip.RouteInfo, err error) {
for _, route := range c.routes {
routeInfos = append(routeInfos, &eskip.RouteInfo{Route: *route})
}
return
}
// LoadAll returns the parsed route definitions found in the file.
func (c Client) LoadAll() ([]*eskip.Route, error) { return c.routes, nil }
// LoadUpdate: noop. The current implementation doesn't support watching the eskip file for changes.
func (c Client) LoadUpdate() ([]*eskip.Route, []string, error) { return nil, nil, nil }
package eskipfile
import (
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/net"
"github.com/zalando/skipper/routing"
log "github.com/sirupsen/logrus"
)
var errContentNotChanged = errors.New("content in cache did not change, 304 response status code")
type remoteEskipFile struct {
once sync.Once
preloaded bool
remotePath string
localPath string
eskipFileClient *WatchClient
threshold int
verbose bool
http *net.Client
etag string
}
type RemoteWatchOptions struct {
// URL of the route file
RemoteFile string
// Verbose mode for the dataClient
Verbose bool
// Amount of route changes that will trigger logs after route updates
Threshold int
// It does an initial download and parsing of remote routes, and makes RemoteWatch to return an error
FailOnStartup bool
// HTTPTimeout is the generic timeout for any phase of a single HTTP request to RemoteFile.
HTTPTimeout time.Duration
}
// RemoteWatch creates a route configuration client with (remote) file watching. Watch doesn't follow file system nodes,
// it always reads (or re-downloads) from the file identified by the initially provided file name.
func RemoteWatch(o *RemoteWatchOptions) (routing.DataClient, error) {
if !isFileRemote(o.RemoteFile) {
return Watch(o.RemoteFile), nil
}
tempFilename, err := os.CreateTemp("", "routes")
if err != nil {
return nil, err
}
dataClient := &remoteEskipFile{
once: sync.Once{},
remotePath: o.RemoteFile,
localPath: tempFilename.Name(),
threshold: o.Threshold,
verbose: o.Verbose,
http: net.NewClient(net.Options{Timeout: o.HTTPTimeout}),
}
if o.FailOnStartup {
err = dataClient.DownloadRemoteFile()
if err != nil {
dataClient.http.Close()
return nil, err
}
} else {
f, err := os.OpenFile(tempFilename.Name(), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err == nil {
err = f.Close()
}
if err != nil {
dataClient.http.Close()
return nil, err
}
dataClient.preloaded = true
}
dataClient.eskipFileClient = Watch(tempFilename.Name())
return dataClient, nil
}
// LoadAll returns the parsed route definitions found in the file.
func (client *remoteEskipFile) LoadAll() ([]*eskip.Route, error) {
var err error = nil
if client.preloaded {
client.preloaded = false
} else {
err = client.DownloadRemoteFile()
}
if err != nil {
log.Errorf("LoadAll from remote %s failed. Continue using the last loaded routes", client.remotePath)
return nil, err
}
if client.verbose {
log.Infof("New routes file %s was downloaded", client.remotePath)
}
return client.eskipFileClient.LoadAll()
}
// LoadUpdate returns differential updates when a remote file has changed.
func (client *remoteEskipFile) LoadUpdate() ([]*eskip.Route, []string, error) {
err := client.DownloadRemoteFile()
if err != nil {
log.Errorf("LoadUpdate from remote %s failed. Trying to LoadAll", client.remotePath)
return nil, nil, err
}
newRoutes, deletedRoutes, err := client.eskipFileClient.LoadUpdate()
if err != nil {
log.Errorf("RemoteEskipFile LoadUpdate %s failed. Skipper continues to serve the last successfully updated routes. Error: %s",
client.remotePath, err)
return newRoutes, deletedRoutes, err
}
if client.verbose {
log.Infof("New routes were loaded. New: %d; deleted: %d", len(newRoutes), len(deletedRoutes))
if client.threshold > 0 {
if len(newRoutes)+len(deletedRoutes) > client.threshold {
log.Warnf("Significant amount of routes was updated. New: %d; deleted: %d", len(newRoutes), len(deletedRoutes))
}
}
}
return newRoutes, deletedRoutes, err
}
func (client *remoteEskipFile) Close() {
client.once.Do(func() {
client.http.Close()
client.eskipFileClient.Close()
})
}
func isFileRemote(remotePath string) bool {
return strings.HasPrefix(remotePath, "http://") || strings.HasPrefix(remotePath, "https://")
}
func (client *remoteEskipFile) DownloadRemoteFile() error {
resBody, err := client.getRemoteData()
if err != nil {
if errors.Is(err, errContentNotChanged) {
return nil
}
return err
}
defer resBody.Close()
outFile, err := os.OpenFile(client.localPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer outFile.Close()
if _, err = io.Copy(outFile, resBody); err != nil {
_ = outFile.Close()
return err
}
return outFile.Close()
}
func (client *remoteEskipFile) getRemoteData() (io.ReadCloser, error) {
req, err := http.NewRequest("GET", client.remotePath, nil)
if err != nil {
return nil, err
}
if client.etag != "" {
req.Header.Set("If-None-Match", client.etag)
}
resp, err := client.http.Do(req)
if err != nil {
return nil, err
}
if client.etag != "" && resp.StatusCode == 304 {
resp.Body.Close()
return nil, errContentNotChanged
}
if resp.StatusCode != 200 {
resp.Body.Close()
return nil, fmt.Errorf("failed to download remote file %s, status code: %d", client.remotePath, resp.StatusCode)
}
client.etag = resp.Header.Get("ETag")
return resp.Body, err
}
package eskipfile
import (
"bytes"
"os"
"reflect"
"sync"
"github.com/zalando/skipper/eskip"
)
type watchResponse struct {
routes []*eskip.Route
deletedIDs []string
err error
}
// WatchClient implements a route configuration client with file watching. Use the Watch function to initialize
// instances of it.
type WatchClient struct {
fileName string
lastContent []byte
routes map[string]*eskip.Route
getAll chan (chan<- watchResponse)
getUpdates chan (chan<- watchResponse)
quit chan struct{}
once sync.Once
}
// Watch creates a route configuration client with file watching. Watch doesn't follow file system nodes, it
// always reads from the file identified by the initially provided file name.
func Watch(name string) *WatchClient {
c := &WatchClient{
fileName: name,
getAll: make(chan (chan<- watchResponse)),
getUpdates: make(chan (chan<- watchResponse)),
quit: make(chan struct{}),
once: sync.Once{},
}
go c.watch()
return c
}
func mapRoutes(r []*eskip.Route) map[string]*eskip.Route {
m := make(map[string]*eskip.Route)
for i := range r {
m[r[i].Id] = r[i]
}
return m
}
func (c *WatchClient) storeRoutes(r []*eskip.Route) {
c.routes = mapRoutes(r)
}
func (c *WatchClient) diffStoreRoutes(r []*eskip.Route) (upsert []*eskip.Route, deletedIDs []string) {
for i := range r {
if !reflect.DeepEqual(r[i], c.routes[r[i].Id]) {
upsert = append(upsert, r[i])
}
}
m := mapRoutes(r)
for id := range c.routes {
if _, keep := m[id]; !keep {
deletedIDs = append(deletedIDs, id)
}
}
c.routes = m
return
}
func (c *WatchClient) deleteAllListIDs() []string {
var ids []string
for id := range c.routes {
ids = append(ids, id)
}
c.routes = nil
return ids
}
func cloneRoutes(r []*eskip.Route) []*eskip.Route {
if len(r) == 0 {
return nil
}
c := make([]*eskip.Route, len(r))
for i, ri := range r {
c[i] = ri.Copy()
}
return c
}
func (c *WatchClient) loadAll() watchResponse {
content, err := os.ReadFile(c.fileName)
if err != nil {
c.lastContent = nil
return watchResponse{err: err}
}
r, err := eskip.Parse(string(content))
if err != nil {
c.lastContent = nil
return watchResponse{err: err}
}
c.storeRoutes(r)
c.lastContent = content
return watchResponse{routes: cloneRoutes(r)}
}
func (c *WatchClient) loadUpdates() watchResponse {
content, err := os.ReadFile(c.fileName)
if err != nil {
c.lastContent = nil
if os.IsNotExist(err) {
deletedIDs := c.deleteAllListIDs()
return watchResponse{deletedIDs: deletedIDs}
}
return watchResponse{err: err}
}
if bytes.Equal(content, c.lastContent) {
return watchResponse{}
}
r, err := eskip.Parse(string(content))
if err != nil {
c.lastContent = nil
return watchResponse{err: err}
}
upsert, del := c.diffStoreRoutes(r)
c.lastContent = content
return watchResponse{routes: cloneRoutes(upsert), deletedIDs: del}
}
func (c *WatchClient) watch() {
for {
select {
case req := <-c.getAll:
req <- c.loadAll()
case req := <-c.getUpdates:
req <- c.loadUpdates()
case <-c.quit:
return
}
}
}
// LoadAll returns the parsed route definitions found in the file.
func (c *WatchClient) LoadAll() ([]*eskip.Route, error) {
req := make(chan watchResponse)
select {
case c.getAll <- req:
case <-c.quit:
return nil, nil
}
rsp := <-req
return rsp.routes, rsp.err
}
// LoadUpdate returns differential updates when a watched file has changed.
func (c *WatchClient) LoadUpdate() ([]*eskip.Route, []string, error) {
req := make(chan watchResponse)
select {
case c.getUpdates <- req:
case <-c.quit:
return nil, nil, nil
}
rsp := <-req
return rsp.routes, rsp.deletedIDs, rsp.err
}
// Close stops watching the configured file and providing updates.
func (c *WatchClient) Close() {
c.once.Do(func() {
close(c.quit)
})
}
/*
Package etcd implements a DataClient for reading the skipper route
definitions from an etcd service.
(See the DataClient interface in the skipper/routing package.)
etcd is a generic, distributed configuration service:
https://github.com/coreos/etcd. The route definitions are stored under
individual keys as eskip route expressions. When loaded from etcd, the
routes will get the etcd key as id.
In addition to the DataClient implementation, type Client provides
methods to Upsert and Delete routes.
*/
package etcd
import (
"bytes"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"path"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
)
const (
routesPath = "/routes"
etcdIndexHeader = "X-Etcd-Index"
defaultTimeout = time.Second
)
// etcd serialization objects
type (
node struct {
Key string `json:"key"`
Value string `json:"value"`
Dir bool `json:"Dir"`
ModifiedIndex uint64 `json:"modifiedIndex"`
Nodes []*node `json:"nodes"`
}
response struct {
etcdIndex uint64
Action string `json:"action"`
Node *node `json:"node"`
}
)
// common error object for errors coming from multiple
// etcd instances
type endpointErrors struct {
errors []error
}
func (ee *endpointErrors) Error() string {
err := "request to one or more endpoints failed"
for _, e := range ee.errors {
err = err + ";" + e.Error()
}
return err
}
func (ee *endpointErrors) String() string {
return ee.Error()
}
// Initialization options.
type Options struct {
// A slice of etcd endpoint addresses.
// (Schema and host.)
Endpoints []string
// Etcd path to a directory where the
// Skipper related settings are stored.
Prefix string
// A timeout value for etcd long-polling.
// The default timeout is 1 second.
Timeout time.Duration
// Skip TLS certificate check.
Insecure bool
// Optional OAuth-Token
OAuthToken string
// Optional username for basic auth
Username string
// Optional password for basic auth
Password string
}
// A Client is used to load the whole set of routes and the updates from an
// etcd store.
type Client struct {
endpoints []string
routesRoot string
client *http.Client
etcdIndex uint64
oauthToken string
username string
password string
}
var (
errMissingEtcdEndpoint = errors.New("missing etcd endpoint")
errMissingRouteId = errors.New("missing route id")
errInvalidNode = errors.New("invalid node")
errUnexpectedHttpResponse = errors.New("unexpected http response")
errNotFound = errors.New("not found")
errInvalidResponseDocument = errors.New("invalid response document")
)
// Creates a new Client with the provided options.
func New(o Options) (*Client, error) {
if len(o.Endpoints) == 0 {
return nil, errMissingEtcdEndpoint
}
if o.Timeout == 0 {
o.Timeout = defaultTimeout
}
httpClient := &http.Client{Timeout: o.Timeout}
if o.Insecure {
httpClient.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
/* #nosec */
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
}
return &Client{
endpoints: o.Endpoints,
routesRoot: o.Prefix + routesPath,
client: httpClient,
etcdIndex: 0,
oauthToken: o.OAuthToken,
username: o.Username,
password: o.Password}, nil
}
func isTimeout(err error) bool {
nerr, ok := err.(net.Error)
return ok && nerr.Timeout()
}
// Makes a request to an etcd endpoint. If it fails due to connection problems,
// it makes a new request to the next available endpoint, until all endpoints
// are tried. It returns the response to the first successful request.
func (c *Client) tryEndpoints(mreq func(string) (*http.Request, error)) (*http.Response, error) {
var (
req *http.Request
rsp *http.Response
err error
endpointErrs []error
)
for index, endpoint := range c.endpoints {
req, err = mreq(endpoint + "/v2/keys")
if err != nil {
return nil, err
}
rsp, err = c.client.Do(req)
isTimeoutError := false
if err != nil {
isTimeoutError = isTimeout(err)
if !isTimeoutError {
uerr, ok := err.(*url.Error)
if ok && isTimeout(uerr.Err) {
isTimeoutError = true
err = uerr.Err
}
}
}
if err == nil || isTimeoutError {
if index != 0 {
c.endpoints = append(c.endpoints[index:], c.endpoints[:index]...)
}
return rsp, err
}
endpointErrs = append(endpointErrs, err)
}
return nil, &endpointErrors{endpointErrs}
}
// Converts an http response to a parsed etcd response object.
func parseResponse(rsp *http.Response) (*response, error) {
d, err := io.ReadAll(rsp.Body)
if err != nil {
return nil, err
}
r := &response{}
err = json.Unmarshal(d, &r)
if err != nil {
return nil, err
}
if r.Node == nil || r.Node.Key == "" {
return nil, errInvalidResponseDocument
}
r.etcdIndex, err = strconv.ParseUint(rsp.Header.Get(etcdIndexHeader), 10, 64)
return r, err
}
// Converts a non-success http status code into an in-memory error object.
// As the first argument, returns true in case of error.
func httpError(code int) (bool, error) {
if code == http.StatusNotFound {
return true, errNotFound
}
if code < http.StatusOK || code >= http.StatusMultipleChoices {
return true, errUnexpectedHttpResponse
}
return false, nil
}
// Makes a request to an available etcd endpoint, with retries in case of
// failure, and converts the http response to a parsed etcd response object.
func (c *Client) etcdRequest(method, path, data string) (*response, error) {
rsp, err := c.tryEndpoints(func(a string) (*http.Request, error) {
var body io.Reader
if data != "" {
v := make(url.Values)
v.Add("value", data)
body = bytes.NewBufferString(v.Encode())
}
r, err := http.NewRequest(method, a+path, body)
if err != nil {
return nil, err
}
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Give oauth priority over basic auth
if c.oauthToken != "" {
r.Header.Set("Authorization", "Bearer "+c.oauthToken)
} else if c.username != "" && c.password != "" {
credentials := base64.StdEncoding.EncodeToString([]byte(c.username + ":" + c.password))
r.Header.Set("Authorization", "Basic "+credentials)
}
return r, nil
})
if err != nil {
return nil, err
}
defer rsp.Body.Close()
if hasErr, err := httpError(rsp.StatusCode); hasErr {
return nil, err
}
return parseResponse(rsp)
}
func (c *Client) etcdGet() (*response, error) {
return c.etcdRequest("GET", c.routesRoot, "")
}
// Calls etcd 'watch' but with a timeout configured for
// the http client.
func (c *Client) etcdGetUpdates() (*response, error) {
return c.etcdRequest("GET",
fmt.Sprintf("%s?wait=true&waitIndex=%d&recursive=true",
c.routesRoot, c.etcdIndex+1), "")
}
func (c *Client) etcdSet(r *eskip.Route) error {
_, err := c.etcdRequest("PUT", c.routesRoot+"/"+r.Id, r.String())
return err
}
func (c *Client) etcdDelete(id string) error {
_, err := c.etcdRequest("DELETE", c.routesRoot+"/"+id, "")
return err
}
// Finds all route expressions in the containing directory node.
// Returns a map where the keys are the etcd keys and the values are the
// eskip route expressions.
func (c *Client) iterateNodes(dir *node, highestIndex uint64) (map[string]string, uint64) {
routes := make(map[string]string)
for _, n := range dir.Nodes {
if n.Dir {
continue
}
routes[path.Base(n.Key)] = n.Value
if n.ModifiedIndex > highestIndex {
highestIndex = n.ModifiedIndex
}
}
return routes, highestIndex
}
// Parses a single route expression, fails if more than one
// expressions in the data.
func parseOne(data string) (*eskip.Route, error) {
r, err := eskip.Parse(data)
if err != nil {
return nil, err
}
if len(r) != 1 {
return nil, errors.New("invalid route entry: multiple route expressions")
}
return r[0], nil
}
// Parses a set of eskip routes.
func parseRoutes(data map[string]string) []*eskip.RouteInfo {
allInfo := make([]*eskip.RouteInfo, 0, len(data))
for id, d := range data {
info := &eskip.RouteInfo{}
r, err := parseOne(d)
if err == nil {
info.Route = *r
} else {
info.ParseError = err
}
info.Id = id
allInfo = append(allInfo, info)
}
return allInfo
}
// Converts route info to route objects logging those whose
// parsing failed.
func infoToRoutesLogged(info []*eskip.RouteInfo) []*eskip.Route {
var routes []*eskip.Route
for i := range info {
ri := info[i]
if ri.ParseError == nil {
routes = append(routes, &ri.Route)
} else {
log.Println("error while parsing routes", ri.Id, ri.ParseError)
}
}
return routes
}
// Returns all the route definitions currently stored in etcd,
// or the parsing error in case of failure.
func (c *Client) LoadAndParseAll() ([]*eskip.RouteInfo, error) {
response, err := c.etcdGet()
if err == errNotFound {
return nil, nil
}
if err != nil {
return nil, err
}
if !response.Node.Dir {
return nil, errInvalidNode
}
data, etcdIndex := c.iterateNodes(response.Node, 0)
if response.etcdIndex > etcdIndex {
etcdIndex = response.etcdIndex
}
c.etcdIndex = etcdIndex
return parseRoutes(data), nil
}
// Returns all the route definitions currently stored in etcd.
func (c *Client) LoadAll() ([]*eskip.Route, error) {
routeInfo, err := c.LoadAndParseAll()
if err != nil {
return nil, err
}
return infoToRoutesLogged(routeInfo), nil
}
// Returns the updates (upserts and deletes) since the last initial request
// or update.
//
// It uses etcd's watch functionality that results in blocking this call
// until the next change is detected in etcd or reaches the configured hard
// timeout.
func (c *Client) LoadUpdate() ([]*eskip.Route, []string, error) {
updates := make(map[string]string)
deletes := make(map[string]bool)
for {
response, err := c.etcdGetUpdates()
if isTimeout(err) {
break
} else if err != nil {
return nil, nil, err
} else if response.Node.Dir {
if response.Node.ModifiedIndex > c.etcdIndex {
c.etcdIndex = response.Node.ModifiedIndex
}
continue
}
id := path.Base(response.Node.Key)
if response.Action == "delete" || response.Action == "expire" {
deletes[id] = true
delete(updates, id)
} else {
updates[id] = response.Node.Value
deletes[id] = false
}
if response.Node.ModifiedIndex > c.etcdIndex {
c.etcdIndex = response.Node.ModifiedIndex
}
}
routeInfo := parseRoutes(updates)
routes := infoToRoutesLogged(routeInfo)
deletedIds := make([]string, 0, len(deletes))
for id, deleted := range deletes {
if deleted {
deletedIds = append(deletedIds, id)
}
}
return routes, deletedIds, nil
}
// Inserts or updates a route in etcd.
func (c *Client) Upsert(r *eskip.Route) error {
if r.Id == "" {
return errMissingRouteId
}
return c.etcdSet(r)
}
// Deletes a route from etcd.
func (c *Client) Delete(id string) error {
if id == "" {
return errMissingRouteId
}
err := c.etcdDelete(id)
if err == errNotFound {
err = nil
}
return err
}
func (c *Client) UpsertAll(routes []*eskip.Route) error {
for _, r := range routes {
//lint:ignore SA1019 due to backward compatibility
r.Id = eskip.GenerateIfNeeded(r.Id)
err := c.Upsert(r)
if err != nil {
return err
}
}
return nil
}
func (c *Client) DeleteAllIf(routes []*eskip.Route, cond eskip.RoutePredicate) error {
for _, r := range routes {
if !cond(r) {
continue
}
err := c.Delete(r.Id)
if err != nil {
return err
}
}
return nil
}
package accesslog
import "github.com/zalando/skipper/filters"
const (
// Deprecated, use filters.DisableAccessLogName instead
DisableAccessLogName = filters.DisableAccessLogName
// Deprecated, use filters.EnableAccessLogName instead
EnableAccessLogName = filters.EnableAccessLogName
// AccessLogEnabledKey is the key used in the state bag to pass the access log state to the proxy.
AccessLogEnabledKey = "statebag:access_log:proxy:enabled"
// AccessLogAdditionalDataKey is the key used in the state bag to pass extra data to access log
AccessLogAdditionalDataKey = "statebag:access_log:additional"
)
// Common filter struct for holding access log state
type AccessLogFilter struct {
Enable bool
Prefixes []int
}
func (al *AccessLogFilter) Request(ctx filters.FilterContext) {
bag := ctx.StateBag()
bag[AccessLogEnabledKey] = al
}
func (*AccessLogFilter) Response(filters.FilterContext) {}
func extractFilterValues(args []interface{}, enable bool) (filters.Filter, error) {
prefixes := make([]int, 0)
for _, prefix := range args {
var intPref int
switch p := prefix.(type) {
case float32:
intPref = int(p)
case float64:
intPref = int(p)
default:
var ok bool
intPref, ok = prefix.(int)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
}
prefixes = append(prefixes, intPref)
}
return &AccessLogFilter{Enable: enable, Prefixes: prefixes}, nil
}
type disableAccessLog struct{}
// NewDisableAccessLog creates a filter spec to disable access log for specific route.
// Optionally takes in response code prefixes as arguments. When provided, access log is disabled
// only if response code matches one of the arguments.
//
// disableAccessLog() or
// disableAccessLog(1, 20, 301) to disable logs for 1xx, 20x and 301 codes
func NewDisableAccessLog() filters.Spec {
return &disableAccessLog{}
}
func (*disableAccessLog) Name() string { return filters.DisableAccessLogName }
func (al *disableAccessLog) CreateFilter(args []interface{}) (filters.Filter, error) {
return extractFilterValues(args, false)
}
type enableAccessLog struct{}
// NewEnableAccessLog creates a filter spec to enable access log for specific route
// Optionally takes in response code prefixes as arguments. When provided, access log is enabled
// only if response code matches one of the arguments.
//
// enableAccessLog()
// enableAccessLog(1, 20, 301) to enable logs for 1xx, 20x and 301 codes
func NewEnableAccessLog() filters.Spec {
return &enableAccessLog{}
}
func (*enableAccessLog) Name() string { return filters.EnableAccessLogName }
func (al *enableAccessLog) CreateFilter(args []interface{}) (filters.Filter, error) {
return extractFilterValues(args, true)
}
package accesslog
import (
"github.com/zalando/skipper/filters"
)
const (
// Deprecated: use DisableAccessLogName or EnableAccessLogName
AccessLogDisabledName = "accessLogDisabled"
)
type accessLogDisabled struct {
disabled bool
}
// NewAccessLogDisabled creates a filter spec for overriding the state of the AccessLogDisabled setting. (By default global setting is used.)
//
// accessLogDisabled("false")
//
// Deprecated: use disableAccessLog or enableAccessLog
func NewAccessLogDisabled() filters.Spec {
return &accessLogDisabled{}
}
func (*accessLogDisabled) Name() string { return AccessLogDisabledName }
func (*accessLogDisabled) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
if a, ok := args[0].(string); ok && a == "true" || a == "false" {
return &accessLogDisabled{a == "true"}, nil
} else {
return nil, filters.ErrInvalidFilterParameters
}
}
func (al *accessLogDisabled) Request(ctx filters.FilterContext) {
bag := ctx.StateBag()
bag[AccessLogEnabledKey] = &AccessLogFilter{!al.disabled, nil}
}
func (*accessLogDisabled) Response(filters.FilterContext) {}
package annotate
import (
"fmt"
"github.com/zalando/skipper/filters"
)
type (
annotateSpec struct{}
annotateFilter struct {
key, value string
}
)
const annotateStateBagKey = "filter." + filters.AnnotateName
// New creates filters to annotate a filter chain.
// It stores its key and value arguments into the filter context.
// Use [GetAnnotations] to retrieve the annotations from the context.
func New() filters.Spec {
return annotateSpec{}
}
func (annotateSpec) Name() string {
return filters.AnnotateName
}
func (as annotateSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 2 {
return nil, fmt.Errorf("requires string key and value arguments")
}
af, ok := &annotateFilter{}, false
if af.key, ok = args[0].(string); !ok {
return nil, fmt.Errorf("key argument must be a string")
}
if af.value, ok = args[1].(string); !ok {
return nil, fmt.Errorf("value argument must be a string")
}
return af, nil
}
func (af *annotateFilter) Request(ctx filters.FilterContext) {
if v, ok := ctx.StateBag()[annotateStateBagKey]; ok {
v.(map[string]string)[af.key] = af.value
} else {
ctx.StateBag()[annotateStateBagKey] = map[string]string{af.key: af.value}
}
}
func (af *annotateFilter) Response(filters.FilterContext) {}
func GetAnnotations(ctx filters.FilterContext) map[string]string {
if v, ok := ctx.StateBag()[annotateStateBagKey]; ok {
return v.(map[string]string)
}
return nil
}
package apiusagemonitoring
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/jwt"
)
const (
metricCountAll = "http_count"
metricCountUnknownClass = "httpxxx_count"
metricCount100s = "http1xx_count"
metricCount200s = "http2xx_count"
metricCount300s = "http3xx_count"
metricCount400s = "http4xx_count"
metricCount500s = "http5xx_count"
metricLatency = "latency"
metricLatencySum = "latency_sum"
)
const (
stateBagKey = "filter." + filters.ApiUsageMonitoringName
)
const (
authorizationHeaderName = "Authorization"
authorizationHeaderPrefix = "Bearer "
)
// apiUsageMonitoringFilter implements filters.Filter interface and is the structure
// created for every route invocation of the `apiUsageMonitoring` filter.
type apiUsageMonitoringFilter struct {
clientKeys []string
realmKeys []string
Paths []*pathInfo
UnknownPath *pathInfo
}
type apiUsageMonitoringStateBag struct {
url *url.URL
begin time.Time
}
// HandleErrorResponse is to opt-in for filters to get called
// Response(ctx) in case of errors via proxy. It has to return true to
// opt-in.
func (f *apiUsageMonitoringFilter) HandleErrorResponse() bool { return true }
func (f *apiUsageMonitoringFilter) Request(c filters.FilterContext) {
u := *c.Request().URL
c.StateBag()[stateBagKey] = apiUsageMonitoringStateBag{
url: &u,
begin: time.Now(),
}
}
func (f *apiUsageMonitoringFilter) Response(c filters.FilterContext) {
request, response, metrics := c.Request(), c.Response(), c.Metrics()
stateBag, stateBagPresent := c.StateBag()[stateBagKey].(apiUsageMonitoringStateBag)
path := f.UnknownPath
if stateBagPresent && stateBag.url != nil {
path = f.resolveMatchedPath(stateBag.url)
}
if path == f.UnknownPath {
path = f.resolveMatchedPath(request.URL)
}
classMetricsIndex := response.StatusCode / 100
if classMetricsIndex < 1 || classMetricsIndex > 5 {
log.Errorf(
"Response HTTP Status Code %d is invalid. Response status code metric will be %q.",
response.StatusCode, metricCountUnknownClass)
classMetricsIndex = 0 // unknown classes are tracked, not ignored
}
// Endpoint metrics
endpointMetricsNames := getEndpointMetricsNames(request, path)
metrics.IncCounter(endpointMetricsNames.countAll)
metrics.IncCounter(endpointMetricsNames.countPerStatusCodeRange[classMetricsIndex])
if stateBagPresent {
metrics.MeasureSince(endpointMetricsNames.latency, stateBag.begin)
}
log.Debugf("Pushed endpoint metrics with prefix `%s`", endpointMetricsNames.endpointPrefix)
// Client metrics
if path.ClientTracking != nil {
realmClientKey := f.getRealmClientKey(request, path)
clientMetricsNames := getClientMetricsNames(realmClientKey, path)
metrics.IncCounter(clientMetricsNames.countAll)
metrics.IncCounter(clientMetricsNames.countPerStatusCodeRange[classMetricsIndex])
if stateBagPresent {
latency := time.Since(stateBag.begin).Seconds()
metrics.IncFloatCounterBy(clientMetricsNames.latencySum, latency)
}
log.Debugf("Pushed client metrics with prefix `%s%s.`", path.ClientPrefix, realmClientKey)
}
}
func getClientMetricsNames(realmClientKey string, path *pathInfo) *clientMetricNames {
if value, ok := path.metricPrefixedPerClient.Load(realmClientKey); ok {
if prefixes, ok := value.(clientMetricNames); ok {
return &prefixes
}
}
clientPrefixForThisClient := path.ClientPrefix + realmClientKey + "."
prefixes := &clientMetricNames{
countAll: clientPrefixForThisClient + metricCountAll,
countPerStatusCodeRange: [6]string{
clientPrefixForThisClient + metricCountUnknownClass,
clientPrefixForThisClient + metricCount100s,
clientPrefixForThisClient + metricCount200s,
clientPrefixForThisClient + metricCount300s,
clientPrefixForThisClient + metricCount400s,
clientPrefixForThisClient + metricCount500s,
},
latencySum: clientPrefixForThisClient + metricLatencySum,
}
path.metricPrefixedPerClient.Store(realmClientKey, prefixes)
return prefixes
}
const unknownUnknown = unknownPlaceholder + "." + unknownPlaceholder
// getRealmClientKey generates the proper <realm>.<client> part of the client metrics name.
func (f *apiUsageMonitoringFilter) getRealmClientKey(r *http.Request, path *pathInfo) string {
// no JWT ==> {unknown}.{unknown}
jwt := parseJwtBody(r)
if jwt == nil {
return unknownUnknown
}
// no realm in JWT ==> {unknown}.{unknown}
realm, ok := jwt.getOneOfString(f.realmKeys)
if !ok {
return unknownUnknown
}
// realm is not one of the realmsTrackingPattern to be tracked ==> realm.{all}
if !path.ClientTracking.RealmsTrackingMatcher.MatchString(realm) {
return realm + ".{all}"
}
// no client in JWT ==> realm.{unknown}
client, ok := jwt.getOneOfString(f.clientKeys)
if !ok {
return realm + "." + unknownPlaceholder
}
// if client does not match ==> realm.{no-match}
matcher := path.ClientTracking.ClientTrackingMatcher
if matcher == nil || !matcher.MatchString(client) {
return realm + "." + noMatchPlaceholder
}
// all matched ==> realm.client
return realm + "." + client
}
// resolveMatchedPath tries to match the request's path with one of the configured path template.
func (f *apiUsageMonitoringFilter) resolveMatchedPath(u *url.URL) *pathInfo {
if u != nil {
for _, p := range f.Paths {
if p.Matcher.MatchString(u.Path) {
return p
}
}
}
return f.UnknownPath
}
// getEndpointMetricsNames returns the structure with names of the metrics for this specific context.
// It tries first from the path's cache. If it is not already cached, it is generated and
// caches it to speed up next calls.
func getEndpointMetricsNames(req *http.Request, path *pathInfo) *endpointMetricNames {
method := req.Method
methodIndex, ok := methodToIndex[method]
if !ok {
methodIndex = methodIndexUnknown
method = unknownPlaceholder
}
if p := path.metricPrefixesPerMethod[methodIndex]; p != nil {
return p
}
return createAndCacheMetricsNames(path, method, methodIndex)
}
// createAndCacheMetricsNames generates metrics names and cache them.
func createAndCacheMetricsNames(path *pathInfo, method string, methodIndex int) *endpointMetricNames {
endpointPrefix := path.CommonPrefix + method + "." + path.PathTemplate + ".*.*."
prefixes := &endpointMetricNames{
endpointPrefix: endpointPrefix,
countAll: endpointPrefix + metricCountAll,
countPerStatusCodeRange: [6]string{
endpointPrefix + metricCountUnknownClass,
endpointPrefix + metricCount100s,
endpointPrefix + metricCount200s,
endpointPrefix + metricCount300s,
endpointPrefix + metricCount400s,
endpointPrefix + metricCount500s,
},
latency: endpointPrefix + metricLatency,
}
path.metricPrefixesPerMethod[methodIndex] = prefixes
return prefixes
}
// parseJwtBody parses the JWT token from a HTTP request.
// It returns `nil` if it was not possible to parse the JWT body.
func parseJwtBody(req *http.Request) jwtTokenPayload {
ahead := req.Header.Get(authorizationHeaderName)
tv := strings.TrimPrefix(ahead, authorizationHeaderPrefix)
if tv == ahead {
return nil
}
token, err := jwt.Parse(tv)
if err != nil {
return nil
}
return token.Claims
}
type jwtTokenPayload map[string]interface{}
func (j jwtTokenPayload) getOneOfString(properties []string) (value string, ok bool) {
var rawValue interface{}
for _, p := range properties {
if rawValue, ok = j[p]; ok {
value = fmt.Sprint(rawValue)
return
}
}
return
}
package apiusagemonitoring
import "github.com/zalando/skipper/filters"
type noopSpec struct {
filter filters.Filter
}
func (*noopSpec) Name() string {
return filters.ApiUsageMonitoringName
}
func (s *noopSpec) CreateFilter(config []interface{}) (filters.Filter, error) {
return s.filter, nil
}
type noopFilter struct{}
func (noopFilter) Request(filters.FilterContext) {}
func (noopFilter) Response(filters.FilterContext) {}
package apiusagemonitoring
import (
"net/http"
"regexp"
"strings"
"sync"
)
// pathInfo contains the tracking information for a specific path.
type pathInfo struct {
ApplicationId string
Tag string
ApiId string
PathTemplate string
Matcher *regexp.Regexp
ClientTracking *clientTrackingInfo
CommonPrefix string
ClientPrefix string
metricPrefixesPerMethod [methodIndexLength]*endpointMetricNames
metricPrefixedPerClient sync.Map
}
func newPathInfo(applicationId, tag, apiId, pathTemplate string, clientTracking *clientTrackingInfo) *pathInfo {
id, t, found := strings.Cut(applicationId, ":")
if found {
applicationId = id
tag = t
}
if tag == "" {
tag = noTagPlaceholder
}
commonPrefix := applicationId + "." + tag + "." + apiId + "."
return &pathInfo{
ApplicationId: applicationId,
Tag: tag,
ApiId: apiId,
PathTemplate: pathTemplate,
metricPrefixedPerClient: sync.Map{},
ClientTracking: clientTracking,
CommonPrefix: commonPrefix,
ClientPrefix: commonPrefix + "*.*.",
}
}
// pathInfoByRegExRev allows sort.Sort to reorder a slice of `pathInfo` in
// reverse alphabetical order of their matcher (Regular Expression). That way,
// the more complex RegEx will end up first in the slice.
type pathInfoByRegExRev []*pathInfo
func (s pathInfoByRegExRev) Len() int { return len(s) }
func (s pathInfoByRegExRev) Less(i, j int) bool { return s[i].Matcher.String() > s[j].Matcher.String() }
func (s pathInfoByRegExRev) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type endpointMetricNames struct {
endpointPrefix string
countAll string
countPerStatusCodeRange [6]string
latency string
}
type clientMetricNames struct {
countAll string
countPerStatusCodeRange [6]string
latencySum string
}
const (
methodIndexGet = iota // GET
methodIndexHead // HEAD
methodIndexPost // POST
methodIndexPut // PUT
methodIndexPatch // PATCH
methodIndexDelete // DELETE
methodIndexConnect // CONNECT
methodIndexOptions // OPTIONS
methodIndexTrace // TRACE
methodIndexUnknown // Value when the HTTP Method is not in the known list
methodIndexLength // Gives the constant size of the `metricPrefixesPerMethod` array.
)
var (
methodToIndex = map[string]int{
http.MethodGet: methodIndexGet,
http.MethodHead: methodIndexHead,
http.MethodPost: methodIndexPost,
http.MethodPut: methodIndexPut,
http.MethodPatch: methodIndexPatch,
http.MethodDelete: methodIndexDelete,
http.MethodConnect: methodIndexConnect,
http.MethodOptions: methodIndexOptions,
http.MethodTrace: methodIndexTrace,
}
)
type clientTrackingInfo struct {
ClientTrackingMatcher *regexp.Regexp
RealmsTrackingMatcher *regexp.Regexp
}
package apiusagemonitoring
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
"sync"
"time"
"github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.ApiUsageMonitoringName instead
Name = filters.ApiUsageMonitoringName
unknownPlaceholder = "{unknown}"
noMatchPlaceholder = "{no-match}"
noTagPlaceholder = "{no-tag}"
)
var (
log = logrus.WithField("filter", filters.ApiUsageMonitoringName)
regCache = sync.Map{}
)
func loadOrCompileRegex(pattern string) (*regexp.Regexp, error) {
var err error
var reg *regexp.Regexp
regI, ok := regCache.Load(pattern)
if !ok {
reg, err = regexp.Compile(pattern)
regCache.Store(pattern, reg)
} else {
reg = regI.(*regexp.Regexp)
}
return reg, err
}
// NewApiUsageMonitoring creates a new instance of the API Monitoring filter
// specification (its factory).
func NewApiUsageMonitoring(
enabled bool,
realmKeys string,
clientKeys string,
realmsTrackingPattern string,
) filters.Spec {
if !enabled {
log.Debugf("filter %q is not enabled. spec returns `noop` filters.", filters.ApiUsageMonitoringName)
return &noopSpec{&noopFilter{}}
}
// parse realm keys comma separated list
var realmKeyList []string
for _, key := range strings.Split(realmKeys, ",") {
strippedKey := strings.TrimSpace(key)
if strippedKey != "" {
realmKeyList = append(realmKeyList, strippedKey)
}
}
// parse client keys comma separated list
var clientKeyList []string
for _, key := range strings.Split(clientKeys, ",") {
strippedKey := strings.TrimSpace(key)
if strippedKey != "" {
clientKeyList = append(clientKeyList, strippedKey)
}
}
realmsTrackingMatcher, err := loadOrCompileRegex(realmsTrackingPattern)
if err != nil {
log.Errorf(
"api-usage-monitoring-realmsTrackingPattern-tracking-pattern (global config) ignored: error compiling regular expression %q: %v",
realmsTrackingPattern, err)
realmsTrackingMatcher = regexp.MustCompile("services")
log.Warn("defaulting to 'services' as api-usage-monitoring-realmsTrackingPattern-tracking-pattern (global config)")
}
// Create the filter Spec
var unknownPathClientTracking *clientTrackingInfo = nil // client metrics feature is disabled
if realmKeys != "" {
unknownPathClientTracking = &clientTrackingInfo{
ClientTrackingMatcher: nil, // do not match anything (track `realm.{unknown}`)
RealmsTrackingMatcher: realmsTrackingMatcher,
}
}
unknownPath := newPathInfo(
unknownPlaceholder,
noTagPlaceholder,
unknownPlaceholder,
noMatchPlaceholder,
unknownPathClientTracking,
)
spec := &apiUsageMonitoringSpec{
pathHandler: defaultPathHandler{},
realmKeys: realmKeyList,
clientKeys: clientKeyList,
unknownPath: unknownPath,
realmsTrackingMatcher: realmsTrackingMatcher,
sometimes: rate.Sometimes{First: 3, Interval: 1 * time.Minute},
filterMap: make(map[string]*apiUsageMonitoringFilter),
}
log.Debugf("created filter spec: %+v", spec)
return spec
}
// apiConfig is the structure used to parse the parameters of the filter.
type apiConfig struct {
ApplicationId string `json:"application_id"`
Tag string `json:"tag"`
ApiId string `json:"api_id"`
PathTemplates []string `json:"path_templates"`
ClientTrackingPattern string `json:"client_tracking_pattern"`
}
type apiUsageMonitoringSpec struct {
pathHandler pathHandler
realmKeys []string
clientKeys []string
realmsTrackingMatcher *regexp.Regexp
unknownPath *pathInfo
sometimes rate.Sometimes
mu sync.Mutex
filterMap map[string]*apiUsageMonitoringFilter
}
func (s *apiUsageMonitoringSpec) errorf(format string, args ...interface{}) {
s.sometimes.Do(func() {
log.Errorf(format, args...)
})
}
func (s *apiUsageMonitoringSpec) warnf(format string, args ...interface{}) {
s.sometimes.Do(func() {
log.Warnf(format, args...)
})
}
func (s *apiUsageMonitoringSpec) infof(format string, args ...interface{}) {
s.sometimes.Do(func() {
log.Infof(format, args...)
})
}
func (s *apiUsageMonitoringSpec) debugf(format string, args ...interface{}) {
s.sometimes.Do(func() {
log.Debugf(format, args...)
})
}
func (s *apiUsageMonitoringSpec) Name() string {
return filters.ApiUsageMonitoringName
}
func keyFromArgs(args []interface{}) (string, error) {
var sb strings.Builder
for _, a := range args {
s, ok := a.(string)
if !ok {
sb.Reset()
return "", fmt.Errorf("failed to cast '%v' to string", a)
}
sb.WriteString(s)
}
return sb.String(), nil
}
func (s *apiUsageMonitoringSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
key, err := keyFromArgs(args)
// cache lookup
if err == nil {
s.mu.Lock()
f, ok := s.filterMap[key]
if ok {
s.mu.Unlock()
return f, nil
}
s.mu.Unlock()
}
apis := s.parseJsonConfiguration(args)
paths := s.buildPathInfoListFromConfiguration(apis)
if len(paths) == 0 {
s.errorf("no valid configurations, creating noop api usage monitoring filter")
return noopFilter{}, nil
}
f := &apiUsageMonitoringFilter{
realmKeys: s.realmKeys,
clientKeys: s.clientKeys,
Paths: paths,
UnknownPath: s.buildUnknownPathInfo(paths),
}
// cache write
s.mu.Lock()
s.filterMap[key] = f
s.mu.Unlock()
return f, nil
}
func (s *apiUsageMonitoringSpec) parseJsonConfiguration(args []interface{}) []*apiConfig {
apis := make([]*apiConfig, 0, len(args))
for i, a := range args {
rawJsonConfiguration, ok := a.(string)
if !ok {
s.errorf("args[%d] ignored: expecting a string, was %t", i, a)
continue
}
config := &apiConfig{
ClientTrackingPattern: ".*", // track all clients per default
}
decoder := json.NewDecoder(strings.NewReader(rawJsonConfiguration))
decoder.DisallowUnknownFields()
err := decoder.Decode(config)
if err != nil {
s.errorf("args[%d] ignored: error reading JSON configuration: %s", i, err)
continue
}
apis = append(apis, config)
}
return apis
}
func (s *apiUsageMonitoringSpec) buildUnknownPathInfo(paths []*pathInfo) *pathInfo {
var applicationId *string
for i := range paths {
path := paths[i]
if applicationId != nil && *applicationId != path.ApplicationId {
return s.unknownPath
}
applicationId = &path.ApplicationId
}
if applicationId != nil && *applicationId != "" {
return newPathInfo(
*applicationId,
s.unknownPath.Tag,
s.unknownPath.ApiId,
s.unknownPath.PathTemplate,
s.unknownPath.ClientTracking)
}
return s.unknownPath
}
func (s *apiUsageMonitoringSpec) buildPathInfoListFromConfiguration(apis []*apiConfig) []*pathInfo {
var paths []*pathInfo
existingPathTemplates := make(map[string]*pathInfo)
existingPathPattern := make(map[string]*pathInfo)
for apiIndex, api := range apis {
applicationId := api.ApplicationId
if applicationId == "" {
s.warnf("args[%d] ignored: does not specify an application_id", apiIndex)
continue
}
apiId := api.ApiId
if apiId == "" {
s.warnf("args[%d] ignored: does not specify an api_id", apiIndex)
continue
}
if len(api.PathTemplates) == 0 {
s.warnf("args[%d] ignored: does not specify any path template", apiIndex)
continue
}
clientTrackingInfo := s.buildClientTrackingInfo(apiIndex, api, s.realmsTrackingMatcher)
for templateIndex, template := range api.PathTemplates {
// Path Template validation
if template == "" {
s.warnf(
"args[%d].path_templates[%d] ignored: empty",
apiIndex, templateIndex)
continue
}
// Normalize path template and get regular expression path pattern
pathTemplate := s.pathHandler.normalizePathTemplate(template)
pathPattern := s.pathHandler.createPathPattern(template)
// Create new `pathInfo` with normalized PathTemplate
info := newPathInfo(applicationId, api.Tag, apiId, pathTemplate, clientTrackingInfo)
// Detect path template duplicates
if _, ok := existingPathTemplates[info.PathTemplate]; ok {
s.warnf(
"args[%d].path_templates[%d] ignored: duplicate path template %q",
apiIndex, templateIndex, info.PathTemplate)
continue
}
existingPathTemplates[info.PathTemplate] = info
// Detect regular expression duplicates
if existingMatcher, ok := existingPathPattern[pathPattern]; ok {
s.warnf(
"args[%d].path_templates[%d] ignored: two path templates yielded the same regular expression %q (%q and %q)",
apiIndex, templateIndex, pathPattern, info.PathTemplate, existingMatcher.PathTemplate)
continue
}
existingPathPattern[pathPattern] = info
pathPatternMatcher, err := loadOrCompileRegex(pathPattern)
if err != nil {
s.warnf(
"args[%d].path_templates[%d] ignored: error compiling regular expression %q for path %q: %v",
apiIndex, templateIndex, pathPattern, info.PathTemplate, err)
continue
}
if pathPatternMatcher == nil {
continue
}
info.Matcher = pathPatternMatcher
// Add valid entry to the results
paths = append(paths, info)
}
}
// Sort the paths by their matcher's RegEx
sort.Sort(pathInfoByRegExRev(paths))
return paths
}
func (s *apiUsageMonitoringSpec) buildClientTrackingInfo(apiIndex int, api *apiConfig, realmsTrackingMatcher *regexp.Regexp) *clientTrackingInfo {
if len(s.realmKeys) == 0 {
s.infof(
`args[%d]: skipper wide configuration "api-usage-monitoring-realm-keys" not provided, not tracking client metrics`,
apiIndex)
return nil
}
if len(s.clientKeys) == 0 {
s.infof(
`args[%d]: skipper wide configuration "api-usage-monitoring-client-keys" not provided, not tracking client metrics`,
apiIndex)
return nil
}
if api.ClientTrackingPattern == "" {
s.debugf(
`args[%d]: empty client_tracking_pattern disables the client metrics for its endpoints`,
apiIndex)
return nil
}
clientTrackingMatcher, err := loadOrCompileRegex(api.ClientTrackingPattern)
if err != nil {
s.errorf(
"args[%d].client_tracking_pattern ignored (no client tracking): error compiling regular expression %q: %v",
apiIndex, api.ClientTrackingPattern, err)
return nil
}
if clientTrackingMatcher == nil {
return nil
}
return &clientTrackingInfo{
ClientTrackingMatcher: clientTrackingMatcher,
RealmsTrackingMatcher: realmsTrackingMatcher,
}
}
var (
regexpMultipleSlashes = regexp.MustCompile(`/+`)
regexpLeadingSlashes = regexp.MustCompile(`^/*`)
regexpTrailingSlashes = regexp.MustCompile(`/*$`)
regexpMiddleSlashes = regexp.MustCompile(`([^/^])/+([^/*])`)
rexexpSlashColumnVar = regexp.MustCompile(`/:([^:{}/]*)`)
rexexpCurlyBracketVar = regexp.MustCompile(`{([^{}]*?)([?]?)}`)
regexpEscapeBeforeChars = regexp.MustCompile(`([.*+?\\])`)
regexpEscapeAfterChars = regexp.MustCompile(`([{}[\]()|])`)
)
// pathHandler path handler interface.
type pathHandler interface {
normalizePathTemplate(path string) string
createPathPattern(path string) string
}
// defaultPathHandler default path handler implementation.
type defaultPathHandler struct{}
// normalizePathTemplate normalize path template removing the leading and
// trailing slashes, substituting multiple adjacent slashes with a single
// one, and replacing column based variable declarations by curly bracked
// based.
func (ph defaultPathHandler) normalizePathTemplate(path string) string {
path = regexpLeadingSlashes.ReplaceAllString(path, "")
path = regexpTrailingSlashes.ReplaceAllString(path, "")
path = regexpMultipleSlashes.ReplaceAllString(path, "/")
path = rexexpSlashColumnVar.ReplaceAllString(path, "/{$1}")
path = rexexpCurlyBracketVar.ReplaceAllString(path, "{$1}")
return path
}
// createPathPattern create a regular expression path pattern for a path
// template by escaping regular specific characters, add optional matching
// of leading and trailing slashes, accept adjacent slashes as if a single
// slash was given, and allow free matching of content on variable locations.
func (ph defaultPathHandler) createPathPattern(path string) string {
path = regexpEscapeBeforeChars.ReplaceAllString(path, "\\$1")
path = rexexpSlashColumnVar.ReplaceAllString(path, "/.+")
path = rexexpCurlyBracketVar.ReplaceAllStringFunc(path, selectPathVarPattern)
path = regexpLeadingSlashes.ReplaceAllString(path, "^/*")
path = regexpTrailingSlashes.ReplaceAllString(path, "/*$")
path = regexpMiddleSlashes.ReplaceAllString(path, "$1/+$2")
path = regexpEscapeAfterChars.ReplaceAllString(path, "\\$1")
return path
}
// selectPathVarPattern select the correct path variable pattern depending
// on the path variable syntax. A trailing question mark is interpreted as
// a path variable that is allowed to be empty.
func selectPathVarPattern(match string) string {
if strings.HasSuffix(match, "\\?}") {
return ".*"
}
return ".+"
}
package auth
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/zalando/skipper/filters"
logfilter "github.com/zalando/skipper/filters/log"
)
type roleCheckType int
const (
checkOAuthTokeninfoAnyScopes roleCheckType = iota
checkOAuthTokeninfoAllScopes
checkOAuthTokeninfoAnyKV
checkOAuthTokeninfoAllKV
checkOAuthTokeninfoValidate
checkOAuthTokenintrospectionAnyClaims
checkOAuthTokenintrospectionAllClaims
checkOAuthTokenintrospectionAnyKV
checkOAuthTokenintrospectionAllKV
checkSecureOAuthTokenintrospectionAnyClaims
checkSecureOAuthTokenintrospectionAllClaims
checkSecureOAuthTokenintrospectionAnyKV
checkSecureOAuthTokenintrospectionAllKV
checkOIDCUserInfo
checkOIDCAnyClaims
checkOIDCAllClaims
checkOIDCQueryClaims
)
type rejectReason string
const (
missingBearerToken rejectReason = "missing-bearer-token"
missingToken rejectReason = "missing-token"
missingJWKS rejectReason = "missing-jwks"
authServiceAccess rejectReason = "auth-service-access"
invalidSub rejectReason = "invalid-sub-in-token"
inactiveToken rejectReason = "inactive-token"
invalidToken rejectReason = "invalid-token"
invalidScope rejectReason = "invalid-scope"
invalidClaim rejectReason = "invalid-claim"
invalidFilter rejectReason = "invalid-filter"
invalidAccess rejectReason = "invalid-access"
)
const (
AuthUnknown = "authUnknown"
authHeaderName = "Authorization"
authHeaderPrefix = "Bearer "
// tokenKey defined at https://tools.ietf.org/html/rfc7662#section-2.1
tokenKey = "token"
scopeKey = "scope"
uidKey = "uid"
)
type kv map[string][]string
type requestError struct {
err error
}
var (
errUnsupportedClaimSpecified = errors.New("unsupported claim specified in filter")
errInvalidToken = errors.New("invalid token")
errInvalidTokenintrospectionData = errors.New("invalid tokenintrospection data")
)
func (kv kv) String() string {
var res []string
for k, v := range kv {
res = append(res, k, strings.Join(v, " "))
}
return strings.Join(res, ",")
}
func (err *requestError) Error() string {
return err.err.Error()
}
func requestErrorf(f string, args ...interface{}) error {
return &requestError{
err: fmt.Errorf(f, args...),
}
}
func getToken(r *http.Request) (string, bool) {
h := r.Header.Get(authHeaderName)
if !strings.HasPrefix(h, authHeaderPrefix) {
return "", false
}
return h[len(authHeaderPrefix):], true
}
func reject(
ctx filters.FilterContext,
status int,
username string,
reason rejectReason,
hostname,
debuginfo string,
) {
if debuginfo == "" {
ctx.Logger().Debugf(
"Rejected: status: %d, username: %s, reason: %s.",
status, username, reason,
)
} else {
ctx.Logger().Debugf(
"Rejected: status: %d, username: %s, reason: %s, info: %s.",
status, username, reason, debuginfo,
)
}
ctx.StateBag()[logfilter.AuthUserKey] = username
ctx.StateBag()[logfilter.AuthRejectReasonKey] = string(reason)
rsp := &http.Response{
StatusCode: status,
Header: make(map[string][]string),
}
if hostname != "" {
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4.2
rsp.Header.Add("WWW-Authenticate", hostname)
}
ctx.Serve(rsp)
}
func unauthorized(ctx filters.FilterContext, username string, reason rejectReason, hostname, debuginfo string) {
reject(ctx, http.StatusUnauthorized, username, reason, hostname, debuginfo)
}
func forbidden(ctx filters.FilterContext, username string, reason rejectReason, debuginfo string) {
reject(ctx, http.StatusForbidden, username, reason, "", debuginfo)
}
func authorized(ctx filters.FilterContext, username string) {
ctx.StateBag()[logfilter.AuthUserKey] = username
}
func getStrings(args []interface{}) ([]string, error) {
s := make([]string, len(args))
var ok bool
for i, a := range args {
s[i], ok = a.(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
}
return s, nil
}
// all checks that all strings in the left are also in the
// right. Right can be a superset of left.
func all(left, right []string) bool {
for _, l := range left {
var found bool
for _, r := range right {
if l == r {
found = true
break
}
}
if !found {
return false
}
}
return true
}
package auth
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/net"
)
const (
webhookSpanName = "webhook"
tokenInfoSpanName = "tokeninfo"
tokenIntrospectionSpanName = "tokenintrospection"
)
const (
defaultMaxIdleConns = 64
)
type authClient struct {
url *url.URL
cli *net.Client
}
type tokeninfoClient interface {
getTokeninfo(token string, ctx filters.FilterContext) (map[string]any, error)
Close()
}
var _ tokeninfoClient = &authClient{}
func newAuthClient(baseURL, spanName string, timeout time.Duration, maxIdleConns int, tracer opentracing.Tracer) (*authClient, error) {
if tracer == nil {
tracer = opentracing.NoopTracer{}
}
if maxIdleConns <= 0 {
maxIdleConns = defaultMaxIdleConns
}
u, err := url.Parse(baseURL)
if err != nil {
return nil, err
}
cli := net.NewClient(net.Options{
ResponseHeaderTimeout: timeout,
TLSHandshakeTimeout: timeout,
MaxIdleConnsPerHost: maxIdleConns,
Tracer: tracer,
OpentracingComponentTag: "skipper",
OpentracingSpanName: spanName,
})
return &authClient{url: u, cli: cli}, nil
}
func (ac *authClient) Close() {
ac.cli.Close()
}
func bindContext(ctx filters.FilterContext, req *http.Request) *http.Request {
return req.WithContext(ctx.Request().Context())
}
func (ac *authClient) getTokenintrospect(token string, ctx filters.FilterContext) (tokenIntrospectionInfo, error) {
body := url.Values{}
body.Add(tokenKey, token)
req, err := http.NewRequest("POST", ac.url.String(), strings.NewReader(body.Encode()))
if err != nil {
return nil, err
}
req = bindContext(ctx, req)
if ac.url.User != nil {
authorization := base64.StdEncoding.EncodeToString([]byte(ac.url.User.String()))
req.Header.Add("Authorization", fmt.Sprintf("Basic %s", authorization))
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
rsp, err := ac.cli.Do(req)
if err != nil {
return nil, err
}
defer rsp.Body.Close()
if rsp.StatusCode != 200 {
io.Copy(io.Discard, rsp.Body)
if rsp.StatusCode != 403 && rsp.StatusCode != 401 {
return nil, fmt.Errorf("failed with status code: %d", rsp.StatusCode)
}
return nil, errInvalidToken
}
buf, err := io.ReadAll(rsp.Body)
if err != nil {
return nil, err
}
info := make(tokenIntrospectionInfo)
err = json.Unmarshal(buf, &info)
return info, err
}
func (ac *authClient) getTokeninfo(token string, ctx filters.FilterContext) (map[string]interface{}, error) {
var doc map[string]interface{}
req, err := http.NewRequest("GET", ac.url.String(), nil)
if err != nil {
return doc, err
}
req = bindContext(ctx, req)
if token != "" {
req.Header.Set(authHeaderName, authHeaderPrefix+token)
}
rsp, err := ac.cli.Do(req)
if err != nil {
return doc, err
}
defer rsp.Body.Close()
if rsp.StatusCode != 200 {
io.Copy(io.Discard, rsp.Body)
if rsp.StatusCode != 403 && rsp.StatusCode != 401 {
return nil, fmt.Errorf("failed with status code: %d", rsp.StatusCode)
}
return doc, errInvalidToken
}
d := json.NewDecoder(rsp.Body)
err = d.Decode(&doc)
return doc, err
}
func (ac *authClient) getWebhook(ctx filters.FilterContext) (*http.Response, error) {
req, err := http.NewRequest("GET", ac.url.String(), nil)
if err != nil {
return nil, err
}
req = bindContext(ctx, req)
copyHeader(req.Header, ctx.Request().Header)
rsp, err := ac.cli.Do(req)
if err != nil {
return nil, err
}
defer rsp.Body.Close()
return rsp, nil
}
package auth
import (
"fmt"
"net/http"
"os"
auth "github.com/abbot/go-http-auth"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.BasicAuthName instead
Name = filters.BasicAuthName
ForceBasicAuthHeaderName = "WWW-Authenticate"
ForceBasicAuthHeaderValue = "Basic realm="
DefaultRealmName = "Basic Realm"
)
type basicSpec struct{}
type basic struct {
authenticator *auth.BasicAuth
realmDefinition string
}
func NewBasicAuth() *basicSpec {
return &basicSpec{}
}
// We do not touch response at all
func (a *basic) Response(filters.FilterContext) {}
// check basic auth
func (a *basic) Request(ctx filters.FilterContext) {
username := a.authenticator.CheckAuth(ctx.Request())
if username == "" {
header := http.Header{}
header.Set(ForceBasicAuthHeaderName, a.realmDefinition)
ctx.Serve(&http.Response{
StatusCode: http.StatusUnauthorized,
Header: header,
})
}
}
// Creates out basicAuth Filter
// The first params specifies the used htpasswd file
// The second is optional and defines the realm name
func (spec *basicSpec) CreateFilter(config []interface{}) (filters.Filter, error) {
if len(config) == 0 {
return nil, filters.ErrInvalidFilterParameters
}
configFile, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
realmName := DefaultRealmName
if len(config) == 2 {
if definedName, ok := config[1].(string); ok {
realmName = definedName
}
}
if _, err := os.Stat(configFile); err != nil {
return nil, fmt.Errorf("stat failed for %q: %w", configFile, err)
}
htpasswd := auth.HtpasswdFileProvider(configFile)
authenticator := auth.NewBasicAuthenticator(realmName, htpasswd)
return &basic{
authenticator: authenticator,
realmDefinition: ForceBasicAuthHeaderValue + `"` + realmName + `"`,
}, nil
}
func (spec *basicSpec) Name() string { return filters.BasicAuthName }
package auth
import (
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/secrets"
)
const (
// Deprecated, use filters.BearerInjectorName instead
BearerInjectorName = filters.BearerInjectorName
)
type (
bearerInjectorSpec struct {
secretsReader secrets.SecretsReader
}
bearerInjectorFilter struct {
secretName string
secretsReader secrets.SecretsReader
}
)
func NewBearerInjector(sr secrets.SecretsReader) filters.Spec {
return &bearerInjectorSpec{
secretsReader: sr,
}
}
func (*bearerInjectorSpec) Name() string {
return filters.BearerInjectorName
}
func (b *bearerInjectorSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
secretName, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return newBearerInjectorFilter(secretName, b.secretsReader), nil
}
func newBearerInjectorFilter(s string, sr secrets.SecretsReader) *bearerInjectorFilter {
return &bearerInjectorFilter{
secretName: s,
secretsReader: sr,
}
}
func (f *bearerInjectorFilter) Request(ctx filters.FilterContext) {
b, ok := f.secretsReader.GetSecret(f.secretName)
if !ok {
return
}
ctx.Request().Header.Set(authHeaderName, authHeaderPrefix+string(b))
}
func (*bearerInjectorFilter) Response(filters.FilterContext) {}
package auth
import (
"encoding/json"
"fmt"
"github.com/zalando/skipper/filters"
"golang.org/x/net/http/httpguts"
)
const (
// Deprecated, use filters.ForwardTokenName instead
ForwardTokenName = filters.ForwardTokenName
)
type (
forwardTokenSpec struct{}
forwardTokenFilter struct {
HeaderName string
RetainJsonKeys []string
}
)
// NewForwardToken creates a filter to forward the result of token info or
// token introspection to the backend server.
func NewForwardToken() filters.Spec {
return &forwardTokenSpec{}
}
func (s *forwardTokenSpec) Name() string {
return filters.ForwardTokenName
}
func (*forwardTokenSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 1 {
return nil, filters.ErrInvalidFilterParameters
}
headerName, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
valid := httpguts.ValidHeaderFieldName(headerName)
if !valid {
return nil, fmt.Errorf("header name %s in invalid", headerName)
}
remainingArgs := args[1:]
stringifiedRemainingArgs := make([]string, len(remainingArgs))
for i, v := range remainingArgs {
maskedKeyName, ok := v.(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
stringifiedRemainingArgs[i] = maskedKeyName
}
return &forwardTokenFilter{HeaderName: headerName, RetainJsonKeys: stringifiedRemainingArgs}, nil
}
func getTokenPayload(ctx filters.FilterContext, cacheKey string) interface{} {
cachedValue, ok := ctx.StateBag()[cacheKey]
if !ok {
return nil
}
return cachedValue
}
func (f *forwardTokenFilter) Request(ctx filters.FilterContext) {
tiMap := getTokenPayload(ctx, tokeninfoCacheKey)
if tiMap == nil {
tiMap = getTokenPayload(ctx, tokenintrospectionCacheKey)
}
if tiMap == nil {
return
}
if len(f.RetainJsonKeys) > 0 {
switch typedTiMap := tiMap.(type) {
case map[string]interface{}:
tiMap = retainKeys(typedTiMap, f.RetainJsonKeys)
case tokenIntrospectionInfo:
tiMap = retainKeys(typedTiMap, f.RetainJsonKeys)
default:
ctx.Logger().Errorf("Unexpected input type[%T] for `forwardToken` filter. Unable to apply mask", typedTiMap)
}
}
payload, err := json.Marshal(tiMap)
if err != nil {
ctx.Logger().Errorf("Error while marshaling token: %v.", err)
return
}
request := ctx.Request()
request.Header.Set(f.HeaderName, string(payload))
}
func (*forwardTokenFilter) Response(filters.FilterContext) {}
func retainKeys(data map[string]interface{}, keys []string) map[string]interface{} {
whitelistedKeys := make(map[string]interface{})
for _, v := range keys {
if val, ok := data[v]; ok {
whitelistedKeys[v] = val
}
}
return whitelistedKeys
}
package auth
import (
"fmt"
"github.com/zalando/skipper/filters"
"golang.org/x/net/http/httpguts"
)
const (
// Deprecated, use filters.ForwardTokenFieldName instead
ForwardTokenFieldName = filters.ForwardTokenFieldName
)
type (
forwardTokenFieldSpec struct{}
forwardTokenFieldFilter struct {
HeaderName string
Field string
}
)
// NewForwardTokenField creates a filter to forward fields from token info or
// token introspection or oidc user info as headers to the backend server.
func NewForwardTokenField() filters.Spec {
return &forwardTokenFieldSpec{}
}
func (s *forwardTokenFieldSpec) Name() string {
return filters.ForwardTokenFieldName
}
func (*forwardTokenFieldSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
headerName, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
valid := httpguts.ValidHeaderFieldName(headerName)
if !valid {
return nil, fmt.Errorf("header name %s is invalid", headerName)
}
field, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return &forwardTokenFieldFilter{HeaderName: headerName, Field: field}, nil
}
func (f *forwardTokenFieldFilter) Request(ctx filters.FilterContext) {
payload := getPayload(ctx, tokeninfoCacheKey)
if payload == nil {
payload = getPayload(ctx, tokenintrospectionCacheKey)
}
if payload == nil {
payload = getPayload(ctx, oidcClaimsCacheKey)
}
if payload == nil {
return
}
err := setHeaders(map[string]string{
f.HeaderName: f.Field,
}, ctx, payload)
if err != nil {
ctx.Logger().Errorf("%v", err)
return
}
}
func (*forwardTokenFieldFilter) Response(filters.FilterContext) {}
func getPayload(ctx filters.FilterContext, cacheKey string) interface{} {
cachedValue, ok := ctx.StateBag()[cacheKey]
if !ok {
return nil
}
return cachedValue
}
package auth
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/annotate"
"golang.org/x/oauth2"
)
const (
// Deprecated, use filters.OAuthGrantName instead
OAuthGrantName = filters.OAuthGrantName
secretsRefreshInternal = time.Minute
refreshedTokenKey = "oauth-refreshed-token"
)
var (
errExpiredToken = errors.New("expired access token")
)
type grantSpec struct {
config *OAuthConfig
}
type grantFilter struct {
config *OAuthConfig
}
func (s *grantSpec) Name() string { return filters.OAuthGrantName }
func (s *grantSpec) CreateFilter([]interface{}) (filters.Filter, error) {
return &grantFilter{
config: s.config,
}, nil
}
func providerContext(c *OAuthConfig) context.Context {
return context.WithValue(context.Background(), oauth2.HTTPClient, c.AuthClient)
}
func serverError(ctx filters.FilterContext) {
ctx.Serve(&http.Response{
StatusCode: http.StatusInternalServerError,
})
}
func badRequest(ctx filters.FilterContext) {
ctx.Serve(&http.Response{
StatusCode: http.StatusBadRequest,
})
}
func loginRedirect(ctx filters.FilterContext, config *OAuthConfig) {
loginRedirectWithOverride(ctx, config, "")
}
func loginRedirectWithOverride(ctx filters.FilterContext, config *OAuthConfig, originalOverride string) {
req := ctx.Request()
authConfig, err := config.GetConfig(req)
if err != nil {
ctx.Logger().Debugf("Failed to obtain auth config: %v", err)
ctx.Serve(&http.Response{
StatusCode: http.StatusForbidden,
})
return
}
redirect, original := config.RedirectURLs(req)
if originalOverride != "" {
original = originalOverride
}
state, err := config.flowState.createState(original)
if err != nil {
ctx.Logger().Errorf("Failed to create login redirect: %v", err)
serverError(ctx)
return
}
authCodeURL := authConfig.AuthCodeURL(state, config.GetAuthURLParameters(redirect)...)
if lrs, ok := annotate.GetAnnotations(ctx)["oauthGrant.loginRedirectStub"]; ok {
lrs = strings.ReplaceAll(lrs, "{{authCodeURL}}", authCodeURL)
lrs = strings.ReplaceAll(lrs, "{authCodeURL}", authCodeURL)
ctx.Serve(&http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Length": []string{strconv.Itoa(len(lrs))},
"X-Auth-Code-Url": []string{authCodeURL},
},
Body: io.NopCloser(strings.NewReader(lrs)),
})
} else {
ctx.Serve(&http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: http.Header{
"Location": []string{authCodeURL},
},
})
}
}
func (f *grantFilter) refreshToken(token *oauth2.Token, req *http.Request) (*oauth2.Token, error) {
// Set the expiry of the token to the past to trigger oauth2.TokenSource
// to refresh the access token.
token.Expiry = time.Now().Add(-time.Minute)
ctx := providerContext(f.config)
authConfig, err := f.config.GetConfig(req)
if err != nil {
return nil, err
}
// oauth2.TokenSource implements the refresh functionality,
// we're hijacking it here.
tokenSource := authConfig.TokenSource(ctx, token)
return tokenSource.Token()
}
func (f *grantFilter) refreshTokenIfRequired(t *oauth2.Token, ctx filters.FilterContext) (*oauth2.Token, error) {
canRefresh := t.RefreshToken != ""
if time.Now().After(t.Expiry) {
if canRefresh {
token, err := f.refreshToken(t, ctx.Request())
if err == nil {
// Remember that this token was just successfully refreshed
// so that we can send an updated cookie in the response.
ctx.StateBag()[refreshedTokenKey] = token
}
return token, err
} else {
return nil, errExpiredToken
}
} else {
return t, nil
}
}
func (f *grantFilter) setupToken(token *oauth2.Token, tokeninfo map[string]interface{}, ctx filters.FilterContext) error {
if f.config.AccessTokenHeaderName != "" {
ctx.Request().Header.Set(f.config.AccessTokenHeaderName, authHeaderPrefix+token.AccessToken)
}
subject := ""
if f.config.TokeninfoSubjectKey != "" {
if s, ok := tokeninfo[f.config.TokeninfoSubjectKey].(string); ok {
subject = s
} else {
return fmt.Errorf("tokeninfo subject key '%s' is missing", f.config.TokeninfoSubjectKey)
}
}
tokeninfo["sub"] = subject
if len(f.config.grantTokeninfoKeysLookup) > 0 {
for key := range tokeninfo {
if _, ok := f.config.grantTokeninfoKeysLookup[key]; !ok {
delete(tokeninfo, key)
}
}
}
// By piggy-backing on the OIDC token container,
// we gain downstream compatibility with the oidcClaimsQuery filter.
SetOIDCClaims(ctx, tokeninfo)
// Set the tokeninfo also in the tokeninfoCacheKey state bag, so we
// can reuse e.g. the forwardToken() filter.
ctx.StateBag()[tokeninfoCacheKey] = tokeninfo
return nil
}
func (f *grantFilter) Request(ctx filters.FilterContext) {
token, err := f.config.GrantCookieEncoder.Read(ctx.Request())
if err == http.ErrNoCookie {
loginRedirect(ctx, f.config)
return
}
token, err = f.refreshTokenIfRequired(token, ctx)
if err != nil {
// Refresh failed and we no longer have a valid access token.
loginRedirect(ctx, f.config)
return
}
tokeninfo, err := f.config.TokeninfoClient.getTokeninfo(token.AccessToken, ctx)
if err != nil {
if err != errInvalidToken {
ctx.Logger().Errorf("Failed to call tokeninfo: %v.", err)
}
loginRedirect(ctx, f.config)
return
}
err = f.setupToken(token, tokeninfo, ctx)
if err != nil {
ctx.Logger().Errorf("Failed to create token container: %v.", err)
loginRedirect(ctx, f.config)
return
}
}
func (f *grantFilter) Response(ctx filters.FilterContext) {
// If the token was refreshed in this request flow,
// we want to send an updated cookie. If it wasn't, the
// users will still have their old cookie and we do not
// need to send it again and this function can exit early.
token, ok := ctx.StateBag()[refreshedTokenKey].(*oauth2.Token)
if !ok {
return
}
cookies, err := f.config.GrantCookieEncoder.Update(ctx.Request(), token)
if err != nil {
ctx.Logger().Errorf("Failed to generate cookie: %v.", err)
return
}
for _, c := range cookies {
ctx.Response().Header.Add("Set-Cookie", c.String())
}
}
package auth
import (
"net/http"
"net/url"
"github.com/zalando/skipper/filters"
"golang.org/x/oauth2"
)
// GrantCallbackName is the filter name
// Deprecated, use filters.GrantCallbackName instead
const GrantCallbackName = filters.GrantCallbackName
type grantCallbackSpec struct {
config *OAuthConfig
}
type grantCallbackFilter struct {
config *OAuthConfig
}
func (*grantCallbackSpec) Name() string { return filters.GrantCallbackName }
func (s *grantCallbackSpec) CreateFilter([]interface{}) (filters.Filter, error) {
return &grantCallbackFilter{
config: s.config,
}, nil
}
func (f *grantCallbackFilter) exchangeAccessToken(req *http.Request, code string) (*oauth2.Token, error) {
authConfig, err := f.config.GetConfig(req)
if err != nil {
return nil, err
}
redirectURI, _ := f.config.RedirectURLs(req)
ctx := providerContext(f.config)
params := f.config.GetAuthURLParameters(redirectURI)
return authConfig.Exchange(ctx, code, params...)
}
func (f *grantCallbackFilter) Request(ctx filters.FilterContext) {
req := ctx.Request()
q := req.URL.Query()
code := q.Get("code")
if code == "" {
badRequest(ctx)
return
}
queryState := q.Get("state")
if queryState == "" {
badRequest(ctx)
return
}
state, err := f.config.flowState.extractState(queryState)
if err != nil {
if err == errExpiredAuthState {
// The login flow state expired. Instead of just returning an
// error, restart the login process with the original request
// URL.
loginRedirectWithOverride(ctx, f.config, state.RequestURL)
} else {
serverError(ctx)
}
return
}
// Redirect callback request to the host of the initial request
if initial, _ := url.Parse(state.RequestURL); initial.Host != req.Host {
location := *req.URL
location.Host = initial.Host
location.Scheme = initial.Scheme
ctx.Serve(&http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: http.Header{
"Location": []string{location.String()},
},
})
return
}
token, err := f.exchangeAccessToken(req, code)
if err != nil {
ctx.Logger().Errorf("Failed to exchange access token: %v.", err)
serverError(ctx)
return
}
cookies, err := f.config.GrantCookieEncoder.Update(req, token)
if err != nil {
ctx.Logger().Errorf("Failed to create OAuth grant cookie: %v.", err)
serverError(ctx)
return
}
resp := &http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: http.Header{
"Location": []string{state.RequestURL},
},
}
for _, c := range cookies {
resp.Header.Add("Set-Cookie", c.String())
}
ctx.Serve(resp)
}
func (f *grantCallbackFilter) Response(ctx filters.FilterContext) {}
//
// grantClaimsQuery filter
//
// An alias for oidcClaimsQuery filter allowing a clearer
// API when used in conjunction with the oauthGrant filter.
//
package auth
import "github.com/zalando/skipper/filters"
// GrantClaimsQueryName is the filter name
// Deprecated, use filters.GrantClaimsQueryName instead
const GrantClaimsQueryName = filters.GrantClaimsQueryName
type grantClaimsQuerySpec struct {
oidcSpec oidcIntrospectionSpec
}
func (s *grantClaimsQuerySpec) Name() string {
return filters.GrantClaimsQueryName
}
func (s *grantClaimsQuerySpec) CreateFilter(args []interface{}) (filters.Filter, error) {
return s.oidcSpec.CreateFilter(args)
}
package auth
import (
"errors"
"fmt"
"net"
"net/http"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
snet "github.com/zalando/skipper/net"
"github.com/zalando/skipper/routing"
"github.com/zalando/skipper/secrets"
"golang.org/x/oauth2"
)
type OAuthConfig struct {
initialized bool
flowState *flowState
grantTokeninfoKeysLookup map[string]struct{}
getClientId func(*http.Request) (string, error)
getClientSecret func(*http.Request) (string, error)
// TokeninfoURL is the URL of the service to validate OAuth2 tokens.
TokeninfoURL string
// Secrets is a secret registry to access secret keys used for encrypting
// auth flow state and auth cookies.
Secrets *secrets.Registry
// SecretFile contains the filename with the encryption key for the authentication
// cookie and grant flow state stored in Secrets.
SecretFile string
// AuthURL, the url to redirect the requests to when login is required.
AuthURL string
// TokenURL, the url where the access code should be exchanged for the
// access token.
TokenURL string
// RevokeTokenURL, the url where the access and revoke tokens can be
// revoked during a logout.
RevokeTokenURL string
// CallbackPath contains the path where the callback requests with the
// authorization code should be redirected to.
CallbackPath string
// ClientID, the OAuth2 client id of the current service, used to exchange
// the access code. Must be set if ClientIDFile is not provided.
ClientID string
// ClientSecret, the secret associated with the ClientID, used to exchange
// the access code. Must be set if ClientSecretFile is not provided.
ClientSecret string
// ClientIDFile, the path to the file containing the OAuth2 client id of
// the current service, used to exchange the access code.
// Must be set if ClientID is not provided.
// File name may contain {host} placeholder which will be replaced by the request host.
// Requires SecretsProvider, the path (or path's directory if placeholder is present) will be added to it.
ClientIDFile string
// ClientSecretFile, the path to the file containing the secret associated
// with the ClientID, used to exchange the access code.
// Must be set if ClientSecret is not provided.
// File name may contain {host} placeholder which will be replaced by the request host.
// Requires SecretsProvider, the path (or path's directory if placeholder is present) will be added to it.
ClientSecretFile string
// SecretsProvider is used to read ClientIDFile and ClientSecretFile from the
// file system. Supports secret rotation.
SecretsProvider secrets.SecretsProvider
// TokeninfoClient, optional. When set, it will be used for the
// authorization requests to TokeninfoURL. When not set, a new default
// client is created.
TokeninfoClient *authClient
// AuthClient, optional. When set, it will be used for the
// access code exchange requests to TokenURL. When not set, a new default
// client is created.
AuthClient *snet.Client
// AuthURLParameters, optional. Extra URL parameters to add when calling
// the OAuth2 authorize or token endpoints.
AuthURLParameters map[string]string
// AccessTokenHeaderName, optional. When set, the access token will be set
// on the request to a header with this name.
AccessTokenHeaderName string
// GrantTokeninfoKeys, optional. When not empty, keys not in this list are removed from the tokeninfo map.
GrantTokeninfoKeys []string
// GrantCookieEncoder, optional. Cookie encoder stores and extracts OAuth token from cookies.
GrantCookieEncoder CookieEncoder
// TokeninfoSubjectKey, optional. When set, it is used to look up the subject
// ID in the tokeninfo map received from a tokeninfo endpoint request.
TokeninfoSubjectKey string
// TokenCookieName, optional. The name of the cookie used to store the
// encrypted access token after a successful token exchange.
TokenCookieName string
// TokenCookieRemoveSubdomains sets the number of subdomains to remove from
// the callback request hostname to obtain token cookie domain.
// Init converts default nil to 1.
TokenCookieRemoveSubdomains *int
// Insecure omits Secure attribute of the token cookie and uses http scheme for callback url.
Insecure bool
// ConnectionTimeout used for tokeninfo, access-token and refresh-token endpoint.
ConnectionTimeout time.Duration
// MaxIdleConnectionsPerHost used for tokeninfo, access-token and refresh-token endpoint.
MaxIdleConnectionsPerHost int
// Tracer used for tokeninfo, access-token and refresh-token endpoint.
Tracer opentracing.Tracer
}
var (
ErrMissingClientID = errors.New("missing client ID")
ErrMissingClientSecret = errors.New("missing client secret")
ErrMissingSecretsProvider = errors.New("missing secrets provider")
ErrMissingSecretsRegistry = errors.New("missing secrets registry")
ErrMissingSecretFile = errors.New("missing secret file")
ErrMissingTokeninfoURL = errors.New("missing tokeninfo URL")
ErrMissingProviderURLs = errors.New("missing provider URLs")
)
func (c *OAuthConfig) Init() error {
if c.initialized {
return nil
}
if c.TokeninfoURL == "" {
return ErrMissingTokeninfoURL
}
if c.AuthURL == "" || c.TokenURL == "" {
return ErrMissingProviderURLs
}
if c.Secrets == nil {
return ErrMissingSecretsRegistry
}
if c.SecretFile == "" {
return ErrMissingSecretFile
}
if c.CallbackPath == "" {
c.CallbackPath = defaultCallbackPath
}
if c.TokenCookieName == "" {
c.TokenCookieName = defaultTokenCookieName
}
if c.TokenCookieRemoveSubdomains == nil {
one := 1
c.TokenCookieRemoveSubdomains = &one
} else if *c.TokenCookieRemoveSubdomains < 0 {
return fmt.Errorf("invalid number of cookie subdomains to remove")
}
if c.TokeninfoClient == nil {
client, err := newAuthClient(
c.TokeninfoURL,
"granttokeninfo",
c.ConnectionTimeout,
c.MaxIdleConnectionsPerHost,
c.Tracer,
)
if err != nil {
return err
}
c.TokeninfoClient = client
}
if c.AuthClient == nil {
c.AuthClient = snet.NewClient(snet.Options{
ResponseHeaderTimeout: c.ConnectionTimeout,
TLSHandshakeTimeout: c.ConnectionTimeout,
MaxIdleConnsPerHost: c.MaxIdleConnectionsPerHost,
Tracer: c.Tracer,
OpentracingComponentTag: "skipper",
OpentracingSpanName: "grantauth",
})
}
c.flowState = newFlowState(c.Secrets, c.SecretFile)
if c.ClientID != "" {
c.getClientId = func(*http.Request) (string, error) {
return c.ClientID, nil
}
} else if c.ClientIDFile != "" {
if c.SecretsProvider == nil {
return ErrMissingSecretsProvider
}
if hasPlaceholders(c.ClientIDFile) {
c.getClientId = func(req *http.Request) (string, error) {
return c.getSecret(resolvePlaceholders(c.ClientIDFile, req))
}
if err := c.SecretsProvider.Add(filepath.Dir(c.ClientIDFile)); err != nil {
return err
}
} else {
c.getClientId = func(*http.Request) (string, error) {
return c.getSecret(c.ClientIDFile)
}
if err := c.SecretsProvider.Add(c.ClientIDFile); err != nil {
return err
}
}
} else {
return ErrMissingClientID
}
if c.ClientSecret != "" {
c.getClientSecret = func(*http.Request) (string, error) {
return c.ClientSecret, nil
}
} else if c.ClientSecretFile != "" {
if c.SecretsProvider == nil {
return ErrMissingSecretsProvider
}
if hasPlaceholders(c.ClientSecretFile) {
c.getClientSecret = func(req *http.Request) (string, error) {
return c.getSecret(resolvePlaceholders(c.ClientSecretFile, req))
}
if err := c.SecretsProvider.Add(filepath.Dir(c.ClientSecretFile)); err != nil {
return err
}
} else {
c.getClientSecret = func(*http.Request) (string, error) {
return c.getSecret(c.ClientSecretFile)
}
if err := c.SecretsProvider.Add(c.ClientSecretFile); err != nil {
return err
}
}
} else {
return ErrMissingClientSecret
}
if len(c.GrantTokeninfoKeys) > 0 {
c.grantTokeninfoKeysLookup = make(map[string]struct{}, len(c.GrantTokeninfoKeys))
for _, key := range c.GrantTokeninfoKeys {
c.grantTokeninfoKeysLookup[key] = struct{}{}
}
}
if c.GrantCookieEncoder == nil {
encryption, err := c.Secrets.GetEncrypter(secretsRefreshInternal, c.SecretFile)
if err != nil {
return err
}
c.GrantCookieEncoder = &EncryptedCookieEncoder{
Encryption: encryption,
CookieName: c.TokenCookieName,
RemoveSubdomains: *c.TokenCookieRemoveSubdomains,
Insecure: c.Insecure,
}
}
c.initialized = true
return nil
}
func (c *OAuthConfig) NewGrant() filters.Spec {
return &grantSpec{config: c}
}
func (c *OAuthConfig) NewGrantCallback() filters.Spec {
return &grantCallbackSpec{config: c}
}
func (c *OAuthConfig) NewGrantClaimsQuery() filters.Spec {
return &grantClaimsQuerySpec{
oidcSpec: oidcIntrospectionSpec{
typ: checkOIDCQueryClaims,
},
}
}
func (c *OAuthConfig) NewGrantLogout() filters.Spec {
return &grantLogoutSpec{config: c}
}
func (c *OAuthConfig) NewGrantPreprocessor() routing.PreProcessor {
return &grantPrep{config: c}
}
func (c *OAuthConfig) GetConfig(req *http.Request) (*oauth2.Config, error) {
var err error
authConfig := &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: c.AuthURL,
TokenURL: c.TokenURL,
},
}
authConfig.ClientID, err = c.getClientId(req)
if err != nil {
return nil, err
}
authConfig.ClientSecret, err = c.getClientSecret(req)
if err != nil {
return nil, err
}
return authConfig, nil
}
func (c *OAuthConfig) getSecret(file string) (string, error) {
if secret, ok := c.SecretsProvider.GetSecret(file); ok {
return string(secret), nil
} else {
return "", fmt.Errorf("secret %s does not exist", file)
}
}
func resolvePlaceholders(s string, r *http.Request) string {
h, _, err := net.SplitHostPort(r.Host)
if err != nil {
h = r.Host
}
return strings.ReplaceAll(s, "{host}", h)
}
func hasPlaceholders(s string) bool {
return resolvePlaceholders(s, &http.Request{Host: "example.org"}) != s
}
func (c *OAuthConfig) GetAuthURLParameters(redirectURI string) []oauth2.AuthCodeOption {
params := []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("redirect_uri", redirectURI)}
if c.AuthURLParameters != nil {
for k, v := range c.AuthURLParameters {
params = append(params, oauth2.SetAuthURLParam(k, v))
}
}
return params
}
// RedirectURLs constructs the redirect URI based on the request and the
// configured CallbackPath.
func (c *OAuthConfig) RedirectURLs(req *http.Request) (redirect, original string) {
u := *req.URL
if c.Insecure {
u.Scheme = "http"
} else {
u.Scheme = "https"
}
u.Host = req.Host
original = u.String()
redirect = (&url.URL{
Scheme: u.Scheme,
Host: u.Host,
Path: c.CallbackPath,
}).String()
return
}
package auth
import (
"encoding/base64"
"encoding/json"
"net"
"net/http"
"strings"
"time"
"github.com/zalando/skipper/secrets"
"golang.org/x/oauth2"
)
type CookieEncoder interface {
// Update creates a set of cookies that encodes the token and deletes previously existing cookies if necessary.
// When token is nil it only returns cookies to delete.
Update(request *http.Request, token *oauth2.Token) ([]*http.Cookie, error)
// Read extracts the token from the request cookies.
Read(request *http.Request) (*oauth2.Token, error)
}
// EncryptedCookieEncoder is a CookieEncoder that encrypts the token before storing it in a cookie.
type EncryptedCookieEncoder struct {
Encryption secrets.Encryption
CookieName string
RemoveSubdomains int
Insecure bool
}
var _ CookieEncoder = &EncryptedCookieEncoder{}
func (ce *EncryptedCookieEncoder) Update(request *http.Request, token *oauth2.Token) ([]*http.Cookie, error) {
if token != nil {
c, err := ce.createCookie(request.Host, token)
if err != nil {
return nil, err
}
return []*http.Cookie{c}, nil
} else {
c := ce.createDeleteCookie(request.Host)
return []*http.Cookie{c}, nil
}
}
func (ce *EncryptedCookieEncoder) Read(request *http.Request) (*oauth2.Token, error) {
c, err := ce.extractCookie(request)
if err != nil {
return nil, err
}
return &oauth2.Token{
AccessToken: c.AccessToken,
TokenType: "Bearer",
RefreshToken: c.RefreshToken,
Expiry: c.Expiry,
}, nil
}
type cookie struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
Expiry time.Time `json:"expiry,omitempty"`
Domain string `json:"domain,omitempty"`
}
func (ce *EncryptedCookieEncoder) decodeCookie(cookieHeader string) (c *cookie, err error) {
var eb []byte
if eb, err = base64.StdEncoding.DecodeString(cookieHeader); err != nil {
return
}
var b []byte
if b, err = ce.Encryption.Decrypt(eb); err != nil {
return
}
err = json.Unmarshal(b, &c)
return
}
// allowedForHost checks if provided host matches cookie domain
// according to https://www.rfc-editor.org/rfc/rfc6265#section-5.1.3
func (c *cookie) allowedForHost(host string) bool {
hostWithoutPort, _, err := net.SplitHostPort(host)
if err != nil {
hostWithoutPort = host
}
return strings.HasSuffix(hostWithoutPort, c.Domain)
}
// extractCookie removes and returns the OAuth Grant token cookie from a HTTP request.
// The function supports multiple cookies with the same name and returns
// the best match (the one that decodes properly).
// The client may send multiple cookies if a parent domain has set a
// cookie of the same name.
// The grant token cookie is extracted so it does not get exposed to untrusted downstream
// services.
func (ce *EncryptedCookieEncoder) extractCookie(request *http.Request) (*cookie, error) {
cookies := request.Cookies()
for i, c := range cookies {
if c.Name != ce.CookieName {
continue
}
decoded, err := ce.decodeCookie(c.Value)
if err == nil && decoded.allowedForHost(request.Host) {
request.Header.Del("Cookie")
for j, c := range cookies {
if j != i {
request.AddCookie(c)
}
}
return decoded, nil
}
}
return nil, http.ErrNoCookie
}
// createDeleteCookie creates a cookie, which instructs the client to clear the grant
// token cookie when used with a Set-Cookie header.
func (ce *EncryptedCookieEncoder) createDeleteCookie(host string) *http.Cookie {
return &http.Cookie{
Name: ce.CookieName,
Value: "",
Path: "/",
Domain: extractDomainFromHost(host, ce.RemoveSubdomains),
MaxAge: -1,
Secure: !ce.Insecure,
HttpOnly: true,
}
}
func (ce *EncryptedCookieEncoder) createCookie(host string, t *oauth2.Token) (*http.Cookie, error) {
domain := extractDomainFromHost(host, ce.RemoveSubdomains)
c := &cookie{
AccessToken: t.AccessToken,
RefreshToken: t.RefreshToken,
Expiry: t.Expiry,
Domain: domain,
}
b, err := json.Marshal(c)
if err != nil {
return nil, err
}
eb, err := ce.Encryption.Encrypt(b)
if err != nil {
return nil, err
}
b64 := base64.StdEncoding.EncodeToString(eb)
// The cookie expiry date must not be the same as the access token
// expiry. Otherwise the browser deletes the cookie as soon as the
// access token expires, but _before_ the refresh token has expired.
// Since we don't know the actual refresh token expiry, set it to
// 30 days as a good compromise.
return &http.Cookie{
Name: ce.CookieName,
Value: b64,
Path: "/",
Domain: domain,
Expires: t.Expiry.Add(time.Hour * 24 * 30),
Secure: !ce.Insecure,
HttpOnly: true,
}, nil
}
package auth
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/zalando/skipper/secrets"
)
type state struct {
Validity int64 `json:"validity"`
Nonce string `json:"nonce"`
RequestURL string `json:"redirectUrl"`
}
type flowState struct {
secrets *secrets.Registry
secretsFile string
}
var errExpiredAuthState = errors.New("expired auth state")
func newFlowState(secrets *secrets.Registry, secretsFile string) *flowState {
return &flowState{
secrets: secrets,
secretsFile: secretsFile,
}
}
func stateValidityTime() int64 {
return time.Now().Add(time.Hour).Unix()
}
func (s *flowState) createState(redirectURL string) (string, error) {
encrypter, err := s.secrets.GetEncrypter(secretsRefreshInternal, s.secretsFile)
if err != nil {
return "", err
}
nonce, err := encrypter.CreateNonce()
if err != nil {
return "", err
}
state := state{
Validity: stateValidityTime(),
Nonce: fmt.Sprintf("%x", nonce),
RequestURL: redirectURL,
}
jb, err := json.Marshal(state)
if err != nil {
return "", err
}
eb, err := encrypter.Encrypt(jb)
if err != nil {
return "", err
}
return fmt.Sprintf("%x", eb), nil
}
func (s *flowState) extractState(st string) (state state, err error) {
var encrypter secrets.Encryption
if encrypter, err = s.secrets.GetEncrypter(secretsRefreshInternal, s.secretsFile); err != nil {
return
}
var eb []byte
if _, err = fmt.Sscanf(st, "%x", &eb); err != nil {
return
}
var jb []byte
if jb, err = encrypter.Decrypt(eb); err != nil {
return
}
if err = json.Unmarshal(jb, &state); err != nil {
return
}
validity := time.Unix(state.Validity, 0)
if validity.Before(time.Now()) {
err = errExpiredAuthState
}
return
}
func (s *flowState) Close() {
s.secrets.Close()
}
package auth
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/zalando/skipper/filters"
"golang.org/x/oauth2"
)
const (
// Deprecated, use filters.GrantLogoutName instead
GrantLogoutName = filters.GrantLogoutName
revokeTokenKey = "token"
revokeTokenTypeKey = "token_type_hint"
refreshTokenType = "refresh_token"
accessTokenType = "access_token"
errUnsupportedTokenType = "unsupported_token_type"
)
type grantLogoutSpec struct {
config *OAuthConfig
}
type grantLogoutFilter struct {
config *OAuthConfig
}
type revokeErrorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
func (*grantLogoutSpec) Name() string { return filters.GrantLogoutName }
func (s *grantLogoutSpec) CreateFilter([]interface{}) (filters.Filter, error) {
return &grantLogoutFilter{
config: s.config,
}, nil
}
func responseToError(responseData []byte, statusCode int, tokenType string) error {
var errorResponse revokeErrorResponse
err := json.Unmarshal(responseData, &errorResponse)
if err != nil {
return err
}
if errorResponse.Error == errUnsupportedTokenType && tokenType == accessTokenType {
// Provider does not support revoking access tokens, which can happen according to RFC 7009.
// In that case this is not really an error.
return nil
}
return fmt.Errorf(
"%s revocation failed: %d %s: %s",
tokenType,
statusCode,
errorResponse.Error,
errorResponse.ErrorDescription,
)
}
func (f *grantLogoutFilter) revokeTokenType(c *oauth2.Config, tokenType string, token string) error {
revokeURL, err := url.Parse(f.config.RevokeTokenURL)
if err != nil {
return err
}
query := revokeURL.Query()
for k, v := range f.config.AuthURLParameters {
query.Set(k, v)
}
revokeURL.RawQuery = query.Encode()
body := url.Values{}
body.Add(revokeTokenKey, token)
body.Add(revokeTokenTypeKey, tokenType)
revokeRequest, err := http.NewRequest(
"POST",
revokeURL.String(),
strings.NewReader(body.Encode()))
if err != nil {
return err
}
revokeRequest.SetBasicAuth(c.ClientID, c.ClientSecret)
revokeRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
revokeResponse, err := f.config.AuthClient.Do(revokeRequest)
if err != nil {
return err
}
defer revokeResponse.Body.Close()
buf, err := io.ReadAll(revokeResponse.Body)
if err != nil {
return err
}
if revokeResponse.StatusCode == 400 {
return responseToError(buf, revokeResponse.StatusCode, tokenType)
} else if revokeResponse.StatusCode != 200 {
return fmt.Errorf(
"%s revocation failed: %d",
tokenType,
revokeResponse.StatusCode,
)
}
return nil
}
func (f *grantLogoutFilter) Request(ctx filters.FilterContext) {
if f.config.RevokeTokenURL == "" {
return
}
req := ctx.Request()
token, err := f.config.GrantCookieEncoder.Read(req)
if err != nil {
unauthorized(
ctx,
"",
missingToken,
req.Host,
fmt.Sprintf("No token cookie %v in request.", f.config.TokenCookieName))
return
}
if token.AccessToken == "" && token.RefreshToken == "" {
unauthorized(
ctx,
"",
missingToken,
req.Host,
fmt.Sprintf("Token cookie %v has no tokens.", f.config.TokenCookieName))
return
}
authConfig, err := f.config.GetConfig(req)
if err != nil {
serverError(ctx)
return
}
var accessTokenRevokeError, refreshTokenRevokeError error
if token.AccessToken != "" {
accessTokenRevokeError = f.revokeTokenType(authConfig, accessTokenType, token.AccessToken)
if accessTokenRevokeError != nil {
ctx.Logger().Errorf("%v", accessTokenRevokeError)
}
}
if token.RefreshToken != "" {
refreshTokenRevokeError = f.revokeTokenType(authConfig, refreshTokenType, token.RefreshToken)
if refreshTokenRevokeError != nil {
ctx.Logger().Errorf("%v", refreshTokenRevokeError)
}
}
if refreshTokenRevokeError != nil || accessTokenRevokeError != nil {
serverError(ctx)
}
}
func (f *grantLogoutFilter) Response(ctx filters.FilterContext) {
cookies, err := f.config.GrantCookieEncoder.Update(ctx.Request(), nil)
if err != nil {
ctx.Logger().Errorf("Failed to delete cookies: %v.", err)
return
}
for _, c := range cookies {
ctx.Response().Header.Add("Set-Cookie", c.String())
}
}
package auth
import (
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
)
const (
defaultCallbackRouteID = "__oauth2_grant_callback"
defaultCallbackPath = "/.well-known/oauth2-callback"
defaultTokenCookieName = "oauth-grant"
)
type grantPrep struct {
config *OAuthConfig
}
func (p *grantPrep) Do(r []*eskip.Route) []*eskip.Route {
// In the future, route IDs will serve only logging purpose and won't
// need to be unique.
id := defaultCallbackRouteID
return append(r, &eskip.Route{
Id: id,
Predicates: []*eskip.Predicate{{
Name: "Path",
Args: []interface{}{
p.config.CallbackPath,
},
}},
Filters: []*eskip.Filter{{
Name: filters.GrantCallbackName,
}},
BackendType: eskip.ShuntBackend,
})
}
package auth
import (
"fmt"
"regexp"
"slices"
"strings"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/annotate"
"github.com/zalando/skipper/jwt"
)
type (
jwtMetricsSpec struct {
yamlConfigParser[jwtMetricsFilter]
}
// jwtMetricsFilter implements [yamlConfig],
// make sure it is not modified after initialization.
jwtMetricsFilter struct {
// Issuers is *DEPRECATED* and will be removed in the future. Use the Claims field instead.
Issuers []string `json:"issuers,omitempty"`
OptOutAnnotations []string `json:"optOutAnnotations,omitempty"`
OptOutStateBag []string `json:"optOutStateBag,omitempty"`
OptOutHosts []string `json:"optOutHosts,omitempty"`
Claims []map[string]any `json:"claims,omitempty"`
optOutHostsCompiled []*regexp.Regexp
}
)
func NewJwtMetrics() filters.Spec {
return &jwtMetricsSpec{
newYamlConfigParser[jwtMetricsFilter](64),
}
}
func (s *jwtMetricsSpec) Name() string {
return filters.JwtMetricsName
}
func (s *jwtMetricsSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) == 0 {
return &jwtMetricsFilter{}, nil
}
return s.parseSingleArg(args)
}
func (f *jwtMetricsFilter) initialize() error {
for _, host := range f.OptOutHosts {
if r, err := regexp.Compile(host); err != nil {
return fmt.Errorf("failed to compile opt-out host pattern: %q", host)
} else {
f.optOutHostsCompiled = append(f.optOutHostsCompiled, r)
}
}
return nil
}
func (f *jwtMetricsFilter) Request(ctx filters.FilterContext) {}
func (f *jwtMetricsFilter) Response(ctx filters.FilterContext) {
if len(f.OptOutAnnotations) > 0 {
annotations := annotate.GetAnnotations(ctx)
for _, annotation := range f.OptOutAnnotations {
if _, ok := annotations[annotation]; ok {
return // opt-out
}
}
}
if len(f.OptOutStateBag) > 0 {
sb := ctx.StateBag()
for _, key := range f.OptOutStateBag {
if _, ok := sb[key]; ok {
return // opt-out
}
}
}
if len(f.optOutHostsCompiled) > 0 {
host := ctx.Request().Host
for _, r := range f.optOutHostsCompiled {
if r.MatchString(host) {
return // opt-out
}
}
}
response := ctx.Response()
if response.StatusCode >= 400 && response.StatusCode < 500 {
return // ignore invalid requests
}
request := ctx.Request()
count := func(metric string) {
prefix := fmt.Sprintf("%s.%s.%d.", request.Method, escapeMetricKeySegment(request.Host), response.StatusCode)
ctx.Metrics().IncCounter(prefix + metric)
if span := opentracing.SpanFromContext(ctx.Request().Context()); span != nil {
span.SetTag("jwt", metric)
}
}
ahead := request.Header.Get("Authorization")
if ahead == "" {
count("missing-token")
return
}
tv := strings.TrimPrefix(ahead, "Bearer ")
if tv == ahead {
count("invalid-token-type")
return
}
var token *jwt.Token
if len(f.Issuers) > 0 || len(f.Claims) > 0 {
t, err := jwt.Parse(tv)
if err != nil {
count("invalid-token")
return
}
token = t
}
if len(f.Issuers) > 0 {
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1
if issuer, ok := token.Claims["iss"].(string); !ok {
count("missing-issuer")
} else if !slices.Contains(f.Issuers, issuer) {
count("invalid-issuer")
}
}
if len(f.Claims) > 0 {
found := false
for _, claim := range f.Claims {
if containsAll(token.Claims, claim) {
found = true
break
}
}
if !found {
count("invalid-claims")
}
}
}
var escapeMetricKeySegmentPattern = regexp.MustCompile("[^a-zA-Z0-9_]")
func escapeMetricKeySegment(s string) string {
return escapeMetricKeySegmentPattern.ReplaceAllLiteralString(s, "_")
}
// containsAll returns true if all key-values of b are present in a.
func containsAll(a, b map[string]any) bool {
for kb, vb := range b {
if va, ok := a[kb]; !ok || va != vb {
return false
}
}
return true
}
package auth
import (
"fmt"
"sync"
"time"
"github.com/MicahParks/keyfunc"
jwt "github.com/golang-jwt/jwt/v4"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.JwtValidationName instead
JwtValidationName = filters.JwtValidationName
)
type (
jwtValidationSpec struct {
options TokenintrospectionOptions
}
jwtValidationFilter struct {
jwksUri string
}
)
var refreshInterval = time.Hour
var refreshRateLimit = time.Minute * 5
var refreshTimeout = time.Second * 10
var refreshUnknownKID = true
// the map of jwks keyfunctions stored per jwksUri
var (
jwksMu sync.RWMutex
jwksMap map[string]*keyfunc.JWKS = make(map[string]*keyfunc.JWKS)
)
func NewJwtValidationWithOptions(o TokenintrospectionOptions) filters.Spec {
return &jwtValidationSpec{
options: o,
}
}
func (s *jwtValidationSpec) Name() string {
return filters.JwtValidationName
}
func (s *jwtValidationSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
sargs, err := getStrings(args)
if err != nil {
return nil, err
}
issuerURL := sargs[0]
cfg, err := getOpenIDConfig(issuerURL)
if err != nil {
return nil, err
}
err = registerKeyFunction(cfg.JwksURI)
if err != nil {
return nil, err
}
f := &jwtValidationFilter{
jwksUri: cfg.JwksURI,
}
return f, nil
}
func hasKeyFunction(url string) bool {
jwksMu.RLock()
defer jwksMu.RUnlock()
_, ok := jwksMap[url]
return ok
}
func putKeyFunction(url string, jwks *keyfunc.JWKS) {
jwksMu.Lock()
defer jwksMu.Unlock()
jwksMap[url] = jwks
}
func registerKeyFunction(url string) (err error) {
if hasKeyFunction(url) {
return nil
}
options := keyfunc.Options{
RefreshErrorHandler: func(err error) {
log.Errorf("There was an error on key refresh for the given URL %s\nError:%s\n", url, err.Error())
},
RefreshInterval: refreshInterval,
RefreshRateLimit: refreshRateLimit,
RefreshTimeout: refreshTimeout,
RefreshUnknownKID: refreshUnknownKID,
}
jwks, err := keyfunc.Get(url, options)
if err != nil {
return fmt.Errorf("failed to get the JWKS from the given URL %s Error:%w", url, err)
}
putKeyFunction(url, jwks)
return nil
}
func getKeyFunction(url string) (jwks *keyfunc.JWKS) {
jwksMu.RLock()
defer jwksMu.RUnlock()
return jwksMap[url]
}
func (f *jwtValidationFilter) Request(ctx filters.FilterContext) {
r := ctx.Request()
var info tokenContainer
infoTemp, ok := ctx.StateBag()[oidcClaimsCacheKey]
if !ok {
token, ok := getToken(r)
if !ok || token == "" {
unauthorized(ctx, "", missingToken, "", "")
return
}
claims, err := parseToken(token, f.jwksUri)
if err != nil {
ctx.Logger().Errorf("Error while parsing jwt token : %v.", err)
unauthorized(ctx, "", invalidToken, "", "")
return
}
info.Claims = claims
} else {
info = infoTemp.(tokenContainer)
}
sub, ok := info.Claims["sub"].(string)
if !ok {
unauthorized(ctx, sub, invalidSub, "", "")
return
}
authorized(ctx, sub)
ctx.StateBag()[oidcClaimsCacheKey] = info
}
func (f *jwtValidationFilter) Response(filters.FilterContext) {}
func parseToken(token string, jwksUri string) (map[string]interface{}, error) {
jwks := getKeyFunction(jwksUri)
var claims jwt.MapClaims
parsedToken, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc)
if err != nil {
return nil, fmt.Errorf("error while parsing jwt token : %w", err)
} else if !parsedToken.Valid {
return nil, fmt.Errorf("invalid token")
} else {
return claims, nil
}
}
package auth
import (
"bytes"
"compress/flate"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/zalando/skipper/filters"
snet "github.com/zalando/skipper/net"
"github.com/zalando/skipper/secrets"
"golang.org/x/oauth2"
)
const (
// Deprecated, use filters.OAuthOidcUserInfoName instead
OidcUserInfoName = filters.OAuthOidcUserInfoName
// Deprecated, use filters.OAuthOidcAnyClaimsName instead
OidcAnyClaimsName = filters.OAuthOidcAnyClaimsName
// Deprecated, use filters.OAuthOidcAllClaimsName instead
OidcAllClaimsName = filters.OAuthOidcAllClaimsName
oauthOidcCookieName = "skipperOauthOidc"
stateValidity = 1 * time.Minute
defaultCookieValidity = 1 * time.Hour
oidcInfoHeader = "Skipper-Oidc-Info"
cookieMaxSize = 4093 // common cookie size limit http://browsercookielimits.squawky.net/
// Deprecated: The host of the Azure Active Directory (AAD) graph API
azureADGraphHost = "graph.windows.net"
)
var (
distributedClaimsClients = sync.Map{}
microsoftGraphHost = "graph.microsoft.com" // global for testing
)
type distributedClaims struct {
ClaimNames map[string]string `json:"_claim_names"`
ClaimSources map[string]claimSource `json:"_claim_sources"`
}
type claimSource struct {
Endpoint string `json:"endpoint"`
AccessToken string `json:"access_token,omitempty"`
}
type azureGraphGroups struct {
OdataNextLink string `json:"@odata.nextLink,omitempty"`
Value []struct {
OnPremisesSamAccountName string `json:"onPremisesSamAccountName"`
ID string `json:"id"`
} `json:"value"`
}
// Filter parameter:
//
// oauthOidc...("https://oidc-provider.example.com", "client_id", "client_secret",
// "http://target.example.com/subpath/callback", "email profile", "name email picture",
// "parameter=value", "X-Auth-Authorization:claims.email")
const (
paramIdpURL int = iota
paramClientID
paramClientSecret
paramCallbackURL
paramScopes
paramClaims
paramAuthCodeOpts
paramUpstrHeaders
paramSubdomainsToRemove
paramCookieName
)
type OidcOptions struct {
MaxIdleConns int
CookieRemoveSubdomains *int
CookieValidity time.Duration
Timeout time.Duration
Tracer opentracing.Tracer
OidcClientId string
OidcClientSecret string
}
type (
tokenOidcSpec struct {
typ roleCheckType
SecretsFile string
secretsRegistry secrets.EncrypterCreator
options OidcOptions
}
tokenOidcFilter struct {
typ roleCheckType
config *oauth2.Config
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
claims []string
validity time.Duration
cookiename string
redirectPath string
encrypter secrets.Encryption
authCodeOptions []oauth2.AuthCodeOption
queryParams []string
compressor cookieCompression
upstreamHeaders map[string]string
subdomainsToRemove int
oidcOptions OidcOptions
}
tokenContainer struct {
OAuth2Token *oauth2.Token `json:"oauth2token"`
OIDCIDToken string `json:"oidctoken"`
UserInfo *oidc.UserInfo `json:"userInfo,omitempty"`
Subject string `json:"subject"`
Claims map[string]interface{} `json:"claims"`
}
cookieCompression interface {
compress([]byte) ([]byte, error)
decompress([]byte) ([]byte, error)
}
deflatePoolCompressor struct {
poolWriter *sync.Pool
}
)
// NewOAuthOidcUserInfosWithOptions creates filter spec which tests user info.
func NewOAuthOidcUserInfosWithOptions(secretsFile string, secretsRegistry secrets.EncrypterCreator, o OidcOptions) filters.Spec {
return &tokenOidcSpec{typ: checkOIDCUserInfo, SecretsFile: secretsFile, secretsRegistry: secretsRegistry, options: o}
}
// Deprecated: use NewOAuthOidcUserInfosWithOptions instead.
func NewOAuthOidcUserInfos(secretsFile string, secretsRegistry secrets.EncrypterCreator) filters.Spec {
return NewOAuthOidcUserInfosWithOptions(secretsFile, secretsRegistry, OidcOptions{})
}
// NewOAuthOidcAnyClaimsWithOptions creates a filter spec which verifies that the token
// has one of the claims specified
func NewOAuthOidcAnyClaimsWithOptions(secretsFile string, secretsRegistry secrets.EncrypterCreator, o OidcOptions) filters.Spec {
return &tokenOidcSpec{typ: checkOIDCAnyClaims, SecretsFile: secretsFile, secretsRegistry: secretsRegistry, options: o}
}
// Deprecated: use NewOAuthOidcAnyClaimsWithOptions instead.
func NewOAuthOidcAnyClaims(secretsFile string, secretsRegistry secrets.EncrypterCreator) filters.Spec {
return NewOAuthOidcAnyClaimsWithOptions(secretsFile, secretsRegistry, OidcOptions{})
}
// NewOAuthOidcAllClaimsWithOptions creates a filter spec which verifies that the token
// has all the claims specified
func NewOAuthOidcAllClaimsWithOptions(secretsFile string, secretsRegistry secrets.EncrypterCreator, o OidcOptions) filters.Spec {
return &tokenOidcSpec{typ: checkOIDCAllClaims, SecretsFile: secretsFile, secretsRegistry: secretsRegistry, options: o}
}
// Deprecated: use NewOAuthOidcAllClaimsWithOptions instead.
func NewOAuthOidcAllClaims(secretsFile string, secretsRegistry secrets.EncrypterCreator) filters.Spec {
return NewOAuthOidcAllClaimsWithOptions(secretsFile, secretsRegistry, OidcOptions{})
}
// CreateFilter creates an OpenID Connect authorization filter.
//
// first arg: a provider, for example "https://accounts.google.com",
// which has the path /.well-known/openid-configuration
//
// Example:
//
// oauthOidcAllClaims("https://accounts.identity-provider.com", "some-client-id", "some-client-secret",
// "http://callback.com/auth/provider/callback", "scope1 scope2", "claim1 claim2", "<optional>", "<optional>", "<optional>") -> "https://internal.example.org";
func (s *tokenOidcSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
sargs, err := getStrings(args)
if err != nil {
return nil, err
}
if len(sargs) <= paramClaims {
return nil, filters.ErrInvalidFilterParameters
}
issuerURL, err := url.Parse(sargs[paramIdpURL])
if err != nil {
log.Errorf("Failed to parse url %s: %v.", sargs[paramIdpURL], err)
return nil, filters.ErrInvalidFilterParameters
}
provider, err := oidc.NewProvider(context.Background(), issuerURL.String())
if err != nil {
log.Errorf("Failed to create new provider %s: %v.", issuerURL, err)
return nil, filters.ErrInvalidFilterParameters
}
var cookieName string
if len(sargs) > paramCookieName && sargs[paramCookieName] != "" {
cookieName = sargs[paramCookieName]
} else {
h := sha256.New()
for i, s := range sargs {
// CallbackURL not taken into account for cookie hashing for additional sub path ingresses
if i == paramCallbackURL {
continue
}
// SubdomainsToRemove not taken into account for cookie hashing for additional sub-domain ingresses
if i == paramSubdomainsToRemove {
continue
}
h.Write([]byte(s))
}
byteSlice := h.Sum(nil)
sargsHash := fmt.Sprintf("%x", byteSlice)[:8]
cookieName = oauthOidcCookieName + sargsHash + "-"
}
log.Debugf("Cookie Name: %s", cookieName)
redirectURL, err := url.Parse(sargs[paramCallbackURL])
if err != nil || sargs[paramCallbackURL] == "" {
return nil, fmt.Errorf("invalid redirect url '%s': %w", sargs[paramCallbackURL], err)
}
encrypter, err := s.secretsRegistry.GetEncrypter(1*time.Minute, s.SecretsFile)
if err != nil {
return nil, err
}
subdomainsToRemove := 1
if s.options.CookieRemoveSubdomains != nil {
subdomainsToRemove = *s.options.CookieRemoveSubdomains
}
if len(sargs) > paramSubdomainsToRemove && sargs[paramSubdomainsToRemove] != "" {
subdomainsToRemove, err = strconv.Atoi(sargs[paramSubdomainsToRemove])
if err != nil {
return nil, err
}
}
if subdomainsToRemove < 0 {
return nil, fmt.Errorf("domain level cannot be negative '%d'", subdomainsToRemove)
}
validity := s.options.CookieValidity
if validity == 0 {
validity = defaultCookieValidity
}
oidcClientId := sargs[paramClientID]
if oidcClientId == "" {
oidcClientId = s.options.OidcClientId
}
oidcClientSecret := sargs[paramClientSecret]
if oidcClientSecret == "" {
oidcClientSecret = s.options.OidcClientSecret
}
f := &tokenOidcFilter{
typ: s.typ,
redirectPath: redirectURL.Path,
config: &oauth2.Config{
ClientID: oidcClientId,
ClientSecret: oidcClientSecret,
RedirectURL: sargs[paramCallbackURL], // self endpoint
Endpoint: provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID}, // mandatory scope by spec
},
provider: provider,
verifier: provider.Verifier(&oidc.Config{
ClientID: oidcClientId,
}),
validity: validity,
cookiename: cookieName,
encrypter: encrypter,
compressor: newDeflatePoolCompressor(flate.BestCompression),
subdomainsToRemove: subdomainsToRemove,
oidcOptions: s.options,
}
// user defined scopes
scopes := strings.Split(sargs[paramScopes], " ")
if len(sargs[paramScopes]) == 0 {
scopes = []string{}
}
// scopes are only used to request claims to be in the IDtoken requested from auth server
// https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims
f.config.Scopes = append(f.config.Scopes, scopes...)
// user defined claims to check for authnz
if len(sargs[paramClaims]) > 0 {
f.claims = strings.Split(sargs[paramClaims], " ")
}
f.authCodeOptions = make([]oauth2.AuthCodeOption, 0)
if len(sargs) > paramAuthCodeOpts && sargs[paramAuthCodeOpts] != "" {
extraParameters := strings.Split(sargs[paramAuthCodeOpts], " ")
for _, p := range extraParameters {
splitP := strings.Split(p, "=")
log.Debug(splitP)
if len(splitP) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
if splitP[1] == "skipper-request-query" {
f.queryParams = append(f.queryParams, splitP[0])
} else {
f.authCodeOptions = append(f.authCodeOptions, oauth2.SetAuthURLParam(splitP[0], splitP[1]))
}
}
}
log.Debugf("Auth Code Options: %v", f.authCodeOptions)
// inject additional headers from the access token for upstream applications
if len(sargs) > paramUpstrHeaders && sargs[paramUpstrHeaders] != "" {
f.upstreamHeaders = make(map[string]string)
for _, header := range strings.Split(sargs[paramUpstrHeaders], " ") {
k, v, found := strings.Cut(header, ":")
if !found || k == "" || v == "" {
return nil, fmt.Errorf("%w: malformatted filter for upstream headers %s", filters.ErrInvalidFilterParameters, header)
}
f.upstreamHeaders[k] = v
}
log.Debugf("Upstream Headers: %v", f.upstreamHeaders)
}
return f, nil
}
func (s *tokenOidcSpec) Name() string {
switch s.typ {
case checkOIDCUserInfo:
return filters.OAuthOidcUserInfoName
case checkOIDCAnyClaims:
return filters.OAuthOidcAnyClaimsName
case checkOIDCAllClaims:
return filters.OAuthOidcAllClaimsName
}
return AuthUnknown
}
func (f *tokenOidcFilter) validateAnyClaims(h map[string]interface{}) bool {
if len(f.claims) == 0 {
return true
}
if len(h) == 0 {
return false
}
for _, c := range f.claims {
if _, ok := h[c]; ok {
return true
}
}
return false
}
func (f *tokenOidcFilter) validateAllClaims(h map[string]interface{}) bool {
l := len(f.claims)
if l == 0 {
return true
}
if len(h) < l {
return false
}
for _, c := range f.claims {
if _, ok := h[c]; !ok {
return false
}
}
return true
}
type OauthState struct {
Validity int64 `json:"validity"`
Nonce string `json:"nonce"`
RedirectUrl string `json:"redirectUrl"`
}
func createState(nonce []byte, redirectUrl string) ([]byte, error) {
state := &OauthState{
Validity: time.Now().Add(stateValidity).Unix(),
Nonce: fmt.Sprintf("%x", nonce),
RedirectUrl: redirectUrl,
}
return json.Marshal(state)
}
func extractState(encState []byte) (*OauthState, error) {
var state OauthState
err := json.Unmarshal(encState, &state)
if err != nil {
return nil, err
}
return &state, nil
}
func (f *tokenOidcFilter) internalServerError(ctx filters.FilterContext) {
rsp := &http.Response{
StatusCode: http.StatusInternalServerError,
}
ctx.Serve(rsp)
}
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps
// 1. Client prepares an Authentication Request containing the desired request parameters.
// 2. Client sends the request to the Authorization Server.
func (f *tokenOidcFilter) doOauthRedirect(ctx filters.FilterContext, cookies []*http.Cookie) {
nonce, err := f.encrypter.CreateNonce()
if err != nil {
ctx.Logger().Errorf("Failed to create nonce: %v.", err)
f.internalServerError(ctx)
return
}
redirectUrl := ctx.Request().URL.String()
statePlain, err := createState(nonce, redirectUrl)
if err != nil {
ctx.Logger().Errorf("Failed to create oauth2 state: %v.", err)
f.internalServerError(ctx)
return
}
stateEnc, err := f.encrypter.Encrypt(statePlain)
if err != nil {
ctx.Logger().Errorf("Failed to encrypt data block: %v.", err)
f.internalServerError(ctx)
return
}
opts := f.authCodeOptions
if f.queryParams != nil {
opts = make([]oauth2.AuthCodeOption, len(f.authCodeOptions), len(f.authCodeOptions)+len(f.queryParams))
copy(opts, f.authCodeOptions)
for _, p := range f.queryParams {
if v := ctx.Request().URL.Query().Get(p); v != "" {
opts = append(opts, oauth2.SetAuthURLParam(p, v))
}
}
}
oauth2URL := f.config.AuthCodeURL(fmt.Sprintf("%x", stateEnc), opts...)
rsp := &http.Response{
Header: http.Header{
"Location": []string{oauth2URL},
},
StatusCode: http.StatusTemporaryRedirect,
Status: "Moved Temporarily",
}
for _, cookie := range cookies {
rsp.Header.Add("Set-Cookie", cookie.String())
}
ctx.Logger().Debugf("serve redirect: plaintextState:%s to Location: %s", statePlain, rsp.Header.Get("Location"))
ctx.Serve(rsp)
}
func (f *tokenOidcFilter) Response(filters.FilterContext) {}
func extractDomainFromHost(host string, subdomainsToRemove int) string {
h, _, err := net.SplitHostPort(host)
if err != nil {
h = host
}
ip := net.ParseIP(h)
if ip != nil {
return ip.String()
}
if subdomainsToRemove == 0 {
return h
}
subDomains := strings.Split(h, ".")
if len(subDomains)-subdomainsToRemove < 2 {
return h
}
return strings.Join(subDomains[subdomainsToRemove:], ".")
}
func getHost(request *http.Request) string {
if h := request.Header.Get("host"); h != "" {
return h
} else {
return request.Host
}
}
func (f *tokenOidcFilter) createOidcCookie(ctx filters.FilterContext, name string, value string, maxAge int) (cookie *http.Cookie) {
return &http.Cookie{
Name: name,
Value: value,
Path: "/",
Secure: true,
HttpOnly: true,
MaxAge: maxAge,
Domain: extractDomainFromHost(getHost(ctx.Request()), f.subdomainsToRemove),
}
}
func (f *tokenOidcFilter) deleteOidcCookie(ctx filters.FilterContext, name string) (cookie *http.Cookie) {
return f.createOidcCookie(ctx, name, "", -1)
}
func chunkCookie(cookie *http.Cookie) (cookies []*http.Cookie) {
// We need to dereference the cookie to avoid modifying the original cookie.
cookieCopy := *cookie
for index := 'a'; index <= 'z'; index++ {
cookieSize := len(cookieCopy.String())
if cookieSize < cookieMaxSize {
cookieCopy.Name += string(index)
return append(cookies, &cookieCopy)
}
newCookie := cookieCopy
newCookie.Name += string(index)
// non-deterministic approach support signature changes
cut := len(cookieCopy.Value) - (cookieSize - cookieMaxSize) - 1
newCookie.Value, cookieCopy.Value = cookieCopy.Value[:cut], cookieCopy.Value[cut:]
cookies = append(cookies, &newCookie)
}
log.Error("unsupported amount of chunked cookies")
return
}
func mergerCookies(cookies []*http.Cookie) *http.Cookie {
if len(cookies) == 0 {
return nil
}
cookie := *(cookies[0])
cookie.Name = cookie.Name[:len(cookie.Name)-1]
cookie.Value = ""
// potentially shuffeled
sort.Slice(cookies, func(i, j int) bool {
return cookies[i].Name < cookies[j].Name
})
for _, ck := range cookies {
cookie.Value += ck.Value
}
return &cookie
}
func (f *tokenOidcFilter) doDownstreamRedirect(ctx filters.FilterContext, oidcState []byte, maxAge time.Duration, redirectUrl string) {
ctx.Logger().Debugf("Doing Downstream Redirect to :%s", redirectUrl)
r := &http.Response{
StatusCode: http.StatusTemporaryRedirect,
Header: http.Header{
"Location": {redirectUrl},
},
}
oidcCookies := chunkCookie(
f.createOidcCookie(
ctx,
f.cookiename,
base64.StdEncoding.EncodeToString(oidcState),
int(maxAge.Seconds()),
),
)
for _, cookie := range oidcCookies {
r.Header.Add("Set-Cookie", cookie.String())
}
ctx.Serve(r)
}
func (f *tokenOidcFilter) validateCookie(cookie *http.Cookie) ([]byte, bool) {
if cookie == nil {
log.Debugf("Cookie is nil")
return nil, false
}
log.Debugf("validate cookie name: %s", f.cookiename)
decodedCookie, err := base64.StdEncoding.DecodeString(cookie.Value)
if err != nil {
log.Debugf("Base64 decoding the cookie failed: %v", err)
return nil, false
}
decryptedCookie, err := f.encrypter.Decrypt(decodedCookie)
if err != nil {
log.Debugf("Decrypting the cookie failed: %v", err)
return nil, false
}
decompressedCookie, err := f.compressor.decompress(decryptedCookie)
if err != nil {
log.Error(err)
return nil, false
}
return decompressedCookie, true
}
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps
// 5. Authorization Server sends the End-User back to the Client with an Authorization Code.
func (f *tokenOidcFilter) callbackEndpoint(ctx filters.FilterContext) {
var (
claimsMap map[string]interface{}
oauth2Token *oauth2.Token
data []byte
resp tokenContainer
sub string
userInfo *oidc.UserInfo
oidcIDToken string
)
r := ctx.Request()
oauthState, err := f.getCallbackState(ctx)
if err != nil {
if _, ok := err.(*requestError); !ok {
ctx.Logger().Errorf("Error while retrieving callback state: %v.", err)
}
unauthorized(
ctx,
"",
invalidToken,
r.Host,
fmt.Sprintf("Failed to get state from callback: %v.", err),
)
return
}
oauth2Token, err = f.getTokenWithExchange(oauthState, ctx)
if err != nil {
if _, ok := err.(*requestError); !ok {
ctx.Logger().Errorf("Error while getting token in callback: %v.", err)
}
unauthorized(
ctx,
"",
invalidClaim,
r.Host,
fmt.Sprintf("Failed to get token in callback: %v.", err),
)
return
}
switch f.typ {
case checkOIDCUserInfo:
userInfo, err = f.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(oauth2Token))
if err != nil {
// error coming from an external library and the possible error reasons are
// not documented explicitly, so we assume that the cause is always rooted
// in the incoming request, and only log it with a debug flag, via calling
// unauthorized().
unauthorized(
ctx,
"",
invalidToken,
r.Host,
fmt.Sprintf("Failed to get userinfo: %v.", err),
)
return
}
oidcIDToken, err = f.getidtoken(oauth2Token)
if err != nil {
if _, ok := err.(*requestError); !ok {
ctx.Logger().Errorf("Error while getting id token: %v", err)
}
unauthorized(
ctx,
"",
invalidClaim,
r.Host,
fmt.Sprintf("Failed to get id token: %v", err),
)
return
}
sub = userInfo.Subject
claimsMap, _, err = f.tokenClaims(ctx, oauth2Token)
if err != nil {
unauthorized(
ctx,
"",
invalidToken,
r.Host,
fmt.Sprintf("Failed to get claims: %v.", err),
)
return
}
case checkOIDCAnyClaims, checkOIDCAllClaims:
oidcIDToken, err = f.getidtoken(oauth2Token)
if err != nil {
if _, ok := err.(*requestError); !ok {
ctx.Logger().Errorf("Error while getting id token: %v", err)
}
unauthorized(
ctx,
"",
invalidClaim,
r.Host,
fmt.Sprintf("Failed to get id token: %v", err),
)
return
}
claimsMap, sub, err = f.tokenClaims(ctx, oauth2Token)
if err != nil {
if _, ok := err.(*requestError); !ok {
ctx.Logger().Errorf("Failed to get claims with error: %v", err)
}
unauthorized(
ctx,
"",
invalidToken,
r.Host,
fmt.Sprintf(
"Failed to get claims: %s, %v",
f.claims,
err,
),
)
return
}
}
resp = tokenContainer{
OAuth2Token: oauth2Token,
OIDCIDToken: oidcIDToken,
UserInfo: userInfo,
Subject: sub,
Claims: claimsMap,
}
data, err = json.Marshal(resp)
if err != nil {
log.Errorf("Failed to serialize claims: %v.", err)
unauthorized(
ctx,
"",
invalidSub,
r.Host,
"Failed to serialize claims.",
)
return
}
compressedData, err := f.compressor.compress(data)
if err != nil {
log.Error(err)
}
encryptedData, err := f.encrypter.Encrypt(compressedData)
if err != nil {
log.Errorf("Failed to encrypt the returned oidc data: %v.", err)
unauthorized(
ctx,
"",
invalidSub,
r.Host,
"Failed to encrypt the returned oidc data.",
)
return
}
f.doDownstreamRedirect(ctx, encryptedData, f.getMaxAge(claimsMap), oauthState.RedirectUrl)
}
func (f *tokenOidcFilter) getMaxAge(claimsMap map[string]interface{}) time.Duration {
maxAge := f.validity
if exp, ok := claimsMap["exp"].(float64); ok {
val := time.Until(time.Unix(int64(exp), 0))
if val > time.Minute {
maxAge = val
log.Debugf("Setting maxAge of OIDC cookie to %s", maxAge)
}
}
return maxAge
}
func (f *tokenOidcFilter) Request(ctx filters.FilterContext) {
var (
allowed bool
cookies []*http.Cookie
container tokenContainer
)
r := ctx.Request()
// Retrieve skipperOauthOidc cookie for processing and remove it from downstream request
rCookies := r.Cookies()
r.Header.Del("Cookie")
for _, cookie := range rCookies {
if strings.HasPrefix(cookie.Name, f.cookiename) {
cookies = append(cookies, cookie)
} else {
r.AddCookie(cookie)
}
}
sessionCookie := mergerCookies(cookies)
log.Debugf("Request: Cookie merged, %d chunks, len: %d", len(cookies), len(sessionCookie.String()))
cookie, ok := f.validateCookie(sessionCookie)
log.Debugf("Request: Cookie Validation: %v", ok)
if !ok {
// 5. Authorization Server sends the End-User back to the Client with an Authorization Code.
if strings.Contains(r.URL.Path, f.redirectPath) {
f.callbackEndpoint(ctx)
return
}
// 1. Client prepares an Authentication Request containing the desired request parameters.
// clear existing, invalid cookies
var purgeCookies = make([]*http.Cookie, len(cookies))
for i, c := range cookies {
purgeCookies[i] = f.deleteOidcCookie(ctx, c.Name)
}
f.doOauthRedirect(ctx, purgeCookies)
return
}
err := json.Unmarshal([]byte(cookie), &container)
if err != nil {
unauthorized(
ctx,
"",
invalidToken,
r.Host,
fmt.Sprintf("Failed to deserialize cookie: %v.", err),
)
return
}
// filter specific checks
switch f.typ {
case checkOIDCUserInfo:
if container.OAuth2Token.Valid() && container.UserInfo != nil {
allowed = f.validateAllClaims(container.Claims)
}
case checkOIDCAnyClaims:
allowed = f.validateAnyClaims(container.Claims)
case checkOIDCAllClaims:
allowed = f.validateAllClaims(container.Claims)
default:
unauthorized(ctx, "unknown", invalidFilter, r.Host, "")
return
}
if !allowed {
unauthorized(ctx, container.Subject, invalidClaim, r.Host, "")
return
}
// saving token info for chained filter
ctx.StateBag()[oidcClaimsCacheKey] = container
// adding upstream headers
err = setHeaders(f.upstreamHeaders, ctx, container)
if err != nil {
ctx.Logger().Errorf("%v", err)
f.internalServerError(ctx)
return
}
}
func setHeaders(upstreamHeaders map[string]string, ctx filters.FilterContext, container interface{}) (err error) {
oidcInfoJson, err := json.Marshal(container)
if err != nil || !gjson.ValidBytes(oidcInfoJson) {
return fmt.Errorf("failed to serialize OIDC token info: %w", err)
}
// backwards compatible
if len(upstreamHeaders) == 0 {
ctx.Request().Header.Set(oidcInfoHeader, string(oidcInfoJson))
return
}
parsed := gjson.ParseBytes(oidcInfoJson)
for key, query := range upstreamHeaders {
match := parsed.Get(query)
log.Debugf("header: %s results: %s", query, match.String())
if !match.Exists() {
log.Errorf("Lookup failed for upstream header '%s'", query)
continue
}
ctx.Request().Header.Set(key, match.String())
}
return
}
func (f *tokenOidcFilter) tokenClaims(ctx filters.FilterContext, oauth2Token *oauth2.Token) (map[string]interface{}, string, error) {
r := ctx.Request()
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, "", requestErrorf("invalid token, no id_token field in oauth2 token")
}
var idToken *oidc.IDToken
idToken, err := f.verifier.Verify(r.Context(), rawIDToken)
if err != nil {
return nil, "", requestErrorf("failed to verify id token: %v", err)
}
tokenMap := make(map[string]interface{})
if err = idToken.Claims(&tokenMap); err != nil {
return nil, "", requestErrorf("failed to deserialize id token: %v", err)
}
sub, ok := tokenMap["sub"].(string)
if !ok {
return nil, "", requestErrorf("claims do not contain sub")
}
if err = f.handleDistributedClaims(idToken, oauth2Token, tokenMap); err != nil {
return nil, "", requestErrorf("failed to handle distributed claims: %v", err)
}
return tokenMap, sub, nil
}
func (f *tokenOidcFilter) getidtoken(oauth2Token *oauth2.Token) (string, error) {
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return "", requestErrorf("invalid token, no id_token field in oauth2 token")
}
return rawIDToken, nil
}
func (f *tokenOidcFilter) getCallbackState(ctx filters.FilterContext) (*OauthState, error) {
// CSRF protection using similar to
// https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)_Prevention_Cheat_Sheet#Encrypted_Token_Pattern,
// because of https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
r := ctx.Request()
stateQueryEncHex := r.URL.Query().Get("state")
if stateQueryEncHex == "" {
return nil, requestErrorf("no state parameter")
}
stateQueryEnc := make([]byte, len(stateQueryEncHex))
if _, err := fmt.Sscanf(stateQueryEncHex, "%x", &stateQueryEnc); err != nil && err != io.EOF {
return nil, requestErrorf("failed to read hex string: %v", err)
}
stateQueryPlain, err := f.encrypter.Decrypt(stateQueryEnc)
if err != nil {
// TODO: Implement metrics counter for number of incorrect tokens
return nil, requestErrorf("token from state query is invalid: %v", err)
}
ctx.Logger().Debugf("len(stateQueryPlain): %d, stateQueryEnc: %d, stateQueryEncHex: %d", len(stateQueryPlain), len(stateQueryEnc), len(stateQueryEncHex))
state, err := extractState(stateQueryPlain)
if err != nil {
return nil, requestErrorf("failed to deserialize state: %v", err)
}
return state, nil
}
func (f *tokenOidcFilter) getTokenWithExchange(state *OauthState, ctx filters.FilterContext) (*oauth2.Token, error) {
r := ctx.Request()
if state.Validity < time.Now().Unix() {
return nil, requestErrorf("state is no longer valid. %v", state.Validity)
}
// authcode flow
code := r.URL.Query().Get("code")
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowSteps
// 6. Client requests a response using the Authorization Code at the Token Endpoint.
// 7. Client receives a response that contains an ID Token and Access Token in the response body.
oauth2Token, err := f.config.Exchange(r.Context(), code, f.authCodeOptions...)
if err != nil {
// error coming from an external library and the possible error reasons are
// not documented explicitly, so we assume that the cause is always rooted
// in the incoming request.
err = requestErrorf("oauth2 exchange: %v", err)
}
return oauth2Token, err
}
// handleDistributedClaims handles if user has a distributed / overage token.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
// In Azure, if you are indirectly member of more than 200 groups, they will
// send _claim_names and _claim_sources instead of the groups, per OIDC Core 1.0, section 5.6.2:
// https://openid.net/specs/openid-connect-core-1_0.html#AggregatedDistributedClaims
// Example:
//
// {
// "_claim_names": {
// "groups": "src1"
// },
// "_claim_sources": {
// "src1": {
// "endpoint": "https://graph.windows.net/.../getMemberObjects"
// }
// }
// }
func (f *tokenOidcFilter) handleDistributedClaims(idToken *oidc.IDToken, oauth2Token *oauth2.Token, claimsMap map[string]interface{}) error {
// https://github.com/coreos/go-oidc/issues/171#issuecomment-1044286153
var distClaims distributedClaims
err := idToken.Claims(&distClaims)
if err != nil {
return err
}
if len(distClaims.ClaimNames) == 0 || len(distClaims.ClaimSources) == 0 {
log.Debugf("No distributed claims found")
return nil
}
for claim, ref := range distClaims.ClaimNames {
source, ok := distClaims.ClaimSources[ref]
if !ok {
return fmt.Errorf("invalid distributed claims: missing claim source for %s", claim)
}
uri, err := url.Parse(source.Endpoint)
if err != nil {
return fmt.Errorf("failed to parse distributed claim endpoint: %w", err)
}
var results []interface{}
switch uri.Host {
case azureADGraphHost, microsoftGraphHost:
results, err = f.handleDistributedClaimsAzure(uri, oauth2Token, claimsMap)
if err != nil {
return fmt.Errorf("failed to get distributed Azure claim: %w", err)
}
default:
return fmt.Errorf("unsupported distributed claims endpoint '%s', please create an issue at https://github.com/zalando/skipper/issues/new/choose", uri.Host)
}
claimsMap[claim] = results
}
return nil
}
// Azure customizations https://docs.microsoft.com/en-us/graph/migrate-azure-ad-graph-overview
// If the endpoints provided in _claim_source is pointed to the deprecated "graph.windows.net" api
// replace with handcrafted url to graph.microsoft.com
func (f *tokenOidcFilter) handleDistributedClaimsAzure(url *url.URL, oauth2Token *oauth2.Token, claimsMap map[string]interface{}) (values []interface{}, err error) {
url.Host = microsoftGraphHost
// transitiveMemberOf for group names
userID, ok := claimsMap["oid"].(string)
if !ok {
return nil, fmt.Errorf("oid claim not found in claims map")
}
url.Path = fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID)
q := url.Query()
q.Set("$select", "onPremisesSamAccountName,id")
url.RawQuery = q.Encode()
return f.resolveDistributedClaimAzure(url, oauth2Token)
}
func (f *tokenOidcFilter) initClient() *snet.Client {
newCli := snet.NewClient(snet.Options{
ResponseHeaderTimeout: f.oidcOptions.Timeout,
TLSHandshakeTimeout: f.oidcOptions.Timeout,
MaxIdleConnsPerHost: f.oidcOptions.MaxIdleConns,
Tracer: f.oidcOptions.Tracer,
OpentracingComponentTag: "skipper",
OpentracingSpanName: "distributedClaims",
})
return newCli
}
func (f *tokenOidcFilter) resolveDistributedClaimAzure(url *url.URL, oauth2Token *oauth2.Token) (values []interface{}, err error) {
var target azureGraphGroups
req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
return nil, fmt.Errorf("error constructing groups endpoint request: %w", err)
}
oauth2Token.SetAuthHeader(req)
cli, ok := distributedClaimsClients.Load(url.Host)
if !ok {
var loaded bool
newCli := f.initClient()
cli, loaded = distributedClaimsClients.LoadOrStore(url.Host, newCli)
if loaded {
newCli.Close()
}
}
client, ok := cli.(*snet.Client)
if !ok {
return nil, errors.New("invalid distributed claims client type")
}
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("unable to call API: %w", err)
}
body, err := io.ReadAll(res.Body)
res.Body.Close() // closing for connection reuse
if err != nil {
return nil, fmt.Errorf("failed to read API response: %w", err)
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned error: %s", string(body))
}
err = json.Unmarshal(body, &target)
if err != nil {
return nil, fmt.Errorf("unabled to decode response: %w", err)
}
for _, v := range target.Value {
if v.OnPremisesSamAccountName != "" {
values = append(values, v.OnPremisesSamAccountName)
}
}
// recursive pagination
if target.OdataNextLink != "" {
nextURL, err := url.Parse(target.OdataNextLink)
if err != nil {
return nil, fmt.Errorf("failed to parse next link: %w", err)
}
vs, err := f.resolveDistributedClaimAzure(nextURL, oauth2Token)
if err != nil {
return nil, err
}
values = append(values, vs...)
}
log.Debugf("Distributed claim is :%v", values)
return
}
func newDeflatePoolCompressor(level int) *deflatePoolCompressor {
return &deflatePoolCompressor{
poolWriter: &sync.Pool{
New: func() interface{} {
w, err := flate.NewWriter(io.Discard, level)
if err != nil {
log.Errorf("failed to generate new deflate writer: %v", err)
}
return w
},
},
}
}
func (dc *deflatePoolCompressor) compress(rawData []byte) ([]byte, error) {
pw, ok := dc.poolWriter.Get().(*flate.Writer)
if !ok || pw == nil {
return nil, fmt.Errorf("could not get a flate.Writer from the pool")
}
defer dc.poolWriter.Put(pw)
var buf bytes.Buffer
pw.Reset(&buf)
if _, err := pw.Write(rawData); err != nil {
return nil, err
}
if err := pw.Close(); err != nil {
return nil, err
}
log.Debugf("cookie compressed: %d to %d", len(rawData), buf.Len())
return buf.Bytes(), nil
}
func (dc *deflatePoolCompressor) decompress(compData []byte) ([]byte, error) {
zr := flate.NewReader(bytes.NewReader(compData))
if err := zr.Close(); err != nil {
return nil, err
}
return io.ReadAll(zr)
}
package auth
import (
"encoding/json"
"fmt"
"strings"
"sync"
"github.com/tidwall/gjson"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.OidcClaimsQueryName instead
OidcClaimsQueryName = filters.OidcClaimsQueryName
oidcClaimsCacheKey = "oidcclaimscachekey"
)
var gjsonMu sync.RWMutex
type (
oidcIntrospectionSpec struct {
typ roleCheckType
}
oidcIntrospectionFilter struct {
typ roleCheckType
paths []pathQuery
}
pathQuery struct {
path string
queries []string
}
)
func NewOIDCQueryClaimsFilter() filters.Spec {
return &oidcIntrospectionSpec{
typ: checkOIDCQueryClaims,
}
}
// Sets OIDC claims in the state bag.
// Intended for use with the oidcClaimsQuery filter.
func SetOIDCClaims(ctx filters.FilterContext, claims map[string]interface{}) {
ctx.StateBag()[oidcClaimsCacheKey] = tokenContainer{
Claims: claims,
}
}
func (spec *oidcIntrospectionSpec) Name() string {
switch spec.typ {
case checkOIDCQueryClaims:
return filters.OidcClaimsQueryName
}
return AuthUnknown
}
func (spec *oidcIntrospectionSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
sargs, err := getStrings(args)
if err != nil {
return nil, err
}
if len(sargs) == 0 {
return nil, filters.ErrInvalidFilterParameters
}
filter := &oidcIntrospectionFilter{typ: spec.typ}
switch filter.typ {
case checkOIDCQueryClaims:
for _, arg := range sargs {
path, queries, found := strings.Cut(arg, ":")
if !found || path == "" {
return nil, fmt.Errorf("%v: malformatted filter arg %s", filters.ErrInvalidFilterParameters, arg)
}
pq := pathQuery{path: path}
for _, query := range splitQueries(queries) {
if query == "" {
return nil, fmt.Errorf("%w: %s", errUnsupportedClaimSpecified, arg)
}
pq.queries = append(pq.queries, trimQuotes(query))
}
if len(pq.queries) == 0 {
return nil, fmt.Errorf("%w: %s", errUnsupportedClaimSpecified, arg)
}
filter.paths = append(filter.paths, pq)
}
if len(filter.paths) == 0 {
return nil, fmt.Errorf("%w: no queries could be parsed", errUnsupportedClaimSpecified)
}
default:
return nil, filters.ErrInvalidFilterParameters
}
gjsonMu.RLock()
// method is not thread safe
modExists := gjson.ModifierExists("_", gjsonThisModifier)
gjsonMu.RUnlock()
if !modExists {
gjsonMu.Lock()
// method is not thread safe
gjson.AddModifier("_", gjsonThisModifier)
gjsonMu.Unlock()
}
return filter, nil
}
func (filter *oidcIntrospectionFilter) String() string {
var str []string
for _, query := range filter.paths {
str = append(str, query.String())
}
return fmt.Sprintf("%s(%s)", filters.OidcClaimsQueryName, strings.Join(str, "; "))
}
func (filter *oidcIntrospectionFilter) Request(ctx filters.FilterContext) {
r := ctx.Request()
token, ok := ctx.StateBag()[oidcClaimsCacheKey].(tokenContainer)
if !ok || &token == (&tokenContainer{}) || len(token.Claims) == 0 {
ctx.Logger().Errorf("Error retrieving %s for OIDC token introspection", oidcClaimsCacheKey)
unauthorized(ctx, "", missingToken, r.Host, oidcClaimsCacheKey+" is unavailable in StateBag")
return
}
switch filter.typ {
case checkOIDCQueryClaims:
if !filter.validateClaimsQuery(r.URL.Path, token.Claims) {
unauthorized(ctx, "", invalidAccess, r.Host, "Path not permitted")
return
}
default:
unauthorized(ctx, fmt.Sprint(filter.typ), invalidClaim, r.Host, "Wrong oidcIntrospectionFilter type")
return
}
sub, ok := token.Claims["sub"].(string)
if !ok {
unauthorized(ctx, fmt.Sprint(filter.typ), invalidSub, r.Host, "Invalid Subject")
return
}
authorized(ctx, sub)
}
func (filter *oidcIntrospectionFilter) Response(filters.FilterContext) {}
func gjsonThisModifier(json, arg string) string {
return gjson.Get(json, "[@this].#("+arg+")").Raw
}
func (filter *oidcIntrospectionFilter) validateClaimsQuery(reqPath string, gotToken map[string]interface{}) bool {
l := len(filter.paths)
if l == 0 {
return false
}
json, err := json.Marshal(gotToken)
if err != nil || !gjson.ValidBytes(json) {
log.Errorf("Failed to serialize in validateClaimsQuery: %v", err)
return false
}
parsed := gjson.ParseBytes(json)
for _, path := range filter.paths {
if !strings.HasPrefix(reqPath, path.path) {
continue
}
for _, query := range path.queries {
match := parsed.Get(query)
log.Debugf("claim: %s results:%s", query, match.String())
if match.Exists() {
return true
}
}
return false
}
return false
}
func (p pathQuery) String() string {
return fmt.Sprintf("path: '%s*', matching: %s", p.path, strings.Join(p.queries, " ,"))
}
// Splits space-delimited GJSON queries ignoring spaces within quoted strings
func splitQueries(s string) (q []string) {
for _, p := range strings.Split(s, " ") {
if len(q) == 0 || strings.Count(q[len(q)-1], `"`)%2 == 0 {
q = append(q, p)
} else {
q[len(q)-1] = q[len(q)-1] + " " + p
}
}
return
}
func trimQuotes(s string) string {
if len(s) >= 2 {
if c := s[len(s)-1]; s[0] == c && (c == '"' || c == '\'') {
return s[1 : len(s)-1]
}
}
return s
}
package auth
import (
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/secrets"
)
type (
secretHeaderSpec struct {
secretsReader secrets.SecretsReader
}
secretHeaderFilter struct {
headerName string
secretName string
prefix string
suffix string
secretsReader secrets.SecretsReader
}
)
func NewSetRequestHeaderFromSecret(sr secrets.SecretsReader) filters.Spec {
return &secretHeaderSpec{secretsReader: sr}
}
func (*secretHeaderSpec) Name() string {
return filters.SetRequestHeaderFromSecretName
}
func (s *secretHeaderSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 2 || len(args) > 4 {
return nil, filters.ErrInvalidFilterParameters
}
var ok bool
f := &secretHeaderFilter{
secretsReader: s.secretsReader,
}
f.headerName, ok = args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
f.secretName, ok = args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
if len(args) > 2 {
f.prefix, ok = args[2].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
}
if len(args) > 3 {
f.suffix, ok = args[3].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
}
return f, nil
}
func (f *secretHeaderFilter) Request(ctx filters.FilterContext) {
value, ok := f.secretsReader.GetSecret(f.secretName)
if ok {
ctx.Request().Header.Set(f.headerName, f.prefix+string(value)+f.suffix)
}
}
func (*secretHeaderFilter) Response(filters.FilterContext) {}
package auth
import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/annotate"
"github.com/zalando/skipper/metrics"
)
const (
// Deprecated, use filters.OAuthTokeninfoAnyScopeName instead
OAuthTokeninfoAnyScopeName = filters.OAuthTokeninfoAnyScopeName
// Deprecated, use filters.OAuthTokeninfoAllScopeName instead
OAuthTokeninfoAllScopeName = filters.OAuthTokeninfoAllScopeName
// Deprecated, use filters.OAuthTokeninfoAnyKVName instead
OAuthTokeninfoAnyKVName = filters.OAuthTokeninfoAnyKVName
// Deprecated, use filters.OAuthTokeninfoAllKVName instead
OAuthTokeninfoAllKVName = filters.OAuthTokeninfoAllKVName
tokeninfoCacheKey = "tokeninfo"
)
type TokeninfoOptions struct {
URL string
Timeout time.Duration
MaxIdleConns int
Tracer opentracing.Tracer
Metrics metrics.Metrics
// CacheSize configures the maximum number of cached tokens.
// The cache periodically evicts random items when number of cached tokens exceeds CacheSize.
// Zero value disables tokeninfo cache.
CacheSize int
// CacheTTL limits the lifetime of a cached tokeninfo.
// Tokeninfo is cached for the duration of "expires_in" field value seconds or
// for the duration of CacheTTL if it is not zero and less than "expires_in" value.
CacheTTL time.Duration
}
type (
tokeninfoSpec struct {
typ roleCheckType
options TokeninfoOptions
tokeninfoValidateYamlConfigParser *yamlConfigParser[tokeninfoValidateFilterConfig]
}
tokeninfoFilter struct {
typ roleCheckType
client tokeninfoClient
scopes []string
kv kv
}
tokeninfoValidateFilter struct {
client tokeninfoClient
config *tokeninfoValidateFilterConfig
}
// tokeninfoValidateFilterConfig implements [yamlConfig],
// make sure it is not modified after initialization.
tokeninfoValidateFilterConfig struct {
OptOutAnnotations []string `json:"optOutAnnotations,omitempty"`
UnauthorizedResponse string `json:"unauthorizedResponse,omitempty"`
OptOutHosts []string `json:"optOutHosts,omitempty"`
optOutHostsCompiled []*regexp.Regexp
}
)
var tokeninfoAuthClient map[string]tokeninfoClient = make(map[string]tokeninfoClient)
// getTokeninfoClient creates new or returns a cached instance of tokeninfoClient
func (o *TokeninfoOptions) getTokeninfoClient() (tokeninfoClient, error) {
if c, ok := tokeninfoAuthClient[o.URL]; ok {
return c, nil
}
c, err := o.newTokeninfoClient()
if err == nil {
tokeninfoAuthClient[o.URL] = c
}
return c, err
}
// newTokeninfoClient creates new instance of tokeninfoClient
func (o *TokeninfoOptions) newTokeninfoClient() (tokeninfoClient, error) {
var c tokeninfoClient
c, err := newAuthClient(o.URL, tokenInfoSpanName, o.Timeout, o.MaxIdleConns, o.Tracer)
if err != nil {
return nil, err
}
if o.CacheSize > 0 {
c = newTokeninfoCache(c, o.Metrics, o.CacheSize, o.CacheTTL)
}
return c, nil
}
func NewOAuthTokeninfoAllScopeWithOptions(to TokeninfoOptions) filters.Spec {
return &tokeninfoSpec{
typ: checkOAuthTokeninfoAllScopes,
options: to,
}
}
// NewOAuthTokeninfoAllScope creates a new auth filter specification
// to validate authorization for requests. Current implementation uses
// Bearer tokens to authorize requests and checks that the token
// contains all scopes.
func NewOAuthTokeninfoAllScope(oauthTokeninfoURL string, oauthTokeninfoTimeout time.Duration) filters.Spec {
return NewOAuthTokeninfoAllScopeWithOptions(TokeninfoOptions{
URL: oauthTokeninfoURL,
Timeout: oauthTokeninfoTimeout,
})
}
func NewOAuthTokeninfoAnyScopeWithOptions(to TokeninfoOptions) filters.Spec {
return &tokeninfoSpec{
typ: checkOAuthTokeninfoAnyScopes,
options: to,
}
}
// NewOAuthTokeninfoAnyScope creates a new auth filter specification
// to validate authorization for requests. Current implementation uses
// Bearer tokens to authorize requests and checks that the token
// contains at least one scope.
func NewOAuthTokeninfoAnyScope(OAuthTokeninfoURL string, OAuthTokeninfoTimeout time.Duration) filters.Spec {
return &tokeninfoSpec{
typ: checkOAuthTokeninfoAnyScopes,
options: TokeninfoOptions{
URL: OAuthTokeninfoURL,
Timeout: OAuthTokeninfoTimeout,
},
}
}
func NewOAuthTokeninfoAllKVWithOptions(to TokeninfoOptions) filters.Spec {
return &tokeninfoSpec{
typ: checkOAuthTokeninfoAllKV,
options: to,
}
}
// NewOAuthTokeninfoAllKV creates a new auth filter specification
// to validate authorization for requests. Current implementation uses
// Bearer tokens to authorize requests and checks that the token
// contains all key value pairs provided.
func NewOAuthTokeninfoAllKV(OAuthTokeninfoURL string, OAuthTokeninfoTimeout time.Duration) filters.Spec {
return &tokeninfoSpec{
typ: checkOAuthTokeninfoAllKV,
options: TokeninfoOptions{
URL: OAuthTokeninfoURL,
Timeout: OAuthTokeninfoTimeout,
},
}
}
func NewOAuthTokeninfoAnyKVWithOptions(to TokeninfoOptions) filters.Spec {
return &tokeninfoSpec{
typ: checkOAuthTokeninfoAnyKV,
options: to,
}
}
func NewOAuthTokeninfoValidate(to TokeninfoOptions) filters.Spec {
p := newYamlConfigParser[tokeninfoValidateFilterConfig](64)
return &tokeninfoSpec{
typ: checkOAuthTokeninfoValidate,
options: to,
tokeninfoValidateYamlConfigParser: &p,
}
}
// NewOAuthTokeninfoAnyKV creates a new auth filter specification
// to validate authorization for requests. Current implementation uses
// Bearer tokens to authorize requests and checks that the token
// contains at least one key value pair provided.
func NewOAuthTokeninfoAnyKV(OAuthTokeninfoURL string, OAuthTokeninfoTimeout time.Duration) filters.Spec {
return &tokeninfoSpec{
typ: checkOAuthTokeninfoAnyKV,
options: TokeninfoOptions{
URL: OAuthTokeninfoURL,
Timeout: OAuthTokeninfoTimeout,
},
}
}
// TokeninfoWithOptions creates a new auth filter specification
// for token validation with additional settings to the mandatory
// tokeninfo URL and timeout.
//
// Use one of the base initializer functions as the first argument:
// NewOAuthTokeninfoAllScope, NewOAuthTokeninfoAnyScope,
// NewOAuthTokeninfoAllKV or NewOAuthTokeninfoAnyKV.
func TokeninfoWithOptions(create func(string, time.Duration) filters.Spec, o TokeninfoOptions) filters.Spec {
s := create(o.URL, o.Timeout)
ts, ok := s.(*tokeninfoSpec)
if !ok {
return s
}
ts.options = o
return ts
}
func (s *tokeninfoSpec) Name() string {
switch s.typ {
case checkOAuthTokeninfoAnyScopes:
return filters.OAuthTokeninfoAnyScopeName
case checkOAuthTokeninfoAllScopes:
return filters.OAuthTokeninfoAllScopeName
case checkOAuthTokeninfoAnyKV:
return filters.OAuthTokeninfoAnyKVName
case checkOAuthTokeninfoAllKV:
return filters.OAuthTokeninfoAllKVName
case checkOAuthTokeninfoValidate:
return filters.OAuthTokeninfoValidateName
}
return AuthUnknown
}
// CreateFilter creates an auth filter. All arguments have to be
// strings. Depending on the variant of the auth tokeninfoFilter, the arguments
// represent scopes or key-value pairs to be checked in the tokeninfo
// response. How scopes or key value pairs are checked is based on the
// type. The shown example for checkOAuthTokeninfoAllScopes will grant
// access only to tokens, that have scopes read-x and write-y:
//
// s.CreateFilter("read-x", "write-y")
func (s *tokeninfoSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
sargs, err := getStrings(args)
if err != nil {
return nil, err
}
if len(sargs) == 0 {
return nil, filters.ErrInvalidFilterParameters
}
ac, err := s.options.getTokeninfoClient()
if err != nil {
return nil, filters.ErrInvalidFilterParameters
}
if s.typ == checkOAuthTokeninfoValidate {
config, err := s.tokeninfoValidateYamlConfigParser.parseSingleArg(args)
if err != nil {
return nil, err
}
return &tokeninfoValidateFilter{client: ac, config: config}, nil
}
f := &tokeninfoFilter{typ: s.typ, client: ac, kv: make(map[string][]string)}
switch f.typ {
// all scopes
case checkOAuthTokeninfoAllScopes:
fallthrough
case checkOAuthTokeninfoAnyScopes:
f.scopes = sargs[:]
// key value pairs
case checkOAuthTokeninfoAnyKV:
fallthrough
case checkOAuthTokeninfoAllKV:
for i := 0; i+1 < len(sargs); i += 2 {
f.kv[sargs[i]] = append(f.kv[sargs[i]], sargs[i+1])
}
if len(sargs) == 0 || len(sargs)%2 != 0 {
return nil, filters.ErrInvalidFilterParameters
}
default:
return nil, filters.ErrInvalidFilterParameters
}
return f, nil
}
// String prints nicely the tokeninfoFilter configuration based on the
// configuration and check used.
func (f *tokeninfoFilter) String() string {
switch f.typ {
case checkOAuthTokeninfoAnyScopes:
return fmt.Sprintf("%s(%s)", filters.OAuthTokeninfoAnyScopeName, strings.Join(f.scopes, ","))
case checkOAuthTokeninfoAllScopes:
return fmt.Sprintf("%s(%s)", filters.OAuthTokeninfoAllScopeName, strings.Join(f.scopes, ","))
case checkOAuthTokeninfoAnyKV:
return fmt.Sprintf("%s(%s)", filters.OAuthTokeninfoAnyKVName, f.kv)
case checkOAuthTokeninfoAllKV:
return fmt.Sprintf("%s(%s)", filters.OAuthTokeninfoAllKVName, f.kv)
}
return AuthUnknown
}
func (f *tokeninfoFilter) validateAnyScopes(h map[string]interface{}) bool {
if len(f.scopes) == 0 {
return true
}
vI, ok := h[scopeKey]
if !ok {
return false
}
v, ok := vI.([]interface{})
if !ok {
return false
}
for _, scope := range f.scopes {
if contains(v, scope) {
return true
}
}
return false
}
func (f *tokeninfoFilter) validateAllScopes(h map[string]interface{}) bool {
if len(f.scopes) == 0 {
return true
}
vI, ok := h[scopeKey]
if !ok {
return false
}
v, ok := vI.([]interface{})
if !ok {
return false
}
for _, scope := range f.scopes {
if !contains(v, scope) {
return false
}
}
return true
}
func (f *tokeninfoFilter) validateAnyKV(h map[string]interface{}) bool {
for k, v := range f.kv {
for _, res := range v {
if v2, ok := h[k].(string); ok {
if res == v2 {
return true
}
}
}
}
return false
}
func (f *tokeninfoFilter) validateAllKV(h map[string]interface{}) bool {
if len(h) < len(f.kv) {
return false
}
for k, v := range f.kv {
for _, res := range v {
v2, ok := h[k].(string)
if !ok || res != v2 {
return false
}
}
}
return true
}
func contains(vals []interface{}, s string) bool {
for _, v := range vals {
if v == s {
return true
}
}
return false
}
// Request handles authentication based on the defined auth type.
func (f *tokeninfoFilter) Request(ctx filters.FilterContext) {
r := ctx.Request()
var authMap map[string]interface{}
authMapTemp, ok := ctx.StateBag()[tokeninfoCacheKey]
if !ok {
token, ok := getToken(r)
if !ok || token == "" {
unauthorized(ctx, "", missingBearerToken, "", "")
return
}
var err error
authMap, err = f.client.getTokeninfo(token, ctx)
if err != nil {
reason := authServiceAccess
if err == errInvalidToken {
reason = invalidToken
} else {
ctx.Logger().Errorf("Error while calling tokeninfo: %v", err)
}
unauthorized(ctx, "", reason, "", "")
return
}
} else {
authMap = authMapTemp.(map[string]interface{})
}
uid, _ := authMap[uidKey].(string) // uid can be empty string, but if not we set the who for auditlogging
var allowed bool
switch f.typ {
case checkOAuthTokeninfoAnyScopes:
allowed = f.validateAnyScopes(authMap)
case checkOAuthTokeninfoAllScopes:
allowed = f.validateAllScopes(authMap)
case checkOAuthTokeninfoAnyKV:
allowed = f.validateAnyKV(authMap)
case checkOAuthTokeninfoAllKV:
allowed = f.validateAllKV(authMap)
default:
ctx.Logger().Errorf("Wrong tokeninfoFilter type: %s.", f)
}
if !allowed {
forbidden(ctx, uid, invalidScope, "")
return
}
authorized(ctx, uid)
ctx.StateBag()[tokeninfoCacheKey] = authMap
}
func (f *tokeninfoFilter) Response(filters.FilterContext) {}
func (c *tokeninfoValidateFilterConfig) initialize() error {
for _, host := range c.OptOutHosts {
if r, err := regexp.Compile(host); err != nil {
return fmt.Errorf("failed to compile opt-out host pattern: %q", host)
} else {
c.optOutHostsCompiled = append(c.optOutHostsCompiled, r)
}
}
return nil
}
func (f *tokeninfoValidateFilter) Request(ctx filters.FilterContext) {
if _, ok := ctx.StateBag()[tokeninfoCacheKey]; ok {
return // tokeninfo was already validated by a preceding filter
}
if len(f.config.OptOutAnnotations) > 0 {
annotations := annotate.GetAnnotations(ctx)
for _, annotation := range f.config.OptOutAnnotations {
if _, ok := annotations[annotation]; ok {
return // opt-out from validation
}
}
}
if len(f.config.optOutHostsCompiled) > 0 {
host := ctx.Request().Host
for _, r := range f.config.optOutHostsCompiled {
if r.MatchString(host) {
return // opt-out from validation
}
}
}
token, ok := getToken(ctx.Request())
if !ok {
f.serveUnauthorized(ctx)
return
}
tokeninfo, err := f.client.getTokeninfo(token, ctx)
if err != nil {
f.serveUnauthorized(ctx)
return
}
uid, _ := tokeninfo[uidKey].(string)
authorized(ctx, uid)
ctx.StateBag()[tokeninfoCacheKey] = tokeninfo
}
func (f *tokeninfoValidateFilter) serveUnauthorized(ctx filters.FilterContext) {
ctx.Serve(&http.Response{
StatusCode: http.StatusUnauthorized,
Header: http.Header{
"Content-Length": []string{strconv.Itoa(len(f.config.UnauthorizedResponse))},
},
Body: io.NopCloser(strings.NewReader(f.config.UnauthorizedResponse)),
})
}
func (f *tokeninfoValidateFilter) Response(filters.FilterContext) {}
package auth
import (
"maps"
"sync"
"sync/atomic"
"time"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/metrics"
)
type (
tokeninfoCache struct {
client tokeninfoClient
metrics metrics.Metrics
size int
ttl time.Duration
now func() time.Time
cache sync.Map // map[string]*entry
count atomic.Int64 // estimated number of cached entries, see https://github.com/golang/go/issues/20680
quit chan struct{}
}
entry struct {
expiresAt time.Time
info map[string]any
infoExpiresAt time.Time
}
)
var _ tokeninfoClient = &tokeninfoCache{}
const expiresInField = "expires_in"
func newTokeninfoCache(client tokeninfoClient, metrics metrics.Metrics, size int, ttl time.Duration) *tokeninfoCache {
c := &tokeninfoCache{
client: client,
metrics: metrics,
size: size,
ttl: ttl,
now: time.Now,
quit: make(chan struct{}),
}
go c.evictLoop()
return c
}
func (c *tokeninfoCache) Close() {
c.client.Close()
close(c.quit)
}
func (c *tokeninfoCache) getTokeninfo(token string, ctx filters.FilterContext) (map[string]any, error) {
if cached := c.cached(token); cached != nil {
return cached, nil
}
info, err := c.client.getTokeninfo(token, ctx)
if err == nil {
c.tryCache(token, info)
}
return info, err
}
func (c *tokeninfoCache) cached(token string) map[string]any {
if v, ok := c.cache.Load(token); ok {
now := c.now()
e := v.(*entry)
if now.Before(e.expiresAt) {
// Clone cached value because callers may modify it,
// see e.g. [OAuthConfig.GrantTokeninfoKeys] and [grantFilter.setupToken].
info := maps.Clone(e.info)
info[expiresInField] = e.infoExpiresAt.Sub(now).Truncate(time.Second).Seconds()
return info
}
}
return nil
}
func (c *tokeninfoCache) tryCache(token string, info map[string]any) {
expiresIn := expiresIn(info)
if expiresIn <= 0 {
return
}
now := c.now()
e := &entry{
info: info,
infoExpiresAt: now.Add(expiresIn),
}
if c.ttl > 0 && expiresIn > c.ttl {
e.expiresAt = now.Add(c.ttl)
} else {
e.expiresAt = e.infoExpiresAt
}
if _, loaded := c.cache.Swap(token, e); !loaded {
c.count.Add(1)
}
}
func (c *tokeninfoCache) evictLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-c.quit:
return
case <-ticker.C:
c.evict()
}
}
}
func (c *tokeninfoCache) evict() {
now := c.now()
// Evict expired entries
c.cache.Range(func(key, value any) bool {
e := value.(*entry)
if now.After(e.expiresAt) {
if c.cache.CompareAndDelete(key, value) {
c.count.Add(-1)
}
}
return true
})
// Evict random entries until the cache size is within limits
if c.count.Load() > int64(c.size) {
c.cache.Range(func(key, value any) bool {
if c.cache.CompareAndDelete(key, value) {
c.count.Add(-1)
}
return c.count.Load() > int64(c.size)
})
}
if c.metrics != nil {
c.metrics.UpdateGauge("tokeninfocache.count", float64(c.count.Load()))
}
}
// Returns the lifetime of the access token if present.
// See https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2
func expiresIn(info map[string]any) time.Duration {
if v, ok := info[expiresInField]; ok {
// https://pkg.go.dev/encoding/json#Unmarshal stores JSON numbers in float64
if v, ok := v.(float64); ok {
return time.Duration(v) * time.Second
}
}
return 0
}
package auth
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.OAuthTokenintrospectionAnyClaimsName instead
OAuthTokenintrospectionAnyClaimsName = filters.OAuthTokenintrospectionAnyClaimsName
// Deprecated, use filters.OAuthTokenintrospectionAllClaimsName instead
OAuthTokenintrospectionAllClaimsName = filters.OAuthTokenintrospectionAllClaimsName
// Deprecated, use filters.OAuthTokenintrospectionAnyKVName instead
OAuthTokenintrospectionAnyKVName = filters.OAuthTokenintrospectionAnyKVName
// Deprecated, use filters.OAuthTokenintrospectionAllKVName instead
OAuthTokenintrospectionAllKVName = filters.OAuthTokenintrospectionAllKVName
// Deprecated, use filters.SecureOAuthTokenintrospectionAnyClaimsName instead
SecureOAuthTokenintrospectionAnyClaimsName = filters.SecureOAuthTokenintrospectionAnyClaimsName
// Deprecated, use filters.SecureOAuthTokenintrospectionAllClaimsName instead
SecureOAuthTokenintrospectionAllClaimsName = filters.SecureOAuthTokenintrospectionAllClaimsName
// Deprecated, use filters.SecureOAuthTokenintrospectionAnyKVName instead
SecureOAuthTokenintrospectionAnyKVName = filters.SecureOAuthTokenintrospectionAnyKVName
// Deprecated, use filters.SecureOAuthTokenintrospectionAllKVName instead
SecureOAuthTokenintrospectionAllKVName = filters.SecureOAuthTokenintrospectionAllKVName
tokenintrospectionCacheKey = "tokenintrospection"
TokenIntrospectionConfigPath = "/.well-known/openid-configuration"
)
type TokenintrospectionOptions struct {
Timeout time.Duration
Tracer opentracing.Tracer
MaxIdleConns int
}
type (
tokenIntrospectionSpec struct {
typ roleCheckType
options TokenintrospectionOptions
secure bool
}
tokenIntrospectionInfo map[string]interface{}
tokenintrospectFilter struct {
typ roleCheckType
authClient *authClient
claims []string
kv kv
}
openIDConfig struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
RevocationEndpoint string `json:"revocation_endpoint"`
JwksURI string `json:"jwks_uri"`
RegistrationEndpoint string `json:"registration_endpoint"`
IntrospectionEndpoint string `json:"introspection_endpoint"`
ResponseTypesSupported []string `json:"response_types_supported"`
SubjectTypesSupported []string `json:"subject_types_supported"`
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
ClaimsSupported []string `json:"claims_supported"`
ScopesSupported []string `json:"scopes_supported"`
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
}
)
var issuerAuthClient map[string]*authClient = make(map[string]*authClient)
// Active returns token introspection response, which is true if token
// is not revoked and in the time frame of
// validity. https://tools.ietf.org/html/rfc7662#section-2.2
func (tii tokenIntrospectionInfo) Active() bool {
return tii.getBoolValue("active")
}
func (tii tokenIntrospectionInfo) Sub() (string, error) {
return tii.getStringValue("sub")
}
func (tii tokenIntrospectionInfo) getBoolValue(k string) bool {
if active, ok := tii[k].(bool); ok {
return active
}
return false
}
func (tii tokenIntrospectionInfo) getStringValue(k string) (string, error) {
s, ok := tii[k].(string)
if !ok {
return "", errInvalidTokenintrospectionData
}
return s, nil
}
// NewOAuthTokenintrospectionAnyKV creates a new auth filter specification
// to validate authorization for requests. Current implementation uses
// Bearer tokens to authorize requests and checks that the token
// contains at least one key value pair provided.
//
// This is implementing RFC 7662 compliant implementation. It uses
// POST requests to call introspection_endpoint to get the information
// of the token validity.
//
// It uses /.well-known/openid-configuration path to the passed
// oauthIssuerURL to find introspection_endpoint as defined in draft
// https://tools.ietf.org/html/draft-ietf-oauth-discovery-06, if
// oauthIntrospectionURL is a non empty string, it will set
// IntrospectionEndpoint to the given oauthIntrospectionURL.
func NewOAuthTokenintrospectionAnyKV(timeout time.Duration) filters.Spec {
return newOAuthTokenintrospectionFilter(checkOAuthTokenintrospectionAnyKV, timeout)
}
// NewOAuthTokenintrospectionAllKV creates a new auth filter specification
// to validate authorization for requests. Current implementation uses
// Bearer tokens to authorize requests and checks that the token
// contains at least one key value pair provided.
//
// This is implementing RFC 7662 compliant implementation. It uses
// POST requests to call introspection_endpoint to get the information
// of the token validity.
//
// It uses /.well-known/openid-configuration path to the passed
// oauthIssuerURL to find introspection_endpoint as defined in draft
// https://tools.ietf.org/html/draft-ietf-oauth-discovery-06, if
// oauthIntrospectionURL is a non empty string, it will set
// IntrospectionEndpoint to the given oauthIntrospectionURL.
func NewOAuthTokenintrospectionAllKV(timeout time.Duration) filters.Spec {
return newOAuthTokenintrospectionFilter(checkOAuthTokenintrospectionAllKV, timeout)
}
func NewOAuthTokenintrospectionAnyClaims(timeout time.Duration) filters.Spec {
return newOAuthTokenintrospectionFilter(checkOAuthTokenintrospectionAnyClaims, timeout)
}
func NewOAuthTokenintrospectionAllClaims(timeout time.Duration) filters.Spec {
return newOAuthTokenintrospectionFilter(checkOAuthTokenintrospectionAllClaims, timeout)
}
// Secure Introspection Point
func NewSecureOAuthTokenintrospectionAnyKV(timeout time.Duration) filters.Spec {
return newSecureOAuthTokenintrospectionFilter(checkSecureOAuthTokenintrospectionAnyKV, timeout)
}
func NewSecureOAuthTokenintrospectionAllKV(timeout time.Duration) filters.Spec {
return newSecureOAuthTokenintrospectionFilter(checkSecureOAuthTokenintrospectionAllKV, timeout)
}
func NewSecureOAuthTokenintrospectionAnyClaims(timeout time.Duration) filters.Spec {
return newSecureOAuthTokenintrospectionFilter(checkSecureOAuthTokenintrospectionAnyClaims, timeout)
}
func NewSecureOAuthTokenintrospectionAllClaims(timeout time.Duration) filters.Spec {
return newSecureOAuthTokenintrospectionFilter(checkSecureOAuthTokenintrospectionAllClaims, timeout)
}
// TokenintrospectionWithOptions create a new auth filter specification
// for validating authorization requests with additional options to the
// mandatory timeout parameter.
//
// Use one of the base initializer functions as the first argument:
// NewOAuthTokenintrospectionAnyKV, NewOAuthTokenintrospectionAllKV,
// NewOAuthTokenintrospectionAnyClaims, NewOAuthTokenintrospectionAllClaims,
// NewSecureOAuthTokenintrospectionAnyKV, NewSecureOAuthTokenintrospectionAllKV,
// NewSecureOAuthTokenintrospectionAnyClaims, NewSecureOAuthTokenintrospectionAllClaims,
// pass opentracing.Tracer and other options in TokenintrospectionOptions.
func TokenintrospectionWithOptions(
create func(time.Duration) filters.Spec,
o TokenintrospectionOptions,
) filters.Spec {
s := create(o.Timeout)
ts, ok := s.(*tokenIntrospectionSpec)
if !ok {
return s
}
ts.options = o
return ts
}
func newOAuthTokenintrospectionFilter(typ roleCheckType, timeout time.Duration) filters.Spec {
return &tokenIntrospectionSpec{
typ: typ,
options: TokenintrospectionOptions{
Timeout: timeout,
Tracer: opentracing.NoopTracer{},
},
secure: false,
}
}
func newSecureOAuthTokenintrospectionFilter(typ roleCheckType, timeout time.Duration) filters.Spec {
return &tokenIntrospectionSpec{
typ: typ,
options: TokenintrospectionOptions{
Timeout: timeout,
Tracer: opentracing.NoopTracer{},
},
secure: true,
}
}
func getOpenIDConfig(issuerURL string) (*openIDConfig, error) {
u, err := url.Parse(issuerURL + TokenIntrospectionConfigPath)
if err != nil {
return nil, err
}
rsp, err := http.Get(u.String())
if err != nil {
return nil, err
}
defer rsp.Body.Close()
if rsp.StatusCode != 200 {
return nil, errInvalidToken
}
d := json.NewDecoder(rsp.Body)
var cfg openIDConfig
err = d.Decode(&cfg)
return &cfg, err
}
func (s *tokenIntrospectionSpec) Name() string {
switch s.typ {
case checkOAuthTokenintrospectionAnyClaims:
return filters.OAuthTokenintrospectionAnyClaimsName
case checkOAuthTokenintrospectionAllClaims:
return filters.OAuthTokenintrospectionAllClaimsName
case checkOAuthTokenintrospectionAnyKV:
return filters.OAuthTokenintrospectionAnyKVName
case checkOAuthTokenintrospectionAllKV:
return filters.OAuthTokenintrospectionAllKVName
case checkSecureOAuthTokenintrospectionAnyClaims:
return filters.SecureOAuthTokenintrospectionAnyClaimsName
case checkSecureOAuthTokenintrospectionAllClaims:
return filters.SecureOAuthTokenintrospectionAllClaimsName
case checkSecureOAuthTokenintrospectionAnyKV:
return filters.SecureOAuthTokenintrospectionAnyKVName
case checkSecureOAuthTokenintrospectionAllKV:
return filters.SecureOAuthTokenintrospectionAllKVName
}
return AuthUnknown
}
func (s *tokenIntrospectionSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
sargs, err := getStrings(args)
if err != nil {
return nil, err
}
if s.secure && len(sargs) < 4 || !s.secure && len(sargs) < 2 {
return nil, filters.ErrInvalidFilterParameters
}
issuerURL := sargs[0]
var clientId, clientSecret string
if s.secure {
clientId = sargs[1]
clientSecret = sargs[2]
sargs = sargs[3:]
if clientId == "" {
clientId, _ = os.LookupEnv("OAUTH_CLIENT_ID")
}
if clientSecret == "" {
clientSecret, _ = os.LookupEnv("OAUTH_CLIENT_SECRET")
}
} else {
sargs = sargs[1:]
}
cfg, err := getOpenIDConfig(issuerURL)
if err != nil {
return nil, err
}
var ac *authClient
var ok bool
if ac, ok = issuerAuthClient[issuerURL]; !ok {
ac, err = newAuthClient(cfg.IntrospectionEndpoint, tokenIntrospectionSpanName, s.options.Timeout, s.options.MaxIdleConns, s.options.Tracer)
if err != nil {
return nil, filters.ErrInvalidFilterParameters
}
issuerAuthClient[issuerURL] = ac
}
if s.secure && clientId != "" && clientSecret != "" {
ac.url.User = url.UserPassword(clientId, clientSecret)
} else {
ac.url.User = nil
}
f := &tokenintrospectFilter{
typ: s.typ,
authClient: ac,
kv: make(map[string][]string),
}
switch f.typ {
case checkOAuthTokenintrospectionAllClaims:
fallthrough
case checkSecureOAuthTokenintrospectionAllClaims:
fallthrough
case checkSecureOAuthTokenintrospectionAnyClaims:
fallthrough
case checkOAuthTokenintrospectionAnyClaims:
f.claims = sargs
if !all(f.claims, cfg.ClaimsSupported) {
return nil, fmt.Errorf("%w: %s, supported Claims: %v", errUnsupportedClaimSpecified, strings.Join(f.claims, ","), cfg.ClaimsSupported)
}
// key value pairs
case checkOAuthTokenintrospectionAllKV:
fallthrough
case checkSecureOAuthTokenintrospectionAllKV:
fallthrough
case checkSecureOAuthTokenintrospectionAnyKV:
fallthrough
case checkOAuthTokenintrospectionAnyKV:
for i := 0; i+1 < len(sargs); i += 2 {
f.kv[sargs[i]] = append(f.kv[sargs[i]], sargs[i+1])
}
if len(sargs) == 0 || len(sargs)%2 != 0 {
return nil, filters.ErrInvalidFilterParameters
}
default:
return nil, filters.ErrInvalidFilterParameters
}
return f, nil
}
// String prints nicely the tokenintrospectFilter configuration based on the
// configuration and check used.
func (f *tokenintrospectFilter) String() string {
switch f.typ {
case checkOAuthTokenintrospectionAnyClaims:
return fmt.Sprintf("%s(%s)", filters.OAuthTokenintrospectionAnyClaimsName, strings.Join(f.claims, ","))
case checkOAuthTokenintrospectionAllClaims:
return fmt.Sprintf("%s(%s)", filters.OAuthTokenintrospectionAllClaimsName, strings.Join(f.claims, ","))
case checkOAuthTokenintrospectionAnyKV:
return fmt.Sprintf("%s(%s)", filters.OAuthTokenintrospectionAnyKVName, f.kv)
case checkOAuthTokenintrospectionAllKV:
return fmt.Sprintf("%s(%s)", filters.OAuthTokenintrospectionAllKVName, f.kv)
case checkSecureOAuthTokenintrospectionAnyClaims:
return fmt.Sprintf("%s(%s)", filters.SecureOAuthTokenintrospectionAnyClaimsName, strings.Join(f.claims, ","))
case checkSecureOAuthTokenintrospectionAllClaims:
return fmt.Sprintf("%s(%s)", filters.SecureOAuthTokenintrospectionAllClaimsName, strings.Join(f.claims, ","))
case checkSecureOAuthTokenintrospectionAnyKV:
return fmt.Sprintf("%s(%s)", filters.SecureOAuthTokenintrospectionAnyKVName, f.kv)
case checkSecureOAuthTokenintrospectionAllKV:
return fmt.Sprintf("%s(%s)", filters.SecureOAuthTokenintrospectionAllKVName, f.kv)
}
return AuthUnknown
}
func (f *tokenintrospectFilter) validateAnyClaims(info tokenIntrospectionInfo) bool {
for _, wantedClaim := range f.claims {
if claims, ok := info["claims"].(map[string]interface{}); ok {
if _, ok2 := claims[wantedClaim]; ok2 {
return true
}
}
}
return false
}
func (f *tokenintrospectFilter) validateAllClaims(info tokenIntrospectionInfo) bool {
for _, v := range f.claims {
if claims, ok := info["claims"].(map[string]interface{}); !ok {
return false
} else {
if _, ok := claims[v]; !ok {
return false
}
}
}
return true
}
func (f *tokenintrospectFilter) validateAllKV(info tokenIntrospectionInfo) bool {
for k, v := range f.kv {
for _, res := range v {
v2, ok := info[k].(string)
if !ok || res != v2 {
return false
}
}
}
return true
}
func (f *tokenintrospectFilter) validateAnyKV(info tokenIntrospectionInfo) bool {
for k, v := range f.kv {
for _, res := range v {
v2, ok := info[k].(string)
if ok && res == v2 {
return true
}
}
}
return false
}
func (f *tokenintrospectFilter) Request(ctx filters.FilterContext) {
r := ctx.Request()
var info tokenIntrospectionInfo
infoTemp, ok := ctx.StateBag()[tokenintrospectionCacheKey]
if !ok {
token, ok := getToken(r)
if !ok || token == "" {
unauthorized(ctx, "", missingToken, f.authClient.url.Hostname(), "")
return
}
var err error
info, err = f.authClient.getTokenintrospect(token, ctx)
if err != nil {
reason := authServiceAccess
if err == errInvalidToken {
reason = invalidToken
} else {
ctx.Logger().Errorf("Error while calling token introspection: %v", err)
}
unauthorized(ctx, "", reason, f.authClient.url.Hostname(), "")
return
}
} else {
info = infoTemp.(tokenIntrospectionInfo)
}
sub, err := info.Sub()
if err != nil {
if err != errInvalidTokenintrospectionData {
ctx.Logger().Errorf("Error while reading token: %v", err)
}
unauthorized(ctx, sub, invalidSub, f.authClient.url.Hostname(), "")
return
}
if !info.Active() {
unauthorized(ctx, sub, inactiveToken, f.authClient.url.Hostname(), "")
return
}
var allowed bool
switch f.typ {
case checkOAuthTokenintrospectionAnyClaims, checkSecureOAuthTokenintrospectionAnyClaims:
allowed = f.validateAnyClaims(info)
case checkOAuthTokenintrospectionAnyKV, checkSecureOAuthTokenintrospectionAnyKV:
allowed = f.validateAnyKV(info)
case checkOAuthTokenintrospectionAllClaims, checkSecureOAuthTokenintrospectionAllClaims:
allowed = f.validateAllClaims(info)
case checkOAuthTokenintrospectionAllKV, checkSecureOAuthTokenintrospectionAllKV:
allowed = f.validateAllKV(info)
default:
ctx.Logger().Errorf("Wrong tokenintrospectionFilter type: %s", f)
}
if !allowed {
unauthorized(ctx, sub, invalidClaim, f.authClient.url.Hostname(), "")
return
}
authorized(ctx, sub)
ctx.StateBag()[tokenintrospectionCacheKey] = info
}
func (f *tokenintrospectFilter) Response(filters.FilterContext) {}
package auth
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/opentracing/opentracing-go"
"golang.org/x/net/http/httpguts"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.WebhookName instead
WebhookName = filters.WebhookName
)
type WebhookOptions struct {
Timeout time.Duration
MaxIdleConns int
Tracer opentracing.Tracer
}
type (
webhookSpec struct {
options WebhookOptions
}
webhookFilter struct {
authClient *authClient
forwardResponseHeaderKeys []string
}
)
var webhookAuthClient map[string]*authClient = make(map[string]*authClient)
// NewWebhook creates a new auth filter specification
// to validate authorization for requests via an
// external web hook.
func NewWebhook(timeout time.Duration) filters.Spec {
return WebhookWithOptions(WebhookOptions{Timeout: timeout, Tracer: opentracing.NoopTracer{}})
}
// WebhookWithOptions creates a new auth filter specification
// to validate authorization of requests via an external web
// hook.
func WebhookWithOptions(o WebhookOptions) filters.Spec {
return &webhookSpec{options: o}
}
func (*webhookSpec) Name() string {
return filters.WebhookName
}
// CreateFilter creates an auth filter. The first argument is an URL
// string. The second, optional, argument is a comma separated list of
// headers to forward from from webhook response.
//
// s.CreateFilter("https://my-auth-service.example.org/auth")
// s.CreateFilter("https://my-auth-service.example.org/auth", "X-Auth-User,X-Auth-User-Roles")
func (ws *webhookSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if l := len(args); l == 0 || l > 2 {
return nil, filters.ErrInvalidFilterParameters
}
var ok bool
s, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
forwardResponseHeaderKeys := make([]string, 0)
if len(args) > 1 {
// Capture headers that should be forwarded from webhook responses.
headerKeysOption, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
headerKeys := strings.Split(headerKeysOption, ",")
for _, header := range headerKeys {
valid := httpguts.ValidHeaderFieldName(header)
if !valid {
return nil, fmt.Errorf("header %s is invalid", header)
}
forwardResponseHeaderKeys = append(forwardResponseHeaderKeys, http.CanonicalHeaderKey(header))
}
}
var ac *authClient
var err error
if ac, ok = webhookAuthClient[s]; !ok {
ac, err = newAuthClient(s, webhookSpanName, ws.options.Timeout, ws.options.MaxIdleConns, ws.options.Tracer)
if err != nil {
return nil, filters.ErrInvalidFilterParameters
}
webhookAuthClient[s] = ac
}
return &webhookFilter{authClient: ac, forwardResponseHeaderKeys: forwardResponseHeaderKeys}, nil
}
func copyHeader(to, from http.Header) {
for k, v := range from {
to[http.CanonicalHeaderKey(k)] = v
}
}
func (f *webhookFilter) Request(ctx filters.FilterContext) {
resp, err := f.authClient.getWebhook(ctx)
if err != nil {
ctx.Logger().Errorf("Failed to make authentication webhook request: %v.", err)
}
// forbidden
if err == nil && resp.StatusCode == http.StatusForbidden {
forbidden(ctx, "", invalidScope, filters.WebhookName)
return
}
// errors, redirects, auth errors, webhook errors
if err != nil || resp.StatusCode >= 300 {
unauthorized(ctx, "", invalidAccess, f.authClient.url.Hostname(), filters.WebhookName)
return
}
// copy required headers from webhook response into the current request
for _, hk := range f.forwardResponseHeaderKeys {
if h, ok := resp.Header[hk]; ok {
ctx.Request().Header[hk] = h
}
}
authorized(ctx, filters.WebhookName)
}
func (*webhookFilter) Response(filters.FilterContext) {}
package auth
import (
"fmt"
"github.com/ghodss/yaml"
)
// yamlConfigParser parses and caches yaml configurations of type T.
// Use [newYamlConfigParser] to create instances and ensure that *T implements [yamlConfig].
type yamlConfigParser[T any] struct {
initialize func(*T) error
cacheSize int
cache map[string]*T
}
// yamlConfig must be implemented by config value pointer type.
// It is used to initialize the value after parsing.
type yamlConfig interface {
initialize() error
}
// newYamlConfigParser creates a new parser with a given cache size.
func newYamlConfigParser[T any, PT interface {
*T
yamlConfig
}](cacheSize int) yamlConfigParser[T] {
// We want user to specify config type T but ensure that *T implements [yamlConfig].
//
// Type inference only works for functions but not for types
// (see https://github.com/golang/go/issues/57270 and https://github.com/golang/go/issues/51527)
// therefore we create instances using function with two type parameters
// but second parameter is inferred from the first so the caller does not have to specify it.
//
// To use *T.initialize we setup initialize field
return yamlConfigParser[T]{
initialize: func(v *T) error { return PT(v).initialize() },
cacheSize: cacheSize,
cache: make(map[string]*T, cacheSize),
}
}
// parseSingleArg calls [yamlConfigParser.parse] with the first string argument.
// If args slice does not contain a single string, it returns an error.
func (p *yamlConfigParser[T]) parseSingleArg(args []any) (*T, error) {
if len(args) != 1 {
return nil, fmt.Errorf("requires single string argument")
}
config, ok := args[0].(string)
if !ok {
return nil, fmt.Errorf("requires single string argument")
}
return p.parse(config)
}
// parse parses a yaml configuration or returns a cached value
// if the exact configuration was already parsed before.
// Returned value is shared by multiple callers and therefore must not be modified.
func (p *yamlConfigParser[T]) parse(config string) (*T, error) {
if v, ok := p.cache[config]; ok {
return v, nil
}
v := new(T)
if err := yaml.Unmarshal([]byte(config), v); err != nil {
return nil, err
}
if err := p.initialize(v); err != nil {
return nil, err
}
// evict random element if cache is full
if p.cacheSize > 0 && len(p.cache) == p.cacheSize {
for k := range p.cache {
delete(p.cache, k)
break
}
}
p.cache[config] = v
return v, nil
}
package block
import (
"bytes"
"encoding/hex"
"github.com/zalando/skipper/filters"
skpio "github.com/zalando/skipper/io"
"github.com/zalando/skipper/metrics"
)
type blockSpec struct {
MaxMatcherBufferSize uint64
hex bool
}
type toBlockKeys struct{ Str []byte }
func (b toBlockKeys) String() string {
return string(b.Str)
}
type block struct {
toblockList []toBlockKeys
maxEditorBuffer uint64
maxBufferHandling skpio.MaxBufferHandling
metrics metrics.Metrics
}
// NewBlockFilter *deprecated* version of NewBlock
func NewBlockFilter(maxMatcherBufferSize uint64) filters.Spec {
return NewBlock(maxMatcherBufferSize)
}
func NewBlock(maxMatcherBufferSize uint64) filters.Spec {
return &blockSpec{
MaxMatcherBufferSize: maxMatcherBufferSize,
}
}
func NewBlockHex(maxMatcherBufferSize uint64) filters.Spec {
return &blockSpec{
MaxMatcherBufferSize: maxMatcherBufferSize,
hex: true,
}
}
func (bs *blockSpec) Name() string {
if bs.hex {
return filters.BlockHexName
}
return filters.BlockName
}
func (bs *blockSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) == 0 {
return nil, filters.ErrInvalidFilterParameters
}
sargs := make([]toBlockKeys, 0, len(args))
for _, w := range args {
v, ok := w.(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
if bs.hex {
a, err := hex.DecodeString(v)
if err != nil {
return nil, err
}
sargs = append(sargs, toBlockKeys{Str: a})
} else {
sargs = append(sargs, toBlockKeys{Str: []byte(v)})
}
}
return &block{
toblockList: sargs,
maxBufferHandling: skpio.MaxBufferBestEffort,
maxEditorBuffer: bs.MaxMatcherBufferSize,
metrics: metrics.Default,
}, nil
}
func blockMatcher(m metrics.Metrics, matches []toBlockKeys) func(b []byte) (int, error) {
return func(b []byte) (int, error) {
for _, s := range matches {
s := s
if bytes.Contains(b, s.Str) {
m.IncCounter("blocked.requests")
return 0, skpio.ErrBlocked
}
}
return len(b), nil
}
}
func (b *block) Request(ctx filters.FilterContext) {
req := ctx.Request()
if req.ContentLength == 0 {
return
}
// fix filter chaining - https://github.com/zalando/skipper/issues/2605
ctx.Request().Header.Del("Content-Length")
ctx.Request().ContentLength = -1
req.Body = skpio.InspectReader(
req.Context(),
skpio.BufferOptions{
MaxBufferHandling: b.maxBufferHandling,
ReadBufferSize: b.maxEditorBuffer,
},
blockMatcher(b.metrics, b.toblockList),
req.Body)
}
func (*block) Response(filters.FilterContext) {}
package builtin
import "github.com/zalando/skipper/filters"
type backendIsProxySpec struct{}
type backendIsProxyFilter struct{}
// NewBackendIsProxy returns a filter specification that is used to specify that the backend is also a proxy.
func NewBackendIsProxy() filters.Spec {
return &backendIsProxySpec{}
}
func (s *backendIsProxySpec) Name() string {
return filters.BackendIsProxyName
}
func (s *backendIsProxySpec) CreateFilter(args []interface{}) (filters.Filter, error) {
return &backendIsProxyFilter{}, nil
}
func (f *backendIsProxyFilter) Request(ctx filters.FilterContext) {
ctx.StateBag()[filters.BackendIsProxyKey] = struct{}{}
}
func (f *backendIsProxyFilter) Response(ctx filters.FilterContext) {
}
/*
Package builtin provides a small, generic set of filters.
*/
package builtin
import (
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/accesslog"
"github.com/zalando/skipper/filters/annotate"
"github.com/zalando/skipper/filters/auth"
"github.com/zalando/skipper/filters/circuit"
"github.com/zalando/skipper/filters/consistenthash"
"github.com/zalando/skipper/filters/cookie"
"github.com/zalando/skipper/filters/cors"
"github.com/zalando/skipper/filters/diag"
"github.com/zalando/skipper/filters/fadein"
"github.com/zalando/skipper/filters/flowid"
logfilter "github.com/zalando/skipper/filters/log"
"github.com/zalando/skipper/filters/rfc"
"github.com/zalando/skipper/filters/scheduler"
"github.com/zalando/skipper/filters/sed"
"github.com/zalando/skipper/filters/tee"
"github.com/zalando/skipper/filters/tls"
"github.com/zalando/skipper/filters/tracing"
"github.com/zalando/skipper/filters/xforward"
"github.com/zalando/skipper/script"
)
const (
// Deprecated: use setRequestHeader or appendRequestHeader
RequestHeaderName = "requestHeader"
// Deprecated: use setResponseHeader or appendResponseHeader
ResponseHeaderName = "responseHeader"
// Deprecated: use redirectTo
RedirectName = "redirect"
// Deprecated, use filters.SetRequestHeaderName instead
SetRequestHeaderName = filters.SetRequestHeaderName
// Deprecated, use filters.SetResponseHeaderName instead
SetResponseHeaderName = filters.SetResponseHeaderName
// Deprecated, use filters.AppendRequestHeaderName instead
AppendRequestHeaderName = filters.AppendRequestHeaderName
// Deprecated, use filters.AppendResponseHeaderName instead
AppendResponseHeaderName = filters.AppendResponseHeaderName
// Deprecated, use filters.DropRequestHeaderName instead
DropRequestHeaderName = filters.DropRequestHeaderName
// Deprecated, use filters.DropResponseHeaderName instead
DropResponseHeaderName = filters.DropResponseHeaderName
// Deprecated, use filters.SetContextRequestHeaderName instead
SetContextRequestHeaderName = filters.SetContextRequestHeaderName
// Deprecated, use filters.AppendContextRequestHeaderName instead
AppendContextRequestHeaderName = filters.AppendContextRequestHeaderName
// Deprecated, use filters.SetContextResponseHeaderName instead
SetContextResponseHeaderName = filters.SetContextResponseHeaderName
// Deprecated, use filters.AppendContextResponseHeaderName instead
AppendContextResponseHeaderName = filters.AppendContextResponseHeaderName
// Deprecated, use filters.CopyRequestHeaderName instead
CopyRequestHeaderName = filters.CopyRequestHeaderName
// Deprecated, use filters.CopyResponseHeaderName instead
CopyResponseHeaderName = filters.CopyResponseHeaderName
// Deprecated, use filters.SetDynamicBackendHostFromHeader instead
SetDynamicBackendHostFromHeader = filters.SetDynamicBackendHostFromHeader
// Deprecated, use filters.SetDynamicBackendSchemeFromHeader instead
SetDynamicBackendSchemeFromHeader = filters.SetDynamicBackendSchemeFromHeader
// Deprecated, use filters.SetDynamicBackendUrlFromHeader instead
SetDynamicBackendUrlFromHeader = filters.SetDynamicBackendUrlFromHeader
// Deprecated, use filters.SetDynamicBackendHost instead
SetDynamicBackendHost = filters.SetDynamicBackendHost
// Deprecated, use filters.SetDynamicBackendScheme instead
SetDynamicBackendScheme = filters.SetDynamicBackendScheme
// Deprecated, use filters.SetDynamicBackendUrl instead
SetDynamicBackendUrl = filters.SetDynamicBackendUrl
// Deprecated, use filters.HealthCheckName instead
HealthCheckName = filters.HealthCheckName
// Deprecated, use filters.ModPathName instead
ModPathName = filters.ModPathName
// Deprecated, use filters.SetPathName instead
SetPathName = filters.SetPathName
// Deprecated, use filters.ModRequestHeaderName instead
ModRequestHeaderName = filters.ModRequestHeaderName
// Deprecated, use filters.RedirectToName instead
RedirectToName = filters.RedirectToName
// Deprecated, use filters.RedirectToLowerName instead
RedirectToLowerName = filters.RedirectToLowerName
// Deprecated, use filters.StaticName instead
StaticName = filters.StaticName
// Deprecated, use filters.StripQueryName instead
StripQueryName = filters.StripQueryName
// Deprecated, use filters.PreserveHostName instead
PreserveHostName = filters.PreserveHostName
// Deprecated, use filters.SetFastCgiFilenameName instead
SetFastCgiFilenameName = filters.SetFastCgiFilenameName
// Deprecated, use filters.StatusName instead
StatusName = filters.StatusName
// Deprecated, use filters.CompressName instead
CompressName = filters.CompressName
// Deprecated, use filters.SetQueryName instead
SetQueryName = filters.SetQueryName
// Deprecated, use filters.DropQueryName instead
DropQueryName = filters.DropQueryName
// Deprecated, use filters.InlineContentName instead
InlineContentName = filters.InlineContentName
// Deprecated, use filters.InlineContentIfStatusName instead
InlineContentIfStatusName = filters.InlineContentIfStatusName
// Deprecated, use filters.HeaderToQueryName instead
HeaderToQueryName = filters.HeaderToQueryName
// Deprecated, use filters.QueryToHeaderName instead
QueryToHeaderName = filters.QueryToHeaderName
// Deprecated, use filters.BackendTimeoutName instead
BackendTimeoutName = filters.BackendTimeoutName
)
func Filters() []filters.Spec {
return []filters.Spec{
NewBackendIsProxy(),
NewComment(),
annotate.New(),
NewRequestHeader(),
NewSetRequestHeader(),
NewAppendRequestHeader(),
NewDropRequestHeader(),
NewResponseHeader(),
NewSetResponseHeader(),
NewAppendResponseHeader(),
NewDropResponseHeader(),
NewSetContextRequestHeader(),
NewAppendContextRequestHeader(),
NewSetContextResponseHeader(),
NewAppendContextResponseHeader(),
NewCopyRequestHeader(),
NewCopyResponseHeader(),
NewCopyRequestHeaderDeprecated(),
NewCopyResponseHeaderDeprecated(),
NewModPath(),
NewSetPath(),
NewModRequestHeader(),
NewModResponseHeader(),
NewDropQuery(),
NewSetQuery(),
NewHealthCheck(),
NewStatic(),
NewRedirect(),
NewRedirectTo(),
NewRedirectLower(),
NewStripQuery(),
NewInlineContent(),
NewInlineContentIfStatus(),
flowid.New(),
xforward.New(),
xforward.NewFirst(),
PreserveHost(),
NewSetFastCgiFilename(),
NewStatus(),
NewCompress(),
NewDecompress(),
NewHeaderToQuery(),
NewQueryToHeader(),
NewBackendTimeout(),
NewReadTimeout(),
NewWriteTimeout(),
NewSetDynamicBackendHostFromHeader(),
NewSetDynamicBackendSchemeFromHeader(),
NewSetDynamicBackendUrlFromHeader(),
NewSetDynamicBackendHost(),
NewSetDynamicBackendScheme(),
NewSetDynamicBackendUrl(),
NewOriginMarkerSpec(),
diag.NewRandom(),
diag.NewRepeat(),
diag.NewRepeatHex(),
diag.NewWrap(),
diag.NewWrapHex(),
diag.NewLatency(),
diag.NewBandwidth(),
diag.NewChunks(),
diag.NewBackendLatency(),
diag.NewBackendBandwidth(),
diag.NewBackendChunks(),
diag.NewTarpit(),
diag.NewAbsorb(),
diag.NewAbsorbSilent(),
diag.NewLogHeader(),
diag.NewLogBody(),
diag.NewUniformRequestLatency(),
diag.NewUniformResponseLatency(),
diag.NewNormalRequestLatency(),
diag.NewNormalResponseLatency(),
diag.NewHistogramRequestLatency(),
diag.NewHistogramResponseLatency(),
tee.NewTee(),
tee.NewTeeDeprecated(),
tee.NewTeeNoFollow(),
tee.NewTeeLoopback(),
sed.New(),
sed.NewDelimited(),
sed.NewRequest(),
sed.NewDelimitedRequest(),
auth.NewBasicAuth(),
cookie.NewDropRequestCookie(),
cookie.NewRequestCookie(),
cookie.NewResponseCookie(),
cookie.NewJSCookie(),
circuit.NewConsecutiveBreaker(),
circuit.NewRateBreaker(),
circuit.NewDisableBreaker(),
script.NewLuaScript(),
cors.NewOrigin(),
logfilter.NewUnverifiedAuditLog(),
tracing.NewSpanName(),
tracing.NewBaggageToTagFilter(),
tracing.NewTag(),
tracing.NewTagFromResponse(),
tracing.NewTagFromResponseIfStatus(),
tracing.NewStateBagToTag(),
//lint:ignore SA1019 due to backward compatibility
accesslog.NewAccessLogDisabled(),
accesslog.NewDisableAccessLog(),
accesslog.NewEnableAccessLog(),
auth.NewForwardToken(),
auth.NewForwardTokenField(),
scheduler.NewFifo(),
scheduler.NewFifoWithBody(),
scheduler.NewLIFO(),
scheduler.NewLIFOGroup(),
rfc.NewPath(),
rfc.NewHost(),
fadein.NewFadeIn(),
fadein.NewEndpointCreated(),
consistenthash.NewConsistentHashKey(),
consistenthash.NewConsistentHashBalanceFactor(),
tls.New(),
}
}
// Returns a Registry object initialized with the default set of filter
// specifications found in the filters package. (including the builtin
// and the flowid subdirectories.)
func MakeRegistry() filters.Registry {
r := make(filters.Registry)
for _, s := range Filters() {
r.Register(s)
}
return r
}
package builtin
import (
"slices"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/routing"
)
type comment struct{}
// NewComment is a filter to comment a filter chain. It does nothing
func NewComment() filters.Spec {
return comment{}
}
func (comment) Name() string {
return filters.CommentName
}
func (c comment) CreateFilter(args []interface{}) (filters.Filter, error) {
return c, nil
}
func (comment) Request(filters.FilterContext) {}
func (comment) Response(filters.FilterContext) {}
type CommentPostProcessor struct{}
func (CommentPostProcessor) Do(routes []*routing.Route) []*routing.Route {
for _, r := range routes {
r.Filters = slices.DeleteFunc(r.Filters, func(f *routing.RouteFilter) bool {
return f.Name == filters.CommentName
})
}
return routes
}
package builtin
import (
"compress/flate"
"compress/gzip"
"errors"
"io"
"math"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"github.com/andybalholm/brotli"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/filters"
)
const bufferSize = 8192
type encoding struct {
name string
q float32 // encoding client priority
p int // encoding server priority
}
type encodings []*encoding
type compress struct {
mime []string
level int
encodingPriority map[string]int
}
type CompressOptions struct {
// Specifies encodings supported for compression, the order defines priority when Accept-Header has equal quality values, see RFC 7231 section 5.3.1
Encodings []string
}
type encoder interface {
io.WriteCloser
Reset(io.Writer)
Flush() error
}
var (
supportedEncodings = []string{"gzip", "deflate", "br"}
errUnsupportedEncoding = errors.New("unsupported encoding")
)
var defaultCompressMIME = []string{
"text/plain",
"text/html",
"application/json",
"application/javascript",
"application/x-javascript",
"text/javascript",
"text/css",
"image/svg+xml",
"application/octet-stream",
}
var (
brotliPool = &sync.Pool{New: func() interface{} {
ge, err := newEncoder("br", flate.BestSpeed)
if err != nil {
log.Error(err)
}
return ge
}}
gzipPool = &sync.Pool{New: func() interface{} {
ge, err := newEncoder("gzip", flate.BestSpeed)
if err != nil {
log.Error(err)
}
return ge
}}
deflatePool = &sync.Pool{New: func() interface{} {
fe, err := newEncoder("deflate", flate.BestSpeed)
if err != nil {
log.Error(err)
}
return fe
}}
)
func (e encodings) Len() int { return len(e) }
func (e encodings) Less(i, j int) bool {
if e[i].q != e[j].q {
return e[i].q > e[j].q // higher first
}
return e[i].p < e[j].p // smallest first
}
func (e encodings) Swap(i, j int) { e[i], e[j] = e[j], e[i] }
// Returns a filter specification that is used to compress the response content.
//
// Example:
//
// r: * -> compress() -> "https://www.example.org";
//
// The filter, when executed on the response path, checks if the response
// entity can be compressed. To decide, it checks the Content-Encoding, the
// Cache-Control and the Content-Type headers. It doesn't compress the content
// if the Content-Encoding is set to other than identity, or the Cache-Control
// applies the no-transform pragma, or the Content-Type is set to an unsupported
// value.
//
// The default supported content types are: text/plain, text/html,
// application/json, application/javascript, application/x-javascript,
// text/javascript, text/css, image/svg+xml, application/octet-stream.
//
// The default set of MIME types can be reset or extended by passing in the desired
// types as filter arguments. When extending the defaults, the first argument needs
// to be "...". E.g. to compress tiff in addition to the defaults:
//
// r: * -> compress("...", "image/tiff") -> "https://www.example.org";
//
// To reset the supported types, e.g. to compress only HTML, the "..." argument
// needs to be omitted:
//
// r: * -> compress("text/html") -> "https://www.example.org";
//
// It is possible to control the compression level, by setting it as the first
// filter argument, in front of the MIME types. The default compression level is
// best-speed. The possible values are integers between 0 and 11 (inclusive), where
// 0 means no-compression, 1 means best-speed and 11 means best-compression.
// Example:
//
// r: * -> compress(9, "image/tiff") -> "https://www.example.org";
//
// The filter also checks the incoming request, if it accepts the supported
// encodings, explicitly stated in the Accept-Encoding header. The filter currently
// supports brotli, gzip and deflate. It does not assume that the client accepts any
// encoding if the Accept-Encoding header is not set. It ignores * in the
// Accept-Encoding header.
//
// Supported encodings are prioritized on:
// - quality value if provided by client
// - server side priority (encodingPriority) otherwise
//
// When compressing the response, it updates the response header. It deletes the
// the Content-Length value triggering the proxy to always return the response
// with chunked transfer encoding, sets the Content-Encoding to the selected
// encoding and sets the Vary: Accept-Encoding header, if missing.
//
// The compression happens in a streaming way, using only a small internal buffer.
func NewCompress() filters.Spec {
c, err := NewCompressWithOptions(CompressOptions{supportedEncodings})
if err != nil {
log.Warningf("Failed to create compress filter: %v", err)
}
return c
}
func NewCompressWithOptions(options CompressOptions) (filters.Spec, error) {
m := map[string]int{}
for i, v := range options.Encodings {
if !stringsContain(supportedEncodings, v) {
return nil, errUnsupportedEncoding
}
m[v] = i
}
return &compress{encodingPriority: m}, nil
}
func (c *compress) Name() string {
return filters.CompressName
}
func (c *compress) CreateFilter(args []interface{}) (filters.Filter, error) {
f := &compress{
mime: defaultCompressMIME,
level: flate.BestSpeed,
encodingPriority: c.encodingPriority,
}
if len(args) == 0 {
return f, nil
}
if lf, ok := args[0].(float64); ok && math.Trunc(lf) == lf {
f.level = int(lf)
if f.level < flate.HuffmanOnly || f.level > brotli.BestCompression {
return nil, filters.ErrInvalidFilterParameters
}
args = args[1:]
}
if len(args) == 0 {
return f, nil
}
if args[0] == "..." {
args = args[1:]
} else {
f.mime = nil
}
for _, a := range args {
if s, ok := a.(string); ok {
f.mime = append(f.mime, s)
} else {
return nil, filters.ErrInvalidFilterParameters
}
}
return f, nil
}
func (c *compress) Request(_ filters.FilterContext) {}
func stringsContain(ss []string, s string, transform ...func(string) string) bool {
for _, si := range ss {
for _, t := range transform {
si = t(si)
}
if si == s {
return true
}
}
return false
}
func canEncodeEntity(r *http.Response, mime []string) bool {
if ce := r.Header.Get("Content-Encoding"); ce != "" && ce != "identity" /* forgiving for identity */ {
return false
}
cc := strings.ToLower(r.Header.Get("Cache-Control"))
if strings.Contains(cc, "no-transform") {
return false
}
ct := r.Header.Get("Content-Type")
if i := strings.Index(ct, ";"); i >= 0 {
ct = ct[:i]
}
if !stringsContain(mime, ct) {
return false
}
return true
}
func (c *compress) acceptedEncoding(r *http.Request) string {
var encs encodings
for s := range splitSeq(r.Header.Get("Accept-Encoding"), ",") {
name, weight, hasWeight := strings.Cut(s, ";")
name = strings.ToLower(strings.TrimSpace(name))
if name == "" {
continue
}
prio, ok := c.encodingPriority[name]
if !ok {
continue
}
enc := &encoding{name, 1, prio}
encs = append(encs, enc)
if !hasWeight {
continue
}
weight = strings.TrimSpace(weight)
if !strings.HasPrefix(weight, "q=") {
continue
}
q, err := strconv.ParseFloat(strings.TrimPrefix(weight, "q="), 32)
if err != nil {
continue
}
if float32(q) < 0 || float32(q) > 1.0 {
continue
}
enc.q = float32(q)
}
if len(encs) == 0 {
return ""
}
sort.Sort(encs)
return encs[0].name
}
// TODO: use [strings.SplitSeq] added in go1.24 once go1.25 is released.
func splitSeq(s string, sep string) func(yield func(string) bool) {
return func(yield func(string) bool) {
for {
i := strings.Index(s, sep)
if i < 0 {
break
}
frag := s[:i]
if !yield(frag) {
return
}
s = s[i+len(sep):]
}
yield(s)
}
}
func responseHeader(r *http.Response, enc string) {
r.Header.Del("Content-Length")
r.Header.Set("Content-Encoding", enc)
if !stringsContain(r.Header["Vary"], "Accept-Encoding", http.CanonicalHeaderKey) {
r.Header.Add("Vary", "Accept-Encoding")
}
}
// Not handled encoding is considered as an implementation error, since
// these functions are only called from inside the package, and the
// encoding should be selected from a predefined set.
func unsupported() {
panic(errUnsupportedEncoding)
}
func newEncoder(enc string, level int) (encoder, error) {
switch enc {
case "br":
return brotli.NewWriterLevel(nil, level), nil
case "gzip":
if level > gzip.BestCompression {
level = gzip.BestCompression
}
return gzip.NewWriterLevel(nil, level)
case "deflate":
if level > flate.BestCompression {
level = flate.BestCompression
}
return flate.NewWriter(nil, level)
default:
unsupported()
return nil, nil
}
}
func encoderPool(enc string) *sync.Pool {
switch enc {
case "br":
return brotliPool
case "gzip":
return gzipPool
case "deflate":
return deflatePool
default:
unsupported()
return nil
}
}
func encode(out *io.PipeWriter, in io.ReadCloser, enc string, level int) {
var (
e encoder
err error
)
defer func() {
if e != nil {
cerr := e.Close()
if cerr == nil && level == flate.BestSpeed {
encoderPool(enc).Put(e)
}
}
if err == nil {
err = io.EOF
}
out.CloseWithError(err)
in.Close()
}()
if level == flate.BestSpeed {
e = encoderPool(enc).Get().(encoder)
// if the pool.New failed to create an encoder,
// then we already have logged the error
if e == nil {
return
}
} else {
e, err = newEncoder(enc, level)
if err != nil {
log.Error(err)
return
}
}
e.Reset(out)
b := make([]byte, bufferSize)
for {
n, rerr := in.Read(b)
if n > 0 {
_, err = e.Write(b[:n])
if err != nil {
break
}
err = e.Flush()
if err != nil {
break
}
}
if rerr != nil {
err = rerr
break
}
}
}
func responseBody(rsp *http.Response, enc string, level int) {
in := rsp.Body
r, w := io.Pipe()
rsp.Body = r
go encode(w, in, enc, level)
}
func (c *compress) Response(ctx filters.FilterContext) {
rsp := ctx.Response()
if !canEncodeEntity(rsp, c.mime) {
return
}
enc := c.acceptedEncoding(ctx.Request())
if enc == "" {
return
}
responseHeader(rsp, enc)
responseBody(rsp, enc, c.level)
}
package builtin
import (
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/routing"
)
const (
maxAge = 2
metricsPrefix = "routeCreationTime."
)
// RouteCreationMetrics reports metrics about the time it took to create metrics.
// It looks for filters of type OriginMarker to determine when the source object (e.g. ingress) of the route
// was created.
// If an OriginMarker with the same type and id is seen again, the creation time is not reported again, because
// a route with the same configuration already existed before.
type RouteCreationMetrics struct {
metrics filters.Metrics
originIdAges map[string]map[string]int
initialized bool
}
func NewRouteCreationMetrics(metrics filters.Metrics) *RouteCreationMetrics {
return &RouteCreationMetrics{metrics: metrics, originIdAges: map[string]map[string]int{}}
}
func removeOriginMarkers(r []*routing.Route) {
for _, ri := range r {
var (
f []*routing.RouteFilter
foundOriginMarkers bool
)
for i, fi := range ri.Filters {
_, isOriginMarker := fi.Filter.(*OriginMarker)
if !isOriginMarker && !foundOriginMarkers {
continue
}
if !isOriginMarker {
f = append(f, fi)
continue
}
if !foundOriginMarkers {
f = make([]*routing.RouteFilter, i, len(ri.Filters))
copy(f, ri.Filters[:i])
foundOriginMarkers = true
}
}
if foundOriginMarkers {
ri.Filters = f
}
}
}
// Do implements routing.PostProcessor and records the filter creation time.
func (m *RouteCreationMetrics) Do(routes []*routing.Route) []*routing.Route {
return m.reportRouteCreationTimes(routes)
}
func (m *RouteCreationMetrics) reportRouteCreationTimes(routes []*routing.Route) []*routing.Route {
for _, r := range routes {
for origin, start := range m.startTimes(r) {
if m.initialized {
m.metrics.MeasureSince(metricsPrefix+origin, start)
}
}
}
if !m.initialized {
//must be done after filling the cache
m.initialized = true
}
m.pruneCache()
removeOriginMarkers(routes)
return routes
}
func (m *RouteCreationMetrics) startTimes(route *routing.Route) map[string]time.Time {
startTimes := map[string]time.Time{}
for _, f := range route.Filters {
origin, t := m.originStartTime(f.Filter)
if t.IsZero() {
continue
}
old, exists := startTimes[origin]
if !exists || t.Before(old) {
startTimes[origin] = t
}
}
return startTimes
}
func (m *RouteCreationMetrics) originStartTime(f filters.Filter) (string, time.Time) {
marker, ok := f.(*OriginMarker)
if !ok {
return "", time.Time{}
}
origin := marker.Origin
id := marker.Id
created := marker.Created
if origin == "" || id == "" || created.IsZero() {
return "", time.Time{}
}
originCache := m.originIdAges[origin]
if originCache == nil {
originCache = map[string]int{}
m.originIdAges[origin] = originCache
}
_, exists := originCache[id]
originCache[id] = 0
if exists {
return "", time.Time{}
}
log.WithFields(log.Fields{
"origin": origin,
"id": id,
"seconds": time.Since(created).Seconds(),
}).Debug("route creation time")
return origin, created
}
func (m *RouteCreationMetrics) pruneCache() {
for origin, idAges := range m.originIdAges {
for id, age := range idAges {
age++
if age > maxAge {
log.WithFields(log.Fields{
"origin": origin,
"id": id,
}).Debug("delete from route creation cache")
delete(idAges, id)
} else {
idAges[id] = age
}
}
if len(idAges) == 0 {
log.WithField("origin", origin).Debug("delete from route creation cache")
delete(m.originIdAges, origin)
}
}
}
package builtin
import (
"compress/flate"
"compress/gzip"
"fmt"
"io"
"net/http"
"runtime"
"strings"
"sync"
"github.com/andybalholm/brotli"
"github.com/zalando/skipper/filters"
)
const (
// DecompressionNotPossible is the state-bag key to indicate
// to the subsequent filters during response processing that the
// content is compressed, but decompression was not possible, e.g
// because the encoding is not supported.
DecompressionNotPossible = "filter::decompress::not-possible"
// DecompressionError is the state-bag key to indicate to the
// subsequent filters during response processing that the
// decompression of the content was attempted but failed. The
// response body may have been sniffed, and therefore it was
// discarded.
DecompressionError = "filter::decompress::error"
)
type decodedBody struct {
enc string
original io.Closer
decoder io.ReadCloser
isFromPool bool
}
type decodingError struct {
decoder error
original error
}
type decompress struct{}
// workaround to make brotli library compatible with decompress
type brotliWrapper struct {
brotli.Reader
}
func (brotliWrapper) Close() error { return nil }
var supportedEncodingsDecompress = map[string]*sync.Pool{
"gzip": {},
"deflate": {},
"br": {},
}
func init() {
// #cpu * 4: pool size decided based on some
// simple tests, checking performance by binary
// steps (https://github.com/zalando/skipper)
for enc, pool := range supportedEncodingsDecompress {
for i := 0; i < runtime.NumCPU()*4; i++ {
pool.Put(newDecoder(enc))
}
}
}
func newDecoder(enc string) io.ReadCloser {
switch enc {
case "gzip":
return new(gzip.Reader)
case "br":
return new(brotliWrapper)
default:
return flate.NewReader(nil)
}
}
func fromPool(enc string) (io.ReadCloser, bool) {
d, ok := supportedEncodingsDecompress[enc].Get().(io.ReadCloser)
return d, ok
}
func reset(decoder, original io.ReadCloser, enc string) error {
switch enc {
case "gzip":
return decoder.(*gzip.Reader).Reset(original)
case "br":
return decoder.(*brotliWrapper).Reset(original)
default:
return decoder.(flate.Resetter).Reset(original, nil)
}
}
func newDecodedBody(original io.ReadCloser, encs []string) (body io.ReadCloser, err error) {
if len(encs) == 0 {
body = original
return
}
last := len(encs) - 1
enc := encs[last]
encs = encs[:last]
decoder, isFromPool := fromPool(enc)
if !isFromPool {
decoder = newDecoder(enc)
}
if err = reset(decoder, original, enc); err != nil {
return
}
decoded := decodedBody{
enc: enc,
original: original,
decoder: decoder,
isFromPool: isFromPool,
}
return newDecodedBody(decoded, encs)
}
func (b decodedBody) Read(p []byte) (int, error) {
return b.decoder.Read(p)
}
func (b decodedBody) Close() error {
derr := b.decoder.Close()
if b.isFromPool {
if derr != nil {
supportedEncodingsDecompress[b.enc].Put(newDecoder(b.enc))
} else {
supportedEncodingsDecompress[b.enc].Put(b.decoder)
}
}
oerr := b.original.Close()
var err error
if derr != nil || oerr != nil {
err = decodingError{
decoder: derr,
original: oerr,
}
}
return err
}
func (e decodingError) Error() string {
switch {
case e.decoder == nil:
return e.original.Error()
case e.original == nil:
return e.decoder.Error()
default:
return fmt.Sprintf("%v; %v", e.decoder, e.original)
}
}
// NewDecompress creates a filter specification for the decompress() filter.
// The filter attempts to decompress the response body, if it was compressed
// with any of deflate, gzip or br.
//
// If decompression is not possible, but the body is compressed, then it indicates it
// with the "filter::decompress::not-possible" key in the state-bag. If the decompression
// was attempted and failed to get initialized, it indicates it in addition with the
// "filter::decompress::error" state-bag key, storing the error. Due to the streaming,
// decompression may fail after all the filters were processed.
//
// The filter does not need any parameters.
func NewDecompress() filters.Spec {
return decompress{}
}
func (d decompress) Name() string { return filters.DecompressName }
func (d decompress) CreateFilter([]interface{}) (filters.Filter, error) {
return d, nil
}
func (d decompress) Request(filters.FilterContext) {}
func getEncodings(header string) []string {
var encs []string
for r := range splitSeq(header, ",") {
r = strings.TrimSpace(r)
if r != "" {
encs = append(encs, r)
}
}
return encs
}
func encodingsSupported(encs []string) bool {
for _, e := range encs {
if _, supported := supportedEncodingsDecompress[e]; !supported {
return false
}
}
return true
}
func (d decompress) Response(ctx filters.FilterContext) {
rsp := ctx.Response()
encs := getEncodings(rsp.Header.Get("Content-Encoding"))
if len(encs) == 0 {
return
}
if !encodingsSupported(encs) {
ctx.StateBag()[DecompressionNotPossible] = true
return
}
rsp.Header.Del("Content-Encoding")
rsp.Header.Del("Vary")
rsp.Header.Del("Content-Length")
rsp.ContentLength = -1
b, err := newDecodedBody(rsp.Body, encs)
if err != nil {
// we may have already sniffed from the response via the gzip.Reader
rsp.Body.Close()
rsp.Body = http.NoBody
sb := ctx.StateBag()
sb[DecompressionNotPossible] = true
sb[DecompressionError] = err
ctx.Logger().Errorf("Error while initializing decompression: %v", err)
return
}
rsp.Body = b
}
package builtin
import (
"net/url"
"github.com/zalando/skipper/filters"
)
type dynamicBackendFilterType int
const (
setDynamicBackendHostFromHeader dynamicBackendFilterType = iota
setDynamicBackendSchemeFromHeader
setDynamicBackendUrlFromHeader
setDynamicBackendHost
setDynamicBackendScheme
setDynamicBackendUrl
)
type dynamicBackendFilter struct {
typ dynamicBackendFilterType
input string
}
// verifies that the filter config has one string parameter
func dynamicBackendFilterConfig(config []interface{}) (string, error) {
if len(config) != 1 {
return "", filters.ErrInvalidFilterParameters
}
input, ok := config[0].(string)
if !ok {
return "", filters.ErrInvalidFilterParameters
}
return input, nil
}
// Returns a filter specification that is used to set dynamic backend host from a header.
// Instances expect one parameters: a header name.
// Name: "setDynamicBackendHostFromHeader".
//
// If the header exists the value is put into the `StateBag`, additionally
// `SetOutgoingHost()` is used to set the host header
func NewSetDynamicBackendHostFromHeader() filters.Spec {
return &dynamicBackendFilter{typ: setDynamicBackendHostFromHeader}
}
// Returns a filter specification that is used to set dynamic backend scheme from a header.
// Instances expect one parameters: a header name.
// Name: "setDynamicBackendSchemeFromHeader".
//
// If the header exists the value is put into the `StateBag`
func NewSetDynamicBackendSchemeFromHeader() filters.Spec {
return &dynamicBackendFilter{typ: setDynamicBackendSchemeFromHeader}
}
// Returns a filter specification that is used to set dynamic backend url from a header.
// Instances expect one parameters: a header name.
// Name: "setDynamicBackendUrlFromHeader".
//
// If the header exists the value is put into the `StateBag`, additionally
// `SetOutgoingHost()` is used to set the host header if the header is a valid url
func NewSetDynamicBackendUrlFromHeader() filters.Spec {
return &dynamicBackendFilter{typ: setDynamicBackendUrlFromHeader}
}
// Returns a filter specification that is used to set dynamic backend host.
// Instances expect one parameters: a host name.
// Name: "setDynamicBackendHost".
//
// The value is put into the `StateBag`, additionally
// `SetOutgoingHost()` is used to set the host header
func NewSetDynamicBackendHost() filters.Spec {
return &dynamicBackendFilter{typ: setDynamicBackendHost}
}
// Returns a filter specification that is used to set dynamic backend scheme.
// Instances expect one parameters: a scheme name.
// Name: "setDynamicBackendScheme".
//
// The value is put into the `StateBag`
func NewSetDynamicBackendScheme() filters.Spec {
return &dynamicBackendFilter{typ: setDynamicBackendScheme}
}
// Returns a filter specification that is used to set dynamic backend url.
// Instances expect one parameters: a url.
// Name: "setDynamicBackendUrl".
//
// The value is put into the `StateBag`, additionally `SetOutgoingHost()`
// is used to set the host header if the input provided is a valid url
func NewSetDynamicBackendUrl() filters.Spec {
return &dynamicBackendFilter{typ: setDynamicBackendUrl}
}
func (spec *dynamicBackendFilter) Name() string {
switch spec.typ {
case setDynamicBackendHostFromHeader:
return filters.SetDynamicBackendHostFromHeader
case setDynamicBackendSchemeFromHeader:
return filters.SetDynamicBackendSchemeFromHeader
case setDynamicBackendUrlFromHeader:
return filters.SetDynamicBackendUrlFromHeader
case setDynamicBackendHost:
return filters.SetDynamicBackendHost
case setDynamicBackendScheme:
return filters.SetDynamicBackendScheme
case setDynamicBackendUrl:
return filters.SetDynamicBackendUrl
default:
panic("invalid type")
}
}
//lint:ignore ST1016 "spec" makes sense here and we reuse the type for the filter
func (spec *dynamicBackendFilter) CreateFilter(config []interface{}) (filters.Filter, error) {
input, err := dynamicBackendFilterConfig(config)
return &dynamicBackendFilter{typ: spec.typ, input: input}, err
}
func (f *dynamicBackendFilter) Request(ctx filters.FilterContext) {
switch f.typ {
case setDynamicBackendHostFromHeader:
header := ctx.Request().Header.Get(f.input)
if header != "" {
ctx.StateBag()[filters.DynamicBackendHostKey] = header
ctx.SetOutgoingHost(header)
}
case setDynamicBackendSchemeFromHeader:
header := ctx.Request().Header.Get(f.input)
if header != "" {
ctx.StateBag()[filters.DynamicBackendSchemeKey] = header
}
case setDynamicBackendUrlFromHeader:
header := ctx.Request().Header.Get(f.input)
if header != "" {
ctx.StateBag()[filters.DynamicBackendURLKey] = header
bu, err := url.ParseRequestURI(header)
if err == nil {
ctx.SetOutgoingHost(bu.Host)
}
}
case setDynamicBackendHost:
ctx.StateBag()[filters.DynamicBackendHostKey] = f.input
ctx.SetOutgoingHost(f.input)
case setDynamicBackendScheme:
ctx.StateBag()[filters.DynamicBackendSchemeKey] = f.input
case setDynamicBackendUrl:
ctx.StateBag()[filters.DynamicBackendURLKey] = f.input
bu, err := url.ParseRequestURI(f.input)
if err == nil {
ctx.SetOutgoingHost(bu.Host)
}
}
}
func (f *dynamicBackendFilter) Response(ctx filters.FilterContext) {}
// Copyright 2015 Zalando SE
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package builtin
import (
"github.com/zalando/skipper/filters"
)
type setFastCgiFilenameSpec struct {
fileName string
}
// NewSetFastCgiFilename returns a filter spec that makes it possible to change
// the FastCGI filename.
func NewSetFastCgiFilename() filters.Spec { return &setFastCgiFilenameSpec{} }
func (s *setFastCgiFilenameSpec) Name() string { return filters.SetFastCgiFilenameName }
func (s *setFastCgiFilenameSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
if a, ok := args[0].(string); ok {
return setFastCgiFilenameSpec{a}, nil
}
return nil, filters.ErrInvalidFilterParameters
}
func (s setFastCgiFilenameSpec) Response(_ filters.FilterContext) {}
func (s setFastCgiFilenameSpec) Request(ctx filters.FilterContext) {
ctx.StateBag()["fastCgiFilename"] = s.fileName
}
package builtin
import (
"fmt"
"strings"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
)
type headerType int
const (
setRequestHeader headerType = iota
appendRequestHeader
dropRequestHeader
setResponseHeader
appendResponseHeader
dropResponseHeader
setContextRequestHeader
appendContextRequestHeader
setContextResponseHeader
appendContextResponseHeader
copyRequestHeader
copyResponseHeader
copyRequestHeaderDeprecated
copyResponseHeaderDeprecated
depRequestHeader
depResponseHeader
)
const (
copyRequestHeaderDeprecatedName = "requestCopyHeader"
copyResponseHeaderDeprecatedName = "responseCopyHeader"
)
// common structure for requestHeader, responseHeader specifications and
// filters
type headerFilter struct {
typ headerType
key, value string
template *eskip.Template
}
// verifies that the filter config has two string parameters
func headerFilterConfig(typ headerType, config []interface{}) (string, string, *eskip.Template, error) {
switch typ {
case dropRequestHeader, dropResponseHeader:
if len(config) != 1 {
return "", "", nil, filters.ErrInvalidFilterParameters
}
default:
if len(config) != 2 {
return "", "", nil, filters.ErrInvalidFilterParameters
}
}
key, ok := config[0].(string)
if !ok {
return "", "", nil, filters.ErrInvalidFilterParameters
}
var value string
if len(config) == 2 {
value, ok = config[1].(string)
if !ok {
return "", "", nil, filters.ErrInvalidFilterParameters
}
}
switch typ {
case setRequestHeader, appendRequestHeader,
setResponseHeader, appendResponseHeader:
return key, "", eskip.NewTemplate(value), nil
default:
return key, value, nil, nil
}
}
// Deprecated: use setRequestHeader or appendRequestHeader
func NewRequestHeader() filters.Spec {
return &headerFilter{typ: depRequestHeader}
}
// Deprecated: use setRequestHeader or appendRequestHeader
func NewResponseHeader() filters.Spec {
return &headerFilter{typ: depResponseHeader}
}
// Returns a filter specification that is used to set headers for requests.
// Instances expect two parameters: the header name and the header value template,
// see eskip.Template.ApplyContext
// Name: "setRequestHeader".
//
// If the header name is 'Host', the filter uses the `SetOutgoingHost()`
// method to set the header in addition to the standard `Request.Header`
// map.
func NewSetRequestHeader() filters.Spec {
return &headerFilter{typ: setRequestHeader}
}
// Returns a filter specification that is used to append headers for requests.
// Instances expect two parameters: the header name and the header value template,
// see eskip.Template.ApplyContext
// Name: "appendRequestHeader".
//
// If the header name is 'Host', the filter uses the `SetOutgoingHost()`
// method to set the header in addition to the standard `Request.Header`
// map.
func NewAppendRequestHeader() filters.Spec {
return &headerFilter{typ: appendRequestHeader}
}
// Returns a filter specification that is used to delete headers for requests.
// Instances expect one parameter: the header name.
// Name: "dropRequestHeader".
func NewDropRequestHeader() filters.Spec {
return &headerFilter{typ: dropRequestHeader}
}
// Returns a filter specification that is used to set headers for responses.
// Instances expect two parameters: the header name and the header value template,
// see eskip.Template.ApplyContext
// Name: "setResponseHeader".
func NewSetResponseHeader() filters.Spec {
return &headerFilter{typ: setResponseHeader}
}
// Returns a filter specification that is used to append headers for responses.
// Instances expect two parameters: the header name and the header value template,
// see eskip.Template.ApplyContext
// Name: "appendResponseHeader".
func NewAppendResponseHeader() filters.Spec {
return &headerFilter{typ: appendResponseHeader}
}
// Returns a filter specification that is used to delete headers for responses.
// Instances expect one parameter: the header name.
// Name: "dropResponseHeader".
func NewDropResponseHeader() filters.Spec {
return &headerFilter{typ: dropResponseHeader}
}
// NewSetContextRequestHeader returns a filter specification used to set
// request headers with a given name and a value taken from the filter
// context state bag identified by its key.
func NewSetContextRequestHeader() filters.Spec {
return &headerFilter{typ: setContextRequestHeader}
}
// NewAppendContextRequestHeader returns a filter specification used to append
// request headers with a given name and a value taken from the filter
// context state bag identified by its key.
func NewAppendContextRequestHeader() filters.Spec {
return &headerFilter{typ: appendContextRequestHeader}
}
// NewSetContextResponseHeader returns a filter specification used to set
// response headers with a given name and a value taken from the filter
// context state bag identified by its key.
func NewSetContextResponseHeader() filters.Spec {
return &headerFilter{typ: setContextResponseHeader}
}
// NewAppendContextResponseHeader returns a filter specification used to append
// response headers with a given name and a value taken from the filter
// context state bag identified by its key.
func NewAppendContextResponseHeader() filters.Spec {
return &headerFilter{typ: appendContextResponseHeader}
}
// NewCopyRequestHeader creates a filter specification whose instances
// copies a specified source Header to a defined destination Header
// from the request to the proxy request.
func NewCopyRequestHeader() filters.Spec {
return &headerFilter{typ: copyRequestHeader}
}
// NewCopyResponseHeader creates a filter specification whose instances
// copies a specified source Header to a defined destination Header
// from the backend response to the proxy response.
func NewCopyResponseHeader() filters.Spec {
return &headerFilter{typ: copyResponseHeader}
}
func NewCopyRequestHeaderDeprecated() filters.Spec {
return &headerFilter{typ: copyRequestHeaderDeprecated}
}
func NewCopyResponseHeaderDeprecated() filters.Spec {
return &headerFilter{typ: copyResponseHeaderDeprecated}
}
func (spec *headerFilter) Name() string {
switch spec.typ {
case setRequestHeader:
return filters.SetRequestHeaderName
case appendRequestHeader:
return filters.AppendRequestHeaderName
case dropRequestHeader:
return filters.DropRequestHeaderName
case setResponseHeader:
return filters.SetResponseHeaderName
case appendResponseHeader:
return filters.AppendResponseHeaderName
case dropResponseHeader:
return filters.DropResponseHeaderName
case depRequestHeader:
return RequestHeaderName
case depResponseHeader:
return ResponseHeaderName
case setContextRequestHeader:
return filters.SetContextRequestHeaderName
case appendContextRequestHeader:
return filters.AppendContextRequestHeaderName
case setContextResponseHeader:
return filters.SetContextResponseHeaderName
case appendContextResponseHeader:
return filters.AppendContextResponseHeaderName
case copyRequestHeader:
return filters.CopyRequestHeaderName
case copyResponseHeader:
return filters.CopyResponseHeaderName
case copyRequestHeaderDeprecated:
return copyRequestHeaderDeprecatedName
case copyResponseHeaderDeprecated:
return copyResponseHeaderDeprecatedName
default:
panic("invalid header type")
}
}
//lint:ignore ST1016 "spec" makes sense here and we reuse the type for the filter
func (spec *headerFilter) CreateFilter(config []interface{}) (filters.Filter, error) {
key, value, template, err := headerFilterConfig(spec.typ, config)
return &headerFilter{typ: spec.typ, key: key, value: value, template: template}, err
}
func valueFromContext(
ctx filters.FilterContext,
headerName,
contextKey string,
isRequest bool,
apply func(string, string),
) {
contextValue, ok := ctx.StateBag()[contextKey]
if !ok {
return
}
stringValue := fmt.Sprint(contextValue)
apply(headerName, stringValue)
if isRequest && strings.ToLower(headerName) == "host" {
ctx.SetOutgoingHost(stringValue)
}
}
func (f *headerFilter) Request(ctx filters.FilterContext) {
header := ctx.Request().Header
switch f.typ {
case setRequestHeader:
value, ok := f.template.ApplyContext(ctx)
if ok {
header.Set(f.key, value)
if strings.ToLower(f.key) == "host" {
ctx.SetOutgoingHost(value)
}
}
case appendRequestHeader:
value, ok := f.template.ApplyContext(ctx)
if ok {
header.Add(f.key, value)
if strings.ToLower(f.key) == "host" {
ctx.SetOutgoingHost(value)
}
}
case depRequestHeader:
header.Add(f.key, f.value)
if strings.ToLower(f.key) == "host" {
ctx.SetOutgoingHost(f.value)
}
case dropRequestHeader:
header.Del(f.key)
case setContextRequestHeader:
valueFromContext(ctx, f.key, f.value, true, header.Set)
case appendContextRequestHeader:
valueFromContext(ctx, f.key, f.value, true, header.Add)
case copyRequestHeader, copyRequestHeaderDeprecated:
headerValue := header.Get(f.key)
if headerValue != "" {
header.Set(f.value, headerValue)
if strings.ToLower(f.value) == "host" {
ctx.SetOutgoingHost(headerValue)
}
}
}
}
func (f *headerFilter) Response(ctx filters.FilterContext) {
header := ctx.Response().Header
switch f.typ {
case setResponseHeader:
value, ok := f.template.ApplyContext(ctx)
if ok {
header.Set(f.key, value)
}
case appendResponseHeader:
value, ok := f.template.ApplyContext(ctx)
if ok {
header.Add(f.key, value)
}
case depResponseHeader:
header.Add(f.key, f.value)
case dropResponseHeader:
header.Del(f.key)
case setContextResponseHeader:
valueFromContext(ctx, f.key, f.value, false, header.Set)
case appendContextResponseHeader:
valueFromContext(ctx, f.key, f.value, false, header.Add)
case copyResponseHeader, copyResponseHeaderDeprecated:
headerValue := header.Get(f.key)
if headerValue != "" {
header.Set(f.value, headerValue)
}
}
}
package builtin
import (
"github.com/zalando/skipper/filters"
)
type (
headerToQuerySpec struct {
}
headerToQueryFilter struct {
headerName string
queryParamName string
}
)
// NewHeaderToQuery creates a filter which converts the headers
// from the incoming Request to query params
//
// headerToQuery("X-Foo-Header", "foo-query-param")
//
// The above filter will set the "foo-query-param" query param
// to the value of "X-Foo-Header" header, to the request
// and will override the value if the queryparam exists already
func NewHeaderToQuery() filters.Spec {
return &headerToQuerySpec{}
}
func (*headerToQuerySpec) Name() string {
return filters.HeaderToQueryName
}
// CreateFilter creates a `headerToQuery` filter instance with below signature
// s.CreateFilter("X-Foo-Header", "foo-query-param")
func (*headerToQuerySpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
h, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
q, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return &headerToQueryFilter{h, q}, nil
}
func (f *headerToQueryFilter) Request(ctx filters.FilterContext) {
req := ctx.Request()
params := req.URL.Query()
headerValue := req.Header.Get(f.headerName)
params.Set(f.queryParamName, headerValue)
req.URL.RawQuery = params.Encode()
}
func (*headerToQueryFilter) Response(ctx filters.FilterContext) {}
// Copyright 2015 Zalando SE
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package builtin
import (
"github.com/zalando/skipper/filters"
"net/http"
)
type healthCheck struct{}
// Creates a new filter Spec, whose instances set the status code of the
// response to 200 OK. Name: "healthcheck".
func NewHealthCheck() filters.Spec { return &healthCheck{} }
// "healthcheck"
func (h *healthCheck) Name() string { return filters.HealthCheckName }
func (h *healthCheck) CreateFilter(_ []interface{}) (filters.Filter, error) { return h, nil }
func (h *healthCheck) Request(ctx filters.FilterContext) {}
func (h *healthCheck) Response(ctx filters.FilterContext) { ctx.Response().StatusCode = http.StatusOK }
package builtin
import (
"bytes"
"io"
"net/http"
"strconv"
"github.com/zalando/skipper/filters"
)
type inlineContent struct {
text string
mime string
}
// Creates a filter spec for the inlineContent() filter.
//
// Usage of the filter:
//
// r: * -> status(420) -> inlineContent("Enhance Your Calm") -> <shunt>;
//
// Or:
//
// r: * -> inlineContent("{\"foo\": 42}", "application/json") -> <shunt>;
//
// It accepts two arguments: the content and the optional content type.
// When the content type is not set, it tries to detect it using
// http.DetectContentType.
//
// The filter shunts the request with status code 200.
func NewInlineContent() filters.Spec {
return &inlineContent{}
}
func (c *inlineContent) Name() string { return filters.InlineContentName }
func stringArg(a interface{}) (s string, err error) {
var ok bool
s, ok = a.(string)
if !ok {
err = filters.ErrInvalidFilterParameters
}
return
}
func (c *inlineContent) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) == 0 || len(args) > 2 {
return nil, filters.ErrInvalidFilterParameters
}
var (
f inlineContent
err error
)
f.text, err = stringArg(args[0])
if err != nil {
return nil, err
}
if len(args) == 2 {
f.mime, err = stringArg(args[1])
if err != nil {
return nil, err
}
} else {
f.mime = http.DetectContentType([]byte(f.text))
}
return &f, nil
}
func (c *inlineContent) Request(ctx filters.FilterContext) {
ctx.Serve(&http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{c.mime},
"Content-Length": []string{strconv.Itoa(len(c.text))},
},
Body: io.NopCloser(bytes.NewBufferString(c.text)),
})
}
func (c *inlineContent) Response(filters.FilterContext) {}
package builtin
import (
"bytes"
"io"
"net/http"
"strconv"
"github.com/zalando/skipper/filters"
)
type inlineContentIfStatus struct {
statusCode int
text string
mime string
}
// Creates a filter spec for the inlineContent() filter.
//
// r: * -> inlineContentIfStatus(401, "{\"foo\": 42}", "application/json") -> "https://www.example.org";
//
// It accepts three arguments: the statusCode code to match, the content and the optional content type.
// When the content type is not set, it tries to detect it using http.DetectContentType.
//
// The filter replaces the response coming from the backend or the following filters.
func NewInlineContentIfStatus() filters.Spec {
return &inlineContentIfStatus{}
}
func (c *inlineContentIfStatus) Name() string { return filters.InlineContentIfStatusName }
func (c *inlineContentIfStatus) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 2 || len(args) > 3 {
return nil, filters.ErrInvalidFilterParameters
}
var (
f inlineContentIfStatus
ok bool
)
f.statusCode, ok = args[0].(int)
if !ok {
floatStatusCode, ok := args[0].(float64)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
f.statusCode = int(floatStatusCode)
}
if f.statusCode < 100 || f.statusCode >= 600 {
return nil, filters.ErrInvalidFilterParameters
}
f.text, ok = args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
if len(args) == 3 {
f.mime, ok = args[2].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
} else {
f.mime = http.DetectContentType([]byte(f.text))
}
return &f, nil
}
func (c *inlineContentIfStatus) Request(filters.FilterContext) {}
func (c *inlineContentIfStatus) Response(ctx filters.FilterContext) {
if ctx.Response().StatusCode != c.statusCode {
return
}
rsp := ctx.Response()
err := rsp.Body.Close()
if err != nil {
ctx.Logger().Errorf("%v", err)
}
contentLength := len(c.text)
rsp.ContentLength = int64(contentLength)
rsp.Header.Set("Content-Type", c.mime)
rsp.Header.Set("Content-Length", strconv.Itoa(contentLength))
rsp.Body = io.NopCloser(bytes.NewBufferString(c.text))
}
package builtin
import (
"net/http"
"regexp"
"strings"
"github.com/zalando/skipper/filters"
)
type modRequestHeader struct {
headerName string
rx *regexp.Regexp
replacement string
}
// NewModRequestHeader returns a new filter Spec, whose instances execute
// regexp.ReplaceAllString on the request host. Instances expect three
// parameters: the header name, the expression to match and the replacement string.
// Name: "modRequestHeader".
func NewModRequestHeader() filters.Spec { return &modRequestHeader{} }
func (spec *modRequestHeader) Name() string {
return filters.ModRequestHeaderName
}
//lint:ignore ST1016 "spec" makes sense here and we reuse the type for the filter
func (spec *modRequestHeader) CreateFilter(config []interface{}) (filters.Filter, error) {
if len(config) != 3 {
return nil, filters.ErrInvalidFilterParameters
}
headerName, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
expr, ok := config[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
replacement, ok := config[2].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
rx, err := regexp.Compile(expr)
if err != nil {
return nil, err
}
return &modRequestHeader{headerName: headerName, rx: rx, replacement: replacement}, nil
}
func (f *modRequestHeader) Request(ctx filters.FilterContext) {
req := ctx.Request()
if strings.ToLower(f.headerName) == "host" {
nh := f.rx.ReplaceAllString(getRequestHost(req), f.replacement)
req.Header.Set(f.headerName, nh)
ctx.SetOutgoingHost(nh)
return
}
if _, ok := req.Header[http.CanonicalHeaderKey(f.headerName)]; !ok {
return
}
req.Header.Set(f.headerName, f.rx.ReplaceAllString(req.Header.Get(f.headerName), f.replacement))
}
func (*modRequestHeader) Response(filters.FilterContext) {}
type modResponseHeader struct {
headerName string
rx *regexp.Regexp
replacement string
}
// NewModResponseHeader returns a new filter Spec, whose instances execute
// regexp.ReplaceAllString on the request host. Instances expect three
// parameters: the header name, the expression to match and the replacement string.
// Name: "modResponseHeader".
func NewModResponseHeader() filters.Spec { return &modResponseHeader{} }
func (spec *modResponseHeader) Name() string {
return filters.ModResponseHeaderName
}
//lint:ignore ST1016 "spec" makes sense here and we reuse the type for the filter
func (spec *modResponseHeader) CreateFilter(config []interface{}) (filters.Filter, error) {
if len(config) != 3 {
return nil, filters.ErrInvalidFilterParameters
}
headerName, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
expr, ok := config[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
replacement, ok := config[2].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
rx, err := regexp.Compile(expr)
if err != nil {
return nil, err
}
return &modResponseHeader{headerName: headerName, rx: rx, replacement: replacement}, nil
}
func (*modResponseHeader) Request(filters.FilterContext) {}
func (f *modResponseHeader) Response(ctx filters.FilterContext) {
resp := ctx.Response()
if _, ok := resp.Header[http.CanonicalHeaderKey(f.headerName)]; !ok {
return
}
resp.Header.Set(f.headerName, f.rx.ReplaceAllString(resp.Header.Get(f.headerName), f.replacement))
}
package builtin
import (
"time"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.OriginMarkerName instead
OriginMarkerName = filters.OriginMarkerName
)
type originMarkerSpec struct{}
// OriginMarker carries information about the origin of a route
type OriginMarker struct {
// the type of origin, e.g. ingress
Origin string `json:"origin"`
// the unique ID (within the origin) of the source object (e.g. ingress) that created the route
Id string `json:"id"`
// when the source object was created
Created time.Time `json:"created"`
}
// NewOriginMarkerSpec creates a filter specification whose instances
// mark the origin an eskip.Route
func NewOriginMarkerSpec() filters.Spec {
return &originMarkerSpec{}
}
func NewOriginMarker(origin string, id string, created time.Time) *eskip.Filter {
return &eskip.Filter{
Name: filters.OriginMarkerName,
Args: []interface{}{origin, id, created},
}
}
func (s *originMarkerSpec) Name() string { return filters.OriginMarkerName }
func (s *originMarkerSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 3 {
return nil, filters.ErrInvalidFilterParameters
}
f := &OriginMarker{}
if value, ok := args[0].(string); ok {
f.Origin = value
} else {
return nil, filters.ErrInvalidFilterParameters
}
if value, ok := args[1].(string); ok {
f.Id = value
} else {
return nil, filters.ErrInvalidFilterParameters
}
switch created := args[2].(type) {
case time.Time:
f.Created = created
case string:
if value, err := time.Parse(time.RFC3339, created); err == nil {
f.Created = value
} else {
return nil, filters.ErrInvalidFilterParameters
}
default:
return nil, filters.ErrInvalidFilterParameters
}
return f, nil
}
func (m OriginMarker) Request(filters.FilterContext) {}
func (m OriginMarker) Response(filters.FilterContext) {}
package builtin
import (
"regexp"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
)
type modPathBehavior int
const (
regexpReplace modPathBehavior = 1 + iota
fullReplace
)
type modPath struct {
behavior modPathBehavior
rx *regexp.Regexp
replacement string
template *eskip.Template
}
// Returns a new modpath filter Spec, whose instances execute
// regexp.ReplaceAllString on the request path. Instances expect two
// parameters: the expression to match and the replacement string.
// Name: "modpath".
func NewModPath() filters.Spec { return &modPath{behavior: regexpReplace} }
// Returns a new setPath filter Spec, whose instances replace
// the request path.
//
// Instances expect one parameter: the new path to be set, or the path
// template to be evaluated, see eskip.Template.ApplyContext
//
// Name: "setPath".
func NewSetPath() filters.Spec { return &modPath{behavior: fullReplace} }
// "modPath" or "setPath"
func (spec *modPath) Name() string {
switch spec.behavior {
case regexpReplace:
return filters.ModPathName
case fullReplace:
return filters.SetPathName
default:
panic("unspecified behavior")
}
}
func createModPath(config []interface{}) (filters.Filter, error) {
if len(config) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
expr, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
replacement, ok := config[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
rx, err := regexp.Compile(expr)
if err != nil {
return nil, err
}
return &modPath{behavior: regexpReplace, rx: rx, replacement: replacement}, nil
}
func createSetPath(config []interface{}) (filters.Filter, error) {
if len(config) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
tpl, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return &modPath{behavior: fullReplace, template: eskip.NewTemplate(tpl)}, nil
}
// Creates instances of the modPath filter.
//
//lint:ignore ST1016 "spec" makes sense here and we reuse the type for the filter
func (spec *modPath) CreateFilter(config []interface{}) (filters.Filter, error) {
switch spec.behavior {
case regexpReplace:
return createModPath(config)
case fullReplace:
return createSetPath(config)
default:
panic("unspecified behavior")
}
}
// Modifies the path with regexp.ReplaceAllString.
func (f *modPath) Request(ctx filters.FilterContext) {
req := ctx.Request()
switch f.behavior {
case regexpReplace:
req.URL.Path = f.rx.ReplaceAllString(req.URL.Path, f.replacement)
case fullReplace:
req.URL.Path, _ = f.template.ApplyContext(ctx)
default:
panic("unspecified behavior")
}
}
// Noop.
func (*modPath) Response(filters.FilterContext) {}
// Copyright 2015 Zalando SE
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package builtin
import (
"net/url"
"github.com/zalando/skipper/filters"
)
type spec struct{}
type filter bool
// Returns a filter specification whose filter instances are used to override
// the `proxyPreserveHost` behavior for individual routes.
//
// Instances expect one argument, with the possible values: "true" or "false",
// where "true" means to use the Host header from the incoming request, and
// "false" means to use the host from the backend address.
//
// The filter takes no effect in either case if another filter modifies the
// outgoing host header to a value other than the one in the incoming request
// or in the backend address.
func PreserveHost() filters.Spec { return &spec{} }
func (s *spec) Name() string { return filters.PreserveHostName }
func (s *spec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
if a, ok := args[0].(string); ok && a == "true" || a == "false" {
return filter(a == "true"), nil
} else {
return nil, filters.ErrInvalidFilterParameters
}
}
func (preserve filter) Response(_ filters.FilterContext) {}
func (preserve filter) Request(ctx filters.FilterContext) {
u, err := url.Parse(ctx.BackendUrl())
if err != nil {
ctx.Logger().Errorf("failed to parse backend host in preserveHost filter %v", err)
return
}
if preserve && ctx.OutgoingHost() == u.Host {
ctx.SetOutgoingHost(ctx.Request().Host)
} else if !preserve && ctx.OutgoingHost() == ctx.Request().Host {
ctx.SetOutgoingHost(u.Host)
}
}
package builtin
import (
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
)
type modQueryBehavior int
const (
set modQueryBehavior = 1 + iota
drop
)
type modQuery struct {
behavior modQueryBehavior
name *eskip.Template
value *eskip.Template
}
// Returns a new dropQuery filter Spec, whose instances drop a corresponding
// query parameter.
//
// # Instances expect the name string or template parameter, see eskip.Template.ApplyContext
//
// Name: "dropQuery".
func NewDropQuery() filters.Spec { return &modQuery{behavior: drop} }
// Returns a new setQuery filter Spec, whose instances replace
// the query parameters.
//
// Instances expect two parameters: the name and the value to be set, either
// strings or templates are valid, see eskip.Template.ApplyContext
//
// Name: "setQuery".
func NewSetQuery() filters.Spec { return &modQuery{behavior: set} }
// "setQuery" or "dropQuery"
func (spec *modQuery) Name() string {
switch spec.behavior {
case drop:
return filters.DropQueryName
case set:
return filters.SetQueryName
default:
panic("unspecified behavior")
}
}
func createDropQuery(config []interface{}) (filters.Filter, error) {
if len(config) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
tpl, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return &modQuery{behavior: drop, name: eskip.NewTemplate(tpl)}, nil
}
func createSetQuery(config []interface{}) (filters.Filter, error) {
l := len(config)
if l < 1 || l > 2 {
return nil, filters.ErrInvalidFilterParameters
}
name, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
if l == 1 {
return &modQuery{behavior: set, name: eskip.NewTemplate(name)}, nil
}
value, ok := config[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return &modQuery{behavior: set, name: eskip.NewTemplate(name), value: eskip.NewTemplate(value)}, nil
}
// Creates instances of the modQuery filter.
//
//lint:ignore ST1016 "spec" makes sense here and we reuse the type for the filter
func (spec *modQuery) CreateFilter(config []interface{}) (filters.Filter, error) {
switch spec.behavior {
case drop:
return createDropQuery(config)
case set:
return createSetQuery(config)
default:
panic("unspecified behavior")
}
}
// Modifies the query of a request.
func (f *modQuery) Request(ctx filters.FilterContext) {
req := ctx.Request()
params := req.URL.Query()
switch f.behavior {
case drop:
name, _ := f.name.ApplyContext(ctx)
params.Del(name)
case set:
if f.value == nil {
req.URL.RawQuery, _ = f.name.ApplyContext(ctx)
return
} else {
name, _ := f.name.ApplyContext(ctx)
value, _ := f.value.ApplyContext(ctx)
params.Set(name, value)
}
default:
panic("unspecified behavior")
}
req.URL.RawQuery = params.Encode()
}
// Noop.
func (*modQuery) Response(filters.FilterContext) {}
package builtin
import (
"fmt"
"github.com/zalando/skipper/filters"
)
type (
queryToHeaderSpec struct {
}
queryToHeaderFilter struct {
headerName string
queryParamName string
formatString string
}
)
// NewQueryToHeader creates a filter which converts query params
// from the incoming Request to headers
//
// queryToHeader("foo-query-param", "X-Foo-Header")
//
// The above filter will set the value of "X-Foo-Header" header to the
// value of "foo-query-param" query param , to the request and will
// not override the value if the header exists already
//
// The header value can be created by a formatstring with an optional third parameter
//
// queryToHeader("foo-query-param", "X-Foo-Header", "prefix %s postfix")
// queryToHeader("access_token", "Authorization", "Bearer %s")
func NewQueryToHeader() filters.Spec {
return &queryToHeaderSpec{}
}
func (*queryToHeaderSpec) Name() string {
return filters.QueryToHeaderName
}
// CreateFilter creates a `queryToHeader` filter instance with below signature
// s.CreateFilter("foo-query-param", "X-Foo-Header")
func (*queryToHeaderSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if l := len(args); l < 2 || l > 3 {
return nil, filters.ErrInvalidFilterParameters
}
q, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
h, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
formatString := "%s"
if len(args) == 3 {
formatString, ok = args[2].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
}
return &queryToHeaderFilter{headerName: h, queryParamName: q, formatString: formatString}, nil
}
func (f *queryToHeaderFilter) Request(ctx filters.FilterContext) {
req := ctx.Request()
headerValue := req.Header.Get(f.headerName)
if headerValue != "" {
return
}
v := req.URL.Query().Get(f.queryParamName)
if v == "" {
return
}
req.Header.Set(f.headerName, fmt.Sprintf(f.formatString, v))
}
func (*queryToHeaderFilter) Response(ctx filters.FilterContext) {}
// Copyright 2015 Zalando SE
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package builtin
import (
"net/http"
"net/url"
"strings"
"github.com/zalando/skipper/filters"
)
type redirectType int
const (
redDeprecated redirectType = iota
redTo
redToLower
)
// Filter to return
type redirect struct {
typ redirectType
code int
location *url.URL
}
// NewRedirect returns a new filter Spec, whose instances create an HTTP redirect
// response. Marks the request as served. Instances expect two
// parameters: the redirect status code and the redirect location.
// Name: "redirect".
//
// This filter is deprecated, use RedirectTo instead.
// This *DEPRECATED* filter can not be used with filters from scheduler package.
func NewRedirect() filters.Spec { return &redirect{typ: redDeprecated} }
// NewRedirectTo returns a new filter Spec, whose instances create an HTTP redirect
// response. It shunts the request flow, meaning that the filter chain on
// the request path is not continued. The request is not forwarded to the
// backend. Instances expect two parameters: the redirect status code and
// the redirect location.
// Name: "redirectTo".
func NewRedirectTo() filters.Spec { return &redirect{typ: redTo} }
// NewRedirectLower returns a new filter Spec, whose instances create an HTTP redirect
// response, which redirects with a lowercase path. It is similar to redTo except that
// it converts the route path to lower while redirecting
// Name: "redirectToLower".
func NewRedirectLower() filters.Spec { return &redirect{typ: redToLower} }
// "redirect" or "redirectToLower" or "redirectTo"
func (spec *redirect) Name() string {
switch spec.typ {
case redDeprecated:
return RedirectName
case redToLower:
return filters.RedirectToLowerName
default:
return filters.RedirectToName
}
}
// Creates an instance of the redirect filter.
func (spec *redirect) CreateFilter(config []interface{}) (filters.Filter, error) {
invalidArgs := func() (filters.Filter, error) {
return nil, filters.ErrInvalidFilterParameters
}
if len(config) == 1 {
config = append(config, "")
}
if len(config) != 2 {
return invalidArgs()
}
code, ok := config[0].(float64)
if !ok {
return invalidArgs()
}
location, ok := config[1].(string)
if !ok {
return invalidArgs()
}
u, err := url.Parse(location)
if err != nil {
return invalidArgs()
}
return &redirect{spec.typ, int(code), u}, nil
}
func getRequestHost(r *http.Request) string {
h := r.Header.Get("Host")
if h == "" {
h = r.Host
}
if h == "" {
h = r.URL.Host
}
return h
}
func getLocation(ctx filters.FilterContext, location *url.URL, typ redirectType) string {
r := ctx.Request()
uc := *location
u := &uc
if u.Scheme == "" {
if r.URL.Scheme != "" {
u.Scheme = r.URL.Scheme
} else {
u.Scheme = "https"
}
}
u.User = r.URL.User
if u.Host == "" {
u.Host = getRequestHost(r)
}
if u.Path == "" {
u.Path = r.URL.Path
}
// Check if the redirect has to be case-insensitive
if typ == redToLower {
u.Path = strings.ToLower(u.Path)
}
if u.RawQuery == "" {
u.RawQuery = r.URL.RawQuery
}
return u.String()
}
func redirectWithType(ctx filters.FilterContext, code int, location *url.URL, typ redirectType) {
u := getLocation(ctx, location, typ)
ctx.Serve(&http.Response{
StatusCode: code,
Header: http.Header{"Location": []string{u}}})
}
// Redirect implements the redirect logic as a standalone function.
func Redirect(ctx filters.FilterContext, code int, location *url.URL) {
redirectWithType(ctx, code, location, redTo)
}
func (spec *redirect) Request(ctx filters.FilterContext) {
if spec.typ == redDeprecated {
return
}
redirectWithType(ctx, spec.code, spec.location, spec.typ)
}
// Sets the status code and the location header of the response. Marks the
// request served.
func (spec *redirect) Response(ctx filters.FilterContext) {
if spec.typ != redDeprecated {
return
}
u := getLocation(ctx, spec.location, spec.typ)
w := ctx.ResponseWriter()
w.Header().Set("Location", u)
w.WriteHeader(spec.code)
ctx.MarkServed()
}
package builtin
import (
"fmt"
"net/http"
"os"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/serve"
)
type static struct {
handler http.Handler
}
// Returns a filter Spec to serve static content from a file system
// location. Behaves similarly to net/http.FileServer. It shunts the route.
//
// Filter instances of this specification expect two parameters: a
// request path prefix and a local directory path. When processing a
// request, it clips the prefix from the request path, and appends the
// rest of the path to the directory path. Then, it uses the resulting
// path to serve static content from the file system.
//
// Name: "static".
func NewStatic() filters.Spec { return &static{} }
// "static"
func (spec *static) Name() string { return filters.StaticName }
// Creates instances of the static filter. Expects two parameters: request path
// prefix and file system root.
//
//lint:ignore ST1016 "spec" makes sense here and we reuse the type for the filter
func (spec *static) CreateFilter(config []interface{}) (filters.Filter, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid number of args: %d, expected 2", len(config))
}
webRoot, ok := config[0].(string)
if !ok {
return nil, fmt.Errorf("invalid parameter type, expected string for web root prefix")
}
root, ok := config[1].(string)
if !ok {
log.Errorf("Invalid parameter type, expected string for path to root dir")
return nil, filters.ErrInvalidFilterParameters
}
if ok, err := existsAndAccessible(root); !ok {
log.Errorf("Invalid parameter for root path. File %s does not exist or is not accessible: %v", root, err)
return nil, filters.ErrInvalidFilterParameters
}
return &static{http.StripPrefix(webRoot, http.FileServer(http.Dir(root)))}, nil
}
// Serves content from the file system and marks the request served.
func (f *static) Request(ctx filters.FilterContext) {
serve.ServeHTTP(ctx, f.handler)
}
// Noop.
func (f *static) Response(filters.FilterContext) {}
// Checks if the file does exist and is accessible
func existsAndAccessible(path string) (bool, error) {
if _, err := os.Stat(path); err != nil {
return os.IsExist(err), err
}
return true, nil
}
package builtin
import "github.com/zalando/skipper/filters"
type statusSpec struct{}
type statusFilter int
// Creates a filter specification whose instances set the
// status of the response to a fixed value regardless of
// backend response.
func NewStatus() filters.Spec { return new(statusSpec) }
func (s *statusSpec) Name() string { return filters.StatusName }
func (s *statusSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
switch c := args[0].(type) {
case int:
return statusFilter(c), nil
case float64:
return statusFilter(c), nil
default:
return nil, filters.ErrInvalidFilterParameters
}
}
func (f statusFilter) Request(filters.FilterContext) {}
func (f statusFilter) Response(ctx filters.FilterContext) {
ctx.Response().StatusCode = int(f)
}
// Copyright 2015 Zalando SE
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package builtin
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/zalando/skipper/filters"
)
type stripQuery struct {
preserveAsHeader bool
}
// Returns a filter Spec to strip query parameters from the request and
// optionally transpose them to request headers.
//
// It always removes the query parameter from the request URL, and if the
// first filter parameter is "true", preserves the query parameter in the form
// of x-query-param-<queryParamName>: <queryParamValue> headers, so that
// ?foo=bar becomes x-query-param-foo: bar
//
// Name: "stripQuery".
func NewStripQuery() filters.Spec { return &stripQuery{} }
// "stripQuery"
func (stripQuery) Name() string { return filters.StripQueryName }
// copied from textproto/reader
func validHeaderFieldByte(b byte) bool {
return ('A' <= b && b <= 'Z') ||
('a' <= b && b <= 'z') ||
('0' <= b && b <= '9') ||
b == '-'
}
// make sure we don't generate invalid headers
func sanitize(input string) string {
var s strings.Builder
toAscii := strconv.QuoteToASCII(input)
for _, i := range toAscii {
if validHeaderFieldByte(byte(i)) {
s.WriteRune(i)
}
}
return s.String()
}
// Strips the query parameters and optionally preserves them in the X-Query-Param-xyz headers.
func (f *stripQuery) Request(ctx filters.FilterContext) {
r := ctx.Request()
if r == nil {
return
}
url := r.URL
if url == nil {
return
}
if !f.preserveAsHeader {
url.RawQuery = ""
return
}
q := url.Query()
for k, vv := range q {
for _, v := range vv {
if r.Header == nil {
r.Header = http.Header{}
}
r.Header.Add(fmt.Sprintf("X-Query-Param-%s", sanitize(k)), v)
}
}
url.RawQuery = ""
}
// Noop.
func (stripQuery) Response(filters.FilterContext) {}
// Creates instances of the stripQuery filter. Accepts one optional parameter:
// "true", in order to preserve the stripped parameters in the request header.
func (stripQuery) CreateFilter(config []interface{}) (filters.Filter, error) {
var preserveAsHeader = false
if len(config) == 1 {
preserveAsHeaderString, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
if strings.ToLower(preserveAsHeaderString) == "true" {
preserveAsHeader = true
}
}
return &stripQuery{preserveAsHeader}, nil
}
package builtin
import (
"time"
"github.com/zalando/skipper/filters"
)
type timeoutType int
const (
backendTimeout timeoutType = iota + 1
readTimeout
writeTimeout
)
type timeout struct {
typ timeoutType
timeout time.Duration
}
func NewBackendTimeout() filters.Spec {
return &timeout{
typ: backendTimeout,
}
}
func NewReadTimeout() filters.Spec {
return &timeout{
typ: readTimeout,
}
}
func NewWriteTimeout() filters.Spec {
return &timeout{
typ: writeTimeout,
}
}
func (t *timeout) Name() string {
switch t.typ {
case backendTimeout:
return filters.BackendTimeoutName
case readTimeout:
return filters.ReadTimeoutName
case writeTimeout:
return filters.WriteTimeoutName
}
return "unknownFilter"
}
func (t *timeout) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
var tf timeout
tf.typ = t.typ
switch v := args[0].(type) {
case string:
d, err := time.ParseDuration(v)
if err != nil {
return nil, err
}
tf.timeout = d
case time.Duration:
tf.timeout = v
default:
return nil, filters.ErrInvalidFilterParameters
}
return &tf, nil
}
// Request allows overwrite of timeout settings.
//
// Type backend timeout sets the timeout for the backend roundtrip.
//
// Type read timeout sets the timeout to read the request including the body.
// It uses http.ResponseController to SetReadDeadline().
//
// Type write timeout allows to set a timeout for writing the response.
// It uses http.ResponseController to SetWriteDeadline().
//
// All these timeouts are set at specific points in proxy.Proxy.
func (t *timeout) Request(ctx filters.FilterContext) {
switch t.typ {
case backendTimeout:
ctx.StateBag()[filters.BackendTimeout] = t.timeout
case readTimeout:
ctx.StateBag()[filters.ReadTimeout] = t.timeout
case writeTimeout:
ctx.StateBag()[filters.WriteTimeout] = t.timeout
}
}
func (*timeout) Response(filters.FilterContext) {}
/*
Package circuit provides filters to control the circuit breaker settings on the route level.
For detailed documentation of the circuit breakers, see https://godoc.org/github.com/zalando/skipper/circuit.
*/
package circuit
import (
"time"
"github.com/zalando/skipper/circuit"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.ConsecutiveBreakerName instead
ConsecutiveBreakerName = filters.ConsecutiveBreakerName
// Deprecated, use filters.RateBreakerName instead
RateBreakerName = filters.RateBreakerName
// Deprecated, use filters.DisableBreakerName instead
DisableBreakerName = filters.DisableBreakerName
RouteSettingsKey = "#circuitbreakersettings"
)
type spec struct {
typ circuit.BreakerType
}
type filter struct {
settings circuit.BreakerSettings
}
func getIntArg(a interface{}) (int, error) {
if i, ok := a.(int); ok {
return i, nil
}
if f, ok := a.(float64); ok {
return int(f), nil
}
return 0, filters.ErrInvalidFilterParameters
}
func getDurationArg(a interface{}) (time.Duration, error) {
if s, ok := a.(string); ok {
return time.ParseDuration(s)
}
i, err := getIntArg(a)
return time.Duration(i) * time.Millisecond, err
}
// NewConsecutiveBreaker creates a filter specification to instantiate consecutiveBreaker() filters.
//
// These filters set a breaker for the current route that open if the backend failures for the route reach a
// value of N, where N is a mandatory argument of the filter:
//
// consecutiveBreaker(15)
//
// The filter accepts the following optional arguments: timeout (milliseconds or duration string),
// half-open-requests (integer), idle-ttl (milliseconds or duration string).
func NewConsecutiveBreaker() filters.Spec {
return &spec{typ: circuit.ConsecutiveFailures}
}
// NewRateBreaker creates a filter specification to instantiate rateBreaker() filters.
//
// These filters set a breaker for the current route that open if the backend failures for the route reach a
// value of N within a window of the last M requests, where N and M are mandatory arguments of the filter:
//
// rateBreaker(30, 300)
//
// The filter accepts the following optional arguments: timeout (milliseconds or duration string),
// half-open-requests (integer), idle-ttl (milliseconds or duration string).
func NewRateBreaker() filters.Spec {
return &spec{typ: circuit.FailureRate}
}
// NewDisableBreaker disables the circuit breaker for a route. It doesn't accept any arguments.
func NewDisableBreaker() filters.Spec {
return &spec{}
}
func (s *spec) Name() string {
switch s.typ {
case circuit.ConsecutiveFailures:
return filters.ConsecutiveBreakerName
case circuit.FailureRate:
return filters.RateBreakerName
default:
return filters.DisableBreakerName
}
}
func consecutiveFilter(args []interface{}) (filters.Filter, error) {
if len(args) == 0 || len(args) > 4 {
return nil, filters.ErrInvalidFilterParameters
}
failures, err := getIntArg(args[0])
if err != nil {
return nil, err
}
var timeout time.Duration
if len(args) > 1 {
timeout, err = getDurationArg(args[1])
if err != nil {
return nil, err
}
}
var halfOpenRequests int
if len(args) > 2 {
halfOpenRequests, err = getIntArg(args[2])
if err != nil {
return nil, err
}
}
var idleTTL time.Duration
if len(args) > 3 {
idleTTL, err = getDurationArg(args[3])
if err != nil {
return nil, err
}
}
return &filter{
settings: circuit.BreakerSettings{
Type: circuit.ConsecutiveFailures,
Failures: failures,
Timeout: timeout,
HalfOpenRequests: halfOpenRequests,
IdleTTL: idleTTL,
},
}, nil
}
func rateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 2 || len(args) > 5 {
return nil, filters.ErrInvalidFilterParameters
}
failures, err := getIntArg(args[0])
if err != nil {
return nil, err
}
window, err := getIntArg(args[1])
if err != nil {
return nil, err
}
var timeout time.Duration
if len(args) > 2 {
timeout, err = getDurationArg(args[2])
if err != nil {
return nil, err
}
}
var halfOpenRequests int
if len(args) > 3 {
halfOpenRequests, err = getIntArg(args[3])
if err != nil {
return nil, err
}
}
var idleTTL time.Duration
if len(args) > 4 {
idleTTL, err = getDurationArg(args[4])
if err != nil {
return nil, err
}
}
return &filter{
settings: circuit.BreakerSettings{
Type: circuit.FailureRate,
Failures: failures,
Window: window,
Timeout: timeout,
HalfOpenRequests: halfOpenRequests,
IdleTTL: idleTTL,
},
}, nil
}
func disableFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 0 {
return nil, filters.ErrInvalidFilterParameters
}
return &filter{
settings: circuit.BreakerSettings{
Type: circuit.BreakerDisabled,
},
}, nil
}
func (s *spec) CreateFilter(args []interface{}) (filters.Filter, error) {
switch s.typ {
case circuit.ConsecutiveFailures:
return consecutiveFilter(args)
case circuit.FailureRate:
return rateFilter(args)
default:
return disableFilter(args)
}
}
func (f *filter) Request(ctx filters.FilterContext) {
ctx.StateBag()[RouteSettingsKey] = f.settings
}
func (f *filter) Response(filters.FilterContext) {}
package consistenthash
import (
"fmt"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/loadbalancer"
)
type consistentHashBalanceFactor struct {
balanceFactor float64
}
// NewConsistentHashBalanceFactor creates a filter Spec, whose instances
// set the balancer factor used by the `consistentHash` algorithm to avoid
// popular hashes overloading a single endpoint
func NewConsistentHashBalanceFactor() filters.Spec { return &consistentHashBalanceFactor{} }
func (*consistentHashBalanceFactor) Name() string {
return filters.ConsistentHashBalanceFactorName
}
func (*consistentHashBalanceFactor) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
value, ok := args[0].(float64)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
if value < 1 {
return nil, fmt.Errorf("invalid consistentHashBalanceFactor filter value, must be >=1 but got %f", value)
}
return &consistentHashBalanceFactor{value}, nil
}
func (c *consistentHashBalanceFactor) Request(ctx filters.FilterContext) {
ctx.StateBag()[loadbalancer.ConsistentHashBalanceFactor] = c.balanceFactor
}
func (*consistentHashBalanceFactor) Response(filters.FilterContext) {}
package consistenthash
import (
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/loadbalancer"
)
type consistentHashKey struct {
template *eskip.Template
}
// NewConsistentHashKey creates a filter Spec, whose instances
// set the request key used by the `consistentHash` algorithm to select backend endpoint
func NewConsistentHashKey() filters.Spec { return &consistentHashKey{} }
func (*consistentHashKey) Name() string {
return filters.ConsistentHashKeyName
}
func (*consistentHashKey) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
value, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return &consistentHashKey{eskip.NewTemplate(value)}, nil
}
func (c *consistentHashKey) Request(ctx filters.FilterContext) {
if key, ok := c.template.ApplyContext(ctx); ok {
ctx.StateBag()[loadbalancer.ConsistentHashKey] = key
}
}
func (*consistentHashKey) Response(filters.FilterContext) {}
/*
Package cookie implements filters to append to requests or responses.
It implements two filters, one for appending cookies to requests in
the "Cookie" header, and one for appending cookies to responses in the
"Set-Cookie" header.
Both the request and response cookies expect a name and a value argument.
The response cookie accepts an optional argument to control the max-age
property of the cookie, of type number, in seconds.
The response cookie accepts an optional fourth argument, "change-only",
to control if the cookie should be set on every response, or only if the
request doesn't contain a cookie with the provided name and value. If the
fourth argument is "change-only", and a cookie with the same name and value
is found in the request, the cookie is not set. This argument can be used
to disable sliding TTL of the cookie.
The JS cookie behaves exactly as the response cookie, but it doesn't
set the HttpOnly directive, so these cookies will be
accessible from JS code running in web browsers.
Examples:
requestCookie("test-session", "abc")
responseCookie("test-session", "abc", 31536000)
responseCookie("test-session", "abc", 31536000, "change-only")
// response cookie without HttpOnly:
jsCookie("test-session-info", "abc-debug", 31536000, "change-only")
*/
package cookie
import (
"net"
"net/http"
"strings"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.RequestCookieName instead
RequestCookieFilterName = filters.RequestCookieName
// Deprecated, use filters.ResponseCookieName instead
ResponseCookieFilterName = filters.ResponseCookieName
// Deprecated, use filters.JsCookieName instead
ResponseJSCookieFilterName = filters.JsCookieName
ChangeOnlyArg = "change-only"
SetCookieHttpHeader = "Set-Cookie"
)
type direction int
const (
request direction = iota
response
responseJS
)
type spec struct {
typ direction
filterName string
}
type filter struct {
typ direction
name string
value string
maxAge int
changeOnly bool
}
type dropCookie struct {
typ direction
name string
}
func NewDropRequestCookie() filters.Spec {
return &dropCookie{
typ: request,
}
}
func NewDropResponseCookie() filters.Spec {
return &dropCookie{
typ: response,
}
}
func (d *dropCookie) Name() string {
switch d.typ {
case request:
return filters.DropRequestCookieName
case response:
return filters.DropResponseCookieName
}
return "unknown"
}
func (d *dropCookie) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
s, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return &dropCookie{
typ: d.typ,
name: s,
}, nil
}
func removeCookie(request *http.Request, name string) bool {
cookies := request.Cookies()
hasCookie := false
for _, c := range cookies {
if c.Name == name {
hasCookie = true
break
}
}
if hasCookie {
request.Header.Del("Cookie")
for _, c := range cookies {
if c.Name != name {
request.AddCookie(c)
}
}
}
return hasCookie
}
func removeCookieResponse(rsp *http.Response, name string) bool {
cookies := rsp.Cookies()
hasCookie := false
for _, c := range cookies {
if c.Name == name {
hasCookie = true
break
}
}
if hasCookie {
rsp.Header.Del("Set-Cookie")
for _, c := range cookies {
if c.Name != name {
rsp.Header.Add("Set-Cookie", c.String())
}
}
}
return hasCookie
}
func (d *dropCookie) Request(ctx filters.FilterContext) {
if d.typ != request {
return
}
removeCookie(ctx.Request(), d.name)
}
func (d *dropCookie) Response(ctx filters.FilterContext) {
if d.typ != response {
return
}
removeCookieResponse(ctx.Response(), d.name)
}
// Creates a filter spec for appending cookies to requests.
// Name: requestCookie
func NewRequestCookie() filters.Spec {
return &spec{request, filters.RequestCookieName}
}
// Creates a filter spec for appending cookies to responses.
// Name: responseCookie
func NewResponseCookie() filters.Spec {
return &spec{response, filters.ResponseCookieName}
}
// Creates a filter spec for appending cookies to responses without the
// HttpOnly directive.
// Name: jsCookie
func NewJSCookie() filters.Spec {
return &spec{responseJS, filters.JsCookieName}
}
func (s *spec) Name() string { return s.filterName }
func (s *spec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 2 || (len(args) > 2 && s.typ == request) || len(args) > 4 {
return nil, filters.ErrInvalidFilterParameters
}
f := &filter{typ: s.typ}
if name, ok := args[0].(string); ok && name != "" {
f.name = name
} else {
return nil, filters.ErrInvalidFilterParameters
}
if value, ok := args[1].(string); ok {
f.value = value
} else {
return nil, filters.ErrInvalidFilterParameters
}
if len(args) >= 3 {
if maxAge, ok := args[2].(float64); ok {
// https://pkg.go.dev/net/http#Cookie uses zero to omit Max-Age attribute:
// > MaxAge=0 means no 'Max-Age' attribute specified.
// > MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
// > MaxAge>0 means Max-Age attribute present and given in seconds
//
// Here we know user specified Max-Age explicitly, so we interpret zero
// as a signal to delete the cookie similar to what user would expect naturally,
// see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#max-agenumber
// > A zero or negative number will expire the cookie immediately.
if maxAge == 0 {
f.maxAge = -1
} else {
f.maxAge = int(maxAge)
}
} else {
return nil, filters.ErrInvalidFilterParameters
}
}
if len(args) == 4 {
f.changeOnly = args[3] == ChangeOnlyArg
}
return f, nil
}
func (f *filter) Request(ctx filters.FilterContext) {
if f.typ != request {
return
}
ctx.StateBag()["CookieSet:"+f.name] = f.value
ctx.Request().AddCookie(&http.Cookie{Name: f.name, Value: f.value})
}
func (f *filter) Response(ctx filters.FilterContext) {
var set func(filters.FilterContext, string, string, int)
switch f.typ {
case request:
return
case response:
set = configSetCookie(false)
case responseJS:
set = configSetCookie(true)
default:
panic("invalid cookie filter type")
}
ctx.StateBag()["CookieSet:"+f.name] = f.value
if !f.changeOnly {
set(ctx, f.name, f.value, f.maxAge)
return
}
var req *http.Request
if req = ctx.OriginalRequest(); req == nil {
req = ctx.Request()
}
requestCookie, err := req.Cookie(f.name)
if err == nil && requestCookie.Value == f.value {
return
}
set(ctx, f.name, f.value, f.maxAge)
}
func setCookie(ctx filters.FilterContext, name, value string, maxAge int, jsEnabled bool) {
var req = ctx.Request()
if ctx.OriginalRequest() != nil {
req = ctx.OriginalRequest()
}
d := extractDomainFromHost(req.Host)
c := &http.Cookie{
Name: name,
Value: value,
HttpOnly: !jsEnabled,
Secure: true,
Domain: d,
Path: "/",
MaxAge: maxAge,
}
ctx.Response().Header.Add(SetCookieHttpHeader, c.String())
}
func configSetCookie(jscookie bool) func(filters.FilterContext, string, string, int) {
return func(ctx filters.FilterContext, name, value string, maxAge int) {
setCookie(ctx, name, value, maxAge, jscookie)
}
}
func extractDomainFromHost(host string) string {
h, _, err := net.SplitHostPort(host)
if err != nil {
h = host
}
if strings.Count(h, ".") < 2 {
return h
}
return strings.Join(strings.Split(h, ".")[1:], ".")
}
package cors
import (
"github.com/zalando/skipper/filters"
)
const (
allowOriginHeader = "Access-Control-Allow-Origin"
)
type basicSpec struct {
}
type filter struct {
allowedOrigins []string
}
// NewOrigin creates a CORS origin handler
// that can check for allowed origin or set an all allowed header
func NewOrigin() filters.Spec {
return &basicSpec{}
}
// Response checks for the origin header if there are allowed origins
// otherwise it just sets '*' as the value
func (a filter) Response(ctx filters.FilterContext) {
if len(a.allowedOrigins) == 0 {
ctx.Response().Header.Set(allowOriginHeader, "*")
return
}
origin := ctx.Request().Header.Get("Origin")
if origin == "" {
return
}
for _, o := range a.allowedOrigins {
if o == origin {
ctx.Response().Header.Set(allowOriginHeader, o)
return
}
}
}
// Request is a noop
func (a filter) Request(filters.FilterContext) {}
// CreateFilter takes an optional string array.
// If any argument is not a string, it will return an error
func (spec basicSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
f := &filter{}
for _, a := range args {
if s, ok := a.(string); ok {
f.allowedOrigins = append(f.allowedOrigins, s)
} else {
return nil, filters.ErrInvalidFilterParameters
}
}
return f, nil
}
func (spec basicSpec) Name() string { return filters.CorsOriginName }
package diag
import (
"io"
"net/http"
"time"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/flowid"
"github.com/zalando/skipper/logging"
)
// AbsorbName contains the name of the absorb filter.
// Deprecated, use filters.AbsorbName instead
const AbsorbName = filters.AbsorbName
// AbsorbSilentName contains the name of the absorbSilent filter.
// Deprecated, use filters.AbsorbSilentName instead
const AbsorbSilentName = filters.AbsorbSilentName
const loggingInterval = time.Second
type absorb struct {
logger logging.Logger
id flowid.Generator
silent bool
}
func withLogger(silent bool, l logging.Logger) filters.Spec {
if l == nil {
l = &logging.DefaultLog{}
}
id, err := flowid.NewStandardGenerator(flowid.MinLength)
if err != nil {
l.Errorf("failed to create ID generator: %v", err)
}
return &absorb{
logger: l,
id: id,
silent: silent,
}
}
// NewAbsorb initializes a filter spec for the absorb filter.
//
// The absorb filter reads and discards the payload of the incoming requests.
// It logs with INFO level and a unique ID per request:
// - the event of receiving the request
// - partial and final events for consuming request payload and total consumed byte count
// - the finishing event of the request
// - any read errors other than EOF
func NewAbsorb() filters.Spec {
return withLogger(false, nil)
}
// NewAbsorbSilent initializes a filter spec for the absorbSilent filter,
// similar to the absorb filter, but without verbose logging of the absorbed
// payload.
//
// The absorbSilent filter reads and discards the payload of the incoming requests. It only
// logs read errors other than EOF.
func NewAbsorbSilent() filters.Spec {
return withLogger(true, nil)
}
func (a *absorb) Name() string {
if a.silent {
return filters.AbsorbSilentName
}
return filters.AbsorbName
}
func (a *absorb) CreateFilter(args []interface{}) (filters.Filter, error) { return a, nil }
func (a *absorb) Response(filters.FilterContext) {}
func (a *absorb) Request(ctx filters.FilterContext) {
req := ctx.Request()
id := req.Header.Get(flowid.HeaderName)
if id == "" {
if a.id == nil {
id = "-"
} else {
var err error
if id, err = a.id.Generate(); err != nil {
a.logger.Error(err)
}
}
}
sink := io.Discard
if !a.silent {
a.logger.Infof("received request to be absorbed: %s", id)
sink = &loggingSink{id: id, logger: a.logger, next: time.Now().Add(loggingInterval)}
}
count, err := io.Copy(sink, req.Body)
if !a.silent {
if err != nil {
a.logger.Infof("request %s, error while consuming request: %v", id, err)
}
a.logger.Infof("request %s, consumed bytes: %d", id, count)
a.logger.Infof("request finished: %s", id)
}
ctx.Serve(&http.Response{StatusCode: http.StatusOK})
}
type loggingSink struct {
id string
logger logging.Logger
next time.Time
count int64
}
func (s *loggingSink) Write(p []byte) (n int, err error) {
n, err = len(p), nil
s.count += int64(n)
if time.Now().After(s.next) {
s.logger.Infof("request %s, consumed bytes: %d", s.id, s.count)
s.next = s.next.Add(loggingInterval)
}
return
}
/*
Package diag provides a set of network throttling filters for diagnostic purpose.
The filters enable adding artificial latency, limiting bandwidth or chunking responses with custom chunk size
and delay. This throttling can be applied to the proxy responses or to the outgoing backend requests. An
additional filter, randomContent, can be used to generate response with random text of specified length.
*/
package diag
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"math/rand"
"net/http"
"strconv"
"sync"
"time"
"github.com/zalando/skipper/filters"
)
const defaultChunkSize = 512
const (
// Deprecated, use filters.RandomContentName instead
RandomName = filters.RandomContentName
// Deprecated, use filters.RepeatContentName instead
RepeatName = filters.RepeatContentName
// Deprecated, use filters.LatencyName instead
LatencyName = filters.LatencyName
// Deprecated, use filters.ChunksName instead
ChunksName = filters.ChunksName
// Deprecated, use filters.BandwidthName instead
BandwidthName = filters.BandwidthName
// Deprecated, use filters.BackendLatencyName instead
BackendLatencyName = filters.BackendLatencyName
// Deprecated, use filters.BackendBandwidthName instead
BackendBandwidthName = filters.BackendBandwidthName
// Deprecated, use filters.BackendChunksName instead
BackendChunksName = filters.BackendChunksName
)
type throttleType int
const (
latency throttleType = iota
bandwidth
chunks
backendLatency
backendBandwidth
backendChunks
)
type random struct {
mu sync.Mutex
rand *rand.Rand
len int64
}
type (
repeatSpec struct {
hex bool
}
repeat struct {
bytes []byte
len int64
}
repeatReader struct {
bytes []byte
offset int
}
)
type (
wrapSpec struct {
hex bool
}
wrap struct {
prefix, suffix []byte
}
wrapReadCloser struct {
io.Reader
io.Closer
}
)
type throttle struct {
typ throttleType
chunkSize int
delay time.Duration
}
type distribution int
const (
uniformRequestDistribution distribution = iota
normalRequestDistribution
uniformResponseDistribution
normalResponseDistribution
)
type jitter struct {
mean time.Duration
delta time.Duration
typ distribution
sleep func(time.Duration)
}
var randomChars = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")
func kbps2bpms(kbps float64) float64 {
return kbps * 1024 / 1000
}
// NewRandom creates a filter specification whose filter instances can be used
// to respond to requests with random text of specified length. It expects the
// the byte length of the random response to be generated as an argument.
// Eskip example:
//
// r: * -> randomContent(2048) -> <shunt>;
func NewRandom() filters.Spec { return &random{} }
// NewRepeat creates a filter specification whose filter instances can be used
// to respond to requests with a repeated text. It expects the text and
// the byte length of the response body to be generated as arguments.
// Eskip example:
//
// r: * -> repeatContent("x", 100) -> <shunt>;
func NewRepeat() filters.Spec { return &repeatSpec{hex: false} }
// NewRepeatHex creates a filter specification whose filter instances can be used
// to respond to requests with a repeated bytes.
// It expects the bytes represented by the hexadecimal string of an even length and
// the byte length of the response body to be generated as arguments.
// Eskip example:
//
// r: * -> repeatContentHex("0123456789abcdef", 16) -> <shunt>;
func NewRepeatHex() filters.Spec { return &repeatSpec{hex: true} }
// NewWrap creates a filter specification whose filter instances can be used
// to add prefix and suffix to the response.
// Eskip example:
//
// r: * -> wrapContent("foo", "baz") -> inlineContent("bar") -> <shunt>;
func NewWrap() filters.Spec { return &wrapSpec{hex: false} }
// NewWrapHex creates a filter specification whose filter instances can be used
// to add prefix and suffix represented by the hexadecimal strings of an even length to the response.
// Eskip example:
//
// r: * -> wrapContentHex("68657861", "6d616c") -> inlineContent("deci") -> <shunt>;
func NewWrapHex() filters.Spec { return &wrapSpec{hex: true} }
// NewLatency creates a filter specification whose filter instances can be used
// to add additional latency to responses. It expects the latency in milliseconds
// as an argument. It always adds this value in addition to the natural latency,
// and does not do any adjustments. Eskip example:
//
// r: * -> latency(120) -> "https://www.example.org";
func NewLatency() filters.Spec { return &throttle{typ: latency} }
// NewBandwidth creates a filter specification whose filter instances can be used
// to maximize the bandwidth of the responses. It expects the bandwidth in
// kbyte/sec as an argument.
//
// r: * -> bandwidth(30) -> "https://www.example.org";
func NewBandwidth() filters.Spec { return &throttle{typ: bandwidth} }
// NewChunks creates a filter specification whose filter instances can be used
// set artificial delays in between response chunks. It expects the byte length
// of the chunks and the delay milliseconds.
//
// r: * -> chunks(1024, "120ms") -> "https://www.example.org";
func NewChunks() filters.Spec { return &throttle{typ: chunks} }
// NewBackendLatency is the equivalent of NewLatency but for outgoing backend
// requests. Eskip example:
//
// r: * -> backendLatency(120) -> "https://www.example.org";
func NewBackendLatency() filters.Spec { return &throttle{typ: backendLatency} }
// NewBackendBandwidth is the equivalent of NewBandwidth but for outgoing backend
// requests. Eskip example:
//
// r: * -> backendBandwidth(30) -> "https://www.example.org";
func NewBackendBandwidth() filters.Spec { return &throttle{typ: backendBandwidth} }
// NewBackendChunks is the equivalent of NewChunks but for outgoing backend
// requests. Eskip example:
//
// r: * -> backendChunks(1024, 120) -> "https://www.example.org";
func NewBackendChunks() filters.Spec { return &throttle{typ: backendChunks} }
// NewUniformRequestLatency creates a latency for requests with uniform
// distribution. Example delay around 1s with +/-120ms.
//
// r: * -> uniformRequestLatency("1s", "120ms") -> "https://www.example.org";
func NewUniformRequestLatency() filters.Spec { return &jitter{typ: uniformRequestDistribution} }
// NewNormalRequestLatency creates a latency for requests with normal
// distribution. Example delay around 1s with +/-120ms.
//
// r: * -> normalRequestLatency("1s", "120ms") -> "https://www.example.org";
func NewNormalRequestLatency() filters.Spec { return &jitter{typ: normalRequestDistribution} }
// NewUniformResponseLatency creates a latency for responses with uniform
// distribution. Example delay around 1s with +/-120ms.
//
// r: * -> uniformRequestLatency("1s", "120ms") -> "https://www.example.org";
func NewUniformResponseLatency() filters.Spec { return &jitter{typ: uniformResponseDistribution} }
// NewNormalResponseLatency creates a latency for responses with normal
// distribution. Example delay around 1s with +/-120ms.
//
// r: * -> normalRequestLatency("1s", "120ms") -> "https://www.example.org";
func NewNormalResponseLatency() filters.Spec { return &jitter{typ: normalResponseDistribution} }
func (r *random) Name() string { return filters.RandomContentName }
func (r *random) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
if l, ok := args[0].(float64); ok {
return &random{rand: rand.New(rand.NewSource(time.Now().UnixNano())), len: int64(l)}, nil // #nosec
} else {
return nil, filters.ErrInvalidFilterParameters
}
}
func (r *random) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
for i := 0; i < len(p); i++ {
p[i] = randomChars[r.rand.Intn(len(randomChars))]
}
return len(p), nil
}
func (r *random) Request(ctx filters.FilterContext) {
ctx.Serve(&http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(io.LimitReader(r, r.len)),
})
}
func (r *random) Response(ctx filters.FilterContext) {}
func (r *repeatSpec) Name() string {
if r.hex {
return filters.RepeatContentHexName
} else {
return filters.RepeatContentName
}
}
func (r *repeatSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
text, ok := args[0].(string)
if !ok || text == "" {
return nil, filters.ErrInvalidFilterParameters
}
f := &repeat{}
if r.hex {
var err error
f.bytes, err = hex.DecodeString(text)
if err != nil {
return nil, err
}
} else {
f.bytes = []byte(text)
}
switch v := args[1].(type) {
case float64:
f.len = int64(v)
case int:
f.len = int64(v)
case int64:
f.len = v
default:
return nil, filters.ErrInvalidFilterParameters
}
if f.len < 0 {
return nil, filters.ErrInvalidFilterParameters
}
return f, nil
}
func (r *repeat) Request(ctx filters.FilterContext) {
ctx.Serve(&http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Length": []string{strconv.FormatInt(r.len, 10)}},
Body: io.NopCloser(io.LimitReader(&repeatReader{r.bytes, 0}, r.len)),
})
}
func (r *repeatReader) Read(p []byte) (int, error) {
n := copy(p, r.bytes[r.offset:])
if n < len(p) {
n += copy(p[n:], r.bytes[:r.offset])
for n < len(p) {
copy(p[n:], p[:n])
n *= 2
}
}
r.offset = (r.offset + len(p)) % len(r.bytes)
return len(p), nil
}
func (r *repeat) Response(ctx filters.FilterContext) {}
func (w *wrapSpec) Name() string {
if w.hex {
return filters.WrapContentHexName
} else {
return filters.WrapContentName
}
}
func (w *wrapSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
prefix, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
suffix, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
f := &wrap{}
if w.hex {
var err error
f.prefix, err = hex.DecodeString(prefix)
if err != nil {
return nil, err
}
f.suffix, err = hex.DecodeString(suffix)
if err != nil {
return nil, err
}
} else {
f.prefix = []byte(prefix)
f.suffix = []byte(suffix)
}
return f, nil
}
func (w *wrap) Request(ctx filters.FilterContext) {}
func (w *wrap) Response(ctx filters.FilterContext) {
rsp := ctx.Response()
if s := rsp.Header.Get("Content-Length"); s != "" {
if n, err := strconv.ParseInt(s, 10, 64); err == nil {
n += int64(len(w.prefix) + len(w.suffix))
rsp.Header["Content-Length"] = []string{strconv.FormatInt(n, 10)}
}
}
if rsp.ContentLength != -1 {
rsp.ContentLength += int64(len(w.prefix) + len(w.suffix))
}
rsp.Body = &wrapReadCloser{
Reader: io.MultiReader(
bytes.NewReader(w.prefix),
rsp.Body,
bytes.NewReader(w.suffix),
),
Closer: rsp.Body,
}
}
func (t *throttle) Name() string {
switch t.typ {
case latency:
return filters.LatencyName
case bandwidth:
return filters.BandwidthName
case chunks:
return filters.ChunksName
case backendLatency:
return filters.BackendLatencyName
case backendBandwidth:
return filters.BackendBandwidthName
case backendChunks:
return filters.BackendChunksName
default:
panic("invalid throttle type")
}
}
func parseDuration(v interface{}) (time.Duration, error) {
var d time.Duration
switch vt := v.(type) {
case float64:
d = time.Duration(vt) * time.Millisecond
case string:
var err error
d, err = time.ParseDuration(vt)
if err != nil {
return 0, filters.ErrInvalidFilterParameters
}
}
if d < 0 {
return 0, filters.ErrInvalidFilterParameters
}
return d, nil
}
func parseLatencyArgs(args []interface{}) (int, time.Duration, error) {
if len(args) != 1 {
return 0, 0, filters.ErrInvalidFilterParameters
}
d, err := parseDuration(args[0])
return 0, d, err
}
func parseBandwidthArgs(args []interface{}) (int, time.Duration, error) {
if len(args) != 1 {
return 0, 0, filters.ErrInvalidFilterParameters
}
kbps, ok := args[0].(float64)
if !ok || kbps <= 0 {
return 0, 0, filters.ErrInvalidFilterParameters
}
bpms := kbps2bpms(kbps)
return defaultChunkSize, time.Duration(float64(defaultChunkSize)/bpms) * time.Millisecond, nil
}
func parseChunksArgs(args []interface{}) (int, time.Duration, error) {
if len(args) != 2 {
return 0, 0, filters.ErrInvalidFilterParameters
}
size, ok := args[0].(float64)
if !ok || size <= 0 {
return 0, 0, filters.ErrInvalidFilterParameters
}
d, err := parseDuration(args[1])
return int(size), d, err
}
func (t *throttle) CreateFilter(args []interface{}) (filters.Filter, error) {
var (
chunkSize int
delay time.Duration
err error
)
switch t.typ {
case latency, backendLatency:
chunkSize, delay, err = parseLatencyArgs(args)
case bandwidth, backendBandwidth:
chunkSize, delay, err = parseBandwidthArgs(args)
case chunks, backendChunks:
chunkSize, delay, err = parseChunksArgs(args)
default:
panic("invalid throttle type")
}
if err != nil {
return nil, err
}
return &throttle{t.typ, chunkSize, delay}, nil
}
func (t *throttle) goThrottle(in io.ReadCloser, close bool) io.ReadCloser {
if t.chunkSize <= 0 {
time.Sleep(t.delay)
return in
}
r, w := io.Pipe()
time.Sleep(t.delay)
go func() {
var err error
defer func() {
w.CloseWithError(err)
if close {
in.Close()
}
}()
b := make([]byte, defaultChunkSize)
for err == nil {
n := 0
var start time.Time
switch t.typ {
case bandwidth, backendBandwidth:
start = time.Now()
}
for n < t.chunkSize {
ni := 0
eof := false
bi := b
if t.chunkSize-n < len(bi) {
bi = bi[:t.chunkSize-n]
}
ni, err = in.Read(bi)
if err == io.EOF {
eof = true
err = nil
}
if err != nil {
break
}
ni, err = w.Write(bi[:ni])
if err != nil {
break
}
n += ni
if eof {
err = io.EOF
break
}
}
if err == nil {
delay := t.delay
switch t.typ {
case bandwidth, backendBandwidth:
delay -= time.Since(start)
}
time.Sleep(delay)
}
}
}()
return r
}
func (t *throttle) Request(ctx filters.FilterContext) {
switch t.typ {
case latency, bandwidth, chunks:
return
}
req := ctx.Request()
req.Body = t.goThrottle(req.Body, false)
}
func (t *throttle) Response(ctx filters.FilterContext) {
switch t.typ {
case backendLatency, backendBandwidth, backendChunks:
return
}
rsp := ctx.Response()
rsp.Body = t.goThrottle(rsp.Body, true)
}
func (j *jitter) Name() string {
switch j.typ {
case normalRequestDistribution:
return filters.NormalRequestLatencyName
case uniformRequestDistribution:
return filters.UniformRequestLatencyName
case normalResponseDistribution:
return filters.NormalResponseLatencyName
case uniformResponseDistribution:
return filters.UniformResponseLatencyName
}
return "unknown"
}
func (j *jitter) CreateFilter(args []interface{}) (filters.Filter, error) {
var (
mean time.Duration
delta time.Duration
err error
)
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
if mean, err = parseDuration(args[0]); err != nil {
return nil, fmt.Errorf("failed to parse duration mean %v: %w", args[0], err)
}
if delta, err = parseDuration(args[1]); err != nil {
return nil, fmt.Errorf("failed to parse duration delta %v: %w", args[1], err)
}
return &jitter{
typ: j.typ,
mean: mean,
delta: delta,
sleep: time.Sleep,
}, nil
}
func (j *jitter) Request(filters.FilterContext) {
var r float64
switch j.typ {
case uniformRequestDistribution:
/* #nosec */
r = 2*rand.Float64() - 1 // +/- sizing
case normalRequestDistribution:
r = rand.NormFloat64()
default:
return
}
f := r * float64(j.delta)
j.sleep(j.mean + time.Duration(int64(f)))
}
func (j *jitter) Response(filters.FilterContext) {
var r float64
switch j.typ {
case uniformResponseDistribution:
/* #nosec */
r = 2*rand.Float64() - 1 // +/- sizing
case normalResponseDistribution:
r = rand.NormFloat64()
default:
return
}
f := r * float64(j.delta)
j.sleep(j.mean + time.Duration(int64(f)))
}
package diag
import (
"math/rand"
"time"
"github.com/zalando/skipper/filters"
)
type (
histSpec struct {
typ string
}
histFilter struct {
response bool
sleep func(time.Duration)
weights []float64
boundaries []time.Duration
}
)
// NewHistogramRequestLatency creates filters that add latency to requests according to the histogram distribution.
// It expects a list of interleaved duration strings and numbers that defines a histogram.
// Duration strings define boundaries of consecutive buckets and numbers define bucket weights.
// The filter randomly selects a bucket with probability equal to its weight divided by the sum of all bucket weights
// (which must be non-zero) and then sleeps for a random duration in between bucket boundaries.
// Eskip example:
//
// r: * -> histogramRequestLatency("0ms", 50, "5ms", 0, "10ms", 30, "15ms", 20, "20ms") -> "https://www.example.org";
//
// The example above adds a latency
// * between 0ms and 5ms to 50% of the requests
// * between 5ms and 10ms to 0% of the requests
// * between 10ms and 15ms to 30% of the requests
// * and between 15ms and 20ms to 20% of the requests.
func NewHistogramRequestLatency() filters.Spec {
return &histSpec{typ: filters.HistogramRequestLatencyName}
}
// NewHistogramResponseLatency creates filters that add latency to responses according to the histogram distribution, similar to NewHistogramRequestLatency.
func NewHistogramResponseLatency() filters.Spec {
return &histSpec{typ: filters.HistogramResponseLatencyName}
}
func (s *histSpec) Name() string {
return s.typ
}
func (s *histSpec) CreateFilter(args []any) (filters.Filter, error) {
if len(args) < 3 || len(args)%2 != 1 {
return nil, filters.ErrInvalidFilterParameters
}
f := &histFilter{
response: s.typ == filters.HistogramResponseLatencyName,
sleep: time.Sleep,
}
sum := 0.0
for i, arg := range args {
if i%2 == 0 {
ds, ok := arg.(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
d, err := time.ParseDuration(ds)
if err != nil || d < 0 {
return nil, filters.ErrInvalidFilterParameters
}
if len(f.boundaries) > 0 && d <= f.boundaries[len(f.boundaries)-1] {
return nil, filters.ErrInvalidFilterParameters
}
f.boundaries = append(f.boundaries, d)
} else {
weight, ok := arg.(float64)
if !ok || weight < 0 {
return nil, filters.ErrInvalidFilterParameters
}
f.weights = append(f.weights, weight)
sum += weight
}
}
if sum == 0 {
return nil, filters.ErrInvalidFilterParameters
}
for i := range f.weights {
f.weights[i] /= sum
}
return f, nil
}
func (f *histFilter) Request(filters.FilterContext) {
if !f.response {
f.sleep(f.sample())
}
}
func (f *histFilter) Response(filters.FilterContext) {
if f.response {
f.sleep(f.sample())
}
}
func (f *histFilter) sample() time.Duration {
r := rand.Float64() // #nosec
i, w, sum := 0, 0.0, 0.0
for i, w = range f.weights {
sum += w
if sum > r {
break
}
}
// len(f.boundaries) = len(f.weights) + 1
min := f.boundaries[i]
max := f.boundaries[i+1]
return min + time.Duration(rand.Int63n(int64(max-min))) // #nosec
}
package diag
import (
"fmt"
"io"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/flowid"
)
type logBody struct {
limit int
request bool
response bool
}
// NewLogBody creates a filter specification for the 'logBody()' filter.
func NewLogBody() filters.Spec { return logBody{} }
// Name returns the logBody filtern name.
func (logBody) Name() string {
return filters.LogBodyName
}
func (logBody) CreateFilter(args []interface{}) (filters.Filter, error) {
var (
request = false
response = false
)
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
opt, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
switch opt {
case "response":
response = true
case "request":
request = true
default:
return nil, fmt.Errorf("failed to match %q: %w", opt, filters.ErrInvalidFilterParameters)
}
limit, ok := args[1].(float64)
if !ok || float64(int(limit)) != limit {
return nil, fmt.Errorf("failed to convert to int: %w", filters.ErrInvalidFilterParameters)
}
return &logBody{
limit: int(limit),
request: request,
response: response,
}, nil
}
func (lb *logBody) Request(ctx filters.FilterContext) {
if !lb.request {
return
}
req := ctx.Request()
if req.Body != nil {
req.Body = newLogBodyStream(
lb.limit,
func(chunk []byte) {
ctx.Logger().Infof(
`logBody("request") %s: %q`,
req.Header.Get(flowid.HeaderName),
chunk)
},
req.Body,
)
}
}
func (lb *logBody) Response(ctx filters.FilterContext) {
if !lb.response {
return
}
rsp := ctx.Response()
if rsp.Body != nil {
rsp.Body = newLogBodyStream(
lb.limit,
func(chunk []byte) {
ctx.Logger().Infof(
`logBody("response") %s: %q`,
ctx.Request().Header.Get(flowid.HeaderName),
chunk)
},
rsp.Body,
)
}
}
type logBodyStream struct {
left int
f func([]byte)
input io.ReadCloser
}
func newLogBodyStream(left int, f func([]byte), rc io.ReadCloser) io.ReadCloser {
return &logBodyStream{
left: left,
f: f,
input: rc,
}
}
func (lb *logBodyStream) Read(p []byte) (n int, err error) {
n, err = lb.input.Read(p)
if lb.left > 0 && n > 0 {
m := min(n, lb.left)
lb.f(p[:m])
lb.left -= m
}
return n, err
}
func (lb *logBodyStream) Close() error {
return lb.input.Close()
}
package diag
import (
"bytes"
"strings"
"github.com/zalando/skipper/filters"
)
type logHeader struct {
request bool
response bool
}
// NewLogHeader creates a filter specification for the 'logHeader()' filter.
func NewLogHeader() filters.Spec { return logHeader{} }
// Name returns the logHeader filtern name.
func (logHeader) Name() string {
return filters.LogHeaderName
}
func (logHeader) CreateFilter(args []interface{}) (filters.Filter, error) {
var (
request = false
response = false
)
// default behavior
if len(args) == 0 {
request = true
}
for i := range args {
opt, ok := args[i].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
switch strings.ToLower(opt) {
case "response":
response = true
case "request":
request = true
}
}
return logHeader{
request: request,
response: response,
}, nil
}
func (lh logHeader) Response(ctx filters.FilterContext) {
if !lh.response {
return
}
req := ctx.Request()
resp := ctx.Response()
buf := bytes.NewBuffer(nil)
buf.WriteString(req.Method)
buf.WriteString(" ")
buf.WriteString(req.URL.Path)
buf.WriteString(" ")
buf.WriteString(req.Proto)
buf.WriteString("\r\n")
buf.WriteString(resp.Status)
buf.WriteString("\r\n")
for k, v := range resp.Header {
if strings.ToLower(k) == "authorization" {
buf.WriteString(k)
buf.WriteString(": ")
buf.WriteString("TRUNCATED\r\n")
} else {
buf.WriteString(k)
buf.WriteString(": ")
buf.WriteString(strings.Join(v, " "))
buf.WriteString("\r\n")
}
}
buf.WriteString("\r\n")
ctx.Logger().Infof("Response for %s", buf.String())
}
func (lh logHeader) Request(ctx filters.FilterContext) {
if !lh.request {
return
}
req := ctx.Request()
buf := bytes.NewBuffer(nil)
buf.WriteString(req.Method)
buf.WriteString(" ")
buf.WriteString(req.URL.Path)
buf.WriteString(" ")
buf.WriteString(req.Proto)
buf.WriteString("\r\nHost: ")
buf.WriteString(req.Host)
buf.WriteString("\r\n")
for k, v := range req.Header {
if strings.ToLower(k) == "authorization" {
buf.WriteString(k)
buf.WriteString(": ")
buf.WriteString("TRUNCATED\r\n")
} else {
buf.WriteString(k)
buf.WriteString(": ")
buf.WriteString(strings.Join(v, " "))
buf.WriteString("\r\n")
}
}
buf.WriteString("\r\n")
ctx.Logger().Infof("%s", buf.String())
}
package diag
import (
"net/http"
"time"
"github.com/zalando/skipper/filters"
)
type tarpitSpec struct{}
type tarpit struct {
d time.Duration
}
func NewTarpit() filters.Spec {
return &tarpitSpec{}
}
func (t *tarpitSpec) Name() string {
return filters.TarpitName
}
func (t *tarpitSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
s, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
d, err := time.ParseDuration(s)
if err != nil {
return nil, filters.ErrInvalidFilterParameters
}
return &tarpit{d: d}, nil
}
func (t *tarpit) Request(ctx filters.FilterContext) {
ctx.Serve(&http.Response{StatusCode: http.StatusOK, Body: &slowBlockingReader{d: t.d}})
}
func (*tarpit) Response(filters.FilterContext) {}
type slowBlockingReader struct {
d time.Duration
}
func (r *slowBlockingReader) Read(p []byte) (int, error) {
time.Sleep(r.d)
n := copy(p, []byte(" "))
return n, nil
}
func (r *slowBlockingReader) Close() error {
return nil
}
package fadein
import (
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
snet "github.com/zalando/skipper/net"
"github.com/zalando/skipper/routing"
)
const (
// Deprecated, use filters.FadeInName instead
FadeInName = filters.FadeInName
// Deprecated, use filters.EndpointCreatedName instead
EndpointCreatedName = filters.EndpointCreatedName
)
type (
fadeIn struct {
duration time.Duration
exponent float64
}
endpointCreated struct {
when time.Time
which string
}
detectedFadeIn struct {
when time.Time
duration time.Duration
lastActive time.Time
}
postProcessor struct {
endpointRegisty *routing.EndpointRegistry
// "http://10.2.1.53:1234": {t0 60s t0-10s}
detected map[string]detectedFadeIn
}
)
// NewFadeIn creates a filter spec for the fade-in filter.
func NewFadeIn() filters.Spec {
return fadeIn{}
}
func (fadeIn) Name() string { return filters.FadeInName }
func (fadeIn) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) == 0 || len(args) > 2 {
return nil, filters.ErrInvalidFilterParameters
}
var f fadeIn
switch v := args[0].(type) {
case int:
f.duration = time.Duration(v * int(time.Millisecond))
case float64:
f.duration = time.Duration(v * float64(time.Millisecond))
case string:
d, err := time.ParseDuration(v)
if err != nil {
return nil, err
}
f.duration = d
case time.Duration:
f.duration = v
default:
return nil, filters.ErrInvalidFilterParameters
}
f.exponent = 1
if len(args) == 2 {
switch v := args[1].(type) {
case int:
f.exponent = float64(v)
case float64:
f.exponent = v
default:
return nil, filters.ErrInvalidFilterParameters
}
}
return f, nil
}
func (fadeIn) Request(filters.FilterContext) {}
func (fadeIn) Response(filters.FilterContext) {}
// NewEndpointCreated creates a filter spec for the endpointCreated filter.
func NewEndpointCreated() filters.Spec {
var ec endpointCreated
return ec
}
func (endpointCreated) Name() string { return filters.EndpointCreatedName }
func endpointKey(scheme, host string) string {
return fmt.Sprintf("%s://%s", scheme, host)
}
func (endpointCreated) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
e, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
s, h, err := snet.SchemeHost(e)
if err != nil {
return nil, err
}
ec := endpointCreated{which: endpointKey(s, h)}
switch v := args[1].(type) {
case int:
ec.when = time.Unix(int64(v), 0)
case float64:
ec.when = time.Unix(int64(v), 0)
case string:
t, err := time.Parse(time.RFC3339, v)
if err != nil {
return nil, err
}
ec.when = t
case time.Time:
ec.when = v
default:
return nil, filters.ErrInvalidFilterParameters
}
// mitigate potential flakiness caused by clock skew. When the created time is in the future based on
// the local clock, we ignore it.
now := time.Now()
if ec.when.After(now) {
log.Errorf(
"Endpoint created time in the future, fading in without endpoint creation time: %v. Potential clock skew.",
ec.when,
)
ec.when = time.Time{}
}
return ec, nil
}
func (endpointCreated) Request(filters.FilterContext) {}
func (endpointCreated) Response(filters.FilterContext) {}
type PostProcessorOptions struct {
EndpointRegistry *routing.EndpointRegistry
}
// NewPostProcessor creates post-processor for maintaining the detection time of LB endpoints with fade-in
// behavior.
func NewPostProcessor(options PostProcessorOptions) routing.PostProcessor {
return &postProcessor{
endpointRegisty: options.EndpointRegistry,
detected: make(map[string]detectedFadeIn),
}
}
func (p *postProcessor) Do(r []*routing.Route) []*routing.Route {
now := time.Now()
for _, ri := range r {
if ri.Route.BackendType != eskip.LBBackend {
continue
}
ri.LBFadeInDuration = 0
ri.LBFadeInExponent = 1
endpointsCreated := make(map[string]time.Time)
for _, f := range ri.Filters {
switch fi := f.Filter.(type) {
case fadeIn:
ri.LBFadeInDuration = fi.duration
ri.LBFadeInExponent = fi.exponent
case endpointCreated:
endpointsCreated[fi.which] = fi.when
}
}
if ri.LBFadeInDuration <= 0 {
continue
}
for i := range ri.LBEndpoints {
ep := &ri.LBEndpoints[i]
key := endpointKey(ep.Scheme, ep.Host)
detected := p.detected[key].when
if detected.IsZero() || endpointsCreated[key].After(detected) {
detected = now
}
if p.endpointRegisty != nil {
metrics := p.endpointRegisty.GetMetrics(ep.Host)
if endpointsCreated[key].After(metrics.DetectedTime()) {
metrics.SetDetected(endpointsCreated[key])
}
}
p.detected[key] = detectedFadeIn{
when: detected,
duration: ri.LBFadeInDuration,
lastActive: now,
}
}
}
// cleanup old detected, considering last active
for key, d := range p.detected {
// this allows tolerating when a fade-in endpoint temporarily disappears:
if d.lastActive.Add(d.duration).Before(now) {
delete(p.detected, key)
}
}
return r
}
package filters
import (
"errors"
"io"
"net/http"
"time"
log "github.com/sirupsen/logrus"
"github.com/opentracing/opentracing-go"
)
const (
// DynamicBackendHostKey is the key used in the state bag to pass host to the proxy.
DynamicBackendHostKey = "backend:dynamic:host"
// DynamicBackendSchemeKey is the key used in the state bag to pass scheme to the proxy.
DynamicBackendSchemeKey = "backend:dynamic:scheme"
// DynamicBackendURLKey is the key used in the state bag to pass url to the proxy.
DynamicBackendURLKey = "backend:dynamic:url"
// BackendIsProxyKey is the key used in the state bag to notify proxy that the backend is also a proxy.
BackendIsProxyKey = "backend:isproxy"
// BackendTimeout is the key used in the state bag to configure backend timeout in proxy
BackendTimeout = "backend:timeout"
// ReadTimeout is the key used in the state bag to configure read request body timeout in proxy
ReadTimeout = "read:timeout"
// WriteTimeout is the key used in the state bag to configure write response body timeout in proxy
WriteTimeout = "write:timeout"
// BackendRatelimit is the key used in the state bag to configure backend ratelimit in proxy
BackendRatelimit = "backend:ratelimit"
)
// FilterContext object providing state and information that is unique to a request.
type FilterContext interface {
// The response writer object belonging to the incoming request. Used by
// filters that handle the requests themselves.
// Deprecated: use Response() or Serve()
ResponseWriter() http.ResponseWriter
// The incoming request object. It is forwarded to the route endpoint
// with its properties changed by the filters.
Request() *http.Request
// The response object. It is returned to the client with its
// properties changed by the filters.
Response() *http.Response
// The copy (deep) of the original incoming request or nil if the
// implementation does not provide it.
//
// The object received from this method contains an empty body, and all
// payload related properties have zero value.
OriginalRequest() *http.Request
// The copy (deep) of the original incoming response or nil if the
// implementation does not provide it.
//
// The object received from this method contains an empty body, and all
// payload related properties have zero value.
OriginalResponse() *http.Response
// This method is deprecated. A FilterContext implementation should flag this state
// internally
Served() bool
// This method is deprecated. You should call Serve providing the desired response
MarkServed()
// Serve a request with the provided response. It can be used by filters that handle the requests
// themselves. FilterContext implementations should flag this state and prevent the filter chain
// from continuing
Serve(*http.Response)
// Provides the wildcard parameter values from the request path by their
// name as the key.
PathParam(string) string
// Provides a read-write state bag, unique to a request and shared by all
// the filters in the route.
StateBag() map[string]interface{}
// Gives filters access to the backend url specified in the route or an empty
// value in case it's a shunt, loopback. In case of dynamic backend is empty.
BackendUrl() string
// Returns the host that will be set for the outgoing proxy request as the
// 'Host' header.
OutgoingHost() string
// Allows explicitly setting the Host header to be sent to the backend, overriding the
// strategy used by the implementation, which can be either the Host header from the
// incoming request or the host fragment of the backend url.
//
// Filters that need to modify the outgoing 'Host' header, need to use
// this method instead of setting the Request().Headers["Host"] value.
// (The requestHeader filter automatically detects if the header name
// is 'Host' and calls this method.)
SetOutgoingHost(string)
// Allow filters to collect metrics other than the default metrics (Filter Request, Filter Response methods)
Metrics() Metrics
// Allow filters to add Tags, Baggage to the trace or set the ComponentName.
//
// Deprecated: OpenTracing is deprecated, see https://github.com/zalando/skipper/issues/2104.
// Use opentracing.SpanFromContext(ctx.Request().Context()).Tracer() to get the Tracer.
Tracer() opentracing.Tracer
// Allow filters to create their own spans
//
// Deprecated: OpenTracing is deprecated, see https://github.com/zalando/skipper/issues/2104.
// Filter spans should be children of the request span,
// use opentracing.SpanFromContext(ctx.Request().Context()) to get it.
ParentSpan() opentracing.Span
// Returns a clone of the FilterContext including a brand new request object.
// The stream body of the new request is shared with the original.
// Whenever the request body of the original request is read, the body of the
// new request body is written.
// The StateBag and filterMetrics object are not preserved in the new context.
// Therefore, you can't access state bag values set in the previous context.
Split() (FilterContext, error)
// Performs a new route lookup and executes the matched route if any
Loopback()
Logger() FilterContextLogger
}
// FilterContextLogger is the logger which logs messages with additional context information.
type FilterContextLogger interface {
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Warnf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// Metrics provides possibility to use custom metrics from filter implementations. The custom metrics will
// be exposed by the common metrics endpoint exposed by the proxy, where they can be accessed by the custom
// key prefixed by the filter name and the string 'custom'. E.g: <filtername>.custom.<customkey>.
type Metrics interface {
// MeasureSince adds values to a timer with a custom key.
MeasureSince(key string, start time.Time)
// IncCounter increments a custom counter identified by its key.
IncCounter(key string)
// IncCounterBy increments a custom counter identified by its key by a certain value.
IncCounterBy(key string, value int64)
// IncFloatCounterBy increments a custom counter identified by its key by a certain
// float (decimal) value. IMPORTANT: Not all Metrics implementation support float
// counters. In that case, a call to IncFloatCounterBy is dropped.
IncFloatCounterBy(key string, value float64)
}
// Filters are created by the Spec components, optionally using filter
// specific settings. When implementing filters, it needs to be taken
// into consideration, that filter instances are route specific and not
// request specific, so any state stored with a filter is shared between
// all requests for the same route and can cause concurrency issues.
type Filter interface {
// The Request method is called while processing the incoming request.
Request(FilterContext)
// The Response method is called while processing the response to be
// returned.
Response(FilterContext)
}
// FilterCloser are Filters that need to cleanup resources after
// filter termination. For example Filters, that create a goroutine
// for some reason need to cleanup their goroutine or they would leak
// goroutines.
type FilterCloser interface {
Filter
io.Closer
}
// Spec objects are specifications for filters. When initializing the routes,
// the Filter instances are created using the Spec objects found in the
// registry.
type Spec interface {
// Name gives the name of the Spec. It is used to identify filters in a route definition.
Name() string
// CreateFilter creates a Filter instance. Called with the parameters in the route
// definition while initializing a route.
CreateFilter(config []interface{}) (Filter, error)
}
// Registry used to lookup Spec objects while initializing routes.
type Registry map[string]Spec
// ErrInvalidFilterParameters is used in case of invalid filter parameters.
var ErrInvalidFilterParameters = errors.New("invalid filter parameters")
// Registers a filter specification.
func (r Registry) Register(s Spec) {
name := s.Name()
if _, ok := r[name]; ok {
log.Infof("Replacing %s filter specification", name)
}
r[name] = s
}
// All Skipper filter names
const (
BackendIsProxyName = "backendIsProxy"
CommentName = "comment"
AnnotateName = "annotate"
ModRequestHeaderName = "modRequestHeader"
SetRequestHeaderName = "setRequestHeader"
AppendRequestHeaderName = "appendRequestHeader"
DropRequestHeaderName = "dropRequestHeader"
ModResponseHeaderName = "modResponseHeader"
SetResponseHeaderName = "setResponseHeader"
AppendResponseHeaderName = "appendResponseHeader"
DropResponseHeaderName = "dropResponseHeader"
SetContextRequestHeaderName = "setContextRequestHeader"
AppendContextRequestHeaderName = "appendContextRequestHeader"
SetContextResponseHeaderName = "setContextResponseHeader"
AppendContextResponseHeaderName = "appendContextResponseHeader"
CopyRequestHeaderName = "copyRequestHeader"
CopyResponseHeaderName = "copyResponseHeader"
ModPathName = "modPath"
SetPathName = "setPath"
RedirectToName = "redirectTo"
RedirectToLowerName = "redirectToLower"
StaticName = "static"
StripQueryName = "stripQuery"
PreserveHostName = "preserveHost"
StatusName = "status"
CompressName = "compress"
DecompressName = "decompress"
SetQueryName = "setQuery"
DropQueryName = "dropQuery"
InlineContentName = "inlineContent"
InlineContentIfStatusName = "inlineContentIfStatus"
FlowIdName = "flowId"
XforwardName = "xforward"
XforwardFirstName = "xforwardFirst"
RandomContentName = "randomContent"
RepeatContentName = "repeatContent"
RepeatContentHexName = "repeatContentHex"
WrapContentName = "wrapContent"
WrapContentHexName = "wrapContentHex"
BackendTimeoutName = "backendTimeout"
ReadTimeoutName = "readTimeout"
WriteTimeoutName = "writeTimeout"
BlockName = "blockContent"
BlockHexName = "blockContentHex"
LatencyName = "latency"
BandwidthName = "bandwidth"
ChunksName = "chunks"
BackendLatencyName = "backendLatency"
BackendBandwidthName = "backendBandwidth"
BackendChunksName = "backendChunks"
TarpitName = "tarpit"
AbsorbName = "absorb"
AbsorbSilentName = "absorbSilent"
UniformRequestLatencyName = "uniformRequestLatency"
UniformResponseLatencyName = "uniformResponseLatency"
NormalRequestLatencyName = "normalRequestLatency"
NormalResponseLatencyName = "normalResponseLatency"
HistogramRequestLatencyName = "histogramRequestLatency"
HistogramResponseLatencyName = "histogramResponseLatency"
LogBodyName = "logBody"
LogHeaderName = "logHeader"
TeeName = "tee"
TeenfName = "teenf"
TeeLoopbackName = "teeLoopback"
SedName = "sed"
SedDelimName = "sedDelim"
SedRequestName = "sedRequest"
SedRequestDelimName = "sedRequestDelim"
BasicAuthName = "basicAuth"
WebhookName = "webhook"
OAuthTokeninfoAnyScopeName = "oauthTokeninfoAnyScope"
OAuthTokeninfoAllScopeName = "oauthTokeninfoAllScope"
OAuthTokeninfoAnyKVName = "oauthTokeninfoAnyKV"
OAuthTokeninfoAllKVName = "oauthTokeninfoAllKV"
OAuthTokeninfoValidateName = "oauthTokeninfoValidate"
OAuthTokenintrospectionAnyClaimsName = "oauthTokenintrospectionAnyClaims"
OAuthTokenintrospectionAllClaimsName = "oauthTokenintrospectionAllClaims"
OAuthTokenintrospectionAnyKVName = "oauthTokenintrospectionAnyKV"
OAuthTokenintrospectionAllKVName = "oauthTokenintrospectionAllKV"
SecureOAuthTokenintrospectionAnyClaimsName = "secureOauthTokenintrospectionAnyClaims"
SecureOAuthTokenintrospectionAllClaimsName = "secureOauthTokenintrospectionAllClaims"
SecureOAuthTokenintrospectionAnyKVName = "secureOauthTokenintrospectionAnyKV"
SecureOAuthTokenintrospectionAllKVName = "secureOauthTokenintrospectionAllKV"
ForwardTokenName = "forwardToken"
ForwardTokenFieldName = "forwardTokenField"
OAuthGrantName = "oauthGrant"
GrantCallbackName = "grantCallback"
GrantLogoutName = "grantLogout"
GrantClaimsQueryName = "grantClaimsQuery"
JwtValidationName = "jwtValidation"
JwtMetricsName = "jwtMetrics"
OAuthOidcUserInfoName = "oauthOidcUserInfo"
OAuthOidcAnyClaimsName = "oauthOidcAnyClaims"
OAuthOidcAllClaimsName = "oauthOidcAllClaims"
OidcClaimsQueryName = "oidcClaimsQuery"
DropRequestCookieName = "dropRequestCookie"
DropResponseCookieName = "dropResponseCookie"
RequestCookieName = "requestCookie"
ResponseCookieName = "responseCookie"
JsCookieName = "jsCookie"
ConsecutiveBreakerName = "consecutiveBreaker"
RateBreakerName = "rateBreaker"
DisableBreakerName = "disableBreaker"
AdmissionControlName = "admissionControl"
ClientRatelimitName = "clientRatelimit"
RatelimitName = "ratelimit"
ClusterClientRatelimitName = "clusterClientRatelimit"
ClusterRatelimitName = "clusterRatelimit"
ClusterLeakyBucketRatelimitName = "clusterLeakyBucketRatelimit"
BackendRateLimitName = "backendRatelimit"
RatelimitFailClosedName = "ratelimitFailClosed"
LuaName = "lua"
CorsOriginName = "corsOrigin"
HeaderToQueryName = "headerToQuery"
QueryToHeaderName = "queryToHeader"
DisableAccessLogName = "disableAccessLog"
EnableAccessLogName = "enableAccessLog"
AuditLogName = "auditLog"
UnverifiedAuditLogName = "unverifiedAuditLog"
SetDynamicBackendHostFromHeader = "setDynamicBackendHostFromHeader"
SetDynamicBackendSchemeFromHeader = "setDynamicBackendSchemeFromHeader"
SetDynamicBackendUrlFromHeader = "setDynamicBackendUrlFromHeader"
SetDynamicBackendHost = "setDynamicBackendHost"
SetDynamicBackendScheme = "setDynamicBackendScheme"
SetDynamicBackendUrl = "setDynamicBackendUrl"
ApiUsageMonitoringName = "apiUsageMonitoring"
FifoName = "fifo"
FifoWithBodyName = "fifoWithBody"
LifoName = "lifo"
LifoGroupName = "lifoGroup"
RfcPathName = "rfcPath"
RfcHostName = "rfcHost"
BearerInjectorName = "bearerinjector"
SetRequestHeaderFromSecretName = "setRequestHeaderFromSecret"
TracingBaggageToTagName = "tracingBaggageToTag"
StateBagToTagName = "stateBagToTag"
TracingTagName = "tracingTag"
TracingTagFromResponseName = "tracingTagFromResponse"
TracingTagFromResponseIfStatusName = "tracingTagFromResponseIfStatus"
TracingSpanNameName = "tracingSpanName"
OriginMarkerName = "originMarker"
FadeInName = "fadeIn"
EndpointCreatedName = "endpointCreated"
ConsistentHashKeyName = "consistentHashKey"
ConsistentHashBalanceFactorName = "consistentHashBalanceFactor"
OpaAuthorizeRequestName = "opaAuthorizeRequest"
OpaAuthorizeRequestWithBodyName = "opaAuthorizeRequestWithBody"
OpaServeResponseName = "opaServeResponse"
OpaServeResponseWithReqBodyName = "opaServeResponseWithReqBody"
TLSName = "tlsPassClientCertificates"
AWSSigV4Name = "awsSigv4"
// Undocumented filters
HealthCheckName = "healthcheck"
SetFastCgiFilenameName = "setFastCgiFilename"
DisableRatelimitName = "disableRatelimit"
UnknownRatelimitName = "unknownRatelimit"
)
package flowid
import (
"fmt"
"log"
"strings"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.FlowIdName instead
Name = filters.FlowIdName
ReuseParameterValue = "reuse"
HeaderName = "X-Flow-Id"
)
// Generator interface should be implemented by types that can generate request tracing Flow IDs.
type Generator interface {
// Generate returns a new Flow ID using the implementation specific format or an error in case of failure.
Generate() (string, error)
// MustGenerate behaves like Generate but panics on failure instead of returning an error.
MustGenerate() string
// IsValid asserts if a given flow ID follows an expected format
IsValid(string) bool
}
type flowIdSpec struct {
generator Generator
}
type flowId struct {
reuseExisting bool
generator Generator
}
// NewFlowId creates a new standard generator with the defined length and returns a Flow ID.
//
// Deprecated: For backward compatibility this exported function is still available but will removed in upcoming
// releases. Use the new Generator interface and respective implementations
func NewFlowId(l int) (string, error) {
g, err := NewStandardGenerator(l)
if err != nil {
return "", fmt.Errorf("deprecated new flowid: %w", err)
}
return g.Generate()
}
// New creates a new instance of the flowId filter spec which uses the StandardGenerator.
// To use another type of Generator use NewWithGenerator()
func New() *flowIdSpec {
g, err := NewStandardGenerator(defaultLen)
if err != nil {
panic(err)
}
return NewWithGenerator(g)
}
// New behaves like New but allows you to specify any other Generator.
func NewWithGenerator(g Generator) *flowIdSpec {
return &flowIdSpec{generator: g}
}
// Request will inspect the current Request for the presence of an X-Flow-Id header which will be kept in case the
// "reuse" flag has been set. In any other case it will set the same header with the value returned from the
// defined Flow ID Generator
func (f *flowId) Request(fc filters.FilterContext) {
r := fc.Request()
var flowId string
if f.reuseExisting {
flowId = r.Header.Get(HeaderName)
if f.generator.IsValid(flowId) {
return
}
}
flowId, err := f.generator.Generate()
if err == nil {
r.Header.Set(HeaderName, flowId)
} else {
log.Println(err)
}
}
// Response is No-Op in this filter
func (*flowId) Response(filters.FilterContext) {}
// CreateFilter will return a new flowId filter from the spec
// If at least 1 argument is present and it contains the value "reuse", the filter instance is configured to accept
// keep the value of the X-Flow-Id header, if it's already set
func (spec *flowIdSpec) CreateFilter(fc []interface{}) (filters.Filter, error) {
var reuseExisting bool
if len(fc) > 0 {
if r, ok := fc[0].(string); ok {
reuseExisting = strings.ToLower(r) == ReuseParameterValue
} else {
return nil, filters.ErrInvalidFilterParameters
}
if len(fc) > 1 {
log.Println("flow id filter warning: this syntax is deprecated and will be removed soon. " +
"please check updated docs")
}
}
return &flowId{reuseExisting: reuseExisting, generator: spec.generator}, nil
}
// Name returns the canonical filter name
func (*flowIdSpec) Name() string { return filters.FlowIdName }
package flowid
import (
"fmt"
"math/rand"
"regexp"
)
const (
flowIdAlphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-+"
alphabetBitMask = 63
MaxLength = 64
MinLength = 8
defaultLen = 16
)
var (
ErrInvalidLen = fmt.Errorf("invalid length, must be between %d and %d", MinLength, MaxLength)
standardFlowIDRegex = regexp.MustCompile(`^[0-9a-zA-Z+-]+$`)
)
type standardGenerator struct {
length int
}
// NewStandardGenerator creates a new FlowID generator that generates flow IDs with length l.
// The alphabet is limited to 64 elements and requires a random 6 bit value to index any of them.
// The cost to rnd.IntXX is not very relevant but the bit shifting operations are faster.
// For this reason a single call to rnd.Int63 is used and its bits are mapped up to 10 chunks of 6 bits each.
// The byte data type carries 2 additional bits for the next chunk which are cleared with the alphabet bit mask.
// It is safe for concurrent use.
func NewStandardGenerator(l int) (Generator, error) {
if l < MinLength || l > MaxLength {
return nil, ErrInvalidLen
}
return &standardGenerator{length: l}, nil
}
// Generate returns a new Flow ID from the built-in generator with the configured length
func (g *standardGenerator) Generate() (string, error) {
u := make([]byte, g.length)
for i := 0; i < g.length; i += 10 {
b := rand.Int63() // #nosec
for e := 0; e < 10 && i+e < g.length; e++ {
c := byte(b>>uint(6*e)) & alphabetBitMask // 6 bits only
u[i+e] = flowIdAlphabet[c]
}
}
return string(u), nil
}
// MustGenerate is a convenience function equivalent to Generate that panics on failure instead of returning an error.
func (g *standardGenerator) MustGenerate() string {
id, err := g.Generate()
if err != nil {
panic(err)
}
return id
}
// IsValid checks if the given flowId follows the format of this generator
func (g *standardGenerator) IsValid(flowId string) bool {
return len(flowId) >= MinLength && len(flowId) <= MaxLength && standardFlowIDRegex.MatchString(flowId)
}
package flowid
import (
"io"
"math/rand"
"regexp"
"sync"
"time"
"github.com/oklog/ulid"
)
const (
flowIDLength = 26
)
type ulidGenerator struct {
sync.Mutex
r io.Reader
}
var ulidFlowIDRegex = regexp.MustCompile(`^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$`)
// NewULIDGenerator returns a flow ID generator that is able to generate Universally Unique Lexicographically
// Sortable Identifier (ULID) flow IDs.
// It uses a shared, pseudo-random source of entropy, seeded with the current timestamp.
// It is safe for concurrent usage.
func NewULIDGenerator() Generator {
return NewULIDGeneratorWithEntropyProvider(rand.New(rand.NewSource(time.Now().UTC().UnixNano()))) // #nosec
}
// NewULIDGeneratorWithEntropyProvider behaves like NewULIDGenerator but allows you to specify your own source of
// entropy.
// Access to the entropy provider is safe for concurrent usage.
func NewULIDGeneratorWithEntropyProvider(r io.Reader) Generator {
return &ulidGenerator{r: r}
}
// Generate returns a random ULID flow ID or an empty string in case of failure. The returned error can be inspected
// to assess the failure reason
func (g *ulidGenerator) Generate() (string, error) {
g.Lock()
id, err := ulid.New(ulid.Now(), g.r)
g.Unlock()
if err != nil {
return "", err
}
return id.String(), nil
}
// MustGenerate behaves like Generate but panics in case of failure
func (g *ulidGenerator) MustGenerate() string {
flowId, err := g.Generate()
if err != nil {
panic(err)
}
return flowId
}
// IsValid checks if the given flowId follows the format of this generator
func (g *ulidGenerator) IsValid(flowId string) bool {
return len(flowId) == flowIDLength && ulidFlowIDRegex.MatchString(flowId)
}
/*
Package log provides a request logging filter, usable also for
audit logging. Audit logging is showing who did a request in case of
OAuth2 provider returns a "uid" key and value.
*/
package log
import (
"bytes"
"encoding/json"
"io"
"os"
"regexp"
"strings"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/jwt"
)
const (
// Deprecated, use filters.AuditLogName instead
AuditLogName = filters.AuditLogName
// AuthUserKey is used by the auth package to set the user
// information into the state bag to pass the information to
// the auditLog filter.
AuthUserKey = "auth-user"
// AuthRejectReasonKey is used by the auth package to set the
// reject reason information into the state bag to pass the
// information to the auditLog filter.
AuthRejectReasonKey = "auth-reject-reason"
// Deprecated, use filters.UnverifiedAuditLogName instead
UnverifiedAuditLogName = filters.UnverifiedAuditLogName
// UnverifiedAuditHeader is the name of the header added to the request which contains the unverified audit details
UnverifiedAuditHeader = "X-Unverified-Audit"
authHeaderName = "Authorization"
authHeaderPrefix = "Bearer "
defaultSub = "<invalid-sub>"
defaultUnverifiedAuditLogKey = "sub"
)
var (
re = regexp.MustCompile("^[a-zA-Z0-9_/:?=&%@.#-]*$")
)
type auditLog struct {
writer io.Writer
maxBodyLog int
}
type teeBody struct {
body io.ReadCloser
buffer *bytes.Buffer
teeReader io.Reader
maxTee int
}
type auditDoc struct {
Method string `json:"method"`
Path string `json:"path"`
Status int `json:"status"`
AuthStatus *authStatusDoc `json:"authStatus,omitempty"`
RequestBody string `json:"requestBody,omitempty"`
}
type authStatusDoc struct {
User string `json:"user,omitempty"`
Rejected bool `json:"rejected"`
Reason string `json:"reason,omitempty"`
}
func newTeeBody(rc io.ReadCloser, maxTee int) io.ReadCloser {
b := bytes.NewBuffer(nil)
tb := &teeBody{
body: rc,
buffer: b,
maxTee: maxTee}
tb.teeReader = io.TeeReader(rc, tb)
return tb
}
func (tb *teeBody) Read(b []byte) (int, error) { return tb.teeReader.Read(b) }
func (tb *teeBody) Close() error { return tb.body.Close() }
func (tb *teeBody) Write(b []byte) (int, error) {
if tb.maxTee < 0 {
return tb.buffer.Write(b)
}
wl := len(b)
if wl >= tb.maxTee {
wl = tb.maxTee
}
n, err := tb.buffer.Write(b[:wl])
if err != nil {
return n, err
}
tb.maxTee -= n
// lie to avoid short write
return len(b), nil
}
// NewAuditLog creates an auditLog filter specification. It expects a
// maxAuditBody attribute to limit the size of the log. It will use
// os.Stderr as writer for the output of the log entries.
//
// spec := NewAuditLog(1024)
func NewAuditLog(maxAuditBody int) filters.Spec {
return &auditLog{
writer: os.Stderr,
maxBodyLog: maxAuditBody,
}
}
func (al *auditLog) Name() string { return filters.AuditLogName }
// CreateFilter has no arguments. It creates the filter if the user
// specifies auditLog() in their route.
func (al *auditLog) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 0 {
return nil, filters.ErrInvalidFilterParameters
}
return &auditLog{writer: al.writer, maxBodyLog: al.maxBodyLog}, nil
}
func (al *auditLog) Request(ctx filters.FilterContext) {
if al.maxBodyLog != 0 {
ctx.Request().Body = newTeeBody(ctx.Request().Body, al.maxBodyLog)
}
}
func (al *auditLog) Response(ctx filters.FilterContext) {
req := ctx.Request()
rsp := ctx.Response()
doc := auditDoc{
Method: req.Method,
Path: req.URL.Path,
Status: rsp.StatusCode}
sb := ctx.StateBag()
au, _ := sb[AuthUserKey].(string)
rr, _ := sb[AuthRejectReasonKey].(string)
if au != "" || rr != "" {
doc.AuthStatus = &authStatusDoc{User: au}
if rr != "" {
doc.AuthStatus.Rejected = true
doc.AuthStatus.Reason = rr
}
}
if tb, ok := req.Body.(*teeBody); ok {
if tb.maxTee < 0 {
io.Copy(tb.buffer, tb.body)
} else {
io.CopyN(tb.buffer, tb.body, int64(tb.maxTee))
}
if tb.buffer.Len() > 0 {
doc.RequestBody = tb.buffer.String()
}
}
enc := json.NewEncoder(al.writer)
err := enc.Encode(&doc)
if err != nil {
ctx.Logger().Errorf("Failed to json encode auditDoc: %v", err)
}
}
type (
unverifiedAuditLogSpec struct{}
unverifiedAuditLogFilter struct {
TokenKeys []string
}
)
// NewUnverifiedAuditLog logs "Sub" of the middle part of a JWT Token. Or else, logs the requested JSON key if present
func NewUnverifiedAuditLog() filters.Spec { return &unverifiedAuditLogSpec{} }
func (ual *unverifiedAuditLogSpec) Name() string { return filters.UnverifiedAuditLogName }
// CreateFilter has no arguments. It creates the filter if the user
// specifies unverifiedAuditLog() in their route.
func (ual *unverifiedAuditLogSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
var len = len(args)
if len == 0 {
return &unverifiedAuditLogFilter{TokenKeys: []string{defaultUnverifiedAuditLogKey}}, nil
}
keys := make([]string, len)
for i := 0; i < len; i++ {
keyName, ok := args[i].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
keys[i] = keyName
}
return &unverifiedAuditLogFilter{TokenKeys: keys}, nil
}
func (ual *unverifiedAuditLogFilter) Request(ctx filters.FilterContext) {
req := ctx.Request()
ahead := req.Header.Get(authHeaderName)
tv := strings.TrimPrefix(ahead, authHeaderPrefix)
if tv == ahead {
return
}
token, err := jwt.Parse(tv)
if err != nil {
return
}
for i := 0; i < len(ual.TokenKeys); i++ {
if k, ok := token.Claims[ual.TokenKeys[i]]; ok {
if v, ok2 := k.(string); ok2 {
req.Header.Add(UnverifiedAuditHeader, cleanSub(v))
return
}
}
}
}
func (*unverifiedAuditLogFilter) Response(filters.FilterContext) {}
func cleanSub(s string) string {
if re.MatchString(s) {
return s
}
return defaultSub
}
package openpolicyagent
import (
"context"
"fmt"
"time"
ext_authz_v3_core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
ext_authz_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/open-policy-agent/opa-envoy-plugin/envoyauth"
"github.com/open-policy-agent/opa-envoy-plugin/opa/decisionlog"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/plugins/logs"
"github.com/open-policy-agent/opa/v1/server"
"github.com/open-policy-agent/opa/v1/topdown"
"github.com/opentracing/opentracing-go"
pbstruct "google.golang.org/protobuf/types/known/structpb"
)
func (opa *OpenPolicyAgentInstance) Eval(ctx context.Context, req *ext_authz_v3.CheckRequest) (*envoyauth.EvalResult, error) {
decisionId, err := opa.idGenerator.Generate()
if err != nil {
opa.Logger().WithFields(map[string]interface{}{"err": err}).Error("Unable to generate decision ID.")
return nil, err
}
err = setDecisionIdInRequest(req, decisionId)
if err != nil {
opa.Logger().WithFields(map[string]interface{}{"err": err}).Error("Unable to set decision ID in Request.")
return nil, err
}
result, stopeval, err := envoyauth.NewEvalResult(withDecisionID(decisionId))
if err != nil {
opa.Logger().WithFields(map[string]interface{}{"err": err}).Error("Unable to generate new result with decision ID.")
return nil, err
}
span := opentracing.SpanFromContext(ctx)
if span != nil {
span.SetTag("opa.decision_id", result.DecisionID)
}
var input map[string]interface{}
defer func() {
stopeval()
if topdown.IsCancel(err) {
// If the evaluation was canceled, we don't want to log the decision.
return
}
err := opa.logDecision(ctx, input, result, err)
if err != nil {
opa.Logger().WithFields(map[string]interface{}{"err": err}).Error("Unable to log decision to control plane.")
}
}()
if ctx.Err() != nil {
return nil, fmt.Errorf("check request timed out before query execution: %w", ctx.Err())
}
logger := opa.Logger().WithFields(map[string]interface{}{"decision-id": result.DecisionID})
input, err = envoyauth.RequestToInput(req, logger, nil, opa.EnvoyPluginConfig().SkipRequestBodyParse)
if err != nil {
return nil, fmt.Errorf("failed to convert request to input: %w", err)
}
inputValue, err := ast.InterfaceToValue(input)
if err != nil {
return nil, err
}
err = envoyauth.Eval(ctx, opa, inputValue, result)
if err != nil {
return nil, err
}
return result, nil
}
func setDecisionIdInRequest(req *ext_authz_v3.CheckRequest, decisionId string) error {
if req.Attributes.MetadataContext == nil {
req.Attributes.MetadataContext = &ext_authz_v3_core.Metadata{
FilterMetadata: map[string]*pbstruct.Struct{},
}
}
filterMeta, err := FormOpenPolicyAgentMetaDataObject(decisionId)
if err != nil {
return err
}
req.Attributes.MetadataContext.FilterMetadata["open_policy_agent"] = filterMeta
return nil
}
func FormOpenPolicyAgentMetaDataObject(decisionId string) (*pbstruct.Struct, error) {
innerFields := make(map[string]interface{})
innerFields["decision_id"] = decisionId
return pbstruct.NewStruct(innerFields)
}
func (opa *OpenPolicyAgentInstance) logDecision(ctx context.Context, input interface{}, result *envoyauth.EvalResult, err error) error {
info := &server.Info{
Timestamp: time.Now(),
Input: &input,
}
if opa.EnvoyPluginConfig().Path != "" {
info.Path = opa.EnvoyPluginConfig().Path
}
plugin := logs.Lookup(opa.manager)
if plugin == nil {
return nil
}
return decisionlog.LogDecision(ctx, plugin, info, result, err)
}
func withDecisionID(decisionID string) func(*envoyauth.EvalResult) {
return func(result *envoyauth.EvalResult) {
result.DecisionID = decisionID
}
}
package internal
import (
"context"
"encoding/json"
"github.com/open-policy-agent/opa/v1/config"
"github.com/open-policy-agent/opa/v1/plugins"
"github.com/open-policy-agent/opa/v1/plugins/bundle"
"github.com/open-policy-agent/opa/v1/plugins/discovery"
)
// ManualOverride is override the plugin trigger to manual trigger mode, allowing the openpolicyagent filter
// to take control of the trigger mechanism.
// * OnConfig will handle a general config change and will override the trigger mode for both discovery
// and bundle plugins.
// * OnConfigDiscovery will handle a config change via discovery mechanism and will only override
// the trigger mode for the bundle plugin as the discovery plugin is involved in the trigger for
// this config change.
// See https://github.com/open-policy-agent/opa/pull/6053 for details on the hooks.
type ManualOverride struct {
}
func (m *ManualOverride) OnConfig(ctx context.Context, config *config.Config) (*config.Config, error) {
config, err := discoveryPluginOverride(config)
if err != nil {
return nil, err
}
return bundlePluginConfigOverride(config)
}
func (m *ManualOverride) OnConfigDiscovery(ctx context.Context, config *config.Config) (*config.Config, error) {
return bundlePluginConfigOverride(config)
}
func discoveryPluginOverride(config *config.Config) (*config.Config, error) {
var (
discoveryConfig discovery.Config
triggerManual = plugins.TriggerManual
message []byte
)
if config.Discovery != nil {
err := json.Unmarshal(config.Discovery, &discoveryConfig)
if err != nil {
return nil, err
}
discoveryConfig.Trigger = &triggerManual
message, err = json.Marshal(discoveryConfig)
if err != nil {
return nil, err
}
config.Discovery = message
}
return config, nil
}
func bundlePluginConfigOverride(config *config.Config) (*config.Config, error) {
var (
bundlesConfig map[string]*bundle.Source
manualTrigger = plugins.TriggerManual
message []byte
)
if config.Bundles != nil {
err := json.Unmarshal(config.Bundles, &bundlesConfig)
if err != nil {
return nil, err
}
for _, bndlCfg := range bundlesConfig {
bndlCfg.Trigger = &manualTrigger
}
message, err = json.Marshal(bundlesConfig)
if err != nil {
return nil, err
}
config.Bundles = message
}
return config, nil
}
package envoy
import (
"context"
"strconv"
"strings"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/plugins"
"github.com/open-policy-agent/opa/v1/util"
)
// Factory defines the interface OPA uses to instantiate a plugin.
type Factory struct{}
// New returns the object initialized with a valid plugin configuration.
func (Factory) New(m *plugins.Manager, cfg interface{}) plugins.Plugin {
p := &Plugin{
manager: m,
cfg: *cfg.(*PluginConfig),
}
m.UpdatePluginStatus(PluginName, &plugins.Status{State: plugins.StateNotReady})
return p
}
// Validate returns a valid configuration to instantiate the plugin.
func (Factory) Validate(m *plugins.Manager, bs []byte) (interface{}, error) {
cfg := PluginConfig{
DryRun: defaultDryRun,
}
if err := util.Unmarshal(bs, &cfg); err != nil {
return nil, err
}
if err := cfg.ParseQuery(); err != nil {
return nil, err
}
return &cfg, nil
}
func (p *Plugin) Reconfigure(ctx context.Context, config interface{}) {
p.cfg = *config.(*PluginConfig)
}
// PluginConfig represents the plugin configuration.
type PluginConfig struct {
Path string `json:"path"`
DryRun bool `json:"dry-run"`
SkipRequestBodyParse bool `json:"skip-request-body-parse"`
ParsedQuery ast.Body
}
type Plugin struct {
cfg PluginConfig
manager *plugins.Manager
}
func (p *Plugin) Start(ctx context.Context) error {
p.manager.UpdatePluginStatus(PluginName, &plugins.Status{State: plugins.StateOK})
return nil
}
func (cfg *PluginConfig) ParseQuery() error {
var parsedQuery ast.Body
var err error
if cfg.Path == "" {
cfg.Path = defaultPath
}
path := stringPathToDataRef(cfg.Path)
parsedQuery, err = ast.ParseBody(path.String())
if err != nil {
return err
}
cfg.ParsedQuery = parsedQuery
return nil
}
func (p *Plugin) Stop(ctx context.Context) {
p.manager.UpdatePluginStatus(PluginName, &plugins.Status{State: plugins.StateNotReady})
}
func (p *Plugin) GetConfig() PluginConfig {
return p.cfg
}
func (p *Plugin) ParsedQuery() ast.Body { return p.cfg.ParsedQuery }
func (p *Plugin) Path() string { return p.cfg.Path }
func stringPathToDataRef(s string) (r ast.Ref) {
result := ast.Ref{ast.DefaultRootDocument}
result = append(result, stringPathToRef(s)...)
return result
}
func stringPathToRef(s string) (r ast.Ref) {
if len(s) == 0 {
return r
}
p := strings.Split(s, "/")
for _, x := range p {
if x == "" {
continue
}
i, err := strconv.Atoi(x)
if err != nil {
r = append(r, ast.StringTerm(x))
} else {
r = append(r, ast.IntNumberTerm(i))
}
}
return r
}
package envoy
import (
"fmt"
"net/http"
"net/url"
"strings"
"unicode/utf8"
ext_authz_v3_core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
ext_authz_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
)
func AdaptToExtAuthRequest(req *http.Request, metadata *ext_authz_v3_core.Metadata, contextExtensions map[string]string, rawBody []byte) (*ext_authz_v3.CheckRequest, error) {
if err := validateURLForInvalidUTF8(req.URL); err != nil {
return nil, fmt.Errorf("invalid url: %w", err)
}
headers := make(map[string]string, len(req.Header))
for h, vv := range req.Header {
// This makes headers in the input compatible with what Envoy does, i.e. allows to use policy fragments designed for envoy
// See: https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/header_casing#http-1-1-header-casing
headers[strings.ToLower(h)] = strings.Join(vv, ", ")
}
ereq := &ext_authz_v3.CheckRequest{
Attributes: &ext_authz_v3.AttributeContext{
Request: &ext_authz_v3.AttributeContext_Request{
Http: &ext_authz_v3.AttributeContext_HttpRequest{
Host: req.Host,
Method: req.Method,
Path: req.URL.RequestURI(),
Headers: headers,
RawBody: rawBody,
},
},
ContextExtensions: contextExtensions,
MetadataContext: metadata,
},
}
return ereq, nil
}
func validateURLForInvalidUTF8(u *url.URL) error {
if !utf8.ValidString(u.Path) {
return fmt.Errorf("invalid utf8 in path: %q", u.Path)
}
decodedQuery, err := url.QueryUnescape(u.RawQuery)
if err != nil {
return fmt.Errorf("error unescaping query string %q: %w", u.RawQuery, err)
}
if !utf8.ValidString(decodedQuery) {
return fmt.Errorf("invalid utf8 in query: %q", u.RawQuery)
}
return nil
}
package opaauthorizerequest
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
ext_authz_v3_core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
"github.com/zalando/skipper/filters"
"gopkg.in/yaml.v2"
"github.com/zalando/skipper/filters/openpolicyagent"
"github.com/zalando/skipper/filters/openpolicyagent/internal/envoy"
)
const responseHeadersKey = "open-policy-agent:decision-response-headers"
type spec struct {
registry *openpolicyagent.OpenPolicyAgentRegistry
opts []func(*openpolicyagent.OpenPolicyAgentInstanceConfig) error
name string
bodyParsing bool
}
func NewOpaAuthorizeRequestSpec(registry *openpolicyagent.OpenPolicyAgentRegistry, opts ...func(*openpolicyagent.OpenPolicyAgentInstanceConfig) error) filters.Spec {
return &spec{
registry: registry,
opts: opts,
name: filters.OpaAuthorizeRequestName,
}
}
func NewOpaAuthorizeRequestWithBodySpec(registry *openpolicyagent.OpenPolicyAgentRegistry, opts ...func(*openpolicyagent.OpenPolicyAgentInstanceConfig) error) filters.Spec {
return &spec{
registry: registry,
opts: opts,
name: filters.OpaAuthorizeRequestWithBodyName,
bodyParsing: true,
}
}
func (s *spec) Name() string {
return s.name
}
func (s *spec) CreateFilter(args []interface{}) (filters.Filter, error) {
var err error
if len(args) < 1 {
return nil, filters.ErrInvalidFilterParameters
}
if len(args) > 2 {
return nil, filters.ErrInvalidFilterParameters
}
bundleName, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
envoyContextExtensions := map[string]string{}
if len(args) > 1 {
_, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
err = yaml.Unmarshal([]byte(args[1].(string)), &envoyContextExtensions)
if err != nil {
return nil, err
}
}
configOptions := s.opts
opaConfig, err := openpolicyagent.NewOpenPolicyAgentConfig(configOptions...)
if err != nil {
return nil, err
}
opa, err := s.registry.NewOpenPolicyAgentInstance(bundleName, *opaConfig, s.Name())
if err != nil {
return nil, err
}
return &opaAuthorizeRequestFilter{
opa: opa,
registry: s.registry,
envoyContextExtensions: envoyContextExtensions,
bodyParsing: s.bodyParsing,
}, nil
}
type opaAuthorizeRequestFilter struct {
opa *openpolicyagent.OpenPolicyAgentInstance
registry *openpolicyagent.OpenPolicyAgentRegistry
envoyContextExtensions map[string]string
bodyParsing bool
}
func (f *opaAuthorizeRequestFilter) Request(fc filters.FilterContext) {
req := fc.Request()
span, ctx := f.opa.StartSpanFromFilterContext(fc)
defer span.Finish()
var rawBodyBytes []byte
if f.bodyParsing {
var body io.ReadCloser
var err error
var finalizer func()
body, rawBodyBytes, finalizer, err = f.opa.ExtractHttpBodyOptionally(req)
defer finalizer()
if err != nil {
f.opa.HandleInvalidDecisionError(fc, span, nil, err, !f.opa.EnvoyPluginConfig().DryRun)
return
}
req.Body = body
}
authzreq, err := envoy.AdaptToExtAuthRequest(req, f.opa.InstanceConfig().GetEnvoyMetadata(), f.envoyContextExtensions, rawBodyBytes)
if err != nil {
f.opa.HandleEvaluationError(fc, span, nil, err, !f.opa.EnvoyPluginConfig().DryRun, http.StatusBadRequest)
return
}
start := time.Now()
result, err := f.opa.Eval(ctx, authzreq)
fc.Metrics().MeasureSince(f.opa.MetricsKey("eval_time"), start)
var jsonErr *json.SyntaxError
if errors.As(err, &jsonErr) {
f.opa.HandleEvaluationError(fc, span, result, err, !f.opa.EnvoyPluginConfig().DryRun, http.StatusBadRequest)
return
}
if err != nil {
f.opa.HandleInvalidDecisionError(fc, span, result, err, !f.opa.EnvoyPluginConfig().DryRun)
return
}
if f.opa.EnvoyPluginConfig().DryRun {
return
}
allowed, err := result.IsAllowed()
if err != nil {
f.opa.HandleInvalidDecisionError(fc, span, result, err, !f.opa.EnvoyPluginConfig().DryRun)
return
}
span.SetTag("opa.decision.allowed", allowed)
if !allowed {
fc.Metrics().IncCounter(f.opa.MetricsKey("decision.deny"))
f.opa.ServeResponse(fc, span, result)
return
}
fc.Metrics().IncCounter(f.opa.MetricsKey("decision.allow"))
headersToRemove, err := result.GetRequestHTTPHeadersToRemove()
if err != nil {
f.opa.HandleInvalidDecisionError(fc, span, result, err, !f.opa.EnvoyPluginConfig().DryRun)
return
}
removeRequestHeaders(fc, headersToRemove)
headers, err := result.GetResponseHTTPHeaders()
if err != nil {
f.opa.HandleInvalidDecisionError(fc, span, result, err, !f.opa.EnvoyPluginConfig().DryRun)
return
}
addRequestHeaders(fc, headers)
if responseHeaders, err := result.GetResponseHTTPHeadersToAdd(); err != nil {
f.opa.HandleInvalidDecisionError(fc, span, result, err, !f.opa.EnvoyPluginConfig().DryRun)
return
} else if len(responseHeaders) > 0 {
fc.StateBag()[responseHeadersKey] = responseHeaders
}
}
func removeRequestHeaders(fc filters.FilterContext, headersToRemove []string) {
for _, header := range headersToRemove {
fc.Request().Header.Del(header)
}
}
func addRequestHeaders(fc filters.FilterContext, headers http.Header) {
for key, values := range headers {
for _, value := range values {
// This is the default behavior from https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/core/v3/base.proto#config-core-v3-headervalueoption
fc.Request().Header.Add(key, value)
}
}
}
func (f *opaAuthorizeRequestFilter) Response(fc filters.FilterContext) {
if headers, ok := fc.StateBag()[responseHeadersKey].([]*ext_authz_v3_core.HeaderValueOption); ok {
addResponseHeaders(fc, headers)
}
}
func addResponseHeaders(fc filters.FilterContext, headersToAdd []*ext_authz_v3_core.HeaderValueOption) {
for _, headerToAdd := range headersToAdd {
header := headerToAdd.GetHeader()
fc.Response().Header.Add(header.GetKey(), header.GetValue())
}
}
func (f *opaAuthorizeRequestFilter) OpenPolicyAgent() *openpolicyagent.OpenPolicyAgentInstance {
return f.opa
}
package opaserveresponse
import (
"io"
"net/http"
"time"
"github.com/zalando/skipper/filters"
"gopkg.in/yaml.v2"
"github.com/zalando/skipper/filters/openpolicyagent"
"github.com/zalando/skipper/filters/openpolicyagent/internal/envoy"
)
type spec struct {
registry *openpolicyagent.OpenPolicyAgentRegistry
opts []func(*openpolicyagent.OpenPolicyAgentInstanceConfig) error
name string
bodyParsing bool
}
func NewOpaServeResponseSpec(registry *openpolicyagent.OpenPolicyAgentRegistry, opts ...func(*openpolicyagent.OpenPolicyAgentInstanceConfig) error) filters.Spec {
return &spec{
registry: registry,
opts: opts,
name: filters.OpaServeResponseName,
bodyParsing: false,
}
}
func NewOpaServeResponseWithReqBodySpec(registry *openpolicyagent.OpenPolicyAgentRegistry, opts ...func(*openpolicyagent.OpenPolicyAgentInstanceConfig) error) filters.Spec {
return &spec{
registry: registry,
opts: opts,
name: filters.OpaServeResponseWithReqBodyName,
bodyParsing: true,
}
}
func (s *spec) Name() string {
return s.name
}
func (s *spec) CreateFilter(args []interface{}) (filters.Filter, error) {
var err error
if len(args) < 1 {
return nil, filters.ErrInvalidFilterParameters
}
if len(args) > 2 {
return nil, filters.ErrInvalidFilterParameters
}
bundleName, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
envoyContextExtensions := map[string]string{}
if len(args) > 1 {
_, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
err = yaml.Unmarshal([]byte(args[1].(string)), &envoyContextExtensions)
if err != nil {
return nil, err
}
}
configOptions := s.opts
opaConfig, err := openpolicyagent.NewOpenPolicyAgentConfig(configOptions...)
if err != nil {
return nil, err
}
opa, err := s.registry.NewOpenPolicyAgentInstance(bundleName, *opaConfig, s.Name())
if err != nil {
return nil, err
}
return &opaServeResponseFilter{
opa: opa,
registry: s.registry,
envoyContextExtensions: envoyContextExtensions,
bodyParsing: s.bodyParsing,
}, nil
}
type opaServeResponseFilter struct {
opa *openpolicyagent.OpenPolicyAgentInstance
registry *openpolicyagent.OpenPolicyAgentRegistry
envoyContextExtensions map[string]string
bodyParsing bool
}
func (f *opaServeResponseFilter) Request(fc filters.FilterContext) {
span, ctx := f.opa.StartSpanFromFilterContext(fc)
defer span.Finish()
req := fc.Request()
var rawBodyBytes []byte
if f.bodyParsing {
var body io.ReadCloser
var err error
var finalizer func()
body, rawBodyBytes, finalizer, err = f.opa.ExtractHttpBodyOptionally(req)
defer finalizer()
if err != nil {
f.opa.ServeInvalidDecisionError(fc, span, nil, err)
return
}
req.Body = body
}
authzreq, err := envoy.AdaptToExtAuthRequest(fc.Request(), f.opa.InstanceConfig().GetEnvoyMetadata(), f.envoyContextExtensions, rawBodyBytes)
if err != nil {
f.opa.HandleEvaluationError(fc, span, nil, err, !f.opa.EnvoyPluginConfig().DryRun, http.StatusBadRequest)
return
}
start := time.Now()
result, err := f.opa.Eval(ctx, authzreq)
fc.Metrics().MeasureSince(f.opa.MetricsKey("eval_time"), start)
if err != nil {
f.opa.ServeInvalidDecisionError(fc, span, result, err)
return
}
allowed, err := result.IsAllowed()
if err != nil {
f.opa.ServeInvalidDecisionError(fc, span, result, err)
return
}
span.SetTag("opa.decision.allowed", allowed)
if allowed {
fc.Metrics().IncCounter(f.opa.MetricsKey("decision.allow"))
} else {
fc.Metrics().IncCounter(f.opa.MetricsKey("decision.deny"))
}
f.opa.ServeResponse(fc, span, result)
}
func (f *opaServeResponseFilter) Response(fc filters.FilterContext) {}
func (f *opaServeResponseFilter) OpenPolicyAgent() *openpolicyagent.OpenPolicyAgentInstance {
return f.opa
}
package openpolicyagent
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"maps"
"math/rand"
"net/http"
"net/url"
"os"
"slices"
"strings"
"sync"
"text/template"
"time"
"google.golang.org/protobuf/proto"
ext_authz_v3_core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
"github.com/google/uuid"
"github.com/open-policy-agent/opa-envoy-plugin/envoyauth"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/config"
"github.com/open-policy-agent/opa/v1/download"
"github.com/open-policy-agent/opa/v1/hooks"
"github.com/open-policy-agent/opa/v1/logging"
"github.com/open-policy-agent/opa/v1/plugins"
"github.com/open-policy-agent/opa/v1/plugins/discovery"
"github.com/open-policy-agent/opa/v1/rego"
"github.com/open-policy-agent/opa/v1/runtime"
"github.com/open-policy-agent/opa/v1/storage"
"github.com/open-policy-agent/opa/v1/storage/inmem"
iCache "github.com/open-policy-agent/opa/v1/topdown/cache"
opatracing "github.com/open-policy-agent/opa/v1/tracing"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters/openpolicyagent/internal"
"golang.org/x/sync/semaphore"
"google.golang.org/protobuf/encoding/protojson"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/flowid"
"github.com/zalando/skipper/filters/openpolicyagent/internal/envoy"
"github.com/zalando/skipper/routing"
"github.com/zalando/skipper/tracing"
)
const (
DefaultCleanIdlePeriod = 10 * time.Second
DefaultControlLoopInterval = 60 * time.Second
DefaultControlLoopMaxJitter = 1000 * time.Millisecond
defaultReuseDuration = 30 * time.Second
defaultShutdownGracePeriod = 30 * time.Second
DefaultOpaStartupTimeout = 30 * time.Second
DefaultMaxRequestBodySize = 1 << 20 // 1 MB
DefaultMaxMemoryBodyParsing = 100 * DefaultMaxRequestBodySize
DefaultRequestBodyBufferSize = 8 * 1024 // 8 KB
spanNameEval = "open-policy-agent"
)
type OpenPolicyAgentRegistry struct {
// Ideally share one Bundle storage across many OPA "instances" using this registry.
// This allows to save memory on bundles that are shared
// between different policies (i.e. global team memberships)
// This not possible due to some limitations in OPA
// See https://github.com/open-policy-agent/opa/issues/5707
mu sync.Mutex
instances map[string]*OpenPolicyAgentInstance
lastused map[*OpenPolicyAgentInstance]time.Time
once sync.Once
closed bool
quit chan struct{}
reuseDuration time.Duration
cleanInterval time.Duration
instanceStartupTimeout time.Duration
maxMemoryBodyParsingSem *semaphore.Weighted
maxRequestBodyBytes int64
bodyReadBufferSize int64
tracer opentracing.Tracer
enableCustomControlLoop bool
controlLoopInterval time.Duration
controlLoopMaxJitter time.Duration
enableDataPreProcessingOptimization bool
}
type OpenPolicyAgentFilter interface {
OpenPolicyAgent() *OpenPolicyAgentInstance
}
func WithReuseDuration(duration time.Duration) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.reuseDuration = duration
return nil
}
}
func WithMaxRequestBodyBytes(n int64) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.maxRequestBodyBytes = n
return nil
}
}
func WithMaxMemoryBodyParsing(n int64) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.maxMemoryBodyParsingSem = semaphore.NewWeighted(n)
return nil
}
}
func WithReadBodyBufferSize(n int64) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.bodyReadBufferSize = n
return nil
}
}
func WithCleanInterval(interval time.Duration) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.cleanInterval = interval
return nil
}
}
func WithInstanceStartupTimeout(timeout time.Duration) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.instanceStartupTimeout = timeout
return nil
}
}
func WithTracer(tracer opentracing.Tracer) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.tracer = tracer
return nil
}
}
func WithEnableCustomControlLoop(enabled bool) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.enableCustomControlLoop = enabled
return nil
}
}
func WithEnableDataPreProcessingOptimization(enabled bool) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.enableDataPreProcessingOptimization = enabled
return nil
}
}
func WithControlLoopInterval(interval time.Duration) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.controlLoopInterval = interval
return nil
}
}
func WithControlLoopMaxJitter(maxJitter time.Duration) func(*OpenPolicyAgentRegistry) error {
return func(cfg *OpenPolicyAgentRegistry) error {
cfg.controlLoopMaxJitter = maxJitter
return nil
}
}
func NewOpenPolicyAgentRegistry(opts ...func(*OpenPolicyAgentRegistry) error) *OpenPolicyAgentRegistry {
registry := &OpenPolicyAgentRegistry{
reuseDuration: defaultReuseDuration,
cleanInterval: DefaultCleanIdlePeriod,
instanceStartupTimeout: DefaultOpaStartupTimeout,
instances: make(map[string]*OpenPolicyAgentInstance),
lastused: make(map[*OpenPolicyAgentInstance]time.Time),
quit: make(chan struct{}),
maxRequestBodyBytes: DefaultMaxMemoryBodyParsing,
bodyReadBufferSize: DefaultRequestBodyBufferSize,
controlLoopInterval: DefaultControlLoopInterval,
controlLoopMaxJitter: DefaultControlLoopMaxJitter,
}
for _, opt := range opts {
opt(registry)
}
if registry.maxMemoryBodyParsingSem == nil {
registry.maxMemoryBodyParsingSem = semaphore.NewWeighted(DefaultMaxMemoryBodyParsing)
}
go registry.startCleanerDaemon()
if registry.enableCustomControlLoop {
go registry.startCustomControlLoopDaemon()
}
return registry
}
type OpenPolicyAgentInstanceConfig struct {
envoyMetadata *ext_authz_v3_core.Metadata
configTemplate []byte
}
func WithConfigTemplate(configTemplate []byte) func(*OpenPolicyAgentInstanceConfig) error {
return func(cfg *OpenPolicyAgentInstanceConfig) error {
cfg.configTemplate = configTemplate
return nil
}
}
func WithConfigTemplateFile(configTemplateFile string) func(*OpenPolicyAgentInstanceConfig) error {
return func(cfg *OpenPolicyAgentInstanceConfig) error {
var err error
cfg.configTemplate, err = os.ReadFile(configTemplateFile)
return err
}
}
func WithEnvoyMetadata(metadata *ext_authz_v3_core.Metadata) func(*OpenPolicyAgentInstanceConfig) error {
return func(cfg *OpenPolicyAgentInstanceConfig) error {
cfg.envoyMetadata = metadata
return nil
}
}
func WithEnvoyMetadataBytes(content []byte) func(*OpenPolicyAgentInstanceConfig) error {
return func(cfg *OpenPolicyAgentInstanceConfig) error {
cfg.envoyMetadata = &ext_authz_v3_core.Metadata{}
return protojson.Unmarshal(content, cfg.envoyMetadata)
}
}
func WithEnvoyMetadataFile(file string) func(*OpenPolicyAgentInstanceConfig) error {
return func(cfg *OpenPolicyAgentInstanceConfig) error {
content, err := os.ReadFile(file)
if err != nil {
return err
}
err = WithEnvoyMetadataBytes(content)(cfg)
if err != nil {
return fmt.Errorf("cannot parse '%q': %w", file, err)
}
return nil
}
}
func (cfg *OpenPolicyAgentInstanceConfig) GetEnvoyMetadata() *ext_authz_v3_core.Metadata {
if cfg.envoyMetadata != nil {
return proto.Clone(cfg.envoyMetadata).(*ext_authz_v3_core.Metadata)
}
return nil
}
func NewOpenPolicyAgentConfig(opts ...func(*OpenPolicyAgentInstanceConfig) error) (*OpenPolicyAgentInstanceConfig, error) {
cfg := OpenPolicyAgentInstanceConfig{}
for _, opt := range opts {
if err := opt(&cfg); err != nil {
return nil, err
}
}
if cfg.configTemplate == nil {
var err error
cfg.configTemplate, err = os.ReadFile("opaconfig.yaml")
if err != nil {
return nil, err
}
}
return &cfg, nil
}
func (registry *OpenPolicyAgentRegistry) Close() {
registry.once.Do(func() {
registry.mu.Lock()
defer registry.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownGracePeriod)
defer cancel()
for _, instance := range registry.instances {
instance.Close(ctx)
}
registry.closed = true
close(registry.quit)
})
}
func (registry *OpenPolicyAgentRegistry) cleanUnusedInstances(t time.Time) {
registry.mu.Lock()
defer registry.mu.Unlock()
if registry.closed {
return
}
ctx, cancel := context.WithTimeout(context.Background(), defaultShutdownGracePeriod)
defer cancel()
for key, inst := range registry.instances {
lastused, ok := registry.lastused[inst]
if ok && t.Sub(lastused) > registry.reuseDuration {
inst.Close(ctx)
delete(registry.instances, key)
delete(registry.lastused, inst)
}
}
}
func (registry *OpenPolicyAgentRegistry) startCleanerDaemon() {
ticker := time.NewTicker(registry.cleanInterval)
defer ticker.Stop()
for {
select {
case <-registry.quit:
return
case t := <-ticker.C:
registry.cleanUnusedInstances(t)
}
}
}
// startCustomControlLoopDaemon starts a custom control loop that triggers the discovery and bundle plugin for all OPA instances in the registry.
// The processing is done in sequence to avoid memory spikes if the bundles of multiple instances are updated at the same time.
// The timeout for the processing of each instance is set to the startup timeout to ensure that the behavior is the same as during startup
// It is accepted that runs can be skipped if the processing of all instances takes longer than the interval.
func (registry *OpenPolicyAgentRegistry) startCustomControlLoopDaemon() {
ticker := time.NewTicker(registry.controlLoopIntervalWithJitter())
defer ticker.Stop()
for {
select {
case <-registry.quit:
return
case <-ticker.C:
registry.mu.Lock()
instances := slices.Collect(maps.Values(registry.instances))
registry.mu.Unlock()
for _, opa := range instances {
func() {
ctx, cancel := context.WithTimeout(context.Background(), registry.instanceStartupTimeout)
defer cancel()
opa.triggerPlugins(ctx)
}()
}
ticker.Reset(registry.controlLoopIntervalWithJitter())
}
}
}
// Prevent different opa instances from triggering plugins (f.ex. downloading new bundles) at the same time
func (registry *OpenPolicyAgentRegistry) controlLoopIntervalWithJitter() time.Duration {
if registry.controlLoopMaxJitter > 0 {
return registry.controlLoopInterval + time.Duration(rand.Int63n(int64(registry.controlLoopMaxJitter))) - registry.controlLoopMaxJitter/2
}
return registry.controlLoopInterval
}
// Do implements routing.PostProcessor and cleans unused OPA instances
func (registry *OpenPolicyAgentRegistry) Do(routes []*routing.Route) []*routing.Route {
inUse := make(map[*OpenPolicyAgentInstance]struct{})
for _, ri := range routes {
for _, fi := range ri.Filters {
if ff, ok := fi.Filter.(OpenPolicyAgentFilter); ok {
inUse[ff.OpenPolicyAgent()] = struct{}{}
}
}
}
registry.markUnused(inUse)
return routes
}
func (registry *OpenPolicyAgentRegistry) NewOpenPolicyAgentInstance(bundleName string, config OpenPolicyAgentInstanceConfig, filterName string) (*OpenPolicyAgentInstance, error) {
registry.mu.Lock()
defer registry.mu.Unlock()
if registry.closed {
return nil, fmt.Errorf("open policy agent registry is already closed")
}
if instance, ok := registry.instances[bundleName]; ok {
delete(registry.lastused, instance)
return instance, nil
}
instance, err := registry.newOpenPolicyAgentInstance(bundleName, config, filterName)
if err != nil {
return nil, err
}
registry.instances[bundleName] = instance
return instance, nil
}
func (registry *OpenPolicyAgentRegistry) markUnused(inUse map[*OpenPolicyAgentInstance]struct{}) {
registry.mu.Lock()
defer registry.mu.Unlock()
for _, instance := range registry.instances {
if _, ok := inUse[instance]; !ok {
registry.lastused[instance] = time.Now()
}
}
}
func (registry *OpenPolicyAgentRegistry) newOpenPolicyAgentInstance(bundleName string, config OpenPolicyAgentInstanceConfig, filterName string) (*OpenPolicyAgentInstance, error) {
runtime.RegisterPlugin(envoy.PluginName, envoy.Factory{})
configBytes, err := interpolateConfigTemplate(config.configTemplate, bundleName)
if err != nil {
return nil, err
}
engine, err := registry.new(inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(registry.enableDataPreProcessingOptimization)), configBytes, config, filterName, bundleName,
registry.maxRequestBodyBytes, registry.bodyReadBufferSize)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), registry.instanceStartupTimeout)
defer cancel()
if registry.enableCustomControlLoop {
if err = engine.StartAndTriggerPlugins(ctx); err != nil {
return nil, err
}
} else {
if err = engine.Start(ctx, registry.instanceStartupTimeout); err != nil {
return nil, err
}
}
return engine, nil
}
type OpenPolicyAgentInstance struct {
manager *plugins.Manager
instanceConfig OpenPolicyAgentInstanceConfig
opaConfig *config.Config
bundleName string
preparedQuery *rego.PreparedEvalQuery
preparedQueryDoOnce *sync.Once
preparedQueryErr error
interQueryBuiltinCache iCache.InterQueryCache
once sync.Once
closing bool
registry *OpenPolicyAgentRegistry
maxBodyBytes int64
bodyReadBufferSize int64
idGenerator flowid.Generator
}
func envVariablesMap() map[string]string {
rawEnvVariables := os.Environ()
envVariables := make(map[string]string)
for _, item := range rawEnvVariables {
tokens := strings.SplitN(item, "=", 2)
envVariables[tokens[0]] = tokens[1]
}
return envVariables
}
// Config sets the configuration file to use on the OPA instance.
func interpolateConfigTemplate(configTemplate []byte, bundleName string) ([]byte, error) {
var buf bytes.Buffer
tpl := template.Must(template.New("opa-config").Parse(string(configTemplate)))
binding := make(map[string]interface{})
binding["bundlename"] = bundleName
binding["Env"] = envVariablesMap()
err := tpl.ExecuteTemplate(&buf, "opa-config", binding)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func buildTracingOptions(tracer opentracing.Tracer, bundleName string, manager *plugins.Manager) opatracing.Options {
return opatracing.NewOptions(WithTracingOptTracer(tracer), WithTracingOptBundleName(bundleName), WithTracingOptManager(manager))
}
func (registry *OpenPolicyAgentRegistry) withTracingOptions(bundleName string) func(*plugins.Manager) {
return func(m *plugins.Manager) {
options := buildTracingOptions(
registry.tracer,
bundleName,
m,
)
plugins.WithDistributedTracingOpts(options)(m)
}
}
// new returns a new OPA object.
func (registry *OpenPolicyAgentRegistry) new(store storage.Store, configBytes []byte, instanceConfig OpenPolicyAgentInstanceConfig, filterName string, bundleName string, maxBodyBytes int64, bodyReadBufferSize int64) (*OpenPolicyAgentInstance, error) {
id := uuid.New().String()
uniqueIDGenerator, err := flowid.NewStandardGenerator(32)
if err != nil {
return nil, err
}
opaConfig, err := config.ParseConfig(configBytes, id)
if err != nil {
return nil, err
}
runtime.RegisterPlugin(envoy.PluginName, envoy.Factory{})
var logger logging.Logger = &QuietLogger{target: logging.Get()}
logger = logger.WithFields(map[string]interface{}{"skipper-filter": filterName, "bundle-name": bundleName})
configHooks := hooks.New()
if registry.enableCustomControlLoop {
configHooks = hooks.New(&internal.ManualOverride{})
}
manager, err := plugins.New(configBytes, id, store, configLabelsInfo(*opaConfig), plugins.Logger(logger), registry.withTracingOptions(bundleName), plugins.WithHooks(configHooks))
if err != nil {
return nil, err
}
discovery, err := discovery.New(manager, discovery.Factories(map[string]plugins.Factory{envoy.PluginName: envoy.Factory{}}), discovery.Hooks(configHooks))
if err != nil {
return nil, err
}
manager.Register("discovery", discovery)
opa := &OpenPolicyAgentInstance{
registry: registry,
instanceConfig: instanceConfig,
manager: manager,
opaConfig: opaConfig,
bundleName: bundleName,
maxBodyBytes: maxBodyBytes,
bodyReadBufferSize: bodyReadBufferSize,
preparedQueryDoOnce: new(sync.Once),
interQueryBuiltinCache: iCache.NewInterQueryCache(manager.InterQueryBuiltinCacheConfig()),
idGenerator: uniqueIDGenerator,
}
manager.RegisterCompilerTrigger(opa.compilerUpdated)
return opa, nil
}
// Start asynchronously starts the policy engine's plugins that download
// policies, report status, etc.
func (opa *OpenPolicyAgentInstance) Start(ctx context.Context, timeout time.Duration) error {
err := opa.manager.Start(ctx)
if err != nil {
return err
}
// check readiness of all plugins
pluginsReady := func() bool {
for _, status := range opa.manager.PluginStatus() {
if status != nil && status.State != plugins.StateOK {
return false
}
}
return true
}
err = waitFunc(ctx, pluginsReady, 100*time.Millisecond)
if err != nil {
for pluginName, status := range opa.manager.PluginStatus() {
if status != nil && status.State != plugins.StateOK {
opa.Logger().WithFields(map[string]interface{}{
"plugin_name": pluginName,
"plugin_state": status.State,
"error_message": status.Message,
}).Error("Open policy agent plugin did not start in time")
}
}
opa.Close(ctx)
return fmt.Errorf("one or more open policy agent plugins failed to start in %v with error: %w", timeout, err)
}
return nil
}
// StartAndTriggerPlugins Start starts the policy engine's plugin manager and triggers the plugins to download policies etc.
func (opa *OpenPolicyAgentInstance) StartAndTriggerPlugins(ctx context.Context) error {
err := opa.manager.Start(ctx)
if err != nil {
return err
}
err = opa.triggerPluginsWithRetry(ctx)
if err != nil {
opa.Close(ctx)
return err
}
err = opa.verifyAllPluginsStarted()
if err != nil {
opa.Close(ctx)
return err
}
return nil
}
func (opa *OpenPolicyAgentInstance) triggerPluginsWithRetry(ctx context.Context) error {
var err error
backoff := 100 * time.Millisecond
retryTrigger := time.NewTimer(backoff)
defer retryTrigger.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("context cancelled while triggering plugins: %w, last retry returned: %w", ctx.Err(), err)
case <-retryTrigger.C:
err = opa.triggerPlugins(ctx)
if !opa.isRetryable(err) {
return err
}
backoff *= 2
retryTrigger.Reset(backoff)
}
}
}
func (opa *OpenPolicyAgentInstance) isRetryable(err error) bool {
var httpError download.HTTPError
if errors.As(err, &httpError) {
opa.Logger().WithFields(map[string]interface{}{
"error": httpError.Error(),
}).Warn("Triggering bundles failed. Response code %v, Retrying.", httpError.StatusCode)
return httpError.StatusCode == 429 || httpError.StatusCode >= 500
}
var urlError *url.Error
if errors.As(err, &urlError) {
retry := strings.Contains(urlError.Error(), "net/http: timeout awaiting response headers")
if retry {
opa.Logger().WithFields(map[string]interface{}{
"error": urlError.Error(),
}).Warn("Triggering bundles failed. Retrying.")
}
return retry
}
return false
}
func (opa *OpenPolicyAgentInstance) verifyAllPluginsStarted() error {
allPluginsReady := true
for pluginName, status := range opa.manager.PluginStatus() {
if status != nil && status.State != plugins.StateOK {
opa.Logger().WithFields(map[string]interface{}{
"plugin_name": pluginName,
"plugin_state": status.State,
"error_message": status.Message,
}).Error("Open policy agent plugin failed to start %s", pluginName)
allPluginsReady = false
}
}
if !allPluginsReady {
return fmt.Errorf("open policy agent plugins failed to start")
}
return nil
}
func (opa *OpenPolicyAgentInstance) triggerPlugins(ctx context.Context) error {
if opa.closing {
return nil
}
for _, pluginName := range []string{"discovery", "bundle"} {
if plugin, ok := opa.manager.Plugin(pluginName).(plugins.Triggerable); ok {
if err := plugin.Trigger(ctx); err != nil {
return err
}
} else if pluginName == "bundle" { // only fail for bundle plugin as discovery plugin is optional
return fmt.Errorf("plugin %s not found", pluginName)
}
}
return nil
}
func (opa *OpenPolicyAgentInstance) Close(ctx context.Context) {
opa.once.Do(func() {
opa.closing = true
opa.manager.Stop(ctx)
})
}
func waitFunc(ctx context.Context, fun func() bool, interval time.Duration) error {
if fun() {
return nil
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("timed out while starting: %w", ctx.Err())
case <-ticker.C:
if fun() {
return nil
}
}
}
}
func configLabelsInfo(opaConfig config.Config) func(*plugins.Manager) {
info := ast.NewObject()
labels := ast.NewObject()
labelsWrapper := ast.NewObject()
for key, value := range opaConfig.Labels {
labels.Insert(ast.StringTerm(key), ast.StringTerm(value))
}
labelsWrapper.Insert(ast.StringTerm("labels"), ast.NewTerm(labels))
info.Insert(ast.StringTerm("config"), ast.NewTerm(labelsWrapper))
return plugins.Info(ast.NewTerm(info))
}
func (opa *OpenPolicyAgentInstance) InstanceConfig() *OpenPolicyAgentInstanceConfig {
return &opa.instanceConfig
}
func (opa *OpenPolicyAgentInstance) compilerUpdated(txn storage.Transaction) {
opa.preparedQueryDoOnce = new(sync.Once)
}
func (opa *OpenPolicyAgentInstance) EnvoyPluginConfig() envoy.PluginConfig {
if plugin, ok := opa.manager.Plugin(envoy.PluginName).(*envoy.Plugin); ok {
return plugin.GetConfig()
}
defaultConfig := envoy.PluginConfig{
Path: "envoy/authz/allow",
DryRun: false,
}
defaultConfig.ParseQuery()
return defaultConfig
}
func setSpanTags(span opentracing.Span, bundleName string, manager *plugins.Manager) {
if bundleName != "" {
span.SetTag("opa.bundle_name", bundleName)
}
if manager != nil {
for label, value := range manager.Labels() {
span.SetTag("opa.label."+label, value)
}
}
}
func (opa *OpenPolicyAgentInstance) startSpanFromContextWithTracer(tr opentracing.Tracer, parent opentracing.Span, ctx context.Context) (opentracing.Span, context.Context) {
var span opentracing.Span
if parent != nil {
span = tr.StartSpan(spanNameEval, opentracing.ChildOf(parent.Context()))
} else {
span = tracing.CreateSpan(spanNameEval, ctx, tr)
}
setSpanTags(span, opa.bundleName, opa.manager)
return span, opentracing.ContextWithSpan(ctx, span)
}
func (opa *OpenPolicyAgentInstance) StartSpanFromFilterContext(fc filters.FilterContext) (opentracing.Span, context.Context) {
return opa.StartSpanFromContext(fc.Request().Context())
}
func (opa *OpenPolicyAgentInstance) StartSpanFromContext(ctx context.Context) (opentracing.Span, context.Context) {
span := opentracing.SpanFromContext(ctx)
if span != nil {
return opa.startSpanFromContextWithTracer(span.Tracer(), span, ctx)
}
return opa.startSpanFromContextWithTracer(opentracing.GlobalTracer(), nil, ctx)
}
func (opa *OpenPolicyAgentInstance) MetricsKey(key string) string {
return key + "." + opa.bundleName
}
var (
ErrClosed = errors.New("reader closed")
ErrTotalBodyBytesExceeded = errors.New("buffer for in-flight request body authorization in Open Policy Agent exceeded")
)
type bufferedBodyReader struct {
input io.ReadCloser
maxBufferSize int64
bodyBuffer bytes.Buffer
readBuffer []byte
once sync.Once
err error
closed bool
}
func newBufferedBodyReader(input io.ReadCloser, maxBufferSize int64, readBufferSize int64) *bufferedBodyReader {
return &bufferedBodyReader{
input: input,
maxBufferSize: maxBufferSize,
readBuffer: make([]byte, readBufferSize),
}
}
func (m *bufferedBodyReader) fillBuffer(expectedSize int64) ([]byte, error) {
var err error
for err == nil && int64(m.bodyBuffer.Len()) < m.maxBufferSize && int64(m.bodyBuffer.Len()) < expectedSize {
var n int
n, err = m.input.Read(m.readBuffer)
m.bodyBuffer.Write(m.readBuffer[:n])
}
if err == io.EOF {
err = nil
}
return m.bodyBuffer.Bytes(), err
}
func (m *bufferedBodyReader) Read(p []byte) (int, error) {
if m.closed {
return 0, ErrClosed
}
if m.err != nil {
return 0, m.err
}
// First read the buffered body
if m.bodyBuffer.Len() != 0 {
return m.bodyBuffer.Read(p)
}
// Continue reading from the underlying body reader
return m.input.Read(p)
}
// Close closes the undelrying reader if it implements io.Closer.
func (m *bufferedBodyReader) Close() error {
var err error
m.once.Do(func() {
m.closed = true
if c, ok := m.input.(io.Closer); ok {
err = c.Close()
}
})
return err
}
func bodyUpperBound(contentLength, maxBodyBytes int64) int64 {
if contentLength <= 0 {
return maxBodyBytes
}
if contentLength < maxBodyBytes {
return contentLength
}
return maxBodyBytes
}
func (opa *OpenPolicyAgentInstance) ExtractHttpBodyOptionally(req *http.Request) (io.ReadCloser, []byte, func(), error) {
body := req.Body
if body != nil && !opa.EnvoyPluginConfig().SkipRequestBodyParse &&
req.ContentLength <= int64(opa.maxBodyBytes) {
wrapper := newBufferedBodyReader(req.Body, opa.maxBodyBytes, opa.bodyReadBufferSize)
requestedBodyBytes := bodyUpperBound(req.ContentLength, opa.maxBodyBytes)
if !opa.registry.maxMemoryBodyParsingSem.TryAcquire(requestedBodyBytes) {
return req.Body, nil, func() {}, ErrTotalBodyBytesExceeded
}
rawBody, err := wrapper.fillBuffer(req.ContentLength)
return wrapper, rawBody, func() { opa.registry.maxMemoryBodyParsingSem.Release(requestedBodyBytes) }, err
}
return req.Body, nil, func() {}, nil
}
// ParsedQuery is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) ParsedQuery() ast.Body {
return opa.EnvoyPluginConfig().ParsedQuery
}
// Store is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) Store() storage.Store { return opa.manager.Store }
// Compiler is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) Compiler() *ast.Compiler { return opa.manager.GetCompiler() }
// Runtime is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) Runtime() *ast.Term { return opa.manager.Info }
// Logger is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) Logger() logging.Logger { return opa.manager.Logger() }
// InterQueryBuiltinCache is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) InterQueryBuiltinCache() iCache.InterQueryCache {
return opa.interQueryBuiltinCache
}
// Config is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) Config() *config.Config { return opa.opaConfig }
// DistributedTracing is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) DistributedTracing() opatracing.Options {
return buildTracingOptions(opa.registry.tracer, opa.bundleName, opa.manager)
}
// CreatePreparedQueryOnce is an implementation of the envoyauth.EvalContext interface
func (opa *OpenPolicyAgentInstance) CreatePreparedQueryOnce(opts envoyauth.PrepareQueryOpts) (*rego.PreparedEvalQuery, error) {
opa.preparedQueryDoOnce.Do(func() {
regoOpts := append(opts.Opts, rego.DistributedTracingOpts(opa.DistributedTracing()))
pq, err := rego.New(regoOpts...).PrepareForEval(context.Background())
opa.preparedQuery = &pq
opa.preparedQueryErr = err
})
return opa.preparedQuery, opa.preparedQueryErr
}
// logging.Logger that does not pollute info with debug logs
type QuietLogger struct {
target logging.Logger
}
func (l *QuietLogger) WithFields(fields map[string]interface{}) logging.Logger {
return &QuietLogger{target: l.target.WithFields(fields)}
}
func (l *QuietLogger) SetLevel(level logging.Level) {
l.target.SetLevel(level)
}
func (l *QuietLogger) GetLevel() logging.Level {
return l.target.GetLevel()
}
func (l *QuietLogger) Debug(fmt string, a ...interface{}) {
l.target.Debug(fmt, a)
}
func (l *QuietLogger) Info(fmt string, a ...interface{}) {
l.target.Debug(fmt, a)
}
func (l *QuietLogger) Error(fmt string, a ...interface{}) {
l.target.Error(fmt, a)
}
func (l *QuietLogger) Warn(fmt string, a ...interface{}) {
l.target.Warn(fmt, a)
}
package openpolicyagent
import (
"bytes"
"io"
"net/http"
"github.com/open-policy-agent/opa-envoy-plugin/envoyauth"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
)
func (opa *OpenPolicyAgentInstance) ServeInvalidDecisionError(fc filters.FilterContext, span opentracing.Span, result *envoyauth.EvalResult, err error) {
opa.HandleInvalidDecisionError(fc, span, result, err, true)
}
func (opa *OpenPolicyAgentInstance) HandleInvalidDecisionError(fc filters.FilterContext, span opentracing.Span, result *envoyauth.EvalResult, err error, serve bool) {
opa.HandleEvaluationError(fc, span, result, err, serve, http.StatusInternalServerError)
}
func (opa *OpenPolicyAgentInstance) HandleEvaluationError(fc filters.FilterContext, span opentracing.Span, result *envoyauth.EvalResult, err error, serve bool, status int) {
fc.Metrics().IncCounter(opa.MetricsKey("decision.err"))
span.SetTag("error", true)
if result != nil {
span.LogKV(
"event", "error",
"opa.decision_id", result.DecisionID,
"message", err.Error(),
)
opa.Logger().WithFields(map[string]interface{}{
"decision": result.Decision,
"err": err,
"decision_id": result.DecisionID,
}).Info("Rejecting request because of an invalid decision")
} else {
span.LogKV(
"event", "error",
"message", err.Error(),
)
opa.Logger().WithFields(map[string]interface{}{
"err": err,
}).Info("Rejecting request because of an error")
}
if serve {
resp := http.Response{}
resp.StatusCode = status
fc.Serve(&resp)
}
}
func (opa *OpenPolicyAgentInstance) ServeResponse(fc filters.FilterContext, span opentracing.Span, result *envoyauth.EvalResult) {
resp := http.Response{}
var err error
resp.StatusCode, err = result.GetResponseHTTPStatus()
if err != nil {
opa.ServeInvalidDecisionError(fc, span, result, err)
return
}
resp.Header, err = result.GetResponseHTTPHeaders()
if err != nil {
opa.ServeInvalidDecisionError(fc, span, result, err)
return
}
if result.HasResponseBody() {
body, err := result.GetResponseBody()
if err != nil {
opa.ServeInvalidDecisionError(fc, span, result, err)
return
}
resp.Body = io.NopCloser(bytes.NewReader([]byte(body)))
}
fc.Serve(&resp)
}
package openpolicyagent
import (
"net/http"
"github.com/open-policy-agent/opa/v1/plugins"
opatracing "github.com/open-policy-agent/opa/v1/tracing"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/logging"
"github.com/zalando/skipper/proxy"
)
const (
spanNameHttpOut = "open-policy-agent.http"
)
func init() {
opatracing.RegisterHTTPTracing(&tracingFactory{})
}
type tracingFactory struct{}
type transport struct {
tracer opentracing.Tracer
bundleName string
manager *plugins.Manager
wrapped http.RoundTripper
}
func WithTracingOptTracer(tracer opentracing.Tracer) func(*transport) {
return func(t *transport) {
t.tracer = tracer
}
}
func WithTracingOptBundleName(bundleName string) func(*transport) {
return func(t *transport) {
t.bundleName = bundleName
}
}
func WithTracingOptManager(manager *plugins.Manager) func(*transport) {
return func(t *transport) {
t.manager = manager
}
}
func (*tracingFactory) NewTransport(tr http.RoundTripper, opts opatracing.Options) http.RoundTripper {
log := &logging.DefaultLog{}
wrapper := &transport{
wrapped: tr,
}
for _, o := range opts {
opt, ok := o.(func(*transport))
if !ok {
log.Warnf("invalid type for OPA tracing option, expected func(*transport) got %T, tracing information might be incomplete", o)
} else {
opt(wrapper)
}
}
return wrapper
}
func (*tracingFactory) NewHandler(f http.Handler, label string, opts opatracing.Options) http.Handler {
return f
}
func (tr *transport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
spanOpts := []opentracing.StartSpanOption{opentracing.Tags{
proxy.HTTPMethodTag: req.Method,
proxy.HTTPUrlTag: req.URL.String(),
proxy.HostnameTag: req.Host,
proxy.HTTPPathTag: req.URL.Path,
proxy.ComponentTag: "skipper",
proxy.SpanKindTag: proxy.SpanKindClient,
}}
var span opentracing.Span
if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil {
spanOpts = append(spanOpts, opentracing.ChildOf(parentSpan.Context()))
span = parentSpan.Tracer().StartSpan(spanNameHttpOut, spanOpts...)
} else if tr.tracer != nil {
span = tr.tracer.StartSpan(spanNameHttpOut, spanOpts...)
}
if span == nil {
return tr.wrapped.RoundTrip(req)
}
defer span.Finish()
setSpanTags(span, tr.bundleName, tr.manager)
req = req.WithContext(opentracing.ContextWithSpan(ctx, span))
carrier := opentracing.HTTPHeadersCarrier(req.Header)
span.Tracer().Inject(span.Context(), opentracing.HTTPHeaders, carrier)
resp, err := tr.wrapped.RoundTrip(req)
if err != nil {
span.SetTag("error", true)
span.LogKV("event", "error", "message", err.Error())
return resp, err
}
span.SetTag(proxy.HTTPStatusCodeTag, resp.StatusCode)
if resp.StatusCode > 399 {
span.SetTag("error", true)
span.LogKV("event", "error", "message", resp.Status)
}
return resp, nil
}
package ratelimit
import (
"net/http"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/ratelimit"
)
type BackendRatelimit struct {
Settings ratelimit.Settings
StatusCode int
}
// NewBackendRatelimit creates a filter Spec, whose instances
// instruct proxy to limit request rate towards a particular backend endpoint
func NewBackendRatelimit() filters.Spec { return &BackendRatelimit{} }
func (*BackendRatelimit) Name() string {
return filters.BackendRateLimitName
}
func (*BackendRatelimit) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 3 && len(args) != 4 {
return nil, filters.ErrInvalidFilterParameters
}
group, err := getStringArg(args[0])
if err != nil {
return nil, err
}
maxHits, err := getIntArg(args[1])
if err != nil {
return nil, err
}
timeWindow, err := getDurationArg(args[2])
if err != nil {
return nil, err
}
f := &BackendRatelimit{
Settings: ratelimit.Settings{
Type: ratelimit.ClusterServiceRatelimit,
Group: "backend." + group,
MaxHits: maxHits,
TimeWindow: timeWindow,
},
StatusCode: http.StatusServiceUnavailable,
}
if len(args) == 4 {
code, err := getIntArg(args[3])
if err != nil {
return nil, err
}
f.StatusCode = code
}
return f, nil
}
func (limit *BackendRatelimit) Request(ctx filters.FilterContext) {
// allows overwrite
ctx.StateBag()[filters.BackendRatelimit] = limit
}
func (*BackendRatelimit) Response(filters.FilterContext) {}
package ratelimit
import (
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/routing"
)
type failClosedSpec struct{}
type failClosed struct{}
type FailClosedPostProcessor struct{}
func NewFailClosedPostProcessor() *FailClosedPostProcessor {
return &FailClosedPostProcessor{}
}
// Do is implementing a PostProcessor interface to change the filter
// configs at filter processing time. The fail open/closed decision
// needs to be done once and can be processed before we activate the
// new routes.
func (*FailClosedPostProcessor) Do(routes []*routing.Route) []*routing.Route {
for _, r := range routes {
var failClosed bool
for _, f := range r.Filters {
if f.Name == filters.RatelimitFailClosedName {
failClosed = true
break
}
}
// no config changes detected
if !failClosed {
continue
}
for _, f := range r.Filters {
switch f.Name {
// leaky bucket has no Settings
case filters.ClusterLeakyBucketRatelimitName:
lf, ok := f.Filter.(*leakyBucketFilter)
if ok {
lf.failClosed = true
}
case filters.BackendRateLimitName:
bf, ok := f.Filter.(*BackendRatelimit)
if ok {
bf.Settings.FailClosed = true
}
case
filters.ClientRatelimitName,
filters.ClusterClientRatelimitName,
filters.ClusterRatelimitName:
ff, ok := f.Filter.(*filter)
if ok {
ff.settings.FailClosed = true
}
}
}
}
return routes
}
func NewFailClosed() filters.Spec {
return &failClosedSpec{}
}
func (*failClosedSpec) Name() string {
return filters.RatelimitFailClosedName
}
func (*failClosedSpec) CreateFilter([]interface{}) (filters.Filter, error) {
return &failClosed{}, nil
}
func (*failClosed) Request(filters.FilterContext) {}
func (*failClosed) Response(filters.FilterContext) {}
package ratelimit
import (
"context"
"fmt"
"net/http"
"strconv"
"time"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/ratelimit"
)
type leakyBucket interface {
Add(ctx context.Context, label string, increment int) (added bool, retry time.Duration, err error)
}
type leakyBucketSpec struct {
create func(capacity int, emission time.Duration) leakyBucket
}
type leakyBucketFilter struct {
label *eskip.Template
bucket leakyBucket
increment int
failClosed bool
}
// NewClusterLeakyBucketRatelimit creates a filter Spec, whose instances implement rate limiting using leaky bucket algorithm.
//
// The leaky bucket is an algorithm based on an analogy of how a bucket with a constant leak will overflow if either
// the average rate at which water is poured in exceeds the rate at which the bucket leaks or if more water than
// the capacity of the bucket is poured in all at once.
// See https://en.wikipedia.org/wiki/Leaky_bucket
//
// Example to allow each unique Authorization header once in five seconds:
//
// clusterLeakyBucketRatelimit("auth-${request.header.Authorization}", 1, "5s", 2, 1)
func NewClusterLeakyBucketRatelimit(registry *ratelimit.Registry) filters.Spec {
return &leakyBucketSpec{
create: func(capacity int, emission time.Duration) leakyBucket {
return ratelimit.NewClusterLeakyBucket(registry, capacity, emission)
},
}
}
func (s *leakyBucketSpec) Name() string {
return filters.ClusterLeakyBucketRatelimitName
}
func (s *leakyBucketSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 5 {
return nil, filters.ErrInvalidFilterParameters
}
label, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
leakVolume, err := natural(args[1])
if err != nil {
return nil, err
}
leakPeriod, err := getDurationArg(args[2])
if err != nil {
return nil, err
}
if leakPeriod <= 0 {
return nil, filters.ErrInvalidFilterParameters
}
capacity, err := natural(args[3])
if err != nil {
return nil, err
}
increment, err := natural(args[4])
if err != nil {
return nil, err
}
// emission is the reciprocal of the leak rate
emission := leakPeriod / time.Duration(leakVolume)
return &leakyBucketFilter{
label: eskip.NewTemplate(label),
bucket: s.create(capacity, emission),
increment: increment,
}, nil
}
func fail(ctx filters.FilterContext, header http.Header) {
ctx.Serve(&http.Response{StatusCode: http.StatusTooManyRequests, Header: header})
}
func (f *leakyBucketFilter) Request(ctx filters.FilterContext) {
label, ok := f.label.ApplyContext(ctx)
if !ok {
return // allow on missing placeholders
}
added, retry, err := f.bucket.Add(ctx.Request().Context(), label, f.increment)
if err != nil {
if f.failClosed {
header := http.Header{}
header.Set("Retry-After", "60")
fail(ctx, header)
}
return
}
if added {
return // allow if successfully added
}
header := http.Header{}
if retry > 0 {
header.Set("Retry-After", strconv.Itoa(int(retry/time.Second)))
}
fail(ctx, header)
}
func (*leakyBucketFilter) Response(filters.FilterContext) {}
func natural(arg interface{}) (n int, err error) {
n, err = getIntArg(arg)
if err == nil && n <= 0 {
err = fmt.Errorf(`number %d must be positive`, n)
}
return
}
/*
Package ratelimit provides filters to control the rate limiter settings on the route level.
For detailed documentation of the ratelimit, see https://godoc.org/github.com/zalando/skipper/ratelimit.
*/
package ratelimit
import (
"context"
"net/http"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/ratelimit"
)
const defaultStatusCode = http.StatusTooManyRequests
type spec struct {
typ ratelimit.RatelimitType
provider RatelimitProvider
filterName string
maxShards int
}
type filter struct {
settings ratelimit.Settings
provider RatelimitProvider
statusCode int
maxHits int // overrides settings.MaxHits
}
// RatelimitProvider returns a limit instance for provided Settings
type RatelimitProvider interface {
get(s ratelimit.Settings) limit
}
type limit interface {
// Allow is used to decide if call with context is allowed to pass
Allow(context.Context, string) bool
// RetryAfter is used to inform the client how many seconds it
// should wait before making a new request
RetryAfter(string) int
}
// RegistryAdapter adapts ratelimit.Registry to RateLimitProvider interface.
// ratelimit.Registry is not an interface and its Get method returns
// ratelimit.Ratelimit which is not an interface either
// RegistryAdapter narrows ratelimit interfaces to necessary minimum
// and enables easier test stubbing
type registryAdapter struct {
registry *ratelimit.Registry
}
func (a *registryAdapter) get(s ratelimit.Settings) limit {
return a.registry.Get(s)
}
func NewRatelimitProvider(registry *ratelimit.Registry) RatelimitProvider {
return ®istryAdapter{registry}
}
// NewLocalRatelimit is *DEPRECATED*, use NewClientRatelimit, instead
func NewLocalRatelimit(provider RatelimitProvider) filters.Spec {
return &spec{typ: ratelimit.LocalRatelimit, provider: provider, filterName: ratelimit.LocalRatelimitName}
}
// NewClientRatelimit creates a instance based client rate limit. If
// you have 5 instances with 20 req/s, then it would allow 100 req/s
// to the backend from the same client. A third argument can be used to
// set which HTTP header of the request should be used to find the
// same user. Third argument defaults to XForwardedForLookuper,
// meaning X-Forwarded-For Header.
//
// Example:
//
// backendHealthcheck: Path("/healthcheck")
// -> clientRatelimit(20, "1m")
// -> "https://foo.backend.net";
//
// Example rate limit per Authorization Header:
//
// login: Path("/login")
// -> clientRatelimit(3, "1m", "Authorization")
// -> "https://login.backend.net";
func NewClientRatelimit(provider RatelimitProvider) filters.Spec {
return &spec{typ: ratelimit.ClientRatelimit, provider: provider, filterName: filters.ClientRatelimitName}
}
// NewRatelimit creates a service rate limiting, that is
// only aware of itself. If you have 5 instances with 20 req/s, then
// it would at max allow 100 req/s to the backend.
//
// Example:
//
// backendHealthcheck: Path("/healthcheck")
// -> ratelimit(20, "1s")
// -> "https://foo.backend.net";
//
// Optionally a custom response status code can be provided as an argument (default is 429).
//
// Example:
//
// backendHealthcheck: Path("/healthcheck")
// -> ratelimit(20, "1s", 503)
// -> "https://foo.backend.net";
func NewRatelimit(provider RatelimitProvider) filters.Spec {
return &spec{typ: ratelimit.ServiceRatelimit, provider: provider, filterName: filters.RatelimitName}
}
// NewClusterRatelimit creates a rate limiting that is aware of the
// other instances. The value given here should be the combined rate
// of all instances. The ratelimit group parameter can be used to
// select the same ratelimit group across one or more routes.
//
// Example:
//
// backendHealthcheck: Path("/healthcheck")
// -> clusterRatelimit("groupA", 200, "1m")
// -> "https://foo.backend.net";
//
// Optionally a custom response status code can be provided as an argument (default is 429).
//
// Example:
//
// backendHealthcheck: Path("/healthcheck")
// -> clusterRatelimit("groupA", 200, "1m", 503)
// -> "https://foo.backend.net";
func NewClusterRateLimit(provider RatelimitProvider) filters.Spec {
return NewShardedClusterRateLimit(provider, 1)
}
// NewShardedClusterRateLimit creates a cluster rate limiter that uses multiple group shards to count hits.
// Based on the configured group and maxHits each filter instance selects N distinct group shards from [1, maxGroupShards].
// For every subsequent request it uniformly picks one of N group shards and limits number of allowed requests per group shard to maxHits/N.
//
// For example if maxGroupShards = 10, clusterRatelimit("groupA", 200, "1m") will use 10 distinct groups to count hits and
// will allow up to 20 hits per each group and thus up to configured 200 hits in total.
func NewShardedClusterRateLimit(provider RatelimitProvider, maxGroupShards int) filters.Spec {
return &spec{typ: ratelimit.ClusterServiceRatelimit, provider: provider, filterName: filters.ClusterRatelimitName, maxShards: maxGroupShards}
}
// NewClusterClientRatelimit creates a rate limiting that is aware of
// the other instances. The value given here should be the combined
// rate of all instances. The ratelimit group parameter can be used to
// select the same ratelimit group across one or more routes.
//
// Example:
//
// backendHealthcheck: Path("/login")
// -> clusterClientRatelimit("groupB", 20, "1h")
// -> "https://foo.backend.net";
//
// The above example would limit access to "/login" if, the client did
// more than 20 requests within the last hour to this route across all
// running skippers in the cluster. A single client can be detected
// by different data from the http request and defaults to client IP
// or X-Forwarded-For header, if exists. The optional third parameter
// chooses the HTTP header to choose a client is
// counted as the same.
//
// Example:
//
// backendHealthcheck: Path("/login")
// -> clusterClientRatelimit("groupC", 20, "1h", "Authorization")
// -> "https://foo.backend.net";
func NewClusterClientRateLimit(provider RatelimitProvider) filters.Spec {
return &spec{typ: ratelimit.ClusterClientRatelimit, provider: provider, filterName: filters.ClusterClientRatelimitName}
}
// NewDisableRatelimit disables rate limiting
//
// Example:
//
// backendHealthcheck: Path("/healthcheck")
// -> disableRatelimit()
// -> "https://foo.backend.net";
func NewDisableRatelimit(provider RatelimitProvider) filters.Spec {
return &spec{typ: ratelimit.DisableRatelimit, provider: provider, filterName: filters.DisableRatelimitName}
}
func (s *spec) Name() string {
return s.filterName
}
func serviceRatelimitFilter(args []interface{}) (*filter, error) {
if !(len(args) == 2 || len(args) == 3) {
return nil, filters.ErrInvalidFilterParameters
}
maxHits, err := getIntArg(args[0])
if err != nil {
return nil, err
}
timeWindow, err := getDurationArg(args[1])
if err != nil {
return nil, err
}
statusCode, err := getStatusCodeArg(args, 2)
if err != nil {
return nil, err
}
return &filter{
settings: ratelimit.Settings{
Type: ratelimit.ServiceRatelimit,
MaxHits: maxHits,
TimeWindow: timeWindow,
Lookuper: ratelimit.NewSameBucketLookuper(),
},
statusCode: statusCode,
}, nil
}
func clusterRatelimitFilter(maxShards int, args []interface{}) (*filter, error) {
if !(len(args) == 3 || len(args) == 4) {
return nil, filters.ErrInvalidFilterParameters
}
group, err := getStringArg(args[0])
if err != nil {
return nil, err
}
maxHits, err := getIntArg(args[1])
if err != nil {
return nil, err
}
timeWindow, err := getDurationArg(args[2])
if err != nil {
return nil, err
}
statusCode, err := getStatusCodeArg(args, 3)
if err != nil {
return nil, err
}
f := &filter{statusCode: statusCode, maxHits: maxHits}
keyShards := getKeyShards(maxHits, maxShards)
if keyShards > 1 {
f.settings = ratelimit.Settings{
Type: ratelimit.ClusterServiceRatelimit,
Group: group + "." + strconv.Itoa(keyShards),
MaxHits: maxHits / keyShards,
TimeWindow: timeWindow,
Lookuper: ratelimit.NewRoundRobinLookuper(uint64(keyShards)),
}
} else {
f.settings = ratelimit.Settings{
Type: ratelimit.ClusterServiceRatelimit,
Group: group,
MaxHits: maxHits,
TimeWindow: timeWindow,
Lookuper: ratelimit.NewSameBucketLookuper(),
}
}
log.Debugf("maxHits: %d, keyShards: %d", maxHits, keyShards)
return f, nil
}
// getKeyShards returns number of key shards based on max hits and max allowed shards.
// Number of key shards k is the largest number from `[1, maxShards]` interval such that `maxHits % k == 0`
func getKeyShards(maxHits, maxShards int) int {
for k := maxShards; k > 1; k-- {
if maxHits%k == 0 {
return k
}
}
return 1
}
func clusterClientRatelimitFilter(args []interface{}) (*filter, error) {
if !(len(args) == 3 || len(args) == 4) {
return nil, filters.ErrInvalidFilterParameters
}
group, err := getStringArg(args[0])
if err != nil {
return nil, err
}
maxHits, err := getIntArg(args[1])
if err != nil {
return nil, err
}
timeWindow, err := getDurationArg(args[2])
if err != nil {
return nil, err
}
s := ratelimit.Settings{
Type: ratelimit.ClusterClientRatelimit,
Group: group,
MaxHits: maxHits,
TimeWindow: timeWindow,
CleanInterval: 10 * timeWindow,
}
if len(args) > 3 {
lookuperString, err := getStringArg(args[3])
if err != nil {
return nil, err
}
if strings.Contains(lookuperString, ",") {
var lookupers []ratelimit.Lookuper
for _, ls := range strings.Split(lookuperString, ",") {
lookupers = append(lookupers, getLookuper(ls))
}
s.Lookuper = ratelimit.NewTupleLookuper(lookupers...)
} else {
s.Lookuper = getLookuper(lookuperString)
}
} else {
s.Lookuper = ratelimit.NewXForwardedForLookuper()
}
return &filter{settings: s, statusCode: defaultStatusCode}, nil
}
func getLookuper(s string) ratelimit.Lookuper {
headerName := http.CanonicalHeaderKey(s)
if headerName == "X-Forwarded-For" {
return ratelimit.NewXForwardedForLookuper()
} else {
return ratelimit.NewHeaderLookuper(headerName)
}
}
func clientRatelimitFilter(args []interface{}) (*filter, error) {
if !(len(args) == 2 || len(args) == 3) {
return nil, filters.ErrInvalidFilterParameters
}
maxHits, err := getIntArg(args[0])
if err != nil {
return nil, err
}
timeWindow, err := getDurationArg(args[1])
if err != nil {
return nil, err
}
var lookuper ratelimit.Lookuper
if len(args) > 2 {
lookuperString, err := getStringArg(args[2])
if err != nil {
return nil, err
}
if strings.Contains(lookuperString, ",") {
var lookupers []ratelimit.Lookuper
for _, ls := range strings.Split(lookuperString, ",") {
lookupers = append(lookupers, getLookuper(ls))
}
lookuper = ratelimit.NewTupleLookuper(lookupers...)
} else {
lookuper = ratelimit.NewHeaderLookuper(lookuperString)
}
} else {
lookuper = ratelimit.NewXForwardedForLookuper()
}
return &filter{
settings: ratelimit.Settings{
Type: ratelimit.ClientRatelimit,
MaxHits: maxHits,
TimeWindow: timeWindow,
CleanInterval: 10 * timeWindow,
Lookuper: lookuper,
},
statusCode: defaultStatusCode,
}, nil
}
func disableFilter([]interface{}) (*filter, error) {
return &filter{
settings: ratelimit.Settings{
Type: ratelimit.DisableRatelimit,
},
statusCode: defaultStatusCode,
}, nil
}
func (s *spec) CreateFilter(args []interface{}) (filters.Filter, error) {
f, err := s.createFilter(args)
if f != nil {
f.provider = s.provider
}
return f, err
}
func (s *spec) createFilter(args []interface{}) (*filter, error) {
switch s.typ {
case ratelimit.ServiceRatelimit:
return serviceRatelimitFilter(args)
case ratelimit.LocalRatelimit:
log.Warning("ratelimit.LocalRatelimit is deprecated, please use ratelimit.ClientRatelimit")
fallthrough
case ratelimit.ClientRatelimit:
return clientRatelimitFilter(args)
case ratelimit.ClusterServiceRatelimit:
return clusterRatelimitFilter(s.maxShards, args)
case ratelimit.ClusterClientRatelimit:
return clusterClientRatelimitFilter(args)
default:
return disableFilter(args)
}
}
func getIntArg(a interface{}) (int, error) {
if i, ok := a.(int); ok {
return i, nil
}
if f, ok := a.(float64); ok {
return int(f), nil
}
return 0, filters.ErrInvalidFilterParameters
}
func getStringArg(a interface{}) (string, error) {
if s, ok := a.(string); ok {
return s, nil
}
return "", filters.ErrInvalidFilterParameters
}
func getDurationArg(a interface{}) (time.Duration, error) {
if s, ok := a.(string); ok {
return time.ParseDuration(s)
}
i, err := getIntArg(a)
return time.Duration(i) * time.Second, err
}
func getStatusCodeArg(args []interface{}, index int) (int, error) {
// status code arg is optional so we return default status code but no error
if len(args) <= index {
return defaultStatusCode, nil
}
return getIntArg(args[index])
}
// Request checks ratelimit using filter settings and serves `429 Too Many Requests` response if limit is reached
func (f *filter) Request(ctx filters.FilterContext) {
rateLimiter := f.provider.get(f.settings)
if rateLimiter == nil {
ctx.Logger().Errorf("RateLimiter is nil for settings: %s", f.settings)
return
}
if f.settings.Lookuper == nil {
ctx.Logger().Errorf("Lookuper is nil for settings: %s", f.settings)
return
}
s := f.settings.Lookuper.Lookup(ctx.Request())
if s == "" {
ctx.Logger().Debugf("Lookuper found no data in request for settings: %s and request: %v", f.settings, ctx.Request())
return
}
if !rateLimiter.Allow(ctx.Request().Context(), s) {
maxHits := f.settings.MaxHits
if f.maxHits != 0 {
maxHits = f.maxHits
}
ctx.Serve(&http.Response{
StatusCode: f.statusCode,
Header: ratelimit.Headers(maxHits, f.settings.TimeWindow, rateLimiter.RetryAfter(s)),
})
}
}
func (*filter) Response(filters.FilterContext) {}
package rfc
import (
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/rfc"
)
const (
// Name is the filter name
// Deprecated, use filters.RfcPathName instead
Name = filters.RfcPathName
// NameHost is the filter name
// Deprecated, use filters.RfcHostName instead
NameHost = filters.RfcHostName
)
type path struct{}
// NewPath creates a filter specification for the rfcPath() filter, that
// reencodes the reserved characters in the request path, if it detects
// that they are encoded in the raw path.
//
// See also the PatchPath documentation in the rfc package.
func NewPath() filters.Spec { return path{} }
func (p path) Name() string { return filters.RfcPathName }
func (p path) CreateFilter([]interface{}) (filters.Filter, error) { return path{}, nil }
func (p path) Response(filters.FilterContext) {}
func (p path) Request(ctx filters.FilterContext) {
req := ctx.Request()
req.URL.Path = rfc.PatchPath(req.URL.Path, req.URL.RawPath)
}
type host struct{}
// NewHost creates a filter specification for the rfcHost() filter, that
// removes a trailing dot in the host header.
//
// See also the PatchHost documentation in the rfc package.
func NewHost() filters.Spec { return host{} }
func (host) Name() string { return filters.RfcHostName }
func (host) CreateFilter([]interface{}) (filters.Filter, error) { return host{}, nil }
func (host) Response(filters.FilterContext) {}
func (host) Request(ctx filters.FilterContext) {
ctx.Request().Host = rfc.PatchHost(ctx.Request().Host)
ctx.SetOutgoingHost(rfc.PatchHost(ctx.OutgoingHost()))
}
package scheduler
import (
"fmt"
"net/http"
"time"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/scheduler"
)
type (
fifoSpec struct {
typ string
}
fifoFilter struct {
config scheduler.Config
queue *scheduler.FifoQueue
typ string
}
)
func NewFifo() filters.Spec {
return &fifoSpec{
typ: filters.FifoName,
}
}
func NewFifoWithBody() filters.Spec {
return &fifoSpec{
typ: filters.FifoWithBodyName,
}
}
func (s *fifoSpec) Name() string {
return s.typ
}
// CreateFilter creates a fifoFilter, that will use a semaphore based
// queue for handling requests to limit concurrency of a route. The first
// parameter is maxConcurrency the second maxQueueSize and the third
// timeout.
func (s *fifoSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 3 {
return nil, filters.ErrInvalidFilterParameters
}
cc, err := intArg(args[0])
if err != nil {
return nil, err
}
if cc < 1 {
return nil, fmt.Errorf("maxconcurrency requires value >0, %w", filters.ErrInvalidFilterParameters)
}
qs, err := intArg(args[1])
if err != nil {
return nil, err
}
if qs < 0 {
return nil, fmt.Errorf("maxqueuesize requires value >=0, %w", filters.ErrInvalidFilterParameters)
}
d, err := durationArg(args[2])
if err != nil {
return nil, err
}
if d < 1*time.Millisecond {
return nil, fmt.Errorf("timeout requires value >=1ms, %w", filters.ErrInvalidFilterParameters)
}
return &fifoFilter{
typ: s.typ,
config: scheduler.Config{
MaxConcurrency: cc,
MaxQueueSize: qs,
Timeout: d,
},
}, nil
}
func (f *fifoFilter) Config() scheduler.Config {
return f.config
}
func (f *fifoFilter) GetQueue() *scheduler.FifoQueue {
return f.queue
}
func (f *fifoFilter) SetQueue(fq *scheduler.FifoQueue) {
f.queue = fq
}
// Request is the filter.Filter interface implementation. Request will
// increase the number of inflight requests and respond to the caller,
// if the bounded queue returns an error. Status code by Error:
//
// - 503 if queue full
// - 502 if queue timeout
// - 500 if error unknown
func (f *fifoFilter) Request(ctx filters.FilterContext) {
q := f.GetQueue()
c := ctx.Request().Context()
done, err := q.Wait(c)
if err != nil {
if span := opentracing.SpanFromContext(c); span != nil {
ext.Error.Set(span, true)
span.LogKV("fifo error", fmt.Sprintf("Failed to wait for fifo queue: %v", err))
}
ctx.Logger().Debugf("Failed to wait for fifo queue: %v", err)
switch err {
case scheduler.ErrQueueFull:
ctx.Serve(&http.Response{
StatusCode: http.StatusServiceUnavailable,
Status: "Queue Full - https://opensource.zalando.com/skipper/operation/operation/#scheduler",
})
return
case scheduler.ErrQueueTimeout:
ctx.Serve(&http.Response{
StatusCode: http.StatusBadGateway,
Status: "Queue Timeout - https://opensource.zalando.com/skipper/operation/operation/#scheduler",
})
return
case scheduler.ErrClientCanceled:
// This case is handled in the proxy with status code 499
return
default:
ctx.Logger().Errorf("Unknown error in fifo() please create an issue https://github.com/zalando/skipper/issues/new/choose: %v", err)
ctx.Serve(&http.Response{
StatusCode: http.StatusInternalServerError,
Status: "Unknown error in fifo https://opensource.zalando.com/skipper/operation/operation/#scheduler, please create an issue https://github.com/zalando/skipper/issues/new/choose",
})
return
}
}
// ok
pending, _ := ctx.StateBag()[f.typ].([]func())
ctx.StateBag()[f.typ] = append(pending, done)
}
// Response will decrease the number of inflight requests to release
// the concurrency reservation for the request.
func (f *fifoFilter) Response(ctx filters.FilterContext) {
switch f.typ {
case filters.FifoName:
pending, ok := ctx.StateBag()[f.typ].([]func())
if !ok {
return
}
last := len(pending) - 1
if last < 0 {
return
}
pending[last]()
ctx.StateBag()[f.typ] = pending[:last]
case filters.FifoWithBodyName:
// nothing to do here, handled in the proxy after copyStream()
}
}
// HandleErrorResponse is to opt-in for filters to get called
// Response(ctx) in case of errors via proxy. It has to return true to opt-in.
func (f *fifoFilter) HandleErrorResponse() bool {
return true
}
package scheduler
import (
"net/http"
"time"
"github.com/aryszka/jobqueue"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/scheduler"
)
type (
lifoSpec struct{}
lifoGroupSpec struct{}
lifoFilter struct {
config scheduler.Config
queue *scheduler.Queue
}
lifoGroupFilter struct {
name string
hasConfig bool
config scheduler.Config
queue *scheduler.Queue
}
)
const (
// Deprecated, use filters.LifoName instead
LIFOName = filters.LifoName
// Deprecated, use filters.LifoGroupName instead
LIFOGroupName = filters.LifoGroupName
defaultMaxConcurreny = 100
defaultMaxQueueSize = 100
defaultTimeout = 10 * time.Second
)
func NewLIFO() filters.Spec {
return &lifoSpec{}
}
func NewLIFOGroup() filters.Spec {
return &lifoGroupSpec{}
}
func intArg(a interface{}) (int, error) {
switch v := a.(type) {
case int:
return v, nil
case float64:
return int(v), nil
default:
return 0, filters.ErrInvalidFilterParameters
}
}
func durationArg(a interface{}) (time.Duration, error) {
switch v := a.(type) {
case string:
return time.ParseDuration(v)
default:
return 0, filters.ErrInvalidFilterParameters
}
}
func (s *lifoSpec) Name() string { return filters.LifoName }
// CreateFilter creates a lifoFilter, that will use a queue based
// queue for handling requests instead of the fifo queue. The first
// parameter is MaxConcurrency the second MaxQueueSize and the third
// Timeout.
//
// The implementation is based on
// https://godoc.org/github.com/aryszka/jobqueue, which provides more
// detailed documentation.
//
// All parameters are optional and defaults to
// MaxConcurrency 100, MaxQueueSize 100, Timeout 10s.
//
// The total maximum number of requests has to be computed by adding
// MaxConcurrency and MaxQueueSize: total max = MaxConcurrency + MaxQueueSize
//
// Min values are 1 for MaxConcurrency and MaxQueueSize, and 1ms for
// Timeout. All configuration that is below will be set to these min
// values.
func (s *lifoSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
var l lifoFilter
// set defaults
l.config.MaxConcurrency = defaultMaxConcurreny
l.config.MaxQueueSize = defaultMaxQueueSize
l.config.Timeout = defaultTimeout
if len(args) > 0 {
c, err := intArg(args[0])
if err != nil {
return nil, err
}
if c >= 1 {
l.config.MaxConcurrency = c
}
}
if len(args) > 1 {
c, err := intArg(args[1])
if err != nil {
return nil, err
}
if c >= 0 {
l.config.MaxQueueSize = c
}
}
if len(args) > 2 {
d, err := durationArg(args[2])
if err != nil {
return nil, err
}
if d >= 1*time.Millisecond {
l.config.Timeout = d
}
}
if len(args) > 3 {
return nil, filters.ErrInvalidFilterParameters
}
return &l, nil
}
func (*lifoGroupSpec) Name() string { return filters.LifoGroupName }
// CreateFilter creates a lifoGroupFilter, that will use a queue based
// queue for handling requests instead of the fifo queue. The first
// parameter is the Name, the second MaxConcurrency, the third
// MaxQueueSize and the fourth Timeout.
//
// The Name parameter is used to group the queue by one or
// multiple routes. All other parameters are optional and defaults to
// MaxConcurrency 100, MaxQueueSize 100, Timeout 10s. If the
// configuration for the same Name is different the behavior is
// undefined.
//
// The implementation is based on
// https://godoc.org/github.com/aryszka/jobqueue, which provides more
// detailed documentation.
//
// The total maximum number of requests has to be computed by adding
// MaxConcurrency and MaxQueueSize: total max = MaxConcurrency + MaxQueueSize
//
// Min values are 1 for MaxConcurrency and MaxQueueSize, and 1ms for
// Timeout. All configuration that is below will be set to these min
// values.
//
// It is enough to set the concurrency, queue size and timeout parameters for
// one instance of the filter in the group, and only the group name for the
// rest. Setting these values for multiple instances is fine, too. While only
// one of them will be used as the source for the applied settings, if there
// is accidentally a difference between the settings in the same group, a
// warning will be logged.
func (*lifoGroupSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 1 || len(args) > 4 {
return nil, filters.ErrInvalidFilterParameters
}
l := &lifoGroupFilter{}
switch v := args[0].(type) {
case string:
l.name = v
default:
return nil, filters.ErrInvalidFilterParameters
}
// set defaults
cfg := scheduler.Config{
MaxConcurrency: defaultMaxConcurreny,
MaxQueueSize: defaultMaxQueueSize,
Timeout: defaultTimeout,
}
l.config = cfg
if len(args) > 1 {
l.hasConfig = true
c, err := intArg(args[1])
if err != nil {
return nil, err
}
if c >= 1 {
l.config.MaxConcurrency = c
}
}
if len(args) > 2 {
c, err := intArg(args[2])
if err != nil {
return nil, err
}
if c >= 1 {
l.config.MaxQueueSize = c
}
}
if len(args) > 3 {
d, err := durationArg(args[3])
if err != nil {
return nil, err
}
if d >= 1*time.Millisecond {
l.config.Timeout = d
}
}
return l, nil
}
// Config returns the scheduler configuration for the given filter
func (l *lifoFilter) Config() scheduler.Config {
return l.config
}
// SetQueue binds the queue to the current filter context
func (l *lifoFilter) SetQueue(q *scheduler.Queue) {
l.queue = q
}
// GetQueue is only used in tests.
func (l *lifoFilter) GetQueue() *scheduler.Queue {
return l.queue
}
// Close will cleanup underlying queues
func (l *lifoFilter) Close() error {
l.queue.Close()
return nil
}
// Request is the filter.Filter interface implementation. Request will
// increase the number of inflight requests and respond to the caller,
// if the bounded queue returns an error. Status code by Error:
//
// - 503 if jobqueue.ErrStackFull
// - 502 if jobqueue.ErrTimeout
func (l *lifoFilter) Request(ctx filters.FilterContext) {
request(l.GetQueue(), scheduler.LIFOKey, ctx)
}
// Response is the filter.Filter interface implementation. Response
// will decrease the number of inflight requests.
func (l *lifoFilter) Response(ctx filters.FilterContext) {
response(scheduler.LIFOKey, ctx)
}
// HandleErrorResponse is to opt-in for filters to get called
// Response(ctx) in case of errors via proxy. It has to return true to opt-in.
func (l *lifoFilter) HandleErrorResponse() bool {
return true
}
func (l *lifoGroupFilter) Group() string {
return l.name
}
func (l *lifoGroupFilter) HasConfig() bool {
return l.hasConfig
}
// Config returns the scheduler configuration for the given filter
func (l *lifoGroupFilter) Config() scheduler.Config {
return l.config
}
// SetQueue binds the queue to the current filter context
func (l *lifoGroupFilter) SetQueue(q *scheduler.Queue) {
l.queue = q
}
// GetQueue is only used in tests
func (l *lifoGroupFilter) GetQueue() *scheduler.Queue {
return l.queue
}
// Close will cleanup underlying queues
func (l *lifoGroupFilter) Close() error {
l.queue.Close()
return nil
}
// Request is the filter.Filter interface implementation. Request will
// increase the number of inflight requests and respond to the caller,
// if the bounded queue returns an error. Status code by Error:
//
// - 503 if jobqueue.ErrStackFull
// - 502 if jobqueue.ErrTimeout
func (l *lifoGroupFilter) Request(ctx filters.FilterContext) {
request(l.GetQueue(), scheduler.LIFOKey, ctx)
}
// Response is the filter.Filter interface implementation. Response
// will decrease the number of inflight requests.
func (l *lifoGroupFilter) Response(ctx filters.FilterContext) {
response(scheduler.LIFOKey, ctx)
}
// HandleErrorResponse is to opt-in for filters to get called
// Response(ctx) in case of errors via proxy. It has to return true to opt-in.
func (l *lifoGroupFilter) HandleErrorResponse() bool {
return true
}
func request(q *scheduler.Queue, key string, ctx filters.FilterContext) {
if q == nil {
ctx.Logger().Warnf("Unexpected scheduler.Queue is nil for key %s", key)
return
}
done, err := q.Wait()
if err != nil {
if span := opentracing.SpanFromContext(ctx.Request().Context()); span != nil {
ext.Error.Set(span, true)
}
switch err {
case jobqueue.ErrStackFull:
ctx.Logger().Debugf("Failed to get an entry on to the queue to process QueueFull: %v for host %s", err, ctx.Request().Host)
ctx.Serve(&http.Response{
StatusCode: http.StatusServiceUnavailable,
Status: "Queue Full - https://opensource.zalando.com/skipper/operation/operation/#scheduler",
})
case jobqueue.ErrTimeout:
ctx.Logger().Debugf("Failed to get an entry on to the queue to process Timeout: %v for host %s", err, ctx.Request().Host)
ctx.Serve(&http.Response{
StatusCode: http.StatusBadGateway,
Status: "Queue timeout - https://opensource.zalando.com/skipper/operation/operation/#scheduler",
})
default:
ctx.Logger().Errorf("Unknown error for route based LIFO: %v for host %s", err, ctx.Request().Host)
ctx.Serve(&http.Response{StatusCode: http.StatusInternalServerError})
}
return
}
pending, _ := ctx.StateBag()[key].([]func())
ctx.StateBag()[key] = append(pending, done)
}
func response(key string, ctx filters.FilterContext) {
pending, _ := ctx.StateBag()[key].([]func())
last := len(pending) - 1
if last < 0 {
return
}
pending[last]()
ctx.StateBag()[key] = pending[:last]
}
package sed
import (
"bytes"
"errors"
"io"
"regexp"
)
const (
readBufferSize = 8192
defaultMaxEditorBufferSize = 2097152 // 2Mi
)
type maxBufferHandling int
const (
maxBufferBestEffort maxBufferHandling = iota
maxBufferAbort
)
// editor provides a reader that wraps an input reader, and replaces each occurrence of
// the provided search pattern with the provided replacement. It can be used with a
// delimiter or without.
//
// When using it with a delimiter, it reads enough data from the input until meeting
// a delimiter or reaching maxBufferSize. The chunk includes the delimiter if any. Then
// every occurrence of the pattern is replaced, and the entire edited chunk is returned
// to the caller.
//
// When not using a delimiter, it reads enough data until at least a complete match of the
// pattern is met or the maxBufferSize is reached. When the pattern matches the entire
// buffered input, the replaced content is returned to the caller when maxBufferSize is
// reached. This also means that more replacements can happen than if we edited the
// entire content in one piece, but this is necessary to be able to use the editor for
// input with unknown length.
//
// When the maxBufferHandling is set to maxBufferAbort, then the streaming is aborted
// and the rest of the payload is dropped.
//
// To limit the number of repeated scans over the buffered data, the size of the
// additional data read from the input grows exponentially with every iteration that
// didn't result with any edited data returned to the caller. If there was any edited
// returned to the caller, the read size is reset to the initial value.
//
// When the input returns an error, e.g. EOF, the editor finishes editing the buffered
// data, returns it to the caller, and returns the received error on every subsequent
// read.
//
// When the editor is closed, it doesn't read anymore from the input or return any
// buffered data. If the input implements io.Closer, closing the editor closes the
// input, too.
type editor struct {
// init:
input io.Reader
pattern *regexp.Regexp
replacement []byte
delimiter []byte
maxBufferSize int
maxBufferHandling maxBufferHandling
prefix []byte
readBuffer []byte
// state:
ready *bytes.Buffer
pending *bytes.Buffer
// final:
err error
closed bool
}
var (
ErrClosed = errors.New("reader closed")
ErrEditorBufferFull = errors.New("editor buffer full")
)
func newEditor(
input io.Reader,
pattern *regexp.Regexp,
replacement []byte,
delimiter []byte,
maxBufferSize int,
mbh maxBufferHandling,
) *editor {
if maxBufferSize <= 0 {
maxBufferSize = defaultMaxEditorBufferSize
}
rsize := readBufferSize
if maxBufferSize < rsize {
rsize = maxBufferSize
}
prefix, _ := pattern.LiteralPrefix()
return &editor{
input: input,
pattern: pattern,
replacement: replacement,
delimiter: delimiter,
maxBufferSize: maxBufferSize,
maxBufferHandling: mbh,
prefix: []byte(prefix),
readBuffer: make([]byte, rsize),
pending: bytes.NewBuffer(nil),
ready: bytes.NewBuffer(nil),
}
}
func (e *editor) readNTimes(times int) (bool, error) {
var consumedInput bool
for i := 0; i < times; i++ {
n, err := e.input.Read(e.readBuffer)
e.pending.Write(e.readBuffer[:n])
if n > 0 {
consumedInput = true
}
if err != nil {
return consumedInput, err
}
}
return consumedInput, nil
}
func (e *editor) edit(b []byte, keepLastChunk bool) (int, bool) {
var consumed int
for len(b) > 0 {
if len(e.prefix) > 0 && len(b) >= len(e.prefix) {
skip := bytes.Index(b, e.prefix)
if skip > 0 {
e.ready.Write(b[:skip])
consumed += skip
b = b[skip:]
}
}
match := e.pattern.FindIndex(b)
if len(match) == 0 {
if keepLastChunk {
return consumed, false
}
e.ready.Write(b)
consumed += len(b)
return consumed, false
}
e.ready.Write(b[:match[0]])
consumed += match[0]
if match[1] == match[0] {
if keepLastChunk {
return consumed, false
}
e.ready.Write(b[match[0]:])
consumed += len(b) - match[0]
return consumed, false
}
if keepLastChunk && match[1] == len(b) {
return consumed, true
}
e.ready.Write(e.replacement)
consumed += match[1] - match[0]
b = b[match[1]:]
}
return consumed, false
}
func (e *editor) editUnbound() bool {
consumed, pendingMatches := e.edit(e.pending.Bytes(), true)
e.pending.Next(consumed)
return pendingMatches
}
func (e *editor) editDelimited() {
for {
endChunk := bytes.Index(e.pending.Bytes(), e.delimiter)
if endChunk < 0 {
return
}
chunk := e.pending.Next(endChunk + len(e.delimiter))
e.edit(chunk, false)
}
}
func (e *editor) finalizeEdit(pendingMatches bool) {
if pendingMatches {
e.ready.Write(e.replacement)
return
}
if len(e.delimiter) == 0 {
io.CopyBuffer(e.ready, e.pending, e.readBuffer)
return
}
e.edit(e.pending.Bytes(), false)
}
func (e *editor) fill(requested int) error {
var pendingMatches bool
readSize := 1
for e.ready.Len() < requested {
consumedInput, err := e.readNTimes(readSize)
if !consumedInput {
if err != nil {
e.finalizeEdit(pendingMatches)
}
return err
}
if len(e.delimiter) == 0 {
pendingMatches = e.editUnbound()
} else {
e.editDelimited()
}
if err != nil {
e.finalizeEdit(pendingMatches)
return err
}
if e.pending.Len() > e.maxBufferSize {
switch e.maxBufferHandling {
case maxBufferAbort:
return ErrEditorBufferFull
default:
e.edit(e.pending.Bytes(), false)
e.pending.Reset()
readSize = 1
}
}
readSize *= 2
}
return nil
}
func (e *editor) Read(p []byte) (int, error) {
if e.closed {
return 0, ErrClosed
}
if e.ready.Len() == 0 && e.err != nil {
return 0, e.err
}
if e.ready.Len() < len(p) {
e.err = e.fill(len(p))
}
if e.err == ErrEditorBufferFull {
return 0, ErrEditorBufferFull
}
n, _ := e.ready.Read(p)
if n == 0 && len(p) > 0 && e.err != nil {
return 0, e.err
}
return n, nil
}
// Closes closes the undelrying reader if it implements io.Closer.
func (e *editor) Close() error {
e.closed = true
if c, ok := e.input.(io.Closer); ok {
return c.Close()
}
return nil
}
package sed
import (
"regexp"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.SedName instead
Name = filters.SedName
// Deprecated, use filters.SedDelimName instead
NameDelimit = filters.SedDelimName
// Deprecated, use filters.SedRequestName instead
NameRequest = filters.SedRequestName
// Deprecated, use filters.SedRequestDelimName instead
NameRequestDelimit = filters.SedRequestDelimName
)
type typ int
const (
simple typ = iota
delimited
simpleRequest
delimitedRequest
)
type spec struct {
typ typ
}
type filter struct {
typ typ
pattern *regexp.Regexp
replacement []byte
delimiter []byte
maxEditorBuffer int
maxBufferHandling maxBufferHandling
}
func ofType(t typ) spec {
return spec{typ: t}
}
// New creates a filter specficiation for the sed() filter.
func New() filters.Spec {
return ofType(simple)
}
// NewDelimited creates a filter specficiation for the sedDelim() filter.
func NewDelimited() filters.Spec {
return ofType(delimited)
}
// NewRequest creates a filter specficiation for the sedRequest() filter.
func NewRequest() filters.Spec {
return ofType(simpleRequest)
}
// NewDelimitedRequest creates a filter specficiation for the sedRequestDelim() filter.
func NewDelimitedRequest() filters.Spec {
return ofType(delimitedRequest)
}
func (s spec) Name() string {
switch s.typ {
case delimited:
return filters.SedDelimName
case simpleRequest:
return filters.SedRequestName
case delimitedRequest:
return filters.SedRequestDelimName
default:
return filters.SedName
}
}
func parseMaxBufferHandling(h interface{}) (maxBufferHandling, error) {
switch h {
case "best-effort":
return maxBufferBestEffort, nil
case "abort":
return maxBufferAbort, nil
default:
return 0, filters.ErrInvalidFilterParameters
}
}
func (s spec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 2 {
return nil, filters.ErrInvalidFilterParameters
}
pattern, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
patternRx, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
replacement, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
f := &filter{
typ: s.typ,
pattern: patternRx,
replacement: []byte(replacement),
maxBufferHandling: maxBufferBestEffort,
}
var (
delimiterString string
maxBuf interface{}
maxBufHandling interface{}
)
switch s.typ {
case delimited, delimitedRequest:
if len(args) < 3 || len(args) > 5 {
return nil, filters.ErrInvalidFilterParameters
}
if delimiterString, ok = args[2].(string); !ok {
return nil, filters.ErrInvalidFilterParameters
}
if len(args) >= 4 {
maxBuf = args[3]
}
if len(args) == 5 {
maxBufHandling = args[4]
}
f.delimiter = []byte(delimiterString)
default:
if len(args) > 4 {
return nil, filters.ErrInvalidFilterParameters
}
if len(args) >= 3 {
maxBuf = args[2]
}
if len(args) == 4 {
maxBufHandling = args[3]
}
}
if maxBuf != nil {
switch v := maxBuf.(type) {
case int:
f.maxEditorBuffer = v
case float64:
f.maxEditorBuffer = int(v)
default:
return nil, filters.ErrInvalidFilterParameters
}
}
if maxBufHandling != nil {
mbh, err := parseMaxBufferHandling(maxBufHandling)
if err != nil {
return nil, err
}
f.maxBufferHandling = mbh
}
return *f, nil
}
func (f filter) Request(ctx filters.FilterContext) {
switch f.typ {
case simple, delimited:
return
}
req := ctx.Request()
req.Header.Del("Content-Length")
req.ContentLength = -1
req.Body = newEditor(
req.Body,
f.pattern,
f.replacement,
f.delimiter,
f.maxEditorBuffer,
f.maxBufferHandling,
)
}
func (f filter) Response(ctx filters.FilterContext) {
switch f.typ {
case simpleRequest, delimitedRequest:
return
}
rsp := ctx.Response()
rsp.Header.Del("Content-Length")
rsp.ContentLength = -1
rsp.Body = newEditor(
rsp.Body,
f.pattern,
f.replacement,
f.delimiter,
f.maxEditorBuffer,
f.maxBufferHandling,
)
}
/*
Package serve provides a wrapper of net/http.Handler to be used as a filter.
*/
package serve
import (
"io"
"net/http"
"github.com/zalando/skipper/filters"
)
type pipedResponse struct {
response *http.Response
reader *io.PipeReader
writer *io.PipeWriter
headerDone chan struct{}
}
// Creates a response from a handler and a request.
//
// It calls the handler's ServeHTTP method with an internal response
// writer that shares the status code, headers and the response body
// with the returned response. It blocks until the handler calls the
// response writer's WriteHeader, or starts writing the body, or
// returns. The written body is not buffered, but piped to the returned
// response's body.
//
// Example, a simple file server:
//
// var handler = http.StripPrefix(webRoot, http.FileServer(http.Dir(root)))
//
// func (f *myFilter) Request(ctx filters.FilterContext) {
// serve.ServeHTTP(ctx, handler)
// }
func ServeHTTP(ctx filters.FilterContext, h http.Handler) {
rsp := &http.Response{Header: make(http.Header)}
r, w := io.Pipe()
d := &pipedResponse{
response: rsp,
reader: r,
writer: w,
headerDone: make(chan struct{})}
req := ctx.Request()
go func() {
h.ServeHTTP(d, req)
select {
case <-d.headerDone:
default:
d.WriteHeader(http.StatusOK)
}
w.CloseWithError(io.EOF)
}()
<-d.headerDone
rsp.Body = d
ctx.Serve(rsp)
}
func (d *pipedResponse) Read(data []byte) (int, error) { return d.reader.Read(data) }
func (d *pipedResponse) Header() http.Header { return d.response.Header }
// Implements http.ResponseWriter.Write. When WriteHeader was
// not called before Write, it calls it with the default 200
// status code.
func (d *pipedResponse) Write(data []byte) (int, error) {
select {
case <-d.headerDone:
default:
d.WriteHeader(http.StatusOK)
}
return d.writer.Write(data)
}
// It sets the status code for the outgoing response, and
// signals that the header is done.
func (d *pipedResponse) WriteHeader(status int) {
d.response.StatusCode = status
close(d.headerDone)
}
func (d *pipedResponse) Close() error {
return d.reader.Close()
}
package shedder
import (
"context"
"math"
"math/rand"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/metrics"
"github.com/zalando/skipper/routing"
)
func getIntArg(a interface{}) (int, error) {
if i, ok := a.(int); ok {
return i, nil
}
if f, ok := a.(float64); ok {
return int(f), nil
}
return 0, filters.ErrInvalidFilterParameters
}
func getDurationArg(a interface{}) (time.Duration, error) {
if s, ok := a.(string); ok {
return time.ParseDuration(s)
}
return 0, filters.ErrInvalidFilterParameters
}
func getFloat64Arg(a interface{}) (float64, error) {
if f, ok := a.(float64); ok {
return f, nil
}
return 0, filters.ErrInvalidFilterParameters
}
func getModeArg(a interface{}) (mode, error) {
s, ok := a.(string)
if !ok {
return 0, filters.ErrInvalidFilterParameters
}
switch s {
case "active":
return active, nil
case "inactive":
return inactive, nil
case "logInactive":
return logInactive, nil
}
return 0, filters.ErrInvalidFilterParameters
}
type mode int
const (
inactive mode = iota + 1
logInactive
active
)
func (m mode) String() string {
switch m {
case active:
return "active"
case inactive:
return "inactive"
case logInactive:
return "logInactive"
}
return "unknown"
}
const (
counterPrefix = "shedder.admission_control."
admissionControlSpanName = "admission_control"
admissionSignalHeaderKey = "Admission-Control"
admissionSignalHeaderValue = "true"
admissionControlKey = "shedder:admission_control"
admissionControlValue = "reject"
minWindowSize = 1
maxWindowSize = 100
)
type Options struct {
Tracer opentracing.Tracer
}
type admissionControlPre struct{}
// Do removes duplicate filters, because we can only handle one in a
// chain. The last one will override the others.
func (spec *admissionControlPre) Do(routes []*eskip.Route) []*eskip.Route {
for _, r := range routes {
foundAt := -1
toDelete := make(map[int]struct{})
for i, f := range r.Filters {
if f.Name == filters.AdmissionControlName {
if foundAt != -1 {
toDelete[foundAt] = struct{}{}
}
foundAt = i
}
}
if len(toDelete) == 0 {
continue
}
rf := make([]*eskip.Filter, 0, len(r.Filters)-len(toDelete))
for i, f := range r.Filters {
if _, ok := toDelete[i]; !ok {
rf = append(rf, f)
}
}
r.Filters = rf
}
return routes
}
type admissionControlPost struct {
filters map[string]*admissionControl
}
// Do implements routing.PostProcessor and makes it possible to close goroutines.
func (spec *admissionControlPost) Do(routes []*routing.Route) []*routing.Route {
inUse := make(map[string]struct{})
for _, r := range routes {
for _, f := range r.Filters {
if ac, ok := f.Filter.(*admissionControl); ok {
oldAc, okOld := spec.filters[r.Id]
if okOld {
// replace: close the old one
oldAc.Close()
}
spec.filters[r.Id] = ac
inUse[r.Id] = struct{}{}
}
}
}
for id, f := range spec.filters {
if _, ok := inUse[id]; !ok {
// delete: close the old one
f.Close()
}
}
return routes
}
type AdmissionControlSpec struct {
tracer opentracing.Tracer
}
type admissionControl struct {
once sync.Once
mu sync.Mutex
quit chan struct{}
closed bool
metrics metrics.Metrics
metricSuffix string
tracer opentracing.Tracer
mode mode
windowSize int
minRps int
d time.Duration
successThreshold float64 // (0,1]
maxRejectProbability float64 // (0,1]
exponent float64 // >0
averageRpsFactor float64
totals []int64
success []int64
counter *atomic.Int64
successCounter *atomic.Int64
}
func NewAdmissionControl(o Options) filters.Spec {
tracer := o.Tracer
if tracer == nil {
tracer = &opentracing.NoopTracer{}
}
return &AdmissionControlSpec{
tracer: tracer,
}
}
func (*AdmissionControlSpec) PreProcessor() *admissionControlPre {
return &admissionControlPre{}
}
func (*AdmissionControlSpec) PostProcessor() *admissionControlPost {
return &admissionControlPost{
filters: make(map[string]*admissionControl),
}
}
func (*AdmissionControlSpec) Name() string { return filters.AdmissionControlName }
// CreateFilter creates a new admissionControl filter with passed configuration:
//
// admissionControl(metricSuffix, mode, d, windowSize, minRps, successThreshold, maxRejectProbability, exponent)
// admissionControl("$app", "active", "1s", 5, 10, 0.1, 0.95, 0.5)
//
// metricSuffix is the suffix key to expose reject counter, should be unique by filter instance
// mode is one of "active", "inactive", "logInactive"
//
// active will reject traffic
// inactive will never reject traffic
// logInactive will not reject traffic, but log to debug filter settings
//
// windowSize is within [minWindowSize, maxWindowSize]
// minRps threshold that needs to be reached such that the filter will apply
// successThreshold is within (0,1] and sets the lowest request success rate at which the filter will not reject requests.
// maxRejectProbability is within (0,1] and sets the upper bound of reject probability.
// exponent >0, 1: linear, 1/2: qudratic, 1/3: cubic, ..
//
// see also https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_filters/admission_control_filter#admission-control
func (spec *AdmissionControlSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
var err error
if len(args) != 8 {
return nil, filters.ErrInvalidFilterParameters
}
metricSuffix, ok := args[0].(string)
if !ok {
log.Warn("metricsuffix required as string")
return nil, filters.ErrInvalidFilterParameters
}
mode, err := getModeArg(args[1])
if err != nil {
log.Warnf("mode failed: %v", err)
return nil, filters.ErrInvalidFilterParameters
}
d, err := getDurationArg(args[2])
if err != nil {
log.Warnf("d failed: %v", err)
return nil, filters.ErrInvalidFilterParameters
}
windowSize, err := getIntArg(args[3])
if err != nil {
log.Warnf("windowsize failed: %v", err)
return nil, filters.ErrInvalidFilterParameters
}
if minWindowSize > windowSize || windowSize > maxWindowSize {
log.Warnf("windowsize too small, should be within: [%d,%d], got: %d", minWindowSize, maxWindowSize, windowSize)
return nil, filters.ErrInvalidFilterParameters
}
minRps, err := getIntArg(args[4])
if err != nil {
log.Warnf("minRequests failed: %v", err)
return nil, filters.ErrInvalidFilterParameters
}
threshold, err := getFloat64Arg(args[5])
if err != nil {
log.Warnf("threshold failed %v", err)
return nil, filters.ErrInvalidFilterParameters
}
maxRejectProbability, err := getFloat64Arg(args[6])
if err != nil {
log.Warnf("maxRejectProbability failed: %v", err)
return nil, filters.ErrInvalidFilterParameters
}
exponent, err := getFloat64Arg(args[7])
if err != nil {
log.Warnf("exponent failed: %v", err)
return nil, filters.ErrInvalidFilterParameters
}
if exponent <= 0.0 {
log.Warn("exponent should be >0")
return nil, filters.ErrInvalidFilterParameters
}
averageRpsFactor := float64(time.Second) / (float64(d) * float64(windowSize))
ac := &admissionControl{
once: sync.Once{},
quit: make(chan struct{}),
metrics: metrics.Default,
metricSuffix: metricSuffix,
tracer: spec.tracer,
mode: mode,
d: d,
windowSize: windowSize,
minRps: minRps,
successThreshold: threshold,
maxRejectProbability: maxRejectProbability,
exponent: exponent,
averageRpsFactor: averageRpsFactor,
totals: make([]int64, windowSize),
success: make([]int64, windowSize),
counter: new(atomic.Int64),
successCounter: new(atomic.Int64),
}
go ac.tickWindows(d)
return ac, nil
}
// Close stops the background goroutine. The filter keeps working on stale data.
func (ac *admissionControl) Close() error {
ac.once.Do(func() {
ac.closed = true
close(ac.quit)
})
return nil
}
func (ac *admissionControl) tickWindows(d time.Duration) {
t := time.NewTicker(d)
defer t.Stop()
i := 0
for range t.C {
select {
case <-ac.quit:
return
default:
}
val := ac.counter.Swap(0)
ok := ac.successCounter.Swap(0)
ac.mu.Lock()
ac.totals[i] = val
ac.success[i] = ok
ac.mu.Unlock()
i = (i + 1) % ac.windowSize
}
}
func (ac *admissionControl) count() (float64, float64) {
ac.mu.Lock()
defer ac.mu.Unlock()
return float64(sum(ac.totals)), float64(sum(ac.success))
}
func sum(a []int64) int64 {
var result int64
for _, v := range a {
result += v
}
return result
}
func (ac *admissionControl) setCommonTags(span opentracing.Span) {
span.SetTag("admissionControl.group", ac.metricSuffix)
span.SetTag("admissionControl.mode", ac.mode.String())
span.SetTag("admissionControl.duration", ac.d.String())
span.SetTag("admissionControl.windowSize", ac.windowSize)
}
// calculates P_{reject} see https://opensource.zalando.com/skipper/reference/filters/#admissioncontrol
func (ac *admissionControl) pReject() float64 {
var rejectP float64
total, success := ac.count()
avgRps := total * ac.averageRpsFactor
if avgRps < float64(ac.minRps) {
if ac.mode == logInactive {
log.Infof("avgRps %0.2f does not reach minRps %d", avgRps, ac.minRps)
}
return -1
}
s := success / ac.successThreshold
if ac.mode == logInactive {
log.Infof("%s: total < s = %v, rejectP = (%0.2f - %0.2f) / (%0.2f + 1) --- success: %0.2f and threshold: %0.2f", filters.AdmissionControlName, total < s, total, s, total, success, ac.successThreshold)
}
if total < s {
return -1
}
rejectP = (total - s) / (total + 1)
rejectP = math.Pow(rejectP, ac.exponent)
rejectP = math.Min(rejectP, ac.maxRejectProbability)
return math.Max(rejectP, 0.0)
}
func (ac *admissionControl) shouldReject() bool {
p := ac.pReject() // [0, ac.maxRejectProbability] and -1 to disable
/* #nosec */
r := rand.Float64() // [0,1)
if ac.mode == logInactive {
log.Infof("%s: p: %0.2f, r: %0.2f", filters.AdmissionControlName, p, r)
}
return p > r
}
func (ac *admissionControl) Request(ctx filters.FilterContext) {
span := ac.startSpan(ctx.Request().Context())
defer span.Finish()
ac.setCommonTags(span)
ac.metrics.IncCounter(counterPrefix + "total." + ac.metricSuffix)
if ac.shouldReject() {
ac.metrics.IncCounter(counterPrefix + "reject." + ac.metricSuffix)
ext.Error.Set(span, true)
ctx.StateBag()[admissionControlKey] = admissionControlValue
// shadow mode to measure data
if ac.mode != active {
return
}
header := make(http.Header)
header.Set(admissionSignalHeaderKey, admissionSignalHeaderValue)
ctx.Serve(&http.Response{
Header: header,
StatusCode: http.StatusServiceUnavailable,
})
}
}
func (ac *admissionControl) Response(ctx filters.FilterContext) {
// we don't want to count our short cutted responses as errors
if ctx.StateBag()[admissionControlKey] == admissionControlValue {
return
}
// we don't want to count other shedders in the call path as errors
if ctx.Response().Header.Get(admissionSignalHeaderKey) == admissionSignalHeaderValue {
return
}
if ctx.Response().StatusCode < 499 {
ac.successCounter.Add(1)
}
ac.counter.Add(1)
}
func (ac *admissionControl) startSpan(ctx context.Context) (span opentracing.Span) {
parent := opentracing.SpanFromContext(ctx)
if parent != nil {
span = ac.tracer.StartSpan(admissionControlSpanName, opentracing.ChildOf(parent.Context()))
ext.Component.Set(span, "skipper")
span.SetTag("mode", ac.mode.String())
}
return
}
// HandleErrorResponse is to opt-in for filters to get called
// Response(ctx) in case of errors via proxy. It has to return true to
// opt-in.
func (ac *admissionControl) HandleErrorResponse() bool { return true }
package tee
import (
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"sync"
"time"
"github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/net"
)
const (
// Deprecated, use filters.TeeName instead
Name = filters.TeeName
// Deprecated, use filters.TeeName instead
DeprecatedName = "Tee"
// Deprecated, use filters.TeenfName instead
NoFollowName = filters.TeenfName
)
const (
defaultTeeTimeout = time.Second
defaultMaxIdleConns = 100
defaultMaxIdleConnsPerHost = 100
defaultIdleConnTimeout = 30 * time.Second
)
type teeSpec struct {
deprecated bool
options Options
}
// Options for tee filter.
type Options struct {
// NoFollow specifies whether tee should follow redirects or not.
// If NoFollow is true, it won't follow, otherwise it will.
NoFollow bool
// Timeout specifies a time limit for requests made by tee filter.
Timeout time.Duration
// Tracer is the opentracing tracer to use in the client
Tracer opentracing.Tracer
// MaxIdleConns defaults to 100
MaxIdleConns int
// MaxIdleConnsPerHost defaults to 100
MaxIdleConnsPerHost int
// IdleConnTimeout defaults to 30s
IdleConnTimeout time.Duration
}
type teeType int
const (
asBackend teeType = iota + 1
pathModified
)
type teeClient struct {
mu sync.Mutex
store map[string]*net.Client
}
var teeClients *teeClient = &teeClient{
store: make(map[string]*net.Client),
}
type tee struct {
client *net.Client
typ teeType
host string
scheme string
rx *regexp.Regexp
replacement string
shadowRequestDone func() // test hook
}
type teeTie struct {
r io.Reader
w *io.PipeWriter
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
// and
// https://golang.org/src/net/http/httputil/reverseproxy.go
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
// Returns a new tee filter Spec, whose instances execute the exact same Request against a shadow backend.
// parameters: shadow backend url, optional - the path(as a regexp) to match and the replacement string.
//
// Name: "tee".
func NewTee() filters.Spec {
return WithOptions(Options{
Timeout: defaultTeeTimeout,
NoFollow: false,
MaxIdleConns: defaultMaxIdleConns,
MaxIdleConnsPerHost: defaultMaxIdleConnsPerHost,
IdleConnTimeout: defaultIdleConnTimeout,
})
}
// Returns a new tee filter Spec, whose instances execute the exact same Request against a shadow backend.
// parameters: shadow backend url, optional - the path(as a regexp) to match and the replacement string.
//
// This version uses the capitalized version of the filter name and to follow conventions, it is deprecated
// and NewTee() (providing the name "tee") should be used instead.
//
// Name: "Tee".
func NewTeeDeprecated() filters.Spec {
sp := WithOptions(Options{
NoFollow: false,
Timeout: defaultTeeTimeout,
MaxIdleConns: defaultMaxIdleConns,
MaxIdleConnsPerHost: defaultMaxIdleConnsPerHost,
IdleConnTimeout: defaultIdleConnTimeout,
})
ts := sp.(*teeSpec)
ts.deprecated = true
return ts
}
// Returns a new tee filter Spec, whose instances execute the exact same Request against a shadow backend.
// It does not follow the redirects from the backend.
// parameters: shadow backend url, optional - the path(as a regexp) to match and the replacement string.
//
// Name: "teenf".
func NewTeeNoFollow() filters.Spec {
return WithOptions(Options{
NoFollow: true,
Timeout: defaultTeeTimeout,
MaxIdleConns: defaultMaxIdleConns,
MaxIdleConnsPerHost: defaultMaxIdleConnsPerHost,
IdleConnTimeout: defaultIdleConnTimeout,
})
}
// Returns a new tee filter Spec, whose instances execute the exact same Request against a shadow backend with given
// options. Available options are nofollow and Timeout for http client. For more available options see Options type.
// parameters: shadow backend url, optional - the path(as a regexp) to match and the replacement string.
func WithOptions(o Options) filters.Spec {
if o.Timeout == 0 {
o.Timeout = defaultIdleConnTimeout
}
if o.MaxIdleConns == 0 {
o.MaxIdleConns = defaultMaxIdleConns
}
if o.MaxIdleConnsPerHost == 0 {
o.MaxIdleConnsPerHost = defaultMaxIdleConnsPerHost
}
if o.IdleConnTimeout == 0 {
o.IdleConnTimeout = defaultIdleConnTimeout
}
return &teeSpec{options: o}
}
func (tt *teeTie) Read(b []byte) (int, error) {
n, err := tt.r.Read(b)
if err != nil && err != io.EOF {
tt.w.CloseWithError(err)
return n, err
}
if n > 0 {
if _, werr := tt.w.Write(b[:n]); werr != nil {
log.Error("tee: error while tee request", werr)
}
}
if err == io.EOF {
tt.w.Close()
}
return n, err
}
func (tt *teeTie) Close() error { return nil }
// We do not touch response at all
func (r *tee) Response(filters.FilterContext) {}
// Request is copied and then modified to adopt changes in new backend
func (r *tee) Request(fc filters.FilterContext) {
req := fc.Request()
copyOfRequest, tr, err := cloneRequest(r, req)
if err != nil {
fc.Logger().Warnf("tee: error while cloning the tee request %v", err)
return
}
req.Body = tr
go func() {
defer func() {
if r.shadowRequestDone != nil {
r.shadowRequestDone()
}
}()
rsp, err := r.client.Do(copyOfRequest)
if err != nil {
fc.Logger().Warnf("tee: error while tee request %v", err)
return
}
rsp.Body.Close()
}()
}
// copies requests changes URL and Host in request.
// If 2nd and 3rd params are given path is also modified by applying regexp
// Returns the cloned request and the tee body to be used on the main request.
func cloneRequest(t *tee, req *http.Request) (*http.Request, io.ReadCloser, error) {
u := new(url.URL)
*u = *req.URL
u.Host = t.host
u.Scheme = t.scheme
if t.typ == pathModified {
u.Path = t.rx.ReplaceAllString(u.Path, t.replacement)
}
h := make(http.Header)
for k, v := range req.Header {
h[k] = v
}
for _, k := range hopHeaders {
h.Del(k)
}
var teeBody io.ReadCloser
mainBody := req.Body
// see proxy.go:231
if req.ContentLength != 0 {
pr, pw := io.Pipe()
teeBody = pr
mainBody = &teeTie{mainBody, pw}
}
clone, err := http.NewRequest(req.Method, u.String(), teeBody)
if err != nil {
return nil, nil, err
}
clone.Header = h
clone.Host = t.host
clone.ContentLength = req.ContentLength
return clone, mainBody, nil
}
// Creates out tee Filter
// If only one parameter is given shadow backend is used as it is specified
// If second and third parameters are also set, then path is modified
func (spec *teeSpec) CreateFilter(config []interface{}) (filters.Filter, error) {
var checkRedirect func(req *http.Request, via []*http.Request) error
if spec.options.NoFollow {
checkRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}
if len(config) == 0 {
return nil, filters.ErrInvalidFilterParameters
}
backend, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
u, err := url.Parse(backend)
if err != nil {
return nil, err
}
var client *net.Client
teeClients.mu.Lock()
if cc, ok := teeClients.store[u.Host]; !ok {
client = net.NewClient(net.Options{
Timeout: spec.options.Timeout,
TLSHandshakeTimeout: spec.options.Timeout,
ResponseHeaderTimeout: spec.options.Timeout,
CheckRedirect: checkRedirect,
MaxIdleConns: spec.options.MaxIdleConns,
MaxIdleConnsPerHost: spec.options.MaxIdleConnsPerHost,
IdleConnTimeout: spec.options.IdleConnTimeout,
Tracer: spec.options.Tracer,
OpentracingComponentTag: "skipper",
OpentracingSpanName: spec.Name(),
})
teeClients.store[u.Host] = client
} else {
client = cc
}
teeClients.mu.Unlock()
tee := tee{
client: client,
host: u.Host,
scheme: u.Scheme,
}
switch len(config) {
case 1:
tee.typ = asBackend
return &tee, nil
case 3:
// modpath
expr, ok := config[1].(string)
if !ok {
return nil, fmt.Errorf("invalid filter config in %s, expecting regexp and string, got: %v", filters.TeeName, config)
}
replacement, ok := config[2].(string)
if !ok {
return nil, fmt.Errorf("invalid filter config in %s, expecting regexp and string, got: %v", filters.TeeName, config)
}
rx, err := regexp.Compile(expr)
if err != nil {
return nil, err
}
tee.typ = pathModified
tee.rx = rx
tee.replacement = replacement
return &tee, nil
default:
return nil, filters.ErrInvalidFilterParameters
}
}
func (spec *teeSpec) Name() string {
if spec.deprecated {
return DeprecatedName
}
if spec.options.NoFollow {
return filters.TeenfName
}
return filters.TeeName
}
package tee
import (
"github.com/zalando/skipper/filters"
teepredicate "github.com/zalando/skipper/predicates/tee"
)
// FilterName is the filter name
// Deprecated, use filters.TeeLoopbackName instead
const FilterName = filters.TeeLoopbackName
type teeLoopbackSpec struct{}
type teeLoopbackFilter struct {
teeKey string
}
func (t *teeLoopbackSpec) Name() string {
return filters.TeeLoopbackName
}
func (t *teeLoopbackSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
teeKey, _ := args[0].(string)
if teeKey == "" {
return nil, filters.ErrInvalidFilterParameters
}
return &teeLoopbackFilter{
teeKey,
}, nil
}
func NewTeeLoopback() filters.Spec {
return &teeLoopbackSpec{}
}
func (f *teeLoopbackFilter) Request(ctx filters.FilterContext) {
cc, err := ctx.Split()
if err != nil {
ctx.Logger().Errorf("teeloopback: failed to split the context request: %v", err)
return
}
cc.Request().Header.Set(teepredicate.HeaderKey, f.teeKey)
go cc.Loopback()
}
func (f *teeLoopbackFilter) Response(_ filters.FilterContext) {}
package tls
import (
"crypto/x509"
"encoding/pem"
"strings"
"github.com/zalando/skipper/filters"
)
type tlsSpec struct{}
type tlsFilter struct{}
func New() filters.Spec {
return &tlsSpec{}
}
func (*tlsSpec) Name() string {
return filters.TLSName
}
func (c *tlsSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 0 {
return nil, filters.ErrInvalidFilterParameters
}
return &tlsFilter{}, nil
}
const (
certSeparator = ","
certHeaderName = "X-Forwarded-Tls-Client-Cert"
)
var (
replacer = strings.NewReplacer(
"-----BEGIN CERTIFICATE-----", "",
"-----END CERTIFICATE-----", "",
"\n", "",
)
)
// sanitize the raw certificates, remove the useless data and make it http request compliant.
func sanitize(cert []byte) string {
return replacer.Replace(string(cert))
}
// getCertificates Build a string with the client certificates.
func getCertificates(certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
headerValues = append(headerValues, extractCertificate(peerCert))
}
return strings.Join(headerValues, certSeparator)
}
// extractCertificate extract the certificate from the request.
func extractCertificate(cert *x509.Certificate) string {
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if certPEM == nil {
return ""
}
return sanitize(certPEM)
}
// Request passes cert information via X-Forwarded-Tls-Client-Cert header to the backend.
// Largely inspired by traefik, see also https://github.com/traefik/traefik/blob/6c19a9cb8fb9e41a274bf712580df3712b69dc3e/pkg/middlewares/passtlsclientcert/pass_tls_client_cert.go#L146
func (f *tlsFilter) Request(ctx filters.FilterContext) {
if t := ctx.Request().TLS; t != nil {
if len(t.PeerCertificates) > 0 {
ctx.Request().Header.Set(certHeaderName, getCertificates(ctx.Request().TLS.PeerCertificates))
}
}
}
func (f *tlsFilter) Response(ctx filters.FilterContext) {}
package tracing
import (
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.TracingBaggageToTagName instead
BaggageToTagFilterName = filters.TracingBaggageToTagName
)
type baggageToTagSpec struct{}
type baggageToTagFilter struct {
baggageItemName string
tagName string
}
func (baggageToTagSpec) Name() string {
return filters.TracingBaggageToTagName
}
func (baggageToTagSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 1 {
return nil, filters.ErrInvalidFilterParameters
}
baggageItemName, ok := args[0].(string)
if !ok || baggageItemName == "" {
return nil, filters.ErrInvalidFilterParameters
}
tagName := baggageItemName
if len(args) > 1 {
tagNameArg, ok := args[1].(string)
if !ok || tagNameArg == "" {
return nil, filters.ErrInvalidFilterParameters
}
tagName = tagNameArg
}
return baggageToTagFilter{
baggageItemName,
tagName,
}, nil
}
func NewBaggageToTagFilter() filters.Spec {
return baggageToTagSpec{}
}
func (f baggageToTagFilter) Request(ctx filters.FilterContext) {
span := opentracing.SpanFromContext(ctx.Request().Context())
if span == nil {
return
}
baggageItem := span.BaggageItem(f.baggageItemName)
if baggageItem == "" {
return
}
span.SetTag(f.tagName, baggageItem)
}
func (baggageToTagFilter) Response(ctx filters.FilterContext) {}
/*
Package tracing provides filters to instrument distributed tracing.
*/
package tracing
import (
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.TracingSpanNameName instead
SpanNameFilterName = filters.TracingSpanNameName
// OpenTracingProxySpanKey is the key used in the state bag to pass the span name to the proxy.
OpenTracingProxySpanKey = "statebag:opentracing:proxy:span"
)
type spec struct{}
type filter struct {
spanName string
}
// NewSpanName creates a filter spec for setting the name of the outgoing span. (By default "proxy".)
//
// tracingSpanName("example-operation")
func NewSpanName() filters.Spec {
return &spec{}
}
func (s *spec) Name() string { return filters.TracingSpanNameName }
func (s *spec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) != 1 {
return nil, filters.ErrInvalidFilterParameters
}
name, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
return &filter{spanName: name}, nil
}
func (f *filter) Request(ctx filters.FilterContext) {
bag := ctx.StateBag()
bag[OpenTracingProxySpanKey] = f.spanName
}
func (f *filter) Response(filters.FilterContext) {}
package tracing
import (
"fmt"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
)
const (
// Deprecated, use filters.StateBagToTagName instead
StateBagToTagFilterName = filters.StateBagToTagName
)
type stateBagToTagSpec struct{}
type stateBagToTagFilter struct {
stateBagItemName string
tagName string
}
func (stateBagToTagSpec) Name() string {
return filters.StateBagToTagName
}
func (stateBagToTagSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
if len(args) < 1 || len(args) > 2 {
return nil, filters.ErrInvalidFilterParameters
}
stateBagItemName, ok := args[0].(string)
if !ok || stateBagItemName == "" {
return nil, filters.ErrInvalidFilterParameters
}
tagName := stateBagItemName
if len(args) > 1 {
tagNameArg, ok := args[1].(string)
if !ok || tagNameArg == "" {
return nil, filters.ErrInvalidFilterParameters
}
tagName = tagNameArg
}
return &stateBagToTagFilter{
stateBagItemName: stateBagItemName,
tagName: tagName,
}, nil
}
func NewStateBagToTag() filters.Spec {
return stateBagToTagSpec{}
}
func (f *stateBagToTagFilter) Request(ctx filters.FilterContext) {
value, ok := ctx.StateBag()[f.stateBagItemName]
if !ok {
return
}
span := opentracing.SpanFromContext(ctx.Request().Context())
if span == nil {
return
}
if _, ok := value.(string); ok {
span.SetTag(f.tagName, value)
} else {
span.SetTag(f.tagName, fmt.Sprint(value))
}
}
func (*stateBagToTagFilter) Response(ctx filters.FilterContext) {}
package tracing
import (
"net/http"
opentracing "github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
)
type tagSpec struct {
typ string
}
type tagFilterType int
const (
tagRequest tagFilterType = iota + 1
tagResponse
tagResponseCondition
)
type tagFilter struct {
typ tagFilterType
tagName string
tagValue *eskip.Template
condition func(*http.Response) bool
}
// NewTag creates a filter specification for the tracingTag filter.
func NewTag() filters.Spec {
return &tagSpec{typ: filters.TracingTagName}
}
// NewTagFromResponse creates a filter similar to NewTag, but applies tags after the request has been processed.
func NewTagFromResponse() filters.Spec {
return &tagSpec{typ: filters.TracingTagFromResponseName}
}
func NewTagFromResponseIfStatus() filters.Spec {
return &tagSpec{typ: filters.TracingTagFromResponseIfStatusName}
}
func (s *tagSpec) Name() string {
return s.typ
}
func (s *tagSpec) CreateFilter(args []interface{}) (filters.Filter, error) {
var typ tagFilterType
switch s.typ {
case filters.TracingTagName:
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
typ = tagRequest
case filters.TracingTagFromResponseName:
if len(args) != 2 {
return nil, filters.ErrInvalidFilterParameters
}
typ = tagResponse
case filters.TracingTagFromResponseIfStatusName:
if len(args) != 4 {
return nil, filters.ErrInvalidFilterParameters
}
typ = tagResponseCondition
default:
return nil, filters.ErrInvalidFilterParameters
}
tagName, ok := args[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
tagValue, ok := args[1].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
f := &tagFilter{
typ: typ,
tagName: tagName,
tagValue: eskip.NewTemplate(tagValue),
}
if len(args) == 4 {
minValue, ok := args[2].(float64)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
maxValue, ok := args[3].(float64)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
minVal := int(minValue)
maxVal := int(maxValue)
if minVal < 0 || maxVal > 599 || minVal > maxVal {
return nil, filters.ErrInvalidFilterParameters
}
f.condition = func(rsp *http.Response) bool {
return minVal <= rsp.StatusCode && rsp.StatusCode <= maxVal
}
}
return f, nil
}
func (f *tagFilter) Request(ctx filters.FilterContext) {
if f.typ == tagRequest {
f.setTag(ctx)
}
}
func (f *tagFilter) Response(ctx filters.FilterContext) {
switch f.typ {
case tagResponse:
f.setTag(ctx)
case tagResponseCondition:
if f.condition(ctx.Response()) {
f.setTag(ctx)
}
}
}
func (f *tagFilter) setTag(ctx filters.FilterContext) {
span := opentracing.SpanFromContext(ctx.Request().Context())
if span == nil {
return
}
if v, ok := f.tagValue.ApplyContext(ctx); ok {
span.SetTag(f.tagName, v)
}
}
package xforward
import (
"github.com/zalando/skipper/filters"
snet "github.com/zalando/skipper/net"
)
const (
// Deprecated, use filters.XforwardName instead
Name = filters.XforwardName
// Deprecated, use filters.XforwardFirstName instead
NameFirst = filters.XforwardFirstName
)
type filter struct {
headers *snet.ForwardedHeaders
}
// New creates a specification for the xforward filter
// that appends the remote IP of the incoming request to the
// X-Forwarded-For header, and sets the X-Forwarded-Host header
// to the value of the incoming request's Host header.
func New() filters.Spec {
return filter{headers: &snet.ForwardedHeaders{For: true, Host: true}}
}
// NewFirst creates a specification for the xforwardFirst filter
// that prepends the remote IP of the incoming request to the
// X-Forwarded-For header, and sets the X-Forwarded-Host header
// to the value of the incoming request's Host header.
func NewFirst() filters.Spec {
return filter{headers: &snet.ForwardedHeaders{PrependFor: true, Host: true}}
}
func (f filter) Name() string {
if f.headers.PrependFor {
return filters.XforwardFirstName
}
return filters.XforwardName
}
func (f filter) CreateFilter([]interface{}) (filters.Filter, error) {
return filter(f), nil
}
func (f filter) Request(ctx filters.FilterContext) {
req := ctx.OriginalRequest()
if req == nil {
req = ctx.Request()
}
f.headers.Set(req)
}
func (filter) Response(filters.FilterContext) {}
//go:build gofuzz
// +build gofuzz
package fuzz
import "github.com/zalando/skipper/net"
func FuzzParseCIDRs(data []byte) int {
if _, err := net.ParseCIDRs([]string{string(data)}); err != nil {
return 0
}
return 1
}
//go:build gofuzz
// +build gofuzz
package fuzz
import "github.com/zalando/skipper/eskip"
func FuzzParseEskip(data []byte) int {
if _, err := eskip.Parse(string(data)); err != nil {
return 0
}
return 1
}
//go:build gofuzz
// +build gofuzz
package fuzz
import "github.com/zalando/skipper/eskip"
func FuzzParseFilters(data []byte) int {
if _, err := eskip.ParseFilters(string(data)); err != nil {
return 0
}
return 1
}
//go:build gofuzz
// +build gofuzz
package fuzz
import "github.com/zalando/skipper/net"
func FuzzParseIPCIDRs(data []byte) int {
if _, err := net.ParseIPCIDRs([]string{string(data)}); err != nil {
return 0
}
return 1
}
//go:build gofuzz
// +build gofuzz
package fuzz
import (
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
)
func FuzzParseIngressV1JSON(data []byte) int {
if _, err := definitions.ParseIngressV1JSON(data); err != nil {
return 0
}
return 1
}
//go:build gofuzz
// +build gofuzz
package fuzz
import "github.com/zalando/skipper/jwt"
func FuzzParseJwt(data []byte) int {
if _, err := jwt.Parse(string(data)); err != nil {
return 0
}
return 1
}
//go:build gofuzz
// +build gofuzz
package fuzz
import "github.com/zalando/skipper/eskip"
func FuzzParsePredicates(data []byte) int {
if _, err := eskip.ParsePredicates(string(data)); err != nil {
return 0
}
return 1
}
//go:build gofuzz
// +build gofuzz
package fuzz
import (
"github.com/zalando/skipper/dataclients/kubernetes/definitions"
)
func FuzzParseRouteGroupsJSON(data []byte) int {
if _, err := definitions.ParseRouteGroupsJSON(data); err != nil {
return 0
}
return 1
}
//go:build gofuzz
// +build gofuzz
package fuzz
import (
"errors"
"log"
"net"
"os"
"time"
"github.com/sirupsen/logrus"
"github.com/zalando/skipper"
"github.com/zalando/skipper/config"
)
var (
address = ""
initialized = false
)
func find_address() (string, error) {
l, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1")})
if err != nil {
return "", err
}
defer l.Close()
return l.Addr().String(), nil
}
func connect(host string) (net.Conn, error) {
for i := 0; i < 15; i++ {
conn, err := net.Dial("tcp", host)
if err != nil {
time.Sleep(10 * time.Millisecond)
continue
}
return conn, err
}
return nil, errors.New("unable to connect")
}
func run_server() {
addr, err := find_address()
if err != nil {
log.Printf("failed to find address: %v\n", err)
os.Exit(-1)
}
cfg := config.NewConfig()
cfg.InlineRoutes = `r: * -> status(200) -> inlineContent("ok") -> <shunt>`
cfg.ApplicationLogLevel = logrus.PanicLevel
cfg.AccessLogDisabled = true
cfg.ApplicationLog = "/dev/null"
cfg.Address = addr
cfg.SupportListener = "127.0.0.1:0"
go func() {
log.Fatal(skipper.Run(cfg.ToOptions()))
}()
address = cfg.Address
}
func FuzzServer(data []byte) int {
if !initialized {
run_server()
initialized = true
}
conn, err := connect(address)
if err != nil {
log.Printf("failed to dial: %v\n", err)
return -1
}
conn.Write(data)
conn.Close()
return 1
}
package io
import (
"bytes"
"context"
"errors"
"io"
"sync"
)
var (
ErrClosed = errors.New("reader closed")
ErrBlocked = errors.New("blocked string match found in stream")
)
const (
defaultReadBufferSize uint64 = 8192
)
type MaxBufferHandling int
const (
MaxBufferBestEffort MaxBufferHandling = iota
MaxBufferAbort
)
type matcher struct {
ctx context.Context
once sync.Once
input io.ReadCloser
f func([]byte) (int, error)
maxBufferSize uint64
maxBufferHandling MaxBufferHandling
readBuffer []byte
ready *bytes.Buffer
pending *bytes.Buffer
err error
closed bool
}
var (
ErrMatcherBufferFull = errors.New("matcher buffer full")
)
func newMatcher(
ctx context.Context,
input io.ReadCloser,
f func([]byte) (int, error),
maxBufferSize uint64,
mbh MaxBufferHandling,
) *matcher {
rsize := defaultReadBufferSize
if maxBufferSize < rsize {
rsize = maxBufferSize
}
return &matcher{
ctx: ctx,
once: sync.Once{},
input: input,
f: f,
maxBufferSize: maxBufferSize,
maxBufferHandling: mbh,
readBuffer: make([]byte, rsize),
pending: bytes.NewBuffer(nil),
ready: bytes.NewBuffer(nil),
}
}
func (m *matcher) readNTimes(times int) (bool, error) {
var consumedInput bool
for i := 0; i < times; i++ {
n, err := m.input.Read(m.readBuffer)
_, err2 := m.pending.Write(m.readBuffer[:n])
if n > 0 {
consumedInput = true
}
if err != nil {
return consumedInput, err
}
if err2 != nil {
return consumedInput, err2
}
}
return consumedInput, nil
}
func (m *matcher) fill(requested int) error {
readSize := 1
for m.ready.Len() < requested {
consumedInput, err := m.readNTimes(readSize)
if !consumedInput {
io.CopyBuffer(m.ready, m.pending, m.readBuffer)
return err
}
if uint64(m.pending.Len()) > m.maxBufferSize {
switch m.maxBufferHandling {
case MaxBufferAbort:
return ErrMatcherBufferFull
default:
select {
case <-m.ctx.Done():
m.Close()
return m.ctx.Err()
default:
}
_, err := m.f(m.pending.Bytes())
if err != nil {
return err
}
m.pending.Reset()
readSize = 1
}
}
readSize *= 2
}
return nil
}
func (m *matcher) Read(p []byte) (int, error) {
if m.closed {
return 0, ErrClosed
}
if m.ready.Len() == 0 && m.err != nil {
return 0, m.err
}
if m.ready.Len() < len(p) {
m.err = m.fill(len(p))
}
switch m.err {
case ErrMatcherBufferFull, ErrBlocked:
return 0, m.err
}
n, _ := m.ready.Read(p)
if n == 0 && len(p) > 0 && m.err != nil {
return 0, m.err
}
p = p[:n]
select {
case <-m.ctx.Done():
m.Close()
return 0, m.ctx.Err()
default:
}
n, err := m.f(p)
if err != nil {
m.closed = true
return 0, err
}
return n, nil
}
// Close closes the underlying reader if it implements io.Closer.
func (m *matcher) Close() error {
var err error
m.once.Do(func() {
m.closed = true
if c, ok := m.input.(io.Closer); ok {
err = c.Close()
}
})
return err
}
/*
Wants:
- [x] filters can read the body content for example WAF scoring
- [ ] filters can change the body content for example sedRequest()
- [x] filters need to be chainable (support -> )
- [x] filters need to be able to stop streaming to request blockContent() or WAF deny()
TODO(sszuecs):
1) major optimization: use registry pattern and have only one body
wrapped for concatenating readers and run all f() in a loop, so
streaming does not happen for all but once for all
readers. Important if one write is between two readers we can not
do this, so we need to detect this case.
3) in case we ErrBlock, then we break the loop or cancel the
context to stop processing. The registry control layer should be
able to stop all processing.
*/
type BufferOptions struct {
MaxBufferHandling MaxBufferHandling
ReadBufferSize uint64
}
// InspectReader wraps the given ReadCloser such that the given
// function f can inspect the streaming while streaming to the
// target. A target can be any io.ReadCloser, so for example the
// request body to the backend or the response body to the
// client. InspectReader applies given BufferOptions to the matcher.
//
// NOTE: This function is *experimental* and will likely change or disappear in the future.
func InspectReader(ctx context.Context, bo BufferOptions, f func([]byte) (int, error), rc io.ReadCloser) io.ReadCloser {
if bo.ReadBufferSize < 1 {
bo.ReadBufferSize = defaultReadBufferSize
}
return newMatcher(ctx, rc, f, bo.ReadBufferSize, bo.MaxBufferHandling)
}
package jwt
import (
"encoding/base64"
"encoding/json"
"errors"
"strings"
)
var (
errInvalidToken = errors.New("invalid jwt token")
)
type Token struct {
Claims map[string]interface{}
}
func Parse(value string) (*Token, error) {
parts := strings.SplitN(value, ".", 4)
if len(parts) != 3 {
return nil, errInvalidToken
}
var token Token
err := unmarshalBase64JSON(parts[1], &token.Claims)
if err != nil {
return nil, errInvalidToken
}
return &token, nil
}
func unmarshalBase64JSON(s string, v interface{}) error {
d, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
return err
}
return json.Unmarshal(d, v)
}
package loadbalancer
import (
"errors"
"fmt"
"math/rand"
"sort"
"sync"
"sync/atomic"
"github.com/cespare/xxhash/v2"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
snet "github.com/zalando/skipper/net"
"github.com/zalando/skipper/routing"
)
// Algorithm indicates the used load balancing algorithm.
type Algorithm int
const (
// None is the default non-specified algorithm.
None Algorithm = iota
// RoundRobin indicates round-robin load balancing between the backend endpoints.
RoundRobin
// Random indicates random choice between the backend endpoints.
Random
// ConsistentHash indicates choice between the backends based on their hashed address.
ConsistentHash
// PowerOfRandomNChoices selects N random endpoints and picks the one with least outstanding requests from them.
PowerOfRandomNChoices
)
const powerOfRandomNChoicesDefaultN = 2
const (
ConsistentHashKey = "consistentHashKey"
ConsistentHashBalanceFactor = "consistentHashBalanceFactor"
)
var (
algorithms = map[Algorithm]initializeAlgorithm{
RoundRobin: newRoundRobin,
Random: newRandom,
ConsistentHash: newConsistentHash,
PowerOfRandomNChoices: newPowerOfRandomNChoices,
}
defaultAlgorithm = newRoundRobin
)
type roundRobin struct {
index int64
}
func newRoundRobin(endpoints []string) routing.LBAlgorithm {
rnd := rand.New(NewLockedSource()) // #nosec
return &roundRobin{
index: int64(rnd.Intn(len(endpoints))),
}
}
// Apply implements routing.LBAlgorithm with a roundrobin algorithm.
func (r *roundRobin) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.LBEndpoints) == 1 {
return ctx.LBEndpoints[0]
}
choice := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.LBEndpoints)))
return ctx.LBEndpoints[choice]
}
type random struct {
rnd *rand.Rand
}
func newRandom(endpoints []string) routing.LBAlgorithm {
// #nosec
return &random{
rnd: rand.New(NewLockedSource()),
}
}
// Apply implements routing.LBAlgorithm with a stateless random algorithm.
func (r *random) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.LBEndpoints) == 1 {
return ctx.LBEndpoints[0]
}
choice := r.rnd.Intn(len(ctx.LBEndpoints))
return ctx.LBEndpoints[choice]
}
type (
endpointHash struct {
index int // index of endpoint in endpoint list
hash uint64 // hash of endpoint
}
consistentHash struct {
hashRing []endpointHash // list of endpoints sorted by hash value
}
)
func (ch *consistentHash) Len() int { return len(ch.hashRing) }
func (ch *consistentHash) Less(i, j int) bool { return ch.hashRing[i].hash < ch.hashRing[j].hash }
func (ch *consistentHash) Swap(i, j int) {
ch.hashRing[i], ch.hashRing[j] = ch.hashRing[j], ch.hashRing[i]
}
func newConsistentHashInternal(endpoints []string, hashesPerEndpoint int) routing.LBAlgorithm {
ch := &consistentHash{
hashRing: make([]endpointHash, hashesPerEndpoint*len(endpoints)),
}
for i, ep := range endpoints {
endpointStartIndex := hashesPerEndpoint * i
for j := 0; j < hashesPerEndpoint; j++ {
ch.hashRing[endpointStartIndex+j] = endpointHash{i, hash(fmt.Sprintf("%s-%d", ep, j))}
}
}
sort.Sort(ch)
return ch
}
func newConsistentHash(endpoints []string) routing.LBAlgorithm {
return newConsistentHashInternal(endpoints, 100)
}
func hash(s string) uint64 {
return xxhash.Sum64String(s)
}
func skipEndpoint(c *routing.LBContext, index int) bool {
host := c.Route.LBEndpoints[index].Host
for i := range c.LBEndpoints {
if c.LBEndpoints[i].Host == host {
return false
}
}
return true
}
// Returns index in hash ring with the closest hash to key's hash
func (ch *consistentHash) searchRing(key string, ctx *routing.LBContext) int {
h := hash(key)
i := sort.Search(ch.Len(), func(i int) bool { return ch.hashRing[i].hash >= h })
if i == ch.Len() { // rollover
i = 0
}
for skipEndpoint(ctx, ch.hashRing[i].index) {
i = (i + 1) % ch.Len()
}
return i
}
// Returns index of endpoint with closest hash to key's hash
func (ch *consistentHash) search(key string, ctx *routing.LBContext) int {
ringIndex := ch.searchRing(key, ctx)
return ch.hashRing[ringIndex].index
}
func computeLoadAverage(ctx *routing.LBContext) float64 {
sum := 1.0 // add 1 to include the request that just arrived
endpoints := ctx.LBEndpoints
for _, v := range endpoints {
sum += float64(v.Metrics.InflightRequests())
}
return sum / float64(len(endpoints))
}
// Returns index of endpoint with closest hash to key's hash, which is also below the target load
// skipEndpoint function is used to skip endpoints we don't want, for example, fading endpoints
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext) int {
ringIndex := ch.searchRing(key, ctx)
averageLoad := computeLoadAverage(ctx)
targetLoad := averageLoad * balanceFactor
// Loop round ring, starting at endpoint with closest hash. Stop when we find one whose load is less than targetLoad.
for i := 0; i < ch.Len(); i++ {
endpointIndex := ch.hashRing[ringIndex].index
if skipEndpoint(ctx, endpointIndex) {
continue
}
load := ctx.Route.LBEndpoints[endpointIndex].Metrics.InflightRequests()
// We know there must be an endpoint whose load <= average load.
// Since targetLoad >= average load (balancerFactor >= 1), there must also be an endpoint with load <= targetLoad.
if float64(load) <= targetLoad {
break
}
ringIndex = (ringIndex + 1) % ch.Len()
}
return ch.hashRing[ringIndex].index
}
// Apply implements routing.LBAlgorithm with a consistent hash algorithm.
func (ch *consistentHash) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.LBEndpoints) == 1 {
return ctx.LBEndpoints[0]
}
// The index returned from this call is taken from hash ring which is built from data about
// all endpoints, including fading in, unhealthy, etc. ones. The index stored in hash ring is
// the index of the endpoint in the original list of endpoints.
choice := ch.chooseConsistentHashEndpoint(ctx)
return ctx.Route.LBEndpoints[choice]
}
func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext) int {
key, ok := ctx.Params[ConsistentHashKey].(string)
if !ok {
key = snet.RemoteHost(ctx.Request).String()
}
balanceFactor, ok := ctx.Params[ConsistentHashBalanceFactor].(float64)
var choice int
if !ok {
choice = ch.search(key, ctx)
} else {
choice = ch.boundedLoadSearch(key, balanceFactor, ctx)
}
return choice
}
type powerOfRandomNChoices struct {
mu sync.Mutex
rnd *rand.Rand
numberOfChoices int
}
// newPowerOfRandomNChoices selects N random backends and picks the one with less outstanding requests.
func newPowerOfRandomNChoices([]string) routing.LBAlgorithm {
rnd := rand.New(NewLockedSource()) // #nosec
return &powerOfRandomNChoices{
rnd: rnd,
numberOfChoices: powerOfRandomNChoicesDefaultN,
}
}
// Apply implements routing.LBAlgorithm with power of random N choices algorithm.
func (p *powerOfRandomNChoices) Apply(ctx *routing.LBContext) routing.LBEndpoint {
ne := len(ctx.LBEndpoints)
p.mu.Lock()
defer p.mu.Unlock()
best := ctx.LBEndpoints[p.rnd.Intn(ne)]
for i := 1; i < p.numberOfChoices; i++ {
ce := ctx.LBEndpoints[p.rnd.Intn(ne)]
if p.getScore(ce) > p.getScore(best) {
best = ce
}
}
return best
}
// getScore returns negative value of inflightrequests count.
func (p *powerOfRandomNChoices) getScore(e routing.LBEndpoint) int64 {
// endpoints with higher inflight request should have lower score
return -int64(e.Metrics.InflightRequests())
}
type (
algorithmProvider struct{}
initializeAlgorithm func(endpoints []string) routing.LBAlgorithm
)
// NewAlgorithmProvider creates a routing.PostProcessor used to initialize
// the algorithm of load balancing routes.
func NewAlgorithmProvider() routing.PostProcessor {
return &algorithmProvider{}
}
// AlgorithmFromString parses the string representation of the algorithm definition.
func AlgorithmFromString(a string) (Algorithm, error) {
switch a {
case "":
// This means that the user didn't explicitly specify which
// algorithm should be used, and we will use a default one.
return None, nil
case "roundRobin":
return RoundRobin, nil
case "random":
return Random, nil
case "consistentHash":
return ConsistentHash, nil
case "powerOfRandomNChoices":
return PowerOfRandomNChoices, nil
default:
return None, errors.New("unsupported algorithm")
}
}
// String returns the string representation of an algorithm definition.
func (a Algorithm) String() string {
switch a {
case RoundRobin:
return "roundRobin"
case Random:
return "random"
case ConsistentHash:
return "consistentHash"
case PowerOfRandomNChoices:
return "powerOfRandomNChoices"
default:
return ""
}
}
func parseEndpoints(r *routing.Route) error {
r.LBEndpoints = make([]routing.LBEndpoint, len(r.Route.LBEndpoints))
for i, e := range r.Route.LBEndpoints {
scheme, host, err := snet.SchemeHost(e)
if err != nil {
return err
}
r.LBEndpoints[i] = routing.LBEndpoint{
Scheme: scheme,
Host: host,
}
}
return nil
}
func setAlgorithm(r *routing.Route) error {
t, err := AlgorithmFromString(r.Route.LBAlgorithm)
if err != nil {
return err
}
initialize := defaultAlgorithm
if t != None {
initialize = algorithms[t]
}
r.LBAlgorithm = initialize(r.Route.LBEndpoints)
return nil
}
// Do implements routing.PostProcessor
func (p *algorithmProvider) Do(r []*routing.Route) []*routing.Route {
rr := make([]*routing.Route, 0, len(r))
for _, ri := range r {
if ri.Route.BackendType != eskip.LBBackend {
rr = append(rr, ri)
continue
}
if len(ri.Route.LBEndpoints) == 0 {
log.Errorf("failed to post-process LB route: %s, no endpoints defined", ri.Id)
continue
}
if err := parseEndpoints(ri); err != nil {
log.Errorf("failed to parse LB endpoints for route %s: %v", ri.Id, err)
continue
}
if err := setAlgorithm(ri); err != nil {
log.Errorf("failed to set LB algorithm implementation for route %s: %v", ri.Id, err)
continue
}
rr = append(rr, ri)
}
return rr
}
package loadbalancer
import (
"math/rand"
"sync"
"time"
)
type lockedSource struct {
mu sync.Mutex
r rand.Source
}
func NewLockedSource() *lockedSource {
return &lockedSource{r: rand.NewSource(time.Now().UnixNano())}
}
func (s *lockedSource) Int63() int64 {
s.mu.Lock()
defer s.mu.Unlock()
return s.r.Int63()
}
func (s *lockedSource) Seed(seed int64) {
s.mu.Lock()
defer s.mu.Unlock()
s.r.Seed(seed)
}
package logging
import (
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/sirupsen/logrus"
flowidFilter "github.com/zalando/skipper/filters/flowid"
logFilter "github.com/zalando/skipper/filters/log"
)
const (
dateFormat = "02/Jan/2006:15:04:05 -0700"
commonLogFormat = `%s - %s [%s] "%s %s %s" %d %d`
// format:
// remote_host - - [date] "method uri protocol" status response_size "referer" "user_agent"
combinedLogFormat = commonLogFormat + ` "%s" "%s"`
// We add the duration in ms, a requested host and a flow id and audit log
accessLogFormat = combinedLogFormat + " %d %s %s %s\n"
)
type accessLogFormatter struct {
format string
}
// Access log entry.
type AccessEntry struct {
// The client request.
Request *http.Request
// The status code of the response.
StatusCode int
// The size of the response in bytes.
ResponseSize int64
// The time spent processing request.
Duration time.Duration
// The time that the request was received.
RequestTime time.Time
// The id of the authenticated user
AuthUser string
}
// TODO: create individual instances from the access log and
// delegate the ownership from the package level to the user
// code.
var (
accessLog *logrus.Logger
stripQuery bool
)
// strip port from addresses with hostname, ipv4 or ipv6
func stripPort(address string) string {
if h, _, err := net.SplitHostPort(address); err == nil {
return h
}
return address
}
// The remote host of the client. When the 'X-Forwarded-For'
// header is set, then its value is used as is.
func remoteHost(r *http.Request) string {
ff := r.Header.Get("X-Forwarded-For")
if ff != "" {
return ff
}
return stripPort(r.RemoteAddr)
}
func omitWhitespace(h string) string {
if h != "" {
return h
}
return "-"
}
func (f *accessLogFormatter) Format(e *logrus.Entry) ([]byte, error) {
keys := []string{
"host", "auth-user", "timestamp", "method", "uri", "proto",
"status", "response-size", "referer", "user-agent",
"duration", "requested-host", "flow-id", "audit"}
values := make([]interface{}, len(keys))
for i, key := range keys {
if s, ok := e.Data[key].(string); ok {
values[i] = omitWhitespace(s)
} else {
values[i] = e.Data[key]
}
}
return []byte(fmt.Sprintf(f.format, values...)), nil
}
func stripQueryString(u string) string {
if i := strings.IndexRune(u, '?'); i < 0 {
return u
} else {
return u[:i]
}
}
// Logs an access event in Apache combined log format (with a minor customization with the duration).
// Additional allows to provide extra data that may be also logged, depending on the specific log format.
func LogAccess(entry *AccessEntry, additional map[string]interface{}) {
if accessLog == nil || entry == nil {
return
}
host := ""
method := ""
uri := ""
proto := ""
referer := ""
userAgent := ""
requestedHost := ""
flowId := ""
auditHeader := ""
ts := entry.RequestTime.Format(dateFormat)
status := entry.StatusCode
responseSize := entry.ResponseSize
duration := int64(entry.Duration / time.Millisecond)
authUser := entry.AuthUser
if entry.Request != nil {
host = remoteHost(entry.Request)
method = entry.Request.Method
proto = entry.Request.Proto
referer = entry.Request.Referer()
userAgent = entry.Request.UserAgent()
requestedHost = entry.Request.Host
flowId = entry.Request.Header.Get(flowidFilter.HeaderName)
uri = entry.Request.RequestURI
if stripQuery {
uri = stripQueryString(uri)
}
auditHeader = entry.Request.Header.Get(logFilter.UnverifiedAuditHeader)
}
logData := logrus.Fields{
"timestamp": ts,
"host": host,
"method": method,
"uri": uri,
"proto": proto,
"referer": referer,
"user-agent": userAgent,
"status": status,
"response-size": responseSize,
"requested-host": requestedHost,
"duration": duration,
"flow-id": flowId,
"audit": auditHeader,
"auth-user": authUser,
}
for k, v := range additional {
logData[k] = v
}
logEntry := accessLog.WithFields(logData)
if entry.Request != nil {
logEntry = logEntry.WithContext(entry.Request.Context())
}
logEntry.Infoln()
}
package logging
import (
"io"
"os"
"github.com/sirupsen/logrus"
)
type prefixFormatter struct {
prefix string
formatter logrus.Formatter
}
// Init options for logging.
type Options struct {
// Prefix for application log entries. Primarily used to be
// able to select between access log and application log
// entries.
ApplicationLogPrefix string
// Output for the application log entries, when nil,
// os.Stderr is used.
ApplicationLogOutput io.Writer
// When set, log in JSON format is used
ApplicationLogJSONEnabled bool
// ApplicationLogJsonFormatter, when set and JSON logging is enabled, is passed along to to the underlying
// Logrus logger for application logs. To enable structured logging, use ApplicationLogJSONEnabled.
ApplicationLogJsonFormatter *logrus.JSONFormatter
// Output for the access log entries, when nil, os.Stderr is
// used.
AccessLogOutput io.Writer
// When set, log in JSON format is used
AccessLogJSONEnabled bool
// AccessLogStripQuery, when set, causes the query strings stripped
// from the request URI in the access logs.
AccessLogStripQuery bool
// AccessLogJsonFormatter, when set and JSON logging is enabled, is passed along to to the underlying
// Logrus logger for access logs. To enable structured logging, use AccessLogJSONEnabled.
// Deprecated: use [AccessLogFormatter].
AccessLogJsonFormatter *logrus.JSONFormatter
// AccessLogFormatter, when set is passed along to the underlying Logrus logger for access logs.
AccessLogFormatter logrus.Formatter
}
func (f *prefixFormatter) Format(e *logrus.Entry) ([]byte, error) {
b, err := f.formatter.Format(e)
if err != nil {
return nil, err
}
return append([]byte(f.prefix), b...), nil
}
func initApplicationLog(o Options) {
if o.ApplicationLogJSONEnabled {
if o.ApplicationLogJsonFormatter != nil {
logrus.SetFormatter(o.ApplicationLogJsonFormatter)
} else {
logrus.SetFormatter(&logrus.JSONFormatter{})
}
} else if o.ApplicationLogPrefix != "" {
logrus.SetFormatter(&prefixFormatter{o.ApplicationLogPrefix, logrus.StandardLogger().Formatter})
}
if o.ApplicationLogOutput != nil {
logrus.SetOutput(o.ApplicationLogOutput)
}
}
func initAccessLog(o Options) {
l := logrus.New()
if o.AccessLogFormatter != nil {
l.Formatter = o.AccessLogFormatter
} else if o.AccessLogJSONEnabled {
if o.AccessLogJsonFormatter != nil {
l.Formatter = o.AccessLogJsonFormatter
} else {
l.Formatter = &logrus.JSONFormatter{TimestampFormat: dateFormat, DisableTimestamp: true}
}
} else {
l.Formatter = &accessLogFormatter{accessLogFormat}
}
l.Out = o.AccessLogOutput
l.Level = logrus.InfoLevel
accessLog = l
stripQuery = o.AccessLogStripQuery
}
// Initializes logging.
func Init(o Options) {
initApplicationLog(o)
if o.AccessLogOutput == nil {
o.AccessLogOutput = os.Stderr
}
initAccessLog(o)
}
package logging
import "github.com/sirupsen/logrus"
// DefaultLog provides a default implementation of the Logger interface.
type DefaultLog struct{}
// Logger instances provide custom logging.
type Logger interface {
// Log with level ERROR
Error(...interface{})
// Log formatted messages with level ERROR
Errorf(string, ...interface{})
// Log with level WARN
Warn(...interface{})
// Log formatted messages with level WARN
Warnf(string, ...interface{})
// Log with level INFO
Info(...interface{})
// Log formatted messages with level INFO
Infof(string, ...interface{})
// Log with level DEBUG
Debug(...interface{})
// Log formatted messages with level DEBUG
Debugf(string, ...interface{})
}
func (dl *DefaultLog) Error(a ...interface{}) { logrus.Error(a...) }
func (dl *DefaultLog) Errorf(f string, a ...interface{}) { logrus.Errorf(f, a...) }
func (dl *DefaultLog) Warn(a ...interface{}) { logrus.Warn(a...) }
func (dl *DefaultLog) Warnf(f string, a ...interface{}) { logrus.Warnf(f, a...) }
func (dl *DefaultLog) Info(a ...interface{}) { logrus.Info(a...) }
func (dl *DefaultLog) Infof(f string, a ...interface{}) { logrus.Infof(f, a...) }
func (dl *DefaultLog) Debug(a ...interface{}) { logrus.Debug(a...) }
func (dl *DefaultLog) Debugf(f string, a ...interface{}) { logrus.Debugf(f, a...) }
package logging
import (
"bufio"
"fmt"
"net"
"net/http"
)
type LoggingWriter struct {
writer http.ResponseWriter
bytes int64
code int
}
func NewLoggingWriter(writer http.ResponseWriter) *LoggingWriter {
return &LoggingWriter{writer: writer}
}
func (lw *LoggingWriter) Write(data []byte) (count int, err error) {
count, err = lw.writer.Write(data)
lw.bytes += int64(count)
return
}
func (lw *LoggingWriter) WriteHeader(code int) {
lw.writer.WriteHeader(code)
if code == 0 {
code = 200
}
lw.code = code
}
func (lw *LoggingWriter) Header() http.Header {
return lw.writer.Header()
}
func (lw *LoggingWriter) Flush() {
lw.writer.(http.Flusher).Flush()
}
func (lw *LoggingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hij, ok := lw.writer.(http.Hijacker)
if ok {
return hij.Hijack()
}
return nil, nil, fmt.Errorf("could not hijack connection")
}
func (lw *LoggingWriter) Unwrap() http.ResponseWriter {
return lw.writer
}
func (lw *LoggingWriter) GetBytes() int64 {
return lw.bytes
}
func (lw *LoggingWriter) GetCode() int {
return lw.code
}
package metrics
import (
"net/http"
"time"
)
type All struct {
prometheus *Prometheus
codaHale *CodaHale
prometheusHandler http.Handler
codaHaleHandler http.Handler
}
func NewAll(o Options) *All {
return &All{
prometheus: NewPrometheus(o),
codaHale: NewCodaHale(o),
}
}
func (a *All) MeasureSince(key string, start time.Time) {
a.prometheus.MeasureSince(key, start)
a.codaHale.MeasureSince(key, start)
}
func (a *All) IncCounter(key string) {
a.prometheus.IncCounter(key)
a.codaHale.IncCounter(key)
}
func (a *All) IncCounterBy(key string, value int64) {
a.prometheus.IncCounterBy(key, value)
a.codaHale.IncCounterBy(key, value)
}
func (a *All) IncFloatCounterBy(key string, value float64) {
a.prometheus.IncFloatCounterBy(key, value)
a.codaHale.IncFloatCounterBy(key, value)
}
func (a *All) UpdateGauge(key string, v float64) {
a.prometheus.UpdateGauge(key, v)
a.codaHale.UpdateGauge(key, v)
}
func (a *All) MeasureRouteLookup(start time.Time) {
a.prometheus.MeasureRouteLookup(start)
a.codaHale.MeasureRouteLookup(start)
}
func (a *All) MeasureFilterCreate(filterName string, start time.Time) {
a.prometheus.MeasureFilterCreate(filterName, start)
a.codaHale.MeasureFilterCreate(filterName, start)
}
func (a *All) MeasureFilterRequest(filterName string, start time.Time) {
a.prometheus.MeasureFilterRequest(filterName, start)
a.codaHale.MeasureFilterRequest(filterName, start)
}
func (a *All) MeasureAllFiltersRequest(routeId string, start time.Time) {
a.prometheus.MeasureAllFiltersRequest(routeId, start)
a.codaHale.MeasureAllFiltersRequest(routeId, start)
}
func (a *All) MeasureBackendRequestHeader(host string, size int) {
a.prometheus.MeasureBackendRequestHeader(host, size)
a.codaHale.MeasureBackendRequestHeader(host, size)
}
func (a *All) MeasureBackend(routeId string, start time.Time) {
a.prometheus.MeasureBackend(routeId, start)
a.codaHale.MeasureBackend(routeId, start)
}
func (a *All) MeasureBackendHost(routeBackendHost string, start time.Time) {
a.prometheus.MeasureBackendHost(routeBackendHost, start)
a.codaHale.MeasureBackendHost(routeBackendHost, start)
}
func (a *All) MeasureFilterResponse(filterName string, start time.Time) {
a.prometheus.MeasureFilterResponse(filterName, start)
a.codaHale.MeasureFilterResponse(filterName, start)
}
func (a *All) MeasureAllFiltersResponse(routeId string, start time.Time) {
a.prometheus.MeasureAllFiltersResponse(routeId, start)
a.codaHale.MeasureAllFiltersResponse(routeId, start)
}
func (a *All) MeasureResponse(code int, method string, routeId string, start time.Time) {
a.prometheus.MeasureResponse(code, method, routeId, start)
a.codaHale.MeasureResponse(code, method, routeId, start)
}
func (a *All) MeasureResponseSize(host string, size int64) {
a.prometheus.MeasureResponseSize(host, size)
a.codaHale.MeasureResponseSize(host, size)
}
func (a *All) MeasureProxy(requestDuration, responseDuration time.Duration) {
a.prometheus.MeasureProxy(requestDuration, responseDuration)
a.codaHale.MeasureProxy(requestDuration, responseDuration)
}
func (a *All) MeasureServe(routeId, host, method string, code int, start time.Time) {
a.prometheus.MeasureServe(routeId, host, method, code, start)
a.codaHale.MeasureServe(routeId, host, method, code, start)
}
func (a *All) IncRoutingFailures() {
a.prometheus.IncRoutingFailures()
a.codaHale.IncRoutingFailures()
}
func (a *All) IncErrorsBackend(routeId string) {
a.prometheus.IncErrorsBackend(routeId)
a.codaHale.IncErrorsBackend(routeId)
}
func (a *All) MeasureBackend5xx(t time.Time) {
a.prometheus.MeasureBackend5xx(t)
a.codaHale.MeasureBackend5xx(t)
}
func (a *All) IncErrorsStreaming(routeId string) {
a.prometheus.IncErrorsStreaming(routeId)
a.codaHale.IncErrorsStreaming(routeId)
}
func (a *All) UpdateInvalidRoute(reasonCounts map[string]int) {
a.prometheus.UpdateInvalidRoute(reasonCounts)
a.codaHale.UpdateInvalidRoute(reasonCounts)
}
func (a *All) Close() {
a.codaHale.Close()
a.prometheus.Close()
}
func (a *All) RegisterHandler(path string, handler *http.ServeMux) {
a.prometheusHandler = a.prometheus.getHandler()
a.codaHaleHandler = a.codaHale.getHandler(path)
handler.Handle(path, a.newHandler())
}
func (a *All) newHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Accept") == "application/codahale+json" {
a.codaHaleHandler.ServeHTTP(w, req)
} else {
a.prometheusHandler.ServeHTTP(w, req)
}
})
}
package metrics
import (
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"time"
"github.com/rcrowley/go-metrics"
)
const (
KeyRouteLookup = "routelookup"
KeyRouteFailure = "routefailure"
KeyFilterCreate = "filter.%s.create"
KeyFilterRequest = "filter.%s.request"
KeyFiltersRequest = "allfilters.request.%s"
KeyAllFiltersRequestCombined = "allfilters.combined.request"
KeyProxyBackend = "backend.%s"
KeyProxyBackendCombined = "all.backend"
KeyProxyBackendHost = "backendhost.%s"
KeyFilterResponse = "filter.%s.response"
KeyFiltersResponse = "allfilters.response.%s"
KeyAllFiltersResponseCombined = "allfilters.combined.response"
KeyResponse = "response.%d.%s.skipper.%s"
KeyResponseCombined = "all.response.%d.%s.skipper"
Key5xxsBackend = "all.backend.5xx"
KeyProxyTotal = "proxy.total"
KeyProxyRequest = "proxy.request"
KeyProxyResponse = "proxy.response"
KeyErrorsBackend = "errors.backend.%s"
KeyErrorsStreaming = "errors.streaming.%s"
KeyInvalidRoutes = "route.invalid.%s"
statsRefreshDuration = time.Duration(5 * time.Second)
defaultUniformReservoirSize = 1024
defaultExpDecayReservoirSize = 1028
defaultExpDecayAlpha = 0.015
)
// CodaHale is the CodaHale format backend, implements Metrics interface in DropWizard's CodaHale metrics format.
type CodaHale struct {
reg metrics.Registry
createTimer func() metrics.Timer
createCounter func() metrics.Counter
createGauge func() metrics.GaugeFloat64
options Options
handler http.Handler
quit chan struct{}
}
// NewCodaHale returns a new CodaHale backend of metrics.
func NewCodaHale(o Options) *CodaHale {
o = applyCompatibilityDefaults(o)
c := &CodaHale{}
c.quit = make(chan struct{})
c.reg = metrics.NewRegistry()
var createSample func() metrics.Sample
if o.UseExpDecaySample {
createSample = newExpDecaySample
} else {
createSample = newUniformSample
}
c.createTimer = func() metrics.Timer { return createTimer(createSample()) }
c.createCounter = metrics.NewCounter
c.createGauge = metrics.NewGaugeFloat64
c.options = o
if o.EnableDebugGcMetrics {
metrics.RegisterDebugGCStats(c.reg)
go c.collectStats(metrics.CaptureDebugGCStatsOnce)
}
if o.EnableRuntimeMetrics {
metrics.RegisterRuntimeMemStats(c.reg)
go c.collectStats(metrics.CaptureRuntimeMemStatsOnce)
}
return c
}
func NewVoid() *CodaHale {
c := &CodaHale{}
c.reg = metrics.NewRegistry()
c.createTimer = func() metrics.Timer { return metrics.NilTimer{} }
c.createCounter = func() metrics.Counter { return metrics.NilCounter{} }
c.createGauge = func() metrics.GaugeFloat64 { return metrics.NilGaugeFloat64{} }
return c
}
func (c *CodaHale) getTimer(key string) metrics.Timer {
return c.reg.GetOrRegister(key, c.createTimer).(metrics.Timer)
}
func (c *CodaHale) updateTimer(key string, d time.Duration) {
c.getTimer(key).Update(d)
}
func (c *CodaHale) MeasureSince(key string, start time.Time) {
c.measureSince(key, start)
}
func (c *CodaHale) getGauge(key string) metrics.GaugeFloat64 {
return c.reg.GetOrRegister(key, c.createGauge).(metrics.GaugeFloat64)
}
func (c *CodaHale) UpdateGauge(key string, v float64) {
c.getGauge(key).Update(v)
}
func (c *CodaHale) IncCounter(key string) {
c.incCounter(key, 1)
}
func (c *CodaHale) IncCounterBy(key string, value int64) {
c.incCounter(key, value)
}
func (c *CodaHale) IncFloatCounterBy(key string, value float64) {
// Dropped. CodaHale does not support float counter.
}
func (c *CodaHale) measureSince(key string, start time.Time) {
c.updateTimer(key, time.Since(start))
}
func (c *CodaHale) MeasureRouteLookup(start time.Time) {
c.measureSince(KeyRouteLookup, start)
}
func (c *CodaHale) MeasureFilterCreate(filterName string, start time.Time) {
c.measureSince(fmt.Sprintf(KeyFilterCreate, filterName), start)
}
func (c *CodaHale) MeasureFilterRequest(filterName string, start time.Time) {
c.measureSince(fmt.Sprintf(KeyFilterRequest, filterName), start)
}
func (c *CodaHale) MeasureAllFiltersRequest(routeId string, start time.Time) {
c.measureSince(KeyAllFiltersRequestCombined, start)
if c.options.EnableAllFiltersMetrics {
c.measureSince(fmt.Sprintf(KeyFiltersRequest, routeId), start)
}
}
func (c *CodaHale) MeasureBackendRequestHeader(host string, size int) {
// not implemented, see https://github.com/zalando/skipper/issues/3530
}
func (c *CodaHale) MeasureBackend(routeId string, start time.Time) {
c.measureSince(KeyProxyBackendCombined, start)
if c.options.EnableRouteBackendMetrics {
c.measureSince(fmt.Sprintf(KeyProxyBackend, routeId), start)
}
}
func (c *CodaHale) MeasureBackendHost(routeBackendHost string, start time.Time) {
if c.options.EnableBackendHostMetrics {
c.measureSince(fmt.Sprintf(KeyProxyBackendHost, hostForKey(routeBackendHost)), start)
}
}
func (c *CodaHale) MeasureFilterResponse(filterName string, start time.Time) {
c.measureSince(fmt.Sprintf(KeyFilterResponse, filterName), start)
}
func (c *CodaHale) MeasureAllFiltersResponse(routeId string, start time.Time) {
c.measureSince(KeyAllFiltersResponseCombined, start)
if c.options.EnableAllFiltersMetrics {
c.measureSince(fmt.Sprintf(KeyFiltersResponse, routeId), start)
}
}
func (c *CodaHale) MeasureResponse(code int, method string, routeId string, start time.Time) {
method = measuredMethod(method)
if c.options.EnableCombinedResponseMetrics {
c.measureSince(fmt.Sprintf(KeyResponseCombined, code, method), start)
}
if c.options.EnableRouteResponseMetrics {
c.measureSince(fmt.Sprintf(KeyResponse, code, method, routeId), start)
}
}
func (c *CodaHale) MeasureResponseSize(host string, size int64) {
// not implemented, see https://github.com/zalando/skipper/issues/3530
}
func (c *CodaHale) MeasureProxy(requestDuration, responseDuration time.Duration) {
skipperDuration := requestDuration + responseDuration
c.updateTimer(KeyProxyTotal, skipperDuration)
if c.options.EnableProxyRequestMetrics {
c.updateTimer(KeyProxyRequest, requestDuration)
}
if c.options.EnableProxyResponseMetrics {
c.updateTimer(KeyProxyResponse, responseDuration)
}
}
func (c *CodaHale) MeasureServe(routeId, host, method string, code int, start time.Time) {
if !(c.options.EnableServeRouteMetrics || c.options.EnableServeHostMetrics) {
return
}
var keyServeRoute, keyServeHost string
method = measuredMethod(method)
hfk := hostForKey(host)
switch {
case c.options.EnableServeMethodMetric && c.options.EnableServeStatusCodeMetric:
keyServeHost = fmt.Sprintf("servehost.%s.%s.%d", hfk, method, code)
keyServeRoute = fmt.Sprintf("serveroute.%s.%s.%d", routeId, method, code)
case c.options.EnableServeMethodMetric:
keyServeHost = fmt.Sprintf("servehost.%s.%s", hfk, method)
keyServeRoute = fmt.Sprintf("serveroute.%s.%s", routeId, method)
case c.options.EnableServeStatusCodeMetric:
keyServeHost = fmt.Sprintf("servehost.%s.%d", hfk, code)
keyServeRoute = fmt.Sprintf("serveroute.%s.%d", routeId, code)
default:
keyServeHost = fmt.Sprintf("servehost.%s", hfk)
keyServeRoute = fmt.Sprintf("serveroute.%s", routeId)
}
if c.options.EnableServeRouteMetrics {
c.measureSince(keyServeRoute, start)
}
if c.options.EnableServeHostMetrics {
c.measureSince(keyServeHost, start)
}
}
func (c *CodaHale) getCounter(key string) metrics.Counter {
return c.reg.GetOrRegister(key, c.createCounter).(metrics.Counter)
}
func (c *CodaHale) incCounter(key string, value int64) {
c.getCounter(key).Inc(value)
}
func (c *CodaHale) IncRoutingFailures() {
c.incCounter(KeyRouteFailure, 1)
}
func (c *CodaHale) IncErrorsBackend(routeId string) {
if c.options.EnableRouteBackendErrorsCounters {
c.incCounter(fmt.Sprintf(KeyErrorsBackend, routeId), 1)
}
}
func (c *CodaHale) MeasureBackend5xx(t time.Time) {
c.measureSince(Key5xxsBackend, t)
}
func (c *CodaHale) IncErrorsStreaming(routeId string) {
if c.options.EnableRouteStreamingErrorsCounters {
c.incCounter(fmt.Sprintf(KeyErrorsStreaming, routeId), 1)
}
}
func (c *CodaHale) UpdateInvalidRoute(reasonCounts map[string]int) {
for reason, count := range reasonCounts {
c.UpdateGauge(fmt.Sprintf(KeyInvalidRoutes, reason), float64(count))
}
}
func (c *CodaHale) Close() {
close(c.quit)
}
func (c *CodaHale) collectStats(capture func(r metrics.Registry)) {
ticker := time.NewTicker(statsRefreshDuration)
defer ticker.Stop()
for {
select {
case <-ticker.C:
capture(c.reg)
case <-c.quit:
return
}
}
}
func (c *CodaHale) RegisterHandler(path string, handler *http.ServeMux) {
h := c.getHandler(path)
handler.Handle(path, h)
}
func (c *CodaHale) CreateHandler(path string) http.Handler {
return &codaHaleMetricsHandler{path: path, registry: c.reg, options: c.options}
}
func (c *CodaHale) getHandler(path string) http.Handler {
if c.handler != nil {
return c.handler
}
c.handler = c.CreateHandler(path)
return c.handler
}
type codaHaleMetricsHandler struct {
path string
registry metrics.Registry
options Options
}
func (c *codaHaleMetricsHandler) sendMetrics(w http.ResponseWriter, p string) {
_, k := path.Split(p)
metrics := filterMetrics(c.registry, c.options.Prefix, k)
if len(metrics) > 0 {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(metrics)
} else {
http.NotFound(w, nil)
}
}
// This listener is only used to expose the metrics
func (c *codaHaleMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
p := r.URL.Path
c.sendMetrics(w, strings.TrimPrefix(p, c.path))
}
func filterMetrics(reg metrics.Registry, prefix, key string) skipperMetrics {
metrics := make(skipperMetrics)
canonicalKey := strings.TrimPrefix(key, prefix)
m := reg.Get(canonicalKey)
if m != nil {
metrics[key] = m
} else {
reg.Each(func(name string, i interface{}) {
if key == "" || (strings.HasPrefix(name, canonicalKey)) {
metrics[prefix+name] = i
}
})
}
return metrics
}
type skipperMetrics map[string]interface{}
// This listener is used to expose the collected metrics.
func (sm skipperMetrics) MarshalJSON() ([]byte, error) {
data := make(map[string]map[string]interface{})
for name, metric := range sm {
values := make(map[string]interface{})
var metricsFamily string
switch m := metric.(type) {
case metrics.Gauge:
metricsFamily = "gauges"
values["value"] = m.Value()
case metrics.GaugeFloat64:
t := m.Snapshot()
metricsFamily = "gauges"
values["value"] = t.Value()
case metrics.Histogram:
metricsFamily = "histograms"
h := m.Snapshot()
ps := h.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999})
values["count"] = h.Count()
values["min"] = h.Min()
values["max"] = h.Max()
values["mean"] = h.Mean()
values["stddev"] = h.StdDev()
values["median"] = ps[0]
values["75%"] = ps[1]
values["95%"] = ps[2]
values["99%"] = ps[3]
values["99.9%"] = ps[4]
case metrics.Timer:
metricsFamily = "timers"
t := m.Snapshot()
ps := t.Percentiles([]float64{0.5, 0.75, 0.95, 0.99, 0.999})
values["count"] = t.Count()
values["min"] = t.Min()
values["max"] = t.Max()
values["mean"] = t.Mean()
values["stddev"] = t.StdDev()
values["median"] = ps[0]
values["75%"] = ps[1]
values["95%"] = ps[2]
values["99%"] = ps[3]
values["99.9%"] = ps[4]
values["1m.rate"] = t.Rate1()
values["5m.rate"] = t.Rate5()
values["15m.rate"] = t.Rate15()
values["mean.rate"] = t.RateMean()
case metrics.Counter:
metricsFamily = "counters"
t := m.Snapshot()
values["count"] = t.Count()
default:
metricsFamily = "unknown"
values["error"] = fmt.Sprintf("unknown metrics type %T", m)
}
if data[metricsFamily] == nil {
data[metricsFamily] = make(map[string]interface{})
}
data[metricsFamily][name] = values
}
return json.Marshal(data)
}
package metrics
import (
"net/http"
"net/http/pprof"
"runtime"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
)
const (
defaultMetricsPath = "/metrics"
)
// Kind is the type a metrics expose backend can be.
type Kind int
const (
UnkownKind Kind = 0
CodaHaleKind Kind = 1 << iota
PrometheusKind
AllKind = CodaHaleKind | PrometheusKind
)
func (k Kind) String() string {
switch k {
case CodaHaleKind:
return "codahale"
case PrometheusKind:
return "prometheus"
case AllKind:
return "all"
default:
return "unknown"
}
}
// ParseMetricsKind parses an string and returns the correct Metrics kind.
func ParseMetricsKind(t string) Kind {
t = strings.ToLower(t)
switch t {
case "codahale":
return CodaHaleKind
case "prometheus":
return PrometheusKind
case "all":
return AllKind
default:
return UnkownKind
}
}
// Metrics is the generic interface that all the required backends
// should implement to be an skipper metrics compatible backend.
type Metrics interface {
// Implements the `filter.Metrics` interface.
MeasureSince(key string, start time.Time)
IncCounter(key string)
IncCounterBy(key string, value int64)
IncFloatCounterBy(key string, value float64)
// Additional methods
MeasureRouteLookup(start time.Time)
MeasureFilterCreate(filterName string, start time.Time)
MeasureFilterRequest(filterName string, start time.Time)
MeasureAllFiltersRequest(routeId string, start time.Time)
MeasureBackendRequestHeader(host string, size int)
MeasureBackend(routeId string, start time.Time)
MeasureBackendHost(routeBackendHost string, start time.Time)
MeasureFilterResponse(filterName string, start time.Time)
MeasureAllFiltersResponse(routeId string, start time.Time)
MeasureResponse(code int, method string, routeId string, start time.Time)
MeasureResponseSize(host string, size int64)
MeasureProxy(requestDuration, responseDuration time.Duration)
MeasureServe(routeId, host, method string, code int, start time.Time)
IncRoutingFailures()
IncErrorsBackend(routeId string)
MeasureBackend5xx(t time.Time)
IncErrorsStreaming(routeId string)
RegisterHandler(path string, handler *http.ServeMux)
UpdateGauge(key string, value float64)
UpdateInvalidRoute(reasonCounts map[string]int)
Close()
}
// Options for initializing metrics collection.
type Options struct {
// the metrics exposing format.
Format Kind
// Common prefix for the keys of the different
// collected metrics.
Prefix string
// If set, garbage collector metrics are collected
// in addition to the http traffic metrics.
EnableDebugGcMetrics bool
// If set, Go runtime metrics are collected in
// addition to the http traffic metrics.
EnableRuntimeMetrics bool
// If set, detailed total response time metrics will be collected
// for each route, additionally grouped by status and method.
EnableServeRouteMetrics bool
// If set, a counter for each route is generated, additionally
// grouped by status and method. It differs from the automatically
// generated counter from `EnableServeRouteMetrics` because it will
// always contain the status and method labels, independently of the
// `EnableServeMethodMetric` and `EnableServeStatusCodeMetric`
// flags.
EnableServeRouteCounter bool
// If set, detailed total response time metrics will be collected
// for each host, additionally grouped by status and method.
EnableServeHostMetrics bool
// If set, a counter for each host is generated, additionally
// grouped by status and method. It differs from the automatically
// generated counter from `EnableServeHostMetrics` because it will
// always contain the status and method labels, independently of the
// `EnableServeMethodMetric` and `EnableServeStatusCodeMetric` flags.
EnableServeHostCounter bool
// If set, the detailed total response time metrics will contain the
// HTTP method as a domain of the metric. It affects both route and
// host split metrics.
EnableServeMethodMetric bool
// If set, the detailed total response time metrics will contain the
// HTTP Response status code as a domain of the metric. It affects
// both route and host split metrics.
EnableServeStatusCodeMetric bool
// If set, the total request handling time taken by skipper will be
// collected. It measures the duration taken by skipper to process
// the request, from the start excluding the filters processing and
// until the backend round trip is started.
EnableProxyRequestMetrics bool
// If set, the total response handling time take by skipper will be
// collected. It measures the duration taken by skipper to process the
// response, from after the backend round trip is finished, excluding
// the filters processing and until the before the response is served.
EnableProxyResponseMetrics bool
// If set, detailed response time metrics will be collected
// for each backend host
EnableBackendHostMetrics bool
// EnableAllFiltersMetrics enables collecting combined filter
// metrics per each route. Without the DisableCompatibilityDefaults,
// it is enabled by default.
EnableAllFiltersMetrics bool
// EnableCombinedResponseMetrics enables collecting response time
// metrics combined for every route.
EnableCombinedResponseMetrics bool
// EnableRouteResponseMetrics enables collecting response time
// metrics per each route. Without the DisableCompatibilityDefaults,
// it is enabled by default.
EnableRouteResponseMetrics bool
// EnableRouteBackendErrorsCounters enables counters for backend
// errors per each route. Without the DisableCompatibilityDefaults,
// it is enabled by default.
EnableRouteBackendErrorsCounters bool
// EnableRouteStreamingErrorsCounters enables counters for streaming
// errors per each route. Without the DisableCompatibilityDefaults,
// it is enabled by default.
EnableRouteStreamingErrorsCounters bool
// EnableRouteBackendMetrics enables backend response time metrics
// per each route. Without the DisableCompatibilityDefaults, it is
// enabled by default.
EnableRouteBackendMetrics bool
// UseExpDecaySample, when set, makes the histograms use an exponentially
// decaying sample instead of the default uniform one.
UseExpDecaySample bool
// HistogramBuckets defines buckets into which the observations are counted for
// histogram metrics.
HistogramBuckets []float64
// The following options, for backwards compatibility, are true
// by default: EnableAllFiltersMetrics, EnableRouteResponseMetrics,
// EnableRouteBackendErrorsCounters, EnableRouteStreamingErrorsCounters,
// EnableRouteBackendMetrics. With this compatibility flag, the default
// for these options can be set to false.
DisableCompatibilityDefaults bool
// EnableProfile exposes profiling information on /pprof of the
// metrics listener.
EnableProfile bool
// BlockProfileRate calls runtime.SetBlockProfileRate(BlockProfileRate) if != 0 (<0 will disable) and profiling is enabled
BlockProfileRate int
// MutexProfileFraction calls runtime.SetMutexProfileFraction(MutexProfileFraction) if != 0 (<0 will disable) and profiling is enabled
MutexProfileFraction int
// MemProfileRate calls runtime.SetMemProfileRate(MemProfileRate) if != 0 (<0 will disable) and profiling is enabled
MemProfileRate int
// An instance of a Prometheus registry. It allows registering and serving custom metrics when skipper is used as a
// library.
// A new registry is created if this option is nil.
PrometheusRegistry *prometheus.Registry
// EnablePrometheusStartLabel adds start label to each prometheus counter with the value of counter creation
// timestamp as unix nanoseconds.
EnablePrometheusStartLabel bool
}
var (
Default Metrics
Void Metrics
)
func init() {
Void = NewVoid()
Default = Void
}
// NewDefaultHandler returns a default metrics handler.
func NewDefaultHandler(o Options) http.Handler {
m := NewMetrics(o)
return NewHandler(o, m)
}
// NewMetrics creates a metrics collector instance based on the Format option.
func NewMetrics(o Options) Metrics {
var m Metrics
switch o.Format {
case AllKind:
m = NewAll(o)
case PrometheusKind:
m = NewPrometheus(o)
default:
// CodaHale is the default metrics implementation.
m = NewCodaHale(o)
}
return m
}
// NewHandler returns a collection of metrics handlers.
func NewHandler(o Options, m Metrics) http.Handler {
mux := http.NewServeMux()
if o.EnableProfile {
mux.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index))
mux.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline))
mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile))
mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol))
mux.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace))
switch n := o.BlockProfileRate; {
case n > 0:
runtime.SetBlockProfileRate(o.BlockProfileRate)
case n < 0:
runtime.SetBlockProfileRate(0)
default:
// 0 keeps default
}
switch n := o.MutexProfileFraction; {
case n > 0:
runtime.SetMutexProfileFraction(o.MutexProfileFraction)
case n < 0:
runtime.SetMutexProfileFraction(0)
default:
// 0 keeps default
}
switch n := o.MemProfileRate; {
case n > 0:
runtime.MemProfileRate = o.MemProfileRate
case n < 0:
runtime.MemProfileRate = 0
default:
// 0 keeps default
}
}
// Root path should return 404.
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
Default = m
m.RegisterHandler(defaultMetricsPath, mux)
m.RegisterHandler(defaultMetricsPath+"/", mux)
return mux
}
package metrics
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
dto "github.com/prometheus/client_model/go"
"google.golang.org/protobuf/proto"
)
const (
promNamespace = "skipper"
promRouteSubsystem = "route"
promFilterSubsystem = "filter"
promBackendSubsystem = "backend"
promStreamingSubsystem = "streaming"
promProxySubsystem = "proxy"
promResponseSubsystem = "response"
promServeSubsystem = "serve"
promCustomSubsystem = "custom"
)
const (
KiB = 1024
MiB = 1024 * KiB
GiB = 1024 * MiB
)
// headerSizeBuckets are chosen to cover typical max request header sizes:
// - 64 KiB for [AWS ELB](https://docs.aws.amazon.com/elasticloadbalancing/latest/userguide/how-elastic-load-balancing-works.html#http-header-limits)
// - 16 KiB for [NodeJS](https://nodejs.org/api/cli.html#cli_max_http_header_size_size)
// - 8 KiB for [Nginx](https://nginx.org/en/docs/http/ngx_http_core_module.html#large_client_header_buffers)
// - 8 KiB for [Spring Boot](https://docs.spring.io/spring-boot/appendix/application-properties/index.html#application-properties.server.server.max-http-request-header-size)
var headerSizeBuckets = []float64{4 * KiB, 8 * KiB, 16 * KiB, 64 * KiB}
// responseSizeBuckets are chosen to cover 2^(10*n) sizes up to 1 GiB and halves of those.
var responseSizeBuckets = []float64{1, 512, 1 * KiB, 512 * KiB, 1 * MiB, 512 * MiB, 1 * GiB}
// Prometheus implements the prometheus metrics backend.
type Prometheus struct {
// Metrics.
routeLookupM *prometheus.HistogramVec
routeErrorsM *prometheus.CounterVec
responseM *prometheus.HistogramVec
responseSizeM *prometheus.HistogramVec
filterCreateM *prometheus.HistogramVec
filterRequestM *prometheus.HistogramVec
filterAllRequestM *prometheus.HistogramVec
filterAllCombinedRequestM *prometheus.HistogramVec
backendRequestHeadersM *prometheus.HistogramVec
backendM *prometheus.HistogramVec
backendCombinedM *prometheus.HistogramVec
filterResponseM *prometheus.HistogramVec
filterAllResponseM *prometheus.HistogramVec
filterAllCombinedResponseM *prometheus.HistogramVec
serveRouteM *prometheus.HistogramVec
serveRouteCounterM *prometheus.CounterVec
serveHostM *prometheus.HistogramVec
serveHostCounterM *prometheus.CounterVec
proxyTotalM *prometheus.HistogramVec
proxyRequestM *prometheus.HistogramVec
proxyResponseM *prometheus.HistogramVec
backend5xxM *prometheus.HistogramVec
backendErrorsM *prometheus.CounterVec
proxyStreamingErrorsM *prometheus.CounterVec
customHistogramM *prometheus.HistogramVec
customCounterM *prometheus.CounterVec
customGaugeM *prometheus.GaugeVec
invalidRouteM *prometheus.GaugeVec
opts Options
registry *prometheus.Registry
handler http.Handler
}
// NewPrometheus returns a new Prometheus metric backend.
func NewPrometheus(opts Options) *Prometheus {
opts = applyCompatibilityDefaults(opts)
p := &Prometheus{
registry: opts.PrometheusRegistry,
opts: opts,
}
if p.registry == nil {
p.registry = prometheus.NewRegistry()
}
namespace := promNamespace
if opts.Prefix != "" {
namespace = strings.TrimSuffix(opts.Prefix, ".")
}
p.routeLookupM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promRouteSubsystem,
Name: "lookup_duration_seconds",
Help: "Duration in seconds of a route lookup.",
Buckets: opts.HistogramBuckets,
}, []string{}))
p.routeErrorsM = register(p, prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: promRouteSubsystem,
Name: "error_total",
Help: "The total of route lookup errors.",
}, []string{}))
p.responseM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promResponseSubsystem,
Name: "duration_seconds",
Help: "Duration in seconds of a response.",
Buckets: opts.HistogramBuckets,
}, []string{"code", "method", "route"}))
p.responseSizeM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promResponseSubsystem,
Name: "size_bytes",
Help: "Size of response in bytes.",
Buckets: responseSizeBuckets,
}, []string{"host"}))
p.filterCreateM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promFilterSubsystem,
Name: "create_duration_seconds",
Help: "Duration in seconds of filter creation.",
Buckets: opts.HistogramBuckets,
}, []string{"filter"}))
p.filterRequestM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promFilterSubsystem,
Name: "request_duration_seconds",
Help: "Duration in seconds of a filter request.",
Buckets: opts.HistogramBuckets,
}, []string{"filter"}))
p.filterAllRequestM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promFilterSubsystem,
Name: "all_request_duration_seconds",
Help: "Duration in seconds of a filter request by all filters.",
Buckets: opts.HistogramBuckets,
}, []string{"route"}))
p.filterAllCombinedRequestM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promFilterSubsystem,
Name: "all_combined_request_duration_seconds",
Help: "Duration in seconds of a filter request combined by all filters.",
Buckets: opts.HistogramBuckets,
}, []string{}))
p.backendRequestHeadersM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promBackendSubsystem,
Name: "request_header_bytes",
Help: "Size of a backend request header.",
Buckets: headerSizeBuckets,
}, []string{"host"}))
p.backendM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promBackendSubsystem,
Name: "duration_seconds",
Help: "Duration in seconds of a proxy backend.",
Buckets: opts.HistogramBuckets,
}, []string{"route", "host"}))
p.backendCombinedM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promBackendSubsystem,
Name: "combined_duration_seconds",
Help: "Duration in seconds of a proxy backend combined.",
Buckets: opts.HistogramBuckets,
}, []string{}))
p.filterResponseM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promFilterSubsystem,
Name: "response_duration_seconds",
Help: "Duration in seconds of a filter request.",
Buckets: opts.HistogramBuckets,
}, []string{"filter"}))
p.filterAllResponseM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promFilterSubsystem,
Name: "all_response_duration_seconds",
Help: "Duration in seconds of a filter response by all filters.",
Buckets: opts.HistogramBuckets,
}, []string{"route"}))
p.filterAllCombinedResponseM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promFilterSubsystem,
Name: "all_combined_response_duration_seconds",
Help: "Duration in seconds of a filter response combined by all filters.",
Buckets: opts.HistogramBuckets,
}, []string{}))
metrics := []string{}
if opts.EnableServeStatusCodeMetric {
metrics = append(metrics, "code")
}
if opts.EnableServeMethodMetric {
metrics = append(metrics, "method")
}
p.serveRouteM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promServeSubsystem,
Name: "route_duration_seconds",
Help: "Duration in seconds of serving a route.",
Buckets: opts.HistogramBuckets,
}, append(metrics, "route")))
p.serveRouteCounterM = register(p, prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: promServeSubsystem,
Name: "route_count",
Help: "Total number of requests of serving a route.",
}, []string{"code", "method", "route"}))
p.serveHostM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promServeSubsystem,
Name: "host_duration_seconds",
Help: "Duration in seconds of serving a host.",
Buckets: opts.HistogramBuckets,
}, append(metrics, "host")))
p.serveHostCounterM = register(p, prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: promServeSubsystem,
Name: "host_count",
Help: "Total number of requests of serving a host.",
}, []string{"code", "method", "host"}))
p.proxyTotalM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promProxySubsystem,
Name: "total_duration_seconds",
Help: "Total duration in seconds of skipper latency.",
Buckets: opts.HistogramBuckets,
}, []string{}))
p.proxyRequestM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promProxySubsystem,
Name: "request_duration_seconds",
Help: "Duration in seconds of skipper latency for request.",
Buckets: opts.HistogramBuckets,
}, []string{}))
p.proxyResponseM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promProxySubsystem,
Name: "response_duration_seconds",
Help: "Duration in seconds of skipper latency for response.",
Buckets: opts.HistogramBuckets,
}, []string{}))
p.backend5xxM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promBackendSubsystem,
Name: "5xx_duration_seconds",
Help: "Duration in seconds of backend 5xx.",
Buckets: opts.HistogramBuckets,
}, []string{}))
p.backendErrorsM = register(p, prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: promBackendSubsystem,
Name: "error_total",
Help: "Total number of backend route errors.",
}, []string{"route"}))
p.proxyStreamingErrorsM = register(p, prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: promStreamingSubsystem,
Name: "error_total",
Help: "Total number of streaming route errors.",
}, []string{"route"}))
p.customCounterM = register(p, prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: promCustomSubsystem,
Name: "total",
Help: "Total number of custom metrics.",
}, []string{"key"}))
p.customGaugeM = register(p, prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: promCustomSubsystem,
Name: "gauges",
Help: "Gauges number of custom metrics.",
}, []string{"key"}))
p.customHistogramM = register(p, prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: promCustomSubsystem,
Name: "duration_seconds",
Help: "Duration in seconds of custom metrics.",
Buckets: opts.HistogramBuckets,
}, []string{"key"}))
p.invalidRouteM = register(p, prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: promRouteSubsystem,
Name: "invalid",
Help: "Number of invalid routes by reason.",
}, []string{"reason"}))
// Register prometheus runtime collectors if required.
if opts.EnableRuntimeMetrics {
register(p, collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
register(p, collectors.NewGoCollector())
}
return p
}
// sinceS returns the seconds passed since the start time until now.
func (p *Prometheus) sinceS(start time.Time) float64 {
return time.Since(start).Seconds()
}
func register[T prometheus.Collector](p *Prometheus, cs T) T {
p.registry.MustRegister(cs)
return cs
}
func (p *Prometheus) CreateHandler() http.Handler {
var gatherer prometheus.Gatherer = p.registry
if p.opts.EnablePrometheusStartLabel {
gatherer = withStartLabelGatherer{p.registry}
}
return promhttp.HandlerFor(gatherer, promhttp.HandlerOpts{})
}
func (p *Prometheus) getHandler() http.Handler {
if p.handler != nil {
return p.handler
}
p.handler = p.CreateHandler()
return p.handler
}
// RegisterHandler satisfies Metrics interface.
func (p *Prometheus) RegisterHandler(path string, mux *http.ServeMux) {
promHandler := p.getHandler()
mux.Handle(path, promHandler)
}
// MeasureSince satisfies Metrics interface.
func (p *Prometheus) MeasureSince(key string, start time.Time) {
t := p.sinceS(start)
p.customHistogramM.WithLabelValues(key).Observe(t)
}
// IncCounter satisfies Metrics interface.
func (p *Prometheus) IncCounter(key string) {
p.customCounterM.WithLabelValues(key).Inc()
}
// IncCounterBy satisfies Metrics interface.
func (p *Prometheus) IncCounterBy(key string, value int64) {
f := float64(value)
p.customCounterM.WithLabelValues(key).Add(f)
}
// IncFloatCounterBy satisfies Metrics interface.
func (p *Prometheus) IncFloatCounterBy(key string, value float64) {
p.customCounterM.WithLabelValues(key).Add(value)
}
// UpdateGauge satisfies Metrics interface.
func (p *Prometheus) UpdateGauge(key string, v float64) {
p.customGaugeM.WithLabelValues(key).Set(v)
}
// MeasureRouteLookup satisfies Metrics interface.
func (p *Prometheus) MeasureRouteLookup(start time.Time) {
t := p.sinceS(start)
p.routeLookupM.WithLabelValues().Observe(t)
}
func (p *Prometheus) MeasureFilterCreate(filterName string, start time.Time) {
t := p.sinceS(start)
p.filterCreateM.WithLabelValues(filterName).Observe(t)
}
// MeasureFilterRequest satisfies Metrics interface.
func (p *Prometheus) MeasureFilterRequest(filterName string, start time.Time) {
t := p.sinceS(start)
p.filterRequestM.WithLabelValues(filterName).Observe(t)
}
// MeasureAllFiltersRequest satisfies Metrics interface.
func (p *Prometheus) MeasureAllFiltersRequest(routeID string, start time.Time) {
t := p.sinceS(start)
p.filterAllCombinedRequestM.WithLabelValues().Observe(t)
if p.opts.EnableAllFiltersMetrics {
p.filterAllRequestM.WithLabelValues(routeID).Observe(t)
}
}
func (p *Prometheus) MeasureBackendRequestHeader(host string, size int) {
p.backendRequestHeadersM.WithLabelValues(hostForKey(host)).Observe(float64(size))
}
// MeasureBackend satisfies Metrics interface.
func (p *Prometheus) MeasureBackend(routeID string, start time.Time) {
t := p.sinceS(start)
p.backendCombinedM.WithLabelValues().Observe(t)
if p.opts.EnableRouteBackendMetrics {
p.backendM.WithLabelValues(routeID, "").Observe(t)
}
}
// MeasureBackendHost satisfies Metrics interface.
func (p *Prometheus) MeasureBackendHost(routeBackendHost string, start time.Time) {
t := p.sinceS(start)
if p.opts.EnableBackendHostMetrics {
p.backendM.WithLabelValues("", routeBackendHost).Observe(t)
}
}
// MeasureFilterResponse satisfies Metrics interface.
func (p *Prometheus) MeasureFilterResponse(filterName string, start time.Time) {
t := p.sinceS(start)
p.filterResponseM.WithLabelValues(filterName).Observe(t)
}
// MeasureAllFiltersResponse satisfies Metrics interface.
func (p *Prometheus) MeasureAllFiltersResponse(routeID string, start time.Time) {
t := p.sinceS(start)
p.filterAllCombinedResponseM.WithLabelValues().Observe(t)
if p.opts.EnableAllFiltersMetrics {
p.filterAllResponseM.WithLabelValues(routeID).Observe(t)
}
}
// MeasureResponse satisfies Metrics interface.
func (p *Prometheus) MeasureResponse(code int, method string, routeID string, start time.Time) {
method = measuredMethod(method)
t := p.sinceS(start)
if p.opts.EnableCombinedResponseMetrics {
p.responseM.WithLabelValues(fmt.Sprint(code), method, "").Observe(t)
}
if p.opts.EnableRouteResponseMetrics {
p.responseM.WithLabelValues(fmt.Sprint(code), method, routeID).Observe(t)
}
}
func (p *Prometheus) MeasureResponseSize(host string, size int64) {
p.responseSizeM.WithLabelValues(hostForKey(host)).Observe(float64(size))
}
func (p *Prometheus) MeasureProxy(requestDuration, responseDuration time.Duration) {
skipperDuration := requestDuration + responseDuration
p.proxyTotalM.WithLabelValues().Observe(skipperDuration.Seconds())
if p.opts.EnableProxyRequestMetrics {
p.proxyRequestM.WithLabelValues().Observe(requestDuration.Seconds())
}
if p.opts.EnableProxyResponseMetrics {
p.proxyResponseM.WithLabelValues().Observe(responseDuration.Seconds())
}
}
// MeasureServe satisfies Metrics interface.
func (p *Prometheus) MeasureServe(routeID, host, method string, code int, start time.Time) {
method = measuredMethod(method)
t := p.sinceS(start)
if p.opts.EnableServeRouteMetrics || p.opts.EnableServeHostMetrics {
metrics := []string{}
if p.opts.EnableServeStatusCodeMetric {
metrics = append(metrics, fmt.Sprint(code))
}
if p.opts.EnableServeMethodMetric {
metrics = append(metrics, method)
}
if p.opts.EnableServeRouteMetrics {
p.serveRouteM.WithLabelValues(append(metrics, routeID)...).Observe(t)
}
if p.opts.EnableServeHostMetrics {
p.serveHostM.WithLabelValues(append(metrics, hostForKey(host))...).Observe(t)
}
}
if p.opts.EnableServeRouteCounter {
p.serveRouteCounterM.WithLabelValues(fmt.Sprint(code), method, routeID).Inc()
}
if p.opts.EnableServeHostCounter {
p.serveHostCounterM.WithLabelValues(fmt.Sprint(code), method, hostForKey(host)).Inc()
}
}
// IncRoutingFailures satisfies Metrics interface.
func (p *Prometheus) IncRoutingFailures() {
p.routeErrorsM.WithLabelValues().Inc()
}
// IncErrorsBackend satisfies Metrics interface.
func (p *Prometheus) IncErrorsBackend(routeID string) {
p.backendErrorsM.WithLabelValues(routeID).Inc()
}
// MeasureBackend5xx satisfies Metrics interface.
func (p *Prometheus) MeasureBackend5xx(start time.Time) {
t := p.sinceS(start)
p.backend5xxM.WithLabelValues().Observe(t)
}
// IncErrorsStreaming satisfies Metrics interface.
func (p *Prometheus) IncErrorsStreaming(routeID string) {
p.proxyStreamingErrorsM.WithLabelValues(routeID).Inc()
}
// UpdateInvalidRoute satisfies Metrics interface.
func (p *Prometheus) UpdateInvalidRoute(reasonCounts map[string]int) {
for reason, count := range reasonCounts {
p.invalidRouteM.WithLabelValues(reason).Set(float64(count))
}
}
func (p *Prometheus) Close() {}
// withStartLabelGatherer adds a "start" label to all counters with
// the value of counter creation timestamp as unix nanoseconds.
type withStartLabelGatherer struct {
*prometheus.Registry
}
func (g withStartLabelGatherer) Gather() ([]*dto.MetricFamily, error) {
metricFamilies, err := g.Registry.Gather()
for _, metricFamily := range metricFamilies {
if metricFamily.GetType() == dto.MetricType_COUNTER {
for _, metric := range metricFamily.Metric {
metric.Label = append(metric.Label, &dto.LabelPair{
Name: proto.String("start"),
Value: proto.String(fmt.Sprintf("%d", metric.Counter.CreatedTimestamp.AsTime().UnixNano())),
})
}
}
}
return metricFamilies, err
}
package metrics
import (
"strings"
metrics "github.com/rcrowley/go-metrics"
)
func newUniformSample() metrics.Sample {
return metrics.NewUniformSample(defaultUniformReservoirSize)
}
func newExpDecaySample() metrics.Sample {
return metrics.NewExpDecaySample(defaultExpDecayReservoirSize, defaultExpDecayAlpha)
}
func createTimer(sample metrics.Sample) metrics.Timer {
return metrics.NewCustomTimer(metrics.NewHistogram(sample), metrics.NewMeter())
}
func hostForKey(h string) string {
h = strings.ReplaceAll(h, ".", "_")
h = strings.ReplaceAll(h, ":", "__")
return h
}
func measuredMethod(m string) string {
switch m {
case "OPTIONS",
"GET",
"HEAD",
"POST",
"PUT",
"PATCH",
"DELETE",
"TRACE",
"CONNECT":
return m
default:
return "_unknownmethod_"
}
}
func applyCompatibilityDefaults(o Options) Options {
if o.DisableCompatibilityDefaults {
return o
}
o.EnableAllFiltersMetrics = true
o.EnableRouteResponseMetrics = true
o.EnableRouteBackendErrorsCounters = true
o.EnableRouteStreamingErrorsCounters = true
o.EnableRouteBackendMetrics = true
return o
}
package net
import (
"context"
"fmt"
"net"
"net/http"
"time"
"github.com/zalando/skipper/metrics"
)
// ConnManager tracks creation of HTTP server connections and
// closes connections when their age or number of requests served reaches configured limits.
// Use [ConnManager.Configure] method to setup ConnManager for an [http.Server].
type ConnManager struct {
// Metrics is an optional metrics registry to count connection events.
Metrics metrics.Metrics
// Keepalive is the duration after which server connection is closed.
Keepalive time.Duration
// KeepaliveRequests is the number of requests after which server connection is closed.
KeepaliveRequests int
handler http.Handler
}
type connState struct {
expiresAt time.Time
requests int
}
type contextKey struct{}
var connection contextKey
func (cm *ConnManager) Configure(server *http.Server) {
cm.handler = server.Handler
server.Handler = http.HandlerFunc(cm.serveHTTP)
if cc := server.ConnContext; cc != nil {
server.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
ctx = cc(ctx, c)
return cm.connContext(ctx, c)
}
} else {
server.ConnContext = cm.connContext
}
if cs := server.ConnState; cs != nil {
server.ConnState = func(c net.Conn, state http.ConnState) {
cs(c, state)
cm.connState(c, state)
}
} else {
server.ConnState = cm.connState
}
}
func (cm *ConnManager) serveHTTP(w http.ResponseWriter, r *http.Request) {
state, _ := r.Context().Value(connection).(*connState)
state.requests++
if cm.KeepaliveRequests > 0 && state.requests >= cm.KeepaliveRequests {
w.Header().Set("Connection", "close")
cm.count("lb-conn-closed.keepalive-requests")
}
if cm.Keepalive > 0 && time.Now().After(state.expiresAt) {
w.Header().Set("Connection", "close")
cm.count("lb-conn-closed.keepalive")
}
cm.handler.ServeHTTP(w, r)
}
func (cm *ConnManager) connContext(ctx context.Context, _ net.Conn) context.Context {
state := &connState{
expiresAt: time.Now().Add(cm.Keepalive),
}
return context.WithValue(ctx, connection, state)
}
func (cm *ConnManager) connState(_ net.Conn, state http.ConnState) {
cm.count(fmt.Sprintf("lb-conn-%s", state))
}
func (cm *ConnManager) count(name string) {
if cm.Metrics != nil {
cm.Metrics.IncCounter(name)
}
}
package net
import (
"net"
"net/http"
)
// ForwardedHeaders sets non-standard X-Forwarded-* Headers
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#proxies and https://github.com/authelia/authelia
type ForwardedHeaders struct {
// For sets or appends request remote IP to the X-Forwarded-For header
For bool
// PrependFor sets or prepends request remote IP to the X-Forwarded-For header, overrides For
PrependFor bool
// Host sets X-Forwarded-Host to the request host
Host bool
// Method sets the http method as X-Forwarded-Method to the request header
Method bool
// Uri sets the path and query as X-Forwarded-Uri header to the request header
Uri bool
// Sets X-Forwarded-Port value
Port string
// Sets X-Forwarded-Proto value
Proto string
}
func (h *ForwardedHeaders) Set(req *http.Request) {
if (h.For || h.PrependFor) && req.RemoteAddr != "" {
addr := req.RemoteAddr
if host, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
addr = host
}
v := req.Header.Get("X-Forwarded-For")
if v == "" {
v = addr
} else if h.PrependFor {
v = addr + ", " + v
} else {
v = v + ", " + addr
}
req.Header.Set("X-Forwarded-For", v)
}
if h.Host {
req.Header.Set("X-Forwarded-Host", req.Host)
}
if h.Method {
req.Header.Set("X-Forwarded-Method", req.Method)
}
if h.Uri {
req.Header.Set("X-Forwarded-Uri", req.RequestURI)
}
if h.Port != "" {
req.Header.Set("X-Forwarded-Port", h.Port)
}
if h.Proto != "" {
req.Header.Set("X-Forwarded-Proto", h.Proto)
}
}
type ForwardedHeadersHandler struct {
Headers ForwardedHeaders
Exclude IPNets
Handler http.Handler
}
func (h *ForwardedHeadersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil && !h.Exclude.Contain(net.ParseIP(host)) {
h.Headers.Set(r)
}
h.Handler.ServeHTTP(w, r)
}
package net
import (
"net"
"net/http"
"strings"
)
// HostPatch is used to modify host[:port] string
type HostPatch struct {
// Remove port if present
RemovePort bool
// Remove trailing dot if present
RemoteTrailingDot bool
// Convert to lowercase
ToLower bool
}
func (h *HostPatch) Apply(original string) string {
host, port := original, ""
// avoid net.SplitHostPort for value without port
if strings.IndexByte(original, ':') != -1 {
if sh, sp, err := net.SplitHostPort(original); err == nil {
host, port = sh, sp
}
}
if h.RemovePort {
port = ""
}
if h.RemoteTrailingDot {
last := len(host) - 1
if last >= 0 && host[last] == '.' {
host = host[:last]
}
}
if h.ToLower {
host = strings.ToLower(host)
}
if port != "" {
return net.JoinHostPort(host, port)
} else {
return host
}
}
type HostPatchHandler struct {
Patch HostPatch
Handler http.Handler
}
func (h *HostPatchHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.Host = h.Patch.Apply(r.Host)
h.Handler.ServeHTTP(w, r)
}
package net
import (
"crypto/tls"
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/url"
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/zalando/skipper/logging"
"github.com/zalando/skipper/secrets"
)
const (
defaultIdleConnTimeout = 30 * time.Second
defaultRefreshInterval = 5 * time.Minute
)
// Client adds additional features like Bearer token injection, and
// opentracing to the wrapped http.Client with the same interface as
// http.Client from the stdlib.
type Client struct {
once sync.Once
client http.Client
tr *Transport
sr secrets.SecretsReader
}
// NewClient creates a wrapped http.Client and uses Transport to
// support OpenTracing. On teardown you have to use Close() to
// not leak a goroutine.
//
// If secrets.SecretsReader is nil, but BearerTokenFile is not empty
// string, it creates StaticDelegateSecret with a wrapped
// secrets.SecretPaths, which can be used with Kubernetes secrets to
// read from the secret an automatically updated Bearer token.
func NewClient(o Options) *Client {
if o.Log == nil {
o.Log = &logging.DefaultLog{}
}
tr := NewTransport(o)
sr := o.SecretsReader
if sr == nil && o.BearerTokenFile != "" {
if o.BearerTokenRefreshInterval == 0 {
o.BearerTokenRefreshInterval = defaultRefreshInterval
}
sp := secrets.NewSecretPaths(o.BearerTokenRefreshInterval)
if err := sp.Add(o.BearerTokenFile); err != nil {
o.Log.Errorf("failed to read secret: %v", err)
}
sr = secrets.NewStaticDelegateSecret(sp, o.BearerTokenFile)
}
c := &Client{
once: sync.Once{},
client: http.Client{
Transport: tr,
CheckRedirect: o.CheckRedirect,
},
tr: tr,
sr: sr,
}
return c
}
func (c *Client) Close() {
c.once.Do(func() {
c.tr.Close()
if c.sr != nil {
c.sr.Close()
}
})
}
func (c *Client) Head(url string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
return c.Do(req)
}
func (c *Client) Get(url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return c.Do(req)
}
func (c *Client) Post(url, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
return c.Do(req)
}
func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) {
return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
// Do delegates the given http.Request to the underlying http.Client
// and adds a Bearer token to the authorization header, if Client has
// a secrets.SecretsReader and the request does not contain an
// Authorization header.
func (c *Client) Do(req *http.Request) (*http.Response, error) {
if c.sr != nil && req.Header.Get("Authorization") == "" {
if b, ok := c.sr.GetSecret(req.URL.String()); ok {
req.Header.Set("Authorization", "Bearer "+string(b))
}
}
return c.client.Do(req)
}
// CloseIdleConnections delegates the call to the underlying
// http.Client.
func (c *Client) CloseIdleConnections() {
c.client.CloseIdleConnections()
}
// Options are mostly passed to the http.Transport of the same
// name. Options.Timeout can be used as default for all timeouts, that
// are not set. You can pass an opentracing.Tracer
// https://godoc.org/github.com/opentracing/opentracing-go#Tracer,
// which can be nil to get the
// https://godoc.org/github.com/opentracing/opentracing-go#NoopTracer.
type Options struct {
// Transport see https://golang.org/pkg/net/http/#Transport
// In case Transport is not nil, the Transport arguments are used below.
Transport *http.Transport
// CheckRedirect see https://golang.org/pkg/net/http/#Client
CheckRedirect func(req *http.Request, via []*http.Request) error
// Proxy see https://golang.org/pkg/net/http/#Transport.Proxy
Proxy func(req *http.Request) (*url.URL, error)
// DisableKeepAlives see https://golang.org/pkg/net/http/#Transport.DisableKeepAlives
DisableKeepAlives bool
// DisableCompression see https://golang.org/pkg/net/http/#Transport.DisableCompression
DisableCompression bool
// ForceAttemptHTTP2 see https://golang.org/pkg/net/http/#Transport.ForceAttemptHTTP2
ForceAttemptHTTP2 bool
// MaxIdleConns see https://golang.org/pkg/net/http/#Transport.MaxIdleConns
MaxIdleConns int
// MaxIdleConnsPerHost see https://golang.org/pkg/net/http/#Transport.MaxIdleConnsPerHost
MaxIdleConnsPerHost int
// MaxConnsPerHost see https://golang.org/pkg/net/http/#Transport.MaxConnsPerHost
MaxConnsPerHost int
// WriteBufferSize see https://golang.org/pkg/net/http/#Transport.WriteBufferSize
WriteBufferSize int
// ReadBufferSize see https://golang.org/pkg/net/http/#Transport.ReadBufferSize
ReadBufferSize int
// MaxResponseHeaderBytes see
// https://golang.org/pkg/net/http/#Transport.MaxResponseHeaderBytes
MaxResponseHeaderBytes int64
// Timeout sets all Timeouts, that are set to 0 to the given
// value. Basically it's the default timeout value.
Timeout time.Duration
// TLSHandshakeTimeout see
// https://golang.org/pkg/net/http/#Transport.TLSHandshakeTimeout,
// if not set or set to 0, its using Options.Timeout.
TLSHandshakeTimeout time.Duration
// IdleConnTimeout see
// https://golang.org/pkg/net/http/#Transport.IdleConnTimeout,
// if not set or set to 0, its using Options.Timeout.
IdleConnTimeout time.Duration
// ResponseHeaderTimeout see
// https://golang.org/pkg/net/http/#Transport.ResponseHeaderTimeout,
// if not set or set to 0, its using Options.Timeout.
ResponseHeaderTimeout time.Duration
// ExpectContinueTimeout see
// https://golang.org/pkg/net/http/#Transport.ExpectContinueTimeout,
// if not set or set to 0, its using Options.Timeout.
ExpectContinueTimeout time.Duration
// Tracer instance, can be nil to not enable tracing
Tracer opentracing.Tracer
// OpentracingComponentTag sets component tag for all requests
OpentracingComponentTag string
// OpentracingSpanName sets span name for all requests
OpentracingSpanName string
// BearerTokenFile injects bearer token read from file, which
// file path is the given string. In case SecretsReader is
// provided, BearerTokenFile will be ignored.
BearerTokenFile string
// BearerTokenRefreshInterval refresh bearer from
// BearerTokenFile. In case SecretsReader is provided,
// BearerTokenFile will be ignored.
BearerTokenRefreshInterval time.Duration
// SecretsReader is used to read and refresh bearer tokens
SecretsReader secrets.SecretsReader
// Log is used for error logging
Log logging.Logger
// BeforeSend is a hook function that runs just before executing RoundTrip(*http.Request)
BeforeSend func(*http.Request)
// AfterResponse is a hook function that runs just after executing RoundTrip(*http.Request)
AfterResponse func(*http.Response, error)
}
// Transport wraps an http.Transport and adds support for tracing and
// bearerToken injection.
type Transport struct {
once sync.Once
quit chan struct{}
tr *http.Transport
tracer opentracing.Tracer
spanName string
componentName string
bearerToken string
beforeSend func(*http.Request)
afterResponse func(*http.Response, error)
}
// NewTransport creates a wrapped http.Transport, with regular DNS
// lookups using CloseIdleConnections on every IdleConnTimeout. You
// can optionally add tracing. On teardown you have to use Close() to
// not leak a goroutine.
func NewTransport(options Options) *Transport {
// set default tracer
if options.Tracer == nil {
options.Tracer = &opentracing.NoopTracer{}
}
// set timeout defaults
if options.TLSHandshakeTimeout == 0 {
options.TLSHandshakeTimeout = options.Timeout
}
if options.IdleConnTimeout == 0 {
if options.Timeout != 0 {
options.IdleConnTimeout = options.Timeout
} else {
options.IdleConnTimeout = defaultIdleConnTimeout
}
}
if options.ResponseHeaderTimeout == 0 {
options.ResponseHeaderTimeout = options.Timeout
}
if options.ExpectContinueTimeout == 0 {
options.ExpectContinueTimeout = options.Timeout
}
if options.Proxy == nil {
options.Proxy = http.ProxyFromEnvironment
}
var htransport *http.Transport
if options.Transport != nil {
htransport = options.Transport
} else {
htransport = &http.Transport{
Proxy: options.Proxy,
DisableKeepAlives: options.DisableKeepAlives,
DisableCompression: options.DisableCompression,
ForceAttemptHTTP2: options.ForceAttemptHTTP2,
MaxIdleConns: options.MaxIdleConns,
MaxIdleConnsPerHost: options.MaxIdleConnsPerHost,
MaxConnsPerHost: options.MaxConnsPerHost,
WriteBufferSize: options.WriteBufferSize,
ReadBufferSize: options.ReadBufferSize,
MaxResponseHeaderBytes: options.MaxResponseHeaderBytes,
ResponseHeaderTimeout: options.ResponseHeaderTimeout,
TLSHandshakeTimeout: options.TLSHandshakeTimeout,
IdleConnTimeout: options.IdleConnTimeout,
ExpectContinueTimeout: options.ExpectContinueTimeout,
}
}
t := &Transport{
once: sync.Once{},
quit: make(chan struct{}),
tr: htransport,
tracer: options.Tracer,
beforeSend: options.BeforeSend,
afterResponse: options.AfterResponse,
}
if t.tracer != nil {
if options.OpentracingComponentTag != "" {
t = WithComponentTag(t, options.OpentracingComponentTag)
}
if options.OpentracingSpanName != "" {
t = WithSpanName(t, options.OpentracingSpanName)
}
}
go func() {
ticker := time.NewTicker(options.IdleConnTimeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
htransport.CloseIdleConnections()
case <-t.quit:
htransport.CloseIdleConnections()
return
}
}
}()
return t
}
// WithSpanName sets the name of the span, if you have an enabled
// tracing Transport.
func WithSpanName(t *Transport, spanName string) *Transport {
tt := t.shallowCopy()
tt.spanName = spanName
return tt
}
// WithComponentTag sets the component name, if you have an enabled
// tracing Transport.
func WithComponentTag(t *Transport, componentName string) *Transport {
tt := t.shallowCopy()
tt.componentName = componentName
return tt
}
// WithBearerToken adds an Authorization header with "Bearer " prefix
// and add the given bearerToken as value to all requests. To regular
// update your token you need to call this method and use the returned
// Transport.
func WithBearerToken(t *Transport, bearerToken string) *Transport {
tt := t.shallowCopy()
tt.bearerToken = bearerToken
return tt
}
func (t *Transport) shallowCopy() *Transport {
return &Transport{
once: sync.Once{},
quit: t.quit,
tr: t.tr,
tracer: t.tracer,
spanName: t.spanName,
componentName: t.componentName,
bearerToken: t.bearerToken,
beforeSend: t.beforeSend,
afterResponse: t.afterResponse,
}
}
func (t *Transport) Close() {
t.once.Do(func() {
close(t.quit)
})
}
func (t *Transport) CloseIdleConnections() {
t.tr.CloseIdleConnections()
}
// RoundTrip the request with tracing, bearer token injection and add client
// tracing: DNS, TCP/IP, TLS handshake, connection pool access. Client
// traces are added as logs into the created span.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
var span opentracing.Span
if t.spanName != "" {
req, span = t.injectSpan(req)
defer span.Finish()
req = injectClientTrace(req, span)
span.LogKV("http_do", "start")
}
if t.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+t.bearerToken)
}
if t.beforeSend != nil {
t.beforeSend(req)
}
rsp, err := t.tr.RoundTrip(req)
if t.afterResponse != nil {
t.afterResponse(rsp, err)
}
if span != nil {
span.LogKV("http_do", "stop")
if rsp != nil {
ext.HTTPStatusCode.Set(span, uint16(rsp.StatusCode))
}
}
return rsp, err
}
func (t *Transport) injectSpan(req *http.Request) (*http.Request, opentracing.Span) {
spanOpts := []opentracing.StartSpanOption{opentracing.Tags{
string(ext.Component): t.componentName,
string(ext.SpanKind): "client",
string(ext.HTTPMethod): req.Method,
string(ext.HTTPUrl): req.URL.String(),
}}
if parentSpan := opentracing.SpanFromContext(req.Context()); parentSpan != nil {
spanOpts = append(spanOpts, opentracing.ChildOf(parentSpan.Context()))
}
span := t.tracer.StartSpan(t.spanName, spanOpts...)
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
_ = t.tracer.Inject(span.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))
return req, span
}
func injectClientTrace(req *http.Request, span opentracing.Span) *http.Request {
trace := &httptrace.ClientTrace{
DNSStart: func(httptrace.DNSStartInfo) {
span.LogKV("DNS", "start")
},
DNSDone: func(httptrace.DNSDoneInfo) {
span.LogKV("DNS", "end")
},
ConnectStart: func(string, string) {
span.LogKV("connect", "start")
},
ConnectDone: func(string, string, error) {
span.LogKV("connect", "end")
},
TLSHandshakeStart: func() {
span.LogKV("TLS", "start")
},
TLSHandshakeDone: func(tls.ConnectionState, error) {
span.LogKV("TLS", "end")
},
GetConn: func(string) {
span.LogKV("get_conn", "start")
},
GotConn: func(httptrace.GotConnInfo) {
span.LogKV("get_conn", "end")
},
WroteHeaders: func() {
span.LogKV("wrote_headers", "done")
},
WroteRequest: func(wri httptrace.WroteRequestInfo) {
if wri.Err != nil {
span.LogKV("wrote_request", ensureUTF8(wri.Err.Error()))
} else {
span.LogKV("wrote_request", "done")
}
},
GotFirstResponseByte: func() {
span.LogKV("got_first_byte", "done")
},
}
return req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
}
func ensureUTF8(s string) string {
if utf8.ValidString(s) {
return s
}
return fmt.Sprintf("invalid utf-8: %q", s)
}
package net
import (
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"go4.org/netipx"
)
// strip port from addresses with hostname, ipv4 or ipv6
func stripPort(address string) string {
if h, _, err := net.SplitHostPort(address); err == nil {
return h
}
return address
}
func parse(addr string) net.IP {
if addr != "" {
res := net.ParseIP(stripPort(addr))
return res
}
return nil
}
// RemoteAddr returns the remote address of the client. When the
// 'X-Forwarded-For' header is set, then it is used instead. This is
// how most often proxies behave. Wikipedia shows the format
// https://en.wikipedia.org/wiki/X-Forwarded-For#Format
//
// Example:
//
// X-Forwarded-For: client, proxy1, proxy2
func RemoteAddr(r *http.Request) netip.Addr {
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
s, _, _ := strings.Cut(xff, ",")
if addr, err := netip.ParseAddr(stripPort(s)); err == nil {
return addr
}
}
addr, _ := netip.ParseAddr(stripPort(r.RemoteAddr))
return addr
}
// RemoteHost is *deprecated* use RemoteAddr
func RemoteHost(r *http.Request) net.IP {
ffs := r.Header.Get("X-Forwarded-For")
ff, _, _ := strings.Cut(ffs, ",")
if ffh := parse(ff); ffh != nil {
return ffh
}
return parse(r.RemoteAddr)
}
// RemoteAddrFromLast returns the remote address of the client. When
// the 'X-Forwarded-For' header is set, then it is used instead. This
// is known to be true for AWS Application LoadBalancer. AWS docs
// https://docs.aws.amazon.com/elasticloadbalancing/latest/classic/x-forwarded-headers.html
//
// Example:
//
// X-Forwarded-For: ip-address-1, ip-address-2, client-ip-address
func RemoteAddrFromLast(r *http.Request) netip.Addr {
ffs := r.Header.Get("X-Forwarded-For")
if ffs == "" {
addr, _ := netip.ParseAddr(stripPort(r.RemoteAddr))
return addr
}
last := ffs
if i := strings.LastIndex(ffs, ","); i != -1 {
last = ffs[i+1:]
}
addr, err := netip.ParseAddr(stripPort(strings.TrimSpace(last)))
if err != nil {
addr, _ := netip.ParseAddr(stripPort(r.RemoteAddr))
return addr
}
return addr
}
// RemoteHostFromLast is *deprecated* use RemoteAddrFromLast instead
func RemoteHostFromLast(r *http.Request) net.IP {
ffs := r.Header.Get("X-Forwarded-For")
ff := ffs
if i := strings.LastIndex(ffs, ","); i != -1 {
ff = ffs[i+1:]
}
if ff != "" {
if ip := parse(strings.TrimSpace(ff)); ip != nil {
return ip
}
}
return parse(r.RemoteAddr)
}
// IPNets is *deprecated* use netipx.IPSet instead
type IPNets []*net.IPNet
// Contain is *deprecated* use netipx.IPSet.Contains() instead
func (nets IPNets) Contain(ip net.IP) bool {
for _, net := range nets {
if net.Contains(ip) {
return true
}
}
return false
}
// ParseCIDRs is *deprecated* use ParseIPCIDRs.
func ParseCIDRs(cidrs []string) (nets IPNets, err error) {
for _, cidr := range cidrs {
if !strings.Contains(cidr, "/") {
cidr += "/32"
}
_, net, err := net.ParseCIDR(cidr)
if err != nil {
return nil, err
}
nets = append(nets, net)
}
return nets, nil
}
// ParseIPCIDRs returns a valid IPSet in case there is no parsing
// error.
func ParseIPCIDRs(cidrs []string) (*netipx.IPSet, error) {
var b netipx.IPSetBuilder
for _, w := range cidrs {
if strings.Contains(w, "/") {
pref, err := netip.ParsePrefix(w)
if err != nil {
return nil, err
}
b.AddPrefix(pref)
} else if addr, err := netip.ParseAddr(w); err != nil {
return nil, err
} else if addr.IsUnspecified() {
return nil, fmt.Errorf("failed to parse cidr: addr is unspecified: %s", w)
} else {
b.Add(addr)
}
}
ips, err := b.IPSet()
if err != nil {
return nil, err
}
return ips, nil
}
// SchemeHost parses URI string (without #fragment part) and returns schema used in this URI as first return value and
// host[:port] part as second return value. Port is never omitted for HTTP(S): if no port is specified in URI, default port for given
// schema is used. If URI is invalid, error is returned.
func SchemeHost(input string) (string, string, error) {
u, err := url.ParseRequestURI(input)
if err != nil {
return "", "", err
}
if u.Scheme == "" {
return "", "", fmt.Errorf(`parse %q: missing scheme`, input)
}
if u.Host == "" {
return "", "", fmt.Errorf(`parse %q: missing host`, input)
}
// endpoint address cannot contain path, the rest is not case sensitive
s, h := strings.ToLower(u.Scheme), strings.ToLower(u.Host)
hh, p, err := net.SplitHostPort(h)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
// Trim is needed to remove brackets from IPv6 addresses, JoinHostPort will add them in case of any IPv6 address,
// so we need to remove them to avoid duplicate pairs of brackets.
h = strings.Trim(h, "[]")
switch s {
case "http":
p = "80"
case "https":
p = "443"
default:
p = ""
}
} else {
return "", "", err
}
} else {
h = hh
}
if p != "" {
h = net.JoinHostPort(h, p)
}
return s, h, nil
}
package net
import (
"net/http"
"net/url"
log "github.com/sirupsen/logrus"
)
type ValidateQueryHandler struct {
Handler http.Handler
}
func (q *ValidateQueryHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := validateQuery(r.URL.RawQuery); err != nil {
http.Error(w, "Invalid query", http.StatusBadRequest)
return
}
q.Handler.ServeHTTP(w, r)
}
type ValidateQueryLogHandler struct {
Handler http.Handler
}
func (q *ValidateQueryLogHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := validateQuery(r.URL.RawQuery); err != nil {
log.Infof("Invalid query: %s -> %s %s %s", r.RemoteAddr, r.Host, r.URL.Path, r.Method)
}
q.Handler.ServeHTTP(w, r)
}
func validateQuery(s string) error {
_, err := url.ParseQuery(s)
return err
}
package net
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/cenkalti/backoff"
"github.com/opentracing/opentracing-go"
"github.com/redis/go-redis/v9"
"github.com/zalando/skipper/logging"
"github.com/zalando/skipper/metrics"
xxhash "github.com/cespare/xxhash/v2"
rendezvous "github.com/dgryski/go-rendezvous"
jump "github.com/dgryski/go-jump"
"github.com/dchest/siphash"
mpchash "github.com/dgryski/go-mpchash"
)
// RedisOptions is used to configure the redis.Ring
type RedisOptions struct {
// Addrs are the list of redis shards
Addrs []string
// AddrUpdater is a func that is regularly called to update
// redis address list. This func should return a list of redis
// shards.
AddrUpdater func() ([]string, error)
// UpdateInterval is the time.Duration that AddrUpdater is
// triggered and SetAddrs be used to update the redis shards
UpdateInterval time.Duration
// Password is the password needed to connect to Redis server
Password string
// ReadTimeout for redis socket reads
ReadTimeout time.Duration
// WriteTimeout for redis socket writes
WriteTimeout time.Duration
// DialTimeout is the max time.Duration to dial a new connection
DialTimeout time.Duration
// PoolTimeout is the max time.Duration to get a connection from pool
PoolTimeout time.Duration
// IdleTimeout requires a non 0 IdleCheckFrequency
IdleTimeout time.Duration
// IdleCheckFrequency - reaper frequency, only used if IdleTimeout > 0
IdleCheckFrequency time.Duration
// MaxConnAge
MaxConnAge time.Duration
// MinIdleConns is the minimum number of socket connections to redis
MinIdleConns int
// MaxIdleConns is the maximum number of socket connections to redis
MaxIdleConns int
// HeartbeatFrequency frequency of PING commands sent to check
// shards availability.
HeartbeatFrequency time.Duration
// ConnMetricsInterval defines the frequency of updating the redis
// connection related metrics. Defaults to 60 seconds.
ConnMetricsInterval time.Duration
// MetricsPrefix is the prefix for redis ring client metrics,
// defaults to "swarm.redis." if not set
MetricsPrefix string
// Tracer provides OpenTracing for Redis queries.
Tracer opentracing.Tracer
// Log is the logger that is used
Log logging.Logger
// HashAlgorithm is one of rendezvous, rendezvousVnodes, jump, mpchash, defaults to github.com/go-redis/redis default
HashAlgorithm string
}
// RedisRingClient is a redis client that does access redis by
// computing a ring hash. It logs to the logging.Logger interface,
// that you can pass. It adds metrics and operations are traced with
// opentracing. You can set timeouts and the defaults are set to be ok
// to be in the hot path of low latency production requests.
type RedisRingClient struct {
ring *redis.Ring
log logging.Logger
metrics metrics.Metrics
metricsPrefix string
options *RedisOptions
tracer opentracing.Tracer
quit chan struct{}
once sync.Once
closed bool
}
type RedisScript struct {
script *redis.Script
}
const (
// DefaultReadTimeout is the default socket read timeout
DefaultReadTimeout = 25 * time.Millisecond
// DefaultWriteTimeout is the default socket write timeout
DefaultWriteTimeout = 25 * time.Millisecond
// DefaultPoolTimeout is the default timeout to access the connection pool
DefaultPoolTimeout = 25 * time.Millisecond
// DefaultDialTimeout is the default dial timeout
DefaultDialTimeout = 25 * time.Millisecond
// DefaultMinConns is the default minimum of connections
DefaultMinConns = 100
// DefaultMaxConns is the default maximum of connections
DefaultMaxConns = 100
defaultConnMetricsInterval = 60 * time.Second
defaultUpdateInterval = 10 * time.Second
)
// https://arxiv.org/pdf/1406.2294.pdf
type jumpHash struct {
shards []string
}
func NewJumpHash(shards []string) redis.ConsistentHash {
return &jumpHash{
shards: shards,
}
}
func (j *jumpHash) Get(k string) string {
key := xxhash.Sum64String(k)
h := jump.Hash(key, len(j.shards))
return j.shards[int(h)]
}
// Multi-probe consistent hashing - mpchash
// https://arxiv.org/pdf/1505.00062.pdf
type multiprobe struct {
hash *mpchash.Multi
}
func NewMultiprobe(shards []string) redis.ConsistentHash {
return &multiprobe{
// 2 seeds and k=21 got from library
hash: mpchash.New(shards, siphash64seed, [2]uint64{1, 2}, 21),
}
}
func (mc *multiprobe) Get(k string) string {
return mc.hash.Hash(k)
}
func siphash64seed(b []byte, s uint64) uint64 { return siphash.Hash(s, 0, b) }
// rendezvous copied from github.com/go-redis/redis/v8@v8.3.3/ring.go
type rendezvousWrapper struct {
*rendezvous.Rendezvous
}
func (w rendezvousWrapper) Get(key string) string {
return w.Lookup(key)
}
func NewRendezvous(shards []string) redis.ConsistentHash {
return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)}
}
// rendezvous vnodes
type rendezvousVnodes struct {
*rendezvous.Rendezvous
table map[string]string
}
const vnodePerShard = 100
func (w rendezvousVnodes) Get(key string) string {
k := w.Lookup(key)
v, ok := w.table[k]
if !ok {
log.Printf("not found: %s in table for input: %s, so return %s", k, key, v)
}
return v
}
func NewRendezvousVnodes(shards []string) redis.ConsistentHash {
vshards := make([]string, vnodePerShard*len(shards))
table := make(map[string]string)
for i := 0; i < vnodePerShard; i++ {
for j, shard := range shards {
vshard := fmt.Sprintf("%s%d", shard, i) // suffix
table[vshard] = shard
vshards[i*len(shards)+j] = vshard
}
}
return rendezvousVnodes{rendezvous.New(vshards, xxhash.Sum64String), table}
}
func NewRedisRingClient(ro *RedisOptions) *RedisRingClient {
const backOffTime = 2 * time.Second
const retryCount = 5
r := &RedisRingClient{
once: sync.Once{},
quit: make(chan struct{}),
metrics: metrics.Default,
tracer: &opentracing.NoopTracer{},
}
ringOptions := &redis.RingOptions{
Addrs: map[string]string{},
}
if ro != nil {
switch ro.HashAlgorithm {
case "rendezvous":
ringOptions.NewConsistentHash = NewRendezvous
case "rendezvousVnodes":
ringOptions.NewConsistentHash = NewRendezvousVnodes
case "jump":
ringOptions.NewConsistentHash = NewJumpHash
case "mpchash":
ringOptions.NewConsistentHash = NewMultiprobe
}
if ro.Log == nil {
ro.Log = &logging.DefaultLog{}
}
if ro.AddrUpdater != nil {
address, err := ro.AddrUpdater()
for i := 0; i < retryCount; i++ {
if err == nil {
break
}
time.Sleep(backOffTime)
address, err = ro.AddrUpdater()
}
if err != nil {
ro.Log.Errorf("Failed at redisclient startup %v", err)
}
ringOptions.Addrs = createAddressMap(address)
} else {
ringOptions.Addrs = createAddressMap(ro.Addrs)
}
ro.Log.Infof("Created ring with addresses: %v", ro.Addrs)
ringOptions.ReadTimeout = ro.ReadTimeout
ringOptions.WriteTimeout = ro.WriteTimeout
ringOptions.PoolTimeout = ro.PoolTimeout
ringOptions.DialTimeout = ro.DialTimeout
ringOptions.MinIdleConns = ro.MinIdleConns
ringOptions.PoolSize = ro.MaxIdleConns
ringOptions.Password = ro.Password
if ro.ConnMetricsInterval <= 0 {
ro.ConnMetricsInterval = defaultConnMetricsInterval
}
if ro.Tracer != nil {
r.tracer = ro.Tracer
}
r.options = ro
r.ring = redis.NewRing(ringOptions)
r.log = ro.Log
r.metricsPrefix = ro.MetricsPrefix
if ro.AddrUpdater != nil {
if ro.UpdateInterval == 0 {
ro.UpdateInterval = defaultUpdateInterval
}
go r.startUpdater(context.Background())
}
}
return r
}
func createAddressMap(addrs []string) map[string]string {
res := make(map[string]string)
for _, addr := range addrs {
res[addr] = addr
}
return res
}
func hasAll(a []string, set map[string]struct{}) bool {
if len(a) != len(set) {
return false
}
for _, w := range a {
if _, ok := set[w]; !ok {
return false
}
}
return true
}
func (r *RedisRingClient) startUpdater(ctx context.Context) {
old := make(map[string]struct{})
for _, addr := range r.options.Addrs {
old[addr] = struct{}{}
}
r.log.Infof("Start goroutine to update redis instances every %s", r.options.UpdateInterval)
defer r.log.Info("Stopped goroutine to update redis")
ticker := time.NewTicker(r.options.UpdateInterval)
defer ticker.Stop()
for {
select {
case <-r.quit:
return
case <-ticker.C:
}
addrs, err := r.options.AddrUpdater()
if err != nil {
r.log.Errorf("Failed to start redis updater: %v", err)
continue
}
if !hasAll(addrs, old) {
r.log.Infof("Redis updater updating old(%d) != new(%d)", len(old), len(addrs))
r.SetAddrs(ctx, addrs)
r.metrics.UpdateGauge(r.metricsPrefix+"shards", float64(r.ring.Len()))
old = make(map[string]struct{})
for _, addr := range addrs {
old[addr] = struct{}{}
}
}
}
}
func (r *RedisRingClient) RingAvailable() bool {
var err error
err = backoff.Retry(func() error {
_, err = r.ring.Ping(context.Background()).Result()
if err != nil {
r.log.Infof("Failed to ping redis, retry with backoff: %v", err)
}
return err
}, backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 7))
return err == nil
}
func (r *RedisRingClient) StartMetricsCollection() {
go func() {
ticker := time.NewTicker(r.options.ConnMetricsInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
stats := r.ring.PoolStats()
// counter values
r.metrics.UpdateGauge(r.metricsPrefix+"hits", float64(stats.Hits))
r.metrics.UpdateGauge(r.metricsPrefix+"misses", float64(stats.Misses))
r.metrics.UpdateGauge(r.metricsPrefix+"timeouts", float64(stats.Timeouts))
// counter of reaped staleconns which were closed
r.metrics.UpdateGauge(r.metricsPrefix+"staleconns", float64(stats.StaleConns))
// gauges
r.metrics.UpdateGauge(r.metricsPrefix+"idleconns", float64(stats.IdleConns))
r.metrics.UpdateGauge(r.metricsPrefix+"totalconns", float64(stats.TotalConns))
case <-r.quit:
return
}
}
}()
}
func (r *RedisRingClient) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
return r.tracer.StartSpan(operationName, opts...)
}
func (r *RedisRingClient) Close() {
r.once.Do(func() {
r.closed = true
close(r.quit)
if r.ring != nil {
r.ring.Close()
}
})
}
func (r *RedisRingClient) SetAddrs(ctx context.Context, addrs []string) {
if len(addrs) == 0 {
return
}
r.ring.SetAddrs(createAddressMap(addrs))
}
func (r *RedisRingClient) Get(ctx context.Context, key string) (string, error) {
res := r.ring.Get(ctx, key)
return res.Val(), res.Err()
}
func (r *RedisRingClient) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) (string, error) {
res := r.ring.Set(ctx, key, value, expiration)
return res.Result()
}
func (r *RedisRingClient) ZAdd(ctx context.Context, key string, val int64, score float64) (int64, error) {
res := r.ring.ZAdd(ctx, key, redis.Z{Member: val, Score: score})
return res.Val(), res.Err()
}
func (r *RedisRingClient) ZRem(ctx context.Context, key string, members ...interface{}) (int64, error) {
res := r.ring.ZRem(ctx, key, members...)
return res.Val(), res.Err()
}
func (r *RedisRingClient) Expire(ctx context.Context, key string, expiration time.Duration) (bool, error) {
res := r.ring.Expire(ctx, key, expiration)
return res.Val(), res.Err()
}
func (r *RedisRingClient) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
res := r.ring.ZRemRangeByScore(ctx, key, fmt.Sprint(min), fmt.Sprint(max))
return res.Val(), res.Err()
}
func (r *RedisRingClient) ZCard(ctx context.Context, key string) (int64, error) {
res := r.ring.ZCard(ctx, key)
return res.Val(), res.Err()
}
func (r *RedisRingClient) ZRangeByScoreWithScoresFirst(ctx context.Context, key string, min, max float64, offset, count int64) (interface{}, error) {
opt := redis.ZRangeBy{
Min: fmt.Sprint(min),
Max: fmt.Sprint(max),
Offset: offset,
Count: count,
}
res := r.ring.ZRangeByScoreWithScores(ctx, key, &opt)
zs, err := res.Result()
if err != nil {
return nil, err
}
if len(zs) == 0 {
return nil, nil
}
return zs[0].Member, nil
}
func (r *RedisRingClient) NewScript(source string) *RedisScript {
return &RedisScript{redis.NewScript(source)}
}
func (r *RedisRingClient) RunScript(ctx context.Context, s *RedisScript, keys []string, args ...interface{}) (interface{}, error) {
return s.script.Run(ctx, r.ring, keys, args...).Result()
}
package net
import (
"net/http"
"net/url"
"strings"
)
type RequestMatchHandler struct {
Match []string
Handler http.Handler
}
func (h *RequestMatchHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if h.matchesRequest(r) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Invalid request\n"))
return
}
h.Handler.ServeHTTP(w, r)
}
func (h *RequestMatchHandler) matchesRequest(r *http.Request) bool {
if h.matches(r.RequestURI) {
return true
}
unescapedURI, _ := url.QueryUnescape(r.RequestURI)
if h.matches(unescapedURI) {
return true
}
for name, values := range r.Header {
if h.matches(name) {
return true
}
for _, value := range values {
if h.matches(value) {
return true
}
}
}
return false
}
func (h *RequestMatchHandler) matches(value string) bool {
for _, v := range h.Match {
if strings.Contains(value, v) {
return true
}
}
return false
}
package net
import (
"context"
"net/http"
)
// SizeOfRequestHeader returns size of the HTTP request header
func SizeOfRequestHeader(r *http.Request) (size int) {
return len(valueOrDefault(r.Method, "GET")) + len(" ") + len(r.URL.RequestURI()) + len(" HTTP/1.1\r\n") +
len("Host: ") + len(valueOrDefault(r.Host, r.URL.Host)) + len("\r\n") +
sizeOfHeader(r.Header) + len("\r\n")
}
func valueOrDefault(value, def string) string {
if value != "" {
return value
}
return def
}
func sizeOfHeader(h http.Header) (size int) {
for k, vv := range h {
for _, v := range vv {
size += len(k) + len(": ") + len(v) + len("\r\n")
}
}
return
}
func exactSizeOfRequestHeader(r *http.Request) (size int) {
r = r.Clone(context.Background())
r.Body = nil // discard body
w := &countingWriter{}
r.Write(w)
return w.size
}
type countingWriter struct {
size int
}
func (w *countingWriter) Write(p []byte) (n int, err error) {
n = len(p)
w.size += n
return n, nil
}
/*
Package pathmux implements a tree lookup for values associated to
paths.
This package is a fork of https://github.com/dimfeld/httptreemux.
*/
package pathmux
import (
"fmt"
"net/url"
"strings"
)
// Matcher objects, when using the LookupMatcher function, can be used for additional checks and to override the
// default result in case of path matches. The argument passed to the Match function is the original value
// passed to the Tree.Add function.
type Matcher interface {
// Match should return true and the object to be returned by the lookup, when the argument value fulfils the
// conditions defined by the custom logic in the matcher itself. If it returns false, it instructs the
// lookup to continue with backtracking from the current tree position.
Match(value interface{}) (bool, interface{})
}
type trueMatcher struct{}
func (m *trueMatcher) Match(value interface{}) (bool, interface{}) {
return true, value
}
var tm *trueMatcher
func init() {
tm = &trueMatcher{}
}
type node struct {
path string
priority int
// The list of static children to check.
staticIndices []byte
staticChild []*node
// If none of the above match, check the wildcard children
wildcardChild *node
// If none of the above match, then we use the catch-all, if applicable.
catchAllChild *node
isCatchAll bool
leafValue interface{}
}
// Tree structure to store values associated to paths.
type Tree node
func (n *node) sortStaticChild(i int) {
for i > 0 && n.staticChild[i].priority > n.staticChild[i-1].priority {
n.staticChild[i], n.staticChild[i-1] = n.staticChild[i-1], n.staticChild[i]
n.staticIndices[i], n.staticIndices[i-1] = n.staticIndices[i-1], n.staticIndices[i]
i -= 1
}
}
func (n *node) addPath(path string) (*node, error) {
leaf := len(path) == 0
if leaf {
return n, nil
}
c := path[0]
nextSlash := strings.Index(path, "/")
var thisToken string
var tokenEnd int
if c == '/' {
thisToken = "/"
tokenEnd = 1
} else if nextSlash == -1 {
thisToken = path
tokenEnd = len(path)
} else {
thisToken = path[0:nextSlash]
tokenEnd = nextSlash
}
remainingPath := path[tokenEnd:]
if c == '*' {
// Token starts with a *, so it's a catch-all
thisToken = thisToken[1:]
if n.catchAllChild == nil {
n.catchAllChild = &node{path: thisToken, isCatchAll: true}
}
if path[1:] != n.catchAllChild.path {
return nil, fmt.Errorf(
"catch-all name in %s doesn't match %s",
path, n.catchAllChild.path)
}
if nextSlash != -1 {
return nil, fmt.Errorf("/ after catch-all found in %s", path)
}
return n.catchAllChild, nil
} else if c == ':' {
// Token starts with a :
if n.wildcardChild == nil {
n.wildcardChild = &node{path: "wildcard"}
}
return n.wildcardChild.addPath(remainingPath)
} else {
if strings.ContainsAny(thisToken, ":*") {
return nil, fmt.Errorf("* or : in middle of path component %s", path)
}
// Do we have an existing node that starts with the same byte?
for i, index := range n.staticIndices {
if c == index {
// Yes. Split it based on the common prefix of the existing
// node and the new one.
child, prefixSplit := n.splitCommonPrefix(i, thisToken)
child.priority++
n.sortStaticChild(i)
return child.addPath(path[prefixSplit:])
}
}
// No existing node starting with this byte, so create it.
child := &node{path: thisToken}
if n.staticIndices == nil {
n.staticIndices = []byte{c}
n.staticChild = []*node{child}
} else {
n.staticIndices = append(n.staticIndices, c)
n.staticChild = append(n.staticChild, child)
}
return child.addPath(remainingPath)
}
}
func (n *node) splitCommonPrefix(existingNodeIndex int, path string) (*node, int) {
childNode := n.staticChild[existingNodeIndex]
if strings.HasPrefix(path, childNode.path) {
// No split needs to be done. Rather, the new path shares the entire
// prefix with the existing node, so the new node is just a child of
// the existing one. Or the new path is the same as the existing path,
// which means that we just move on to the next token. Either way,
// this return accomplishes that
return childNode, len(childNode.path)
}
// Find the length of the common prefix of the child node and the new path.
i := commonPrefixLen(childNode.path, path)
commonPrefix := path[0:i]
childNode.path = childNode.path[i:]
// Create a new intermediary node in the place of the existing node, with
// the existing node as a child.
newNode := &node{
path: commonPrefix,
priority: childNode.priority,
// Index is the first byte of the non-common part of the path.
staticIndices: []byte{childNode.path[0]},
staticChild: []*node{childNode},
}
n.staticChild[existingNodeIndex] = newNode
return newNode, i
}
func commonPrefixLen(x, y string) int {
n := 0
for n < len(x) && n < len(y) && x[n] == y[n] {
n++
}
return n
}
func (n *node) search(path string, m Matcher) (found *node, params []string, value interface{}) {
pathLen := len(path)
if pathLen == 0 {
if n.leafValue == nil {
return nil, nil, nil
}
var match bool
match, value = m.Match(n.leafValue)
if !match {
return nil, nil, nil
}
return n, nil, value
}
// First see if this matches a static token.
firstChar := path[0]
for i, staticIndex := range n.staticIndices {
if staticIndex == firstChar {
child := n.staticChild[i]
childPathLen := len(child.path)
if pathLen >= childPathLen && child.path == path[:childPathLen] {
nextPath := path[childPathLen:]
found, params, value = child.search(nextPath, m)
}
break
}
}
if found != nil {
return
}
if n.wildcardChild != nil {
// Didn't find a static token, so check for a wildcard.
nextSlash := 0
for nextSlash < pathLen && path[nextSlash] != '/' {
nextSlash++
}
thisToken := path[0:nextSlash]
nextToken := path[nextSlash:]
if len(thisToken) > 0 { // Don't match on empty tokens.
found, params, value = n.wildcardChild.search(nextToken, m)
if found != nil {
unescaped, err := url.QueryUnescape(thisToken)
if err != nil {
unescaped = thisToken
}
if params == nil {
params = []string{unescaped}
} else {
params = append(params, unescaped)
}
return
}
}
}
catchAllChild := n.catchAllChild
if catchAllChild != nil {
// Hit the catchall, so just assign the whole remaining path.
unescaped, err := url.QueryUnescape(path)
if err != nil {
unescaped = path
}
var match bool
match, value = m.Match(catchAllChild.leafValue)
if !match {
return nil, nil, nil
}
return catchAllChild, []string{unescaped}, value
}
return nil, nil, nil
}
// Add a value to the tree associated with a path. Paths may contain
// wildcards. Wildcards can be of two types:
//
// - simple wildcard: e.g. /some/:wildcard/path, where a wildcard is
// matched to a single name in the path.
//
// - free wildcard: e.g. /some/path/*wildcard, where a wildcard at the
// end of a path matches anything.
func (t *Tree) Add(path string, value interface{}) error {
n, err := (*node)(t).addPath(path[1:])
if err != nil {
return err
}
n.leafValue = value
return nil
}
// Lookup tries to find a value in the tree associated to a path. If the found path definition contains
// wildcards, the values of the wildcards are returned in the second argument.
func (t *Tree) Lookup(path string) (interface{}, []string) {
node, params, _ := t.LookupMatcher(path, tm)
return node, params
}
// LookupMatcher tries to find value in the tree associated to a path. If the found path definition contains
// wildcards, the values of the wildcards are returned in the second argument. When a value is found,
// the matcher is called to check if the value meets the conditions implemented by the custom matcher. If it
// returns true, then the lookup is done and the additional return value from the matcher is returned as the
// lookup result. If it returns false, the lookup continues with backtracking from the current tree position.
func (t *Tree) LookupMatcher(path string, m Matcher) (interface{}, []string, interface{}) {
if path == "" {
path = "/"
}
node, params, value := (*node)(t).search(path[1:], m)
if node == nil {
return nil, nil, nil
}
return node.leafValue, params, value
}
package pathmux
// Exploded version of the pathmux tree designed for being easy to use in a visualization.
// Simple wildcard nodes are represented by the ':' prefix and free wildcard nodes with the '*' prefix.
type VizTree struct {
Path string // string representation of the node path
Children []*VizTree // children nodes of the current node
CanMatch bool // flag that is set to true if the node has a matcher
}
// Creates a new visualization tree from a pathmux.Tree.
func NewVizTree(tree *Tree) *VizTree {
panic("not implemented anymore")
}
package skipper
import (
"fmt"
"os"
"path/filepath"
"plugin"
"strings"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/routing"
)
func (o *Options) findAndLoadPlugins() error {
found := make(map[string]string)
done := make(map[string][]string)
for _, dir := range o.PluginDirs {
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
// don't fail when default plugin dir is missing
if _, ok := err.(*os.PathError); ok && dir == DefaultPluginDir {
return err
}
log.Fatalf("failed to search for plugins: %s", err)
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(path, ".so") {
name := filepath.Base(path)
name = name[:len(name)-3] // strip suffix
found[name] = path
log.Printf("found plugin %s at %s", name, path)
}
return nil
})
}
if err := o.loadPlugins(found, done); err != nil {
return err
}
if err := o.loadFilterPlugins(found, done); err != nil {
return err
}
if err := o.loadPredicatePlugins(found, done); err != nil {
return err
}
if err := o.loadDataClientPlugins(found, done); err != nil {
return err
}
for name, path := range found {
log.Printf("attempting to load plugin from %s", path)
mod, err := plugin.Open(path)
if err != nil {
return fmt.Errorf("open plugin %s from %s: %s", name, path, err)
}
conf, err := readPluginConfig(path)
if err != nil {
return fmt.Errorf("failed to read config for %s: %s", path, err)
}
if !pluginIsLoaded(done, name, "InitPlugin") {
if sym, err := mod.Lookup("InitPlugin"); err == nil {
fltrs, preds, dcs, err := initPlugin(sym, path, conf)
if err != nil {
return fmt.Errorf("filter plugin %s returned: %s", path, err)
}
o.CustomFilters = append(o.CustomFilters, fltrs...)
o.CustomPredicates = append(o.CustomPredicates, preds...)
o.CustomDataClients = append(o.CustomDataClients, dcs...)
log.Printf("multitype plugin %s loaded from %s (filter: %d, predicate: %d, dataclient: %d)",
name, path, len(fltrs), len(preds), len(dcs))
markPluginLoaded(done, name, "InitPlugin")
}
} else {
log.Printf("plugin %s already loaded with InitPlugin", name)
}
if !pluginIsLoaded(done, name, "InitFilter") {
if sym, err := mod.Lookup("InitFilter"); err == nil {
spec, err := initFilterPlugin(sym, path, conf)
if err != nil {
return fmt.Errorf("filter plugin %s returned: %s", path, err)
}
o.CustomFilters = append(o.CustomFilters, spec)
log.Printf("filter plugin %s loaded from %s", name, path)
markPluginLoaded(done, name, "InitFilter")
}
} else {
log.Printf("plugin %s already loaded with InitFilter", name)
}
if !pluginIsLoaded(done, name, "InitPredicate") {
if sym, err := mod.Lookup("InitPredicate"); err == nil {
spec, err := initPredicatePlugin(sym, path, conf)
if err != nil {
return fmt.Errorf("predicate plugin %s returned: %s", path, err)
}
o.CustomPredicates = append(o.CustomPredicates, spec)
log.Printf("predicate plugin %s loaded from %s", name, path)
markPluginLoaded(done, name, "InitPredicate")
}
} else {
log.Printf("plugin %s already loaded with InitPredicate", name)
}
if !pluginIsLoaded(done, name, "InitDataClient") {
if sym, err := mod.Lookup("InitDataClient"); err == nil {
spec, err := initDataClientPlugin(sym, path, conf)
if err != nil {
return fmt.Errorf("data client plugin %s returned: %s", path, err)
}
o.CustomDataClients = append(o.CustomDataClients, spec)
log.Printf("data client plugin %s loaded from %s", name, path)
markPluginLoaded(done, name, "InitDataClient")
}
} else {
log.Printf("plugin %s already loaded with InitDataClient", name)
}
}
var implementsMultiple []string
for name, specs := range done {
if len(specs) > 1 {
implementsMultiple = append(implementsMultiple, name)
}
}
if len(implementsMultiple) != 0 {
return fmt.Errorf("found plugins implementing multiple Init* functions: %v", implementsMultiple)
}
return nil
}
func pluginIsLoaded(done map[string][]string, name, spec string) bool {
loaded, ok := done[name]
if !ok {
return false
}
for _, s := range loaded {
if s == spec {
return true
}
}
return false
}
func markPluginLoaded(done map[string][]string, name, spec string) {
done[name] = append(done[name], spec)
}
func (o *Options) loadPlugins(found map[string]string, done map[string][]string) error {
for _, plug := range o.Plugins {
name := plug[0]
path, ok := found[name]
if !ok {
return fmt.Errorf("multitype plugin %s not found in plugin dirs", name)
}
fltrs, preds, dcs, err := loadPlugin(path, plug[1:])
if err != nil {
return fmt.Errorf("failed to load plugin %s: %s", path, err)
}
o.CustomFilters = append(o.CustomFilters, fltrs...)
o.CustomPredicates = append(o.CustomPredicates, preds...)
o.CustomDataClients = append(o.CustomDataClients, dcs...)
log.Printf("multitype plugin %s loaded from %s (filter: %d, predicate: %d, dataclient: %d)",
name, path, len(fltrs), len(preds), len(dcs))
markPluginLoaded(done, name, "InitPlugin")
}
return nil
}
func loadPlugin(path string, args []string) ([]filters.Spec, []routing.PredicateSpec, []routing.DataClient, error) {
mod, err := plugin.Open(path)
if err != nil {
return nil, nil, nil, fmt.Errorf("open multitype plugin %s: %s", path, err)
}
conf, err := readPluginConfig(path)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to read config for %s: %s", path, err)
}
sym, err := mod.Lookup("InitPlugin")
if err != nil {
return nil, nil, nil, fmt.Errorf("lookup module symbol failed for %s: %s", path, err)
}
return initPlugin(sym, path, append(conf, args...))
}
func initPlugin(sym plugin.Symbol, path string, args []string) ([]filters.Spec, []routing.PredicateSpec, []routing.DataClient, error) {
fn, ok := sym.(func([]string) ([]filters.Spec, []routing.PredicateSpec, []routing.DataClient, error))
if !ok {
return nil, nil, nil, fmt.Errorf("plugin %s's InitPlugin function has wrong signature", path)
}
fltrs, preds, dcs, err := fn(args)
if err != nil {
return nil, nil, nil, fmt.Errorf("plugin %s returned: %s", path, err)
}
return fltrs, preds, dcs, nil
}
func (o *Options) loadFilterPlugins(found map[string]string, done map[string][]string) error {
for _, fltr := range o.FilterPlugins {
name := fltr[0]
path, ok := found[name]
if !ok {
return fmt.Errorf("filter plugin %s not found in plugin dirs", name)
}
spec, err := loadFilterPlugin(path, fltr[1:])
if err != nil {
return fmt.Errorf("failed to load plugin %s: %s", path, err)
}
o.CustomFilters = append(o.CustomFilters, spec)
log.Printf("loaded plugin %s (%s) from %s", name, spec.Name(), path)
markPluginLoaded(done, name, "InitFilter")
}
return nil
}
func loadFilterPlugin(path string, args []string) (filters.Spec, error) {
mod, err := plugin.Open(path)
if err != nil {
return nil, fmt.Errorf("open filter plugin %s: %s", path, err)
}
conf, err := readPluginConfig(path)
if err != nil {
return nil, fmt.Errorf("failed to read config for %s: %s", path, err)
}
sym, err := mod.Lookup("InitFilter")
if err != nil {
return nil, fmt.Errorf("lookup module symbol failed for %s: %s", path, err)
}
return initFilterPlugin(sym, path, append(conf, args...))
}
func initFilterPlugin(sym plugin.Symbol, path string, args []string) (filters.Spec, error) {
fn, ok := sym.(func([]string) (filters.Spec, error))
if !ok {
return nil, fmt.Errorf("plugin %s's InitFilter function has wrong signature", path)
}
spec, err := fn(args)
if err != nil {
return nil, fmt.Errorf("plugin %s returned: %s", path, err)
}
return spec, nil
}
func (o *Options) loadPredicatePlugins(found map[string]string, done map[string][]string) error {
for _, pred := range o.PredicatePlugins {
name := pred[0]
path, ok := found[name]
if !ok {
return fmt.Errorf("predicate plugin %s not found in plugin dirs", name)
}
spec, err := loadPredicatePlugin(path, pred[1:])
if err != nil {
return fmt.Errorf("failed to load plugin %s: %s", path, err)
}
o.CustomPredicates = append(o.CustomPredicates, spec)
log.Printf("loaded plugin %s (%s) from %s", name, spec.Name(), path)
markPluginLoaded(done, name, "InitPredicate")
}
return nil
}
func loadPredicatePlugin(path string, args []string) (routing.PredicateSpec, error) {
mod, err := plugin.Open(path)
if err != nil {
return nil, fmt.Errorf("open predicate module %s: %s", path, err)
}
conf, err := readPluginConfig(path)
if err != nil {
return nil, fmt.Errorf("failed to read config for %s: %s", path, err)
}
sym, err := mod.Lookup("InitPredicate")
if err != nil {
return nil, fmt.Errorf("lookup module symbol failed for %s: %s", path, err)
}
return initPredicatePlugin(sym, path, append(conf, args...))
}
func initPredicatePlugin(sym plugin.Symbol, path string, args []string) (routing.PredicateSpec, error) {
fn, ok := sym.(func([]string) (routing.PredicateSpec, error))
if !ok {
return nil, fmt.Errorf("plugin %s's InitPredicate function has wrong signature", path)
}
spec, err := fn(args)
if err != nil {
return nil, fmt.Errorf("plugin %s returned: %s", path, err)
}
return spec, nil
}
func (o *Options) loadDataClientPlugins(found map[string]string, done map[string][]string) error {
for _, pred := range o.DataClientPlugins {
name := pred[0]
path, ok := found[name]
if !ok {
return fmt.Errorf("data client plugin %s not found in plugin dirs", name)
}
spec, err := loadDataClientPlugin(path, pred[1:])
if err != nil {
return fmt.Errorf("failed to load plugin %s: %s", path, err)
}
o.CustomDataClients = append(o.CustomDataClients, spec)
log.Printf("loaded plugin %s from %s", name, path)
markPluginLoaded(done, name, "InitDataClient")
}
return nil
}
func loadDataClientPlugin(path string, args []string) (routing.DataClient, error) {
mod, err := plugin.Open(path)
if err != nil {
return nil, fmt.Errorf("open data client module %s: %s", path, err)
}
conf, err := readPluginConfig(path)
if err != nil {
return nil, fmt.Errorf("failed to read config for %s: %s", path, err)
}
sym, err := mod.Lookup("InitDataClient")
if err != nil {
return nil, fmt.Errorf("lookup module symbol failed for %s: %s", path, err)
}
return initDataClientPlugin(sym, path, append(conf, args...))
}
func initDataClientPlugin(sym plugin.Symbol, path string, args []string) (routing.DataClient, error) {
fn, ok := sym.(func([]string) (routing.DataClient, error))
if !ok {
return nil, fmt.Errorf("plugin %s's InitDataClient function has wrong signature", path)
}
spec, err := fn(args)
if err != nil {
return nil, fmt.Errorf("module %s returned: %s", path, err)
}
return spec, nil
}
func readPluginConfig(plugin string) (conf []string, err error) {
data, err := os.ReadFile(plugin[:len(plugin)-3] + ".conf")
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line != "" && line[0] != '#' {
conf = append(conf, line)
}
}
return conf, nil
}
package auth
import (
"crypto/sha256"
"encoding/hex"
"net/http"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
type headerSha256Spec struct{}
type headerSha256Predicate struct {
name string
hashes [][sha256.Size]byte
}
// NewHeaderSHA256 creates a predicate specification, whose instances match SHA-256 hash of the header value.
// The HeaderSHA256 predicate requires the header name and one or more hex-encoded SHA-256 hash values of the matching header.
func NewHeaderSHA256() routing.PredicateSpec {
return &headerSha256Spec{}
}
func (*headerSha256Spec) Name() string {
return predicates.HeaderSHA256Name
}
// Create a predicate instance matching SHA256 hash of the header value
func (*headerSha256Spec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) < 2 {
return nil, predicates.ErrInvalidPredicateParameters
}
name, ok := args[0].(string)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
args = args[1:]
hashes := make([][sha256.Size]byte, 0, len(args))
for _, arg := range args {
hexHash, ok := arg.(string)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
hash, err := hex.DecodeString(hexHash)
if err != nil {
return nil, err
}
if len(hash) != sha256.Size {
return nil, predicates.ErrInvalidPredicateParameters
}
hashes = append(hashes, ([sha256.Size]byte)(hash))
}
return &headerSha256Predicate{name, hashes}, nil
}
func (p *headerSha256Predicate) Match(r *http.Request) bool {
value := r.Header.Get(p.name)
if value == "" {
return false
}
valueHash := sha256.Sum256([]byte(value))
for _, hash := range p.hashes {
if valueHash == hash {
return true
}
}
return false
}
/*
Package auth implements custom predicates to match based on content
of the HTTP Authorization header.
This predicate can be used to match a route based on data in the 2nd
part of a JWT token, for example based on the issuer.
Examples:
// one key value pair has to match
example1: JWTPayloadAnyKV("iss", "https://accounts.google.com", "email", "skipper-router@googlegroups.com")
-> "http://example.org/";
// all key value pairs have to match
example2: * && JWTPayloadAllKV("iss", "https://accounts.google.com", "email", "skipper-router@googlegroups.com")
-> "http://example.org/";
*/
package auth
import (
"net/http"
"regexp"
"strings"
"github.com/zalando/skipper/jwt"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
const (
authHeaderName = "Authorization"
authHeaderPrefix = "Bearer "
)
type (
matchBehavior int
matchMode int
)
type valueMatcher interface {
Match(jwtValue string) bool
}
const (
matchBehaviorAll matchBehavior = iota
matchBehaviorAny
matchModeExact matchMode = iota
matchModeRegexp
)
type (
spec struct {
name string
matchBehavior matchBehavior
matchMode matchMode
}
predicate struct {
kv map[string][]valueMatcher
matchBehavior matchBehavior
}
exactMatcher struct {
expected string
}
regexMatcher struct {
regexp *regexp.Regexp
}
)
func NewJWTPayloadAnyKV() routing.PredicateSpec {
return &spec{
name: predicates.JWTPayloadAnyKVName,
matchBehavior: matchBehaviorAny,
matchMode: matchModeExact,
}
}
func NewJWTPayloadAllKV() routing.PredicateSpec {
return &spec{
name: predicates.JWTPayloadAllKVName,
matchBehavior: matchBehaviorAll,
matchMode: matchModeExact,
}
}
func NewJWTPayloadAnyKVRegexp() routing.PredicateSpec {
return &spec{
name: predicates.JWTPayloadAnyKVRegexpName,
matchBehavior: matchBehaviorAny,
matchMode: matchModeRegexp,
}
}
func NewJWTPayloadAllKVRegexp() routing.PredicateSpec {
return &spec{
name: predicates.JWTPayloadAllKVRegexpName,
matchBehavior: matchBehaviorAll,
matchMode: matchModeRegexp,
}
}
func (s *spec) Name() string {
return s.name
}
func (s *spec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) == 0 || len(args)%2 != 0 {
return nil, predicates.ErrInvalidPredicateParameters
}
kv := make(map[string][]valueMatcher)
for i := 0; i < len(args); i += 2 {
key, keyOk := args[i].(string)
value, valueOk := args[i+1].(string)
if !keyOk || !valueOk {
return nil, predicates.ErrInvalidPredicateParameters
}
var matcher valueMatcher
switch s.matchMode {
case matchModeExact:
matcher = exactMatcher{expected: value}
case matchModeRegexp:
re, err := regexp.Compile(value)
if err != nil {
return nil, predicates.ErrInvalidPredicateParameters
}
matcher = regexMatcher{regexp: re}
default:
return nil, predicates.ErrInvalidPredicateParameters
}
kv[key] = append(kv[key], matcher)
}
return &predicate{
kv: kv,
matchBehavior: s.matchBehavior,
}, nil
}
func (m exactMatcher) Match(jwtValue string) bool {
return jwtValue == m.expected
}
func (m regexMatcher) Match(jwtValue string) bool {
return m.regexp.MatchString(jwtValue)
}
func (p *predicate) Match(r *http.Request) bool {
ahead := r.Header.Get(authHeaderName)
tv := strings.TrimPrefix(ahead, authHeaderPrefix)
if tv == ahead {
return false
}
token, err := jwt.Parse(tv)
if err != nil {
return false
}
switch p.matchBehavior {
case matchBehaviorAll:
return allMatch(p.kv, token.Claims)
case matchBehaviorAny:
return anyMatch(p.kv, token.Claims)
default:
return false
}
}
func stringValue(payload map[string]interface{}, key string) (string, bool) {
if value, ok := payload[key]; ok {
result, ok := value.(string)
return result, ok
}
return "", false
}
func allMatch(expected map[string][]valueMatcher, payload map[string]interface{}) bool {
if len(expected) > len(payload) {
return false
}
for key, expectedValues := range expected {
payloadValue, ok := stringValue(payload, key)
if !ok {
return false
}
// expectedValues is expected to be a slice of one value
for _, expectedValue := range expectedValues {
if !expectedValue.Match(payloadValue) {
return false
}
}
}
return true
}
func anyMatch(expected map[string][]valueMatcher, payload map[string]interface{}) bool {
if len(expected) == 0 {
return true
}
for key, expectedValues := range expected {
if payloadValue, ok := stringValue(payload, key); ok {
for _, expectedValue := range expectedValues {
if expectedValue.Match(payloadValue) {
return true
}
}
}
}
return false
}
package content
import (
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
"net/http"
)
type contentLengthBetweenSpec struct{}
type contentLengthBetweenPredicate struct {
min int64
max int64
}
// NewContentLengthBetween creates a predicate specification,
// whose instances match content length header value in range from min (inclusively) to max (exclusively).
// example: ContentLengthBetween(0, 5000)
func NewContentLengthBetween() routing.PredicateSpec { return &contentLengthBetweenSpec{} }
func (*contentLengthBetweenSpec) Name() string {
return predicates.ContentLengthBetweenName
}
// Create a predicate instance that evaluates content length header value range
func (*contentLengthBetweenSpec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) != 2 {
return nil, predicates.ErrInvalidPredicateParameters
}
x, ok := args[0].(float64)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
minLength := int64(x)
x, ok = args[1].(float64)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
maxLength := int64(x)
if minLength < 0 || maxLength < 0 || minLength >= maxLength {
return nil, predicates.ErrInvalidPredicateParameters
}
return &contentLengthBetweenPredicate{
min: minLength,
max: maxLength,
}, nil
}
func (p *contentLengthBetweenPredicate) Match(req *http.Request) bool {
return req.ContentLength >= p.min && req.ContentLength < p.max
}
/*
Package cookie implements predicate to check parsed cookie headers by name and value.
*/
package cookie
import (
"net/http"
"regexp"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
// Name the predicate can be referenced in eskip by the name "Cookie".
// Deprecated, use predicates.CookieName instead
const Name = predicates.CookieName
type (
spec struct{}
predicate struct {
name string
valueExp *regexp.Regexp
}
)
// New creates a predicate specification, whose instances can be used to match parsed request cookies.
//
// The cookie predicate accepts two arguments, the cookie name, with what a cookie must exist in the request,
// and an expression that the cookie value needs to match.
//
// Eskip example:
//
// Cookie("tcial", /^enabled$/) -> "https://www.example.org";
func New() routing.PredicateSpec { return &spec{} }
func (s *spec) Name() string { return predicates.CookieName }
func (s *spec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) != 2 {
return nil, predicates.ErrInvalidPredicateParameters
}
name, ok := args[0].(string)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
value, ok := args[1].(string)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
valueExp, err := regexp.Compile(value)
if err != nil {
return nil, err
}
return &predicate{name, valueExp}, nil
}
func (p *predicate) Match(r *http.Request) bool {
c, err := r.Cookie(p.name)
if err != nil {
return false
}
return p.valueExp.MatchString(c.Value)
}
/*
Package cron implements custom predicates to match routes
only when they also match the system time matches the given
cron-like expressions.
Package includes a single predicate: Cron.
For supported & unsupported features refer to the "cronmask" package
documentation (https://github.com/sarslanhan/cronmask).
*/
package cron
import (
"github.com/sarslanhan/cronmask"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
"net/http"
"time"
)
type clock func() time.Time
type spec struct {
}
func (*spec) Name() string {
return predicates.CronName
}
func (*spec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) != 1 {
return nil, predicates.ErrInvalidPredicateParameters
}
expr, ok := args[0].(string)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
mask, err := cronmask.New(expr)
if err != nil {
return nil, err
}
return &predicate{
mask: mask,
getTime: time.Now,
}, nil
}
type predicate struct {
mask *cronmask.CronMask
getTime clock
}
func (p *predicate) Match(r *http.Request) bool {
now := p.getTime()
return p.mask.Match(now)
}
func New() routing.PredicateSpec {
return &spec{}
}
/*
Package forwarded implements a set of custom predicate to match routes
based on the standardized Forwarded header.
https://datatracker.ietf.org/doc/html/rfc7239
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded
Examples:
// only match requests to "example.com"
example1: ForwardedHost("example.com") -> "http://example.org";
// only match requests to http
example2: ForwardedProtocol("http") -> "http://example.org";
// only match requests to https
example3: ForwardedProtocol("https") -> "http://example.org";
*/
package forwarded
import (
"net/http"
"regexp"
"strings"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
const (
// Deprecated, use predicates.ForwardedHostName instead
NameHost = predicates.ForwardedHostName
// Deprecated, use predicates.ForwardedProtocolName instead
NameProto = predicates.ForwardedProtocolName
)
type hostPredicateSpec struct{}
type protoPredicateSpec struct{}
type hostPredicate struct {
host *regexp.Regexp
}
func (p *hostPredicateSpec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) != 1 {
return nil, predicates.ErrInvalidPredicateParameters
}
value, ok := args[0].(string)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
if value == "" {
return nil, predicates.ErrInvalidPredicateParameters
}
re, err := regexp.Compile(value)
if err != nil {
return nil, err
}
return hostPredicate{host: re}, err
}
type protoPredicate struct {
proto string
}
func (p *protoPredicateSpec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) != 1 {
return nil, predicates.ErrInvalidPredicateParameters
}
value, ok := args[0].(string)
if !ok {
return nil, predicates.ErrInvalidPredicateParameters
}
switch value {
case "http", "https":
return protoPredicate{proto: value}, nil
default:
return nil, predicates.ErrInvalidPredicateParameters
}
}
func NewForwardedHost() routing.PredicateSpec { return &hostPredicateSpec{} }
func NewForwardedProto() routing.PredicateSpec { return &protoPredicateSpec{} }
func (p *hostPredicateSpec) Name() string {
return predicates.ForwardedHostName
}
func (p *protoPredicateSpec) Name() string {
return predicates.ForwardedProtocolName
}
func (p hostPredicate) Match(r *http.Request) bool {
fh := r.Header.Get("Forwarded")
if fh == "" {
return false
}
fw := parseForwarded(fh)
return p.host.MatchString(fw.host)
}
func (p protoPredicate) Match(r *http.Request) bool {
fh := r.Header.Get("Forwarded")
if fh == "" {
return false
}
fw := parseForwarded(fh)
return p.proto == fw.proto
}
type forwarded struct {
host string
proto string
}
func parseForwarded(fh string) *forwarded {
f := &forwarded{}
for forwardedFull := range splitSeq(fh, ",") {
for forwardedPair := range splitSeq(strings.TrimSpace(forwardedFull), ";") {
token, value, found := strings.Cut(forwardedPair, "=")
value = strings.Trim(value, `"`)
if found && value != "" {
switch token {
case "proto":
f.proto = value
case "host":
f.host = value
}
}
}
}
return f
}
// TODO: use [strings.SplitSeq] added in go1.24 once go1.25 is released.
func splitSeq(s string, sep string) func(yield func(string) bool) {
return func(yield func(string) bool) {
for {
i := strings.Index(s, sep)
if i < 0 {
break
}
frag := s[:i]
if !yield(frag) {
return
}
s = s[i+len(sep):]
}
yield(s)
}
}
package host
import (
"net/http"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
type anySpec struct{}
type anyPredicate struct {
hosts []string
}
// NewAny creates a predicate specification, whose instances match request host.
//
// The HostAny predicate requires one or more string hostnames and matches if request host
// exactly equals to any of the hostnames.
func NewAny() routing.PredicateSpec { return &anySpec{} }
func (*anySpec) Name() string {
return predicates.HostAnyName
}
// Create a predicate instance that always evaluates to true
func (*anySpec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) == 0 {
return nil, predicates.ErrInvalidPredicateParameters
}
p := &anyPredicate{}
for _, arg := range args {
if host, ok := arg.(string); ok {
p.hosts = append(p.hosts, host)
} else {
return nil, predicates.ErrInvalidPredicateParameters
}
}
return p, nil
}
func (ap *anyPredicate) Match(r *http.Request) bool {
for _, host := range ap.hosts {
if host == r.Host {
return true
}
}
return false
}
/*
Package interval implements custom predicates to match routes
only during some period of time.
Package includes three predicates: Between, Before and After.
All predicates can be created using the date represented as:
- a string in RFC3339 format (see https://golang.org/pkg/time/#pkg-constants)
- a string in RFC3339 format without numeric timezone offset and a location name (see https://golang.org/pkg/time/#LoadLocation)
- an int64 or float64 number corresponding to the given Unix time in seconds since January 1, 1970 UTC.
float64 number will be converted into int64 number.
Between predicate matches only if current date is inside the specified
range of dates. Between predicate requires two dates to be constructed.
Upper boundary must be after lower boundary. Range includes the lower
boundary, but excludes the upper boundary.
Before predicate matches only if current date is before the specified
date. Only one date is required to construct the predicate.
After predicate matches only if current date is after or equal to
the specified date. Only one date is required to construct the predicate.
Examples:
example1: Path("/zalando") && Between("2016-01-01T12:00:00+02:00", "2016-02-01T12:00:00+02:00") -> "https://www.zalando.de";
example2: Path("/zalando") && Between(1451642400, 1454320800) -> "https://www.zalando.de";
example3: Path("/zalando") && Before("2016-02-01T12:00:00+02:00") -> "https://www.zalando.de";
example4: Path("/zalando") && Before(1454320800) -> "https://www.zalando.de";
example5: Path("/zalando") && After("2016-01-01T12:00:00+02:00") -> "https://www.zalando.de";
example6: Path("/zalando") && After(1451642400) -> "https://www.zalando.de";
example7: Path("/zalando") && Between("2021-02-18T00:00:00", "2021-02-18T01:00:00", "Europe/Berlin") -> "https://www.zalando.de";
example8: Path("/zalando") && Before("2021-02-18T00:00:00", "Europe/Berlin") -> "https://www.zalando.de";
example9: Path("/zalando") && After("2021-02-18T00:00:00", "Europe/Berlin") -> "https://www.zalando.de";
*/
package interval
import (
"net/http"
"time"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
type spec int
const (
between spec = iota
before
after
)
const rfc3339nz = "2006-01-02T15:04:05" // RFC3339 without numeric timezone offset
type predicate struct {
typ spec
begin time.Time
end time.Time
getTime func() time.Time
}
// Creates Between predicate.
func NewBetween() routing.PredicateSpec { return between }
// Creates Before predicate.
func NewBefore() routing.PredicateSpec { return before }
// Creates After predicate.
func NewAfter() routing.PredicateSpec { return after }
func (s spec) Name() string {
switch s {
case between:
return predicates.BetweenName
case before:
return predicates.BeforeName
case after:
return predicates.AfterName
default:
panic("invalid interval predicate type")
}
}
func (s spec) Create(args []interface{}) (routing.Predicate, error) {
p := predicate{typ: s, getTime: time.Now}
var loc *time.Location
switch {
case
s == between && len(args) == 3 && parseLocation(args[2], &loc) && parseRFCnz(args[0], &p.begin, loc) && parseRFCnz(args[1], &p.end, loc) && p.begin.Before(p.end),
s == between && len(args) == 2 && parseRFC(args[0], &p.begin) && parseRFC(args[1], &p.end) && p.begin.Before(p.end),
s == between && len(args) == 2 && parseUnix(args[0], &p.begin) && parseUnix(args[1], &p.end) && p.begin.Before(p.end),
s == before && len(args) == 2 && parseLocation(args[1], &loc) && parseRFCnz(args[0], &p.end, loc),
s == before && len(args) == 1 && parseRFC(args[0], &p.end),
s == before && len(args) == 1 && parseUnix(args[0], &p.end),
s == after && len(args) == 2 && parseLocation(args[1], &loc) && parseRFCnz(args[0], &p.begin, loc),
s == after && len(args) == 1 && parseRFC(args[0], &p.begin),
s == after && len(args) == 1 && parseUnix(args[0], &p.begin):
return &p, nil
}
return nil, predicates.ErrInvalidPredicateParameters
}
func parseUnix(arg interface{}, t *time.Time) bool {
switch a := arg.(type) {
case float64:
*t = time.Unix(int64(a), 0)
return true
case int64:
*t = time.Unix(a, 0)
return true
}
return false
}
func parseRFC(arg interface{}, t *time.Time) bool {
if s, ok := arg.(string); ok {
tt, err := time.Parse(time.RFC3339, s)
if err == nil {
*t = tt
return true
}
}
return false
}
func parseRFCnz(arg interface{}, t *time.Time, loc *time.Location) bool {
if s, ok := arg.(string); ok {
tt, err := time.ParseInLocation(rfc3339nz, s, loc)
if err == nil {
*t = tt
return true
}
}
return false
}
func parseLocation(arg interface{}, loc **time.Location) bool {
if s, ok := arg.(string); ok {
location, err := time.LoadLocation(s)
if err == nil {
*loc = location
return true
}
}
return false
}
func (p *predicate) Match(r *http.Request) bool {
now := p.getTime()
switch p.typ {
case between:
return (p.begin.Before(now) || p.begin.Equal(now)) && p.end.After(now)
case before:
return p.end.After(now)
case after:
return p.begin.Before(now) || p.begin.Equal(now)
default:
return false
}
}
/*
Package methods implements a custom predicate to match routes
based on the http method in request
# It supports multiple http methods, with case insensitive input
Examples:
// matches GET request
example1: Methods("GET") -> "http://example.org";
// matches GET or POST request
example1: Methods("GET", "post") -> "http://example.org";
*/
package methods
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
// Name is the predicate name
// Deprecated, use predicates.MethodsName instead
const Name = predicates.MethodsName
var ErrInvalidArgumentsCount = errors.New("at least one method should be specified")
var ErrInvalidArgumentType = errors.New("only string values are allowed")
type (
spec struct {
allowedMethods map[string]bool
}
predicate struct {
methods map[string]bool
}
)
// New creates a new Methods predicate specification
func New() routing.PredicateSpec {
return &spec{allowedMethods: map[string]bool{
http.MethodGet: true,
http.MethodHead: true,
http.MethodPost: true,
http.MethodPut: true,
http.MethodPatch: true,
http.MethodDelete: true,
http.MethodConnect: true,
http.MethodOptions: true,
http.MethodTrace: true,
}}
}
func (s *spec) Name() string { return predicates.MethodsName }
func (s *spec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) == 0 {
return nil, ErrInvalidArgumentsCount
}
predicate := predicate{}
predicate.methods = map[string]bool{}
for _, arg := range args {
method, isString := arg.(string)
if !isString {
return nil, ErrInvalidArgumentType
}
method = strings.ToUpper(method)
if s.allowedMethods[method] {
predicate.methods[method] = true
} else {
return nil, fmt.Errorf("method: %s is not allowed", method)
}
}
return &predicate, nil
}
func (p *predicate) Match(r *http.Request) bool {
return p.methods[strings.ToUpper(r.Method)]
}
package primitive
import (
"net/http"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
const (
// Deprecated, use predicates.FalseName instead
NameFalse = predicates.FalseName
)
type falseSpec struct{}
type falsePredicate struct{}
// NewFalse provides a predicate spec to create a Predicate instance that evaluates to false
func NewFalse() routing.PredicateSpec { return &falseSpec{} }
func (*falseSpec) Name() string {
return predicates.FalseName
}
// Create a predicate instance that always evaluates to false
func (*falseSpec) Create(args []interface{}) (routing.Predicate, error) {
return &falsePredicate{}, nil
}
func (*falsePredicate) Match(*http.Request) bool {
return false
}
package primitive
import (
"net/http"
"os"
"os/signal"
"sync/atomic"
"syscall"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
log "github.com/sirupsen/logrus"
)
type shutdown struct {
inShutdown int32
}
// NewShutdown provides a predicate spec to create predicates
// that evaluate to true if Skipper is shutting down
func NewShutdown() routing.PredicateSpec {
s, _ := newShutdown()
return s
}
func newShutdown() (routing.PredicateSpec, chan os.Signal) {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGTERM)
s := &shutdown{}
go func() {
<-sigs
log.Infof("Got shutdown signal for %s predicates", s.Name())
atomic.StoreInt32(&s.inShutdown, 1)
}()
return s, sigs
}
func (*shutdown) Name() string { return predicates.ShutdownName }
// Create returns a Predicate that evaluates to true if Skipper is shutting down
func (s *shutdown) Create(args []interface{}) (routing.Predicate, error) {
if len(args) != 0 {
return nil, predicates.ErrInvalidPredicateParameters
}
return s, nil
}
func (s *shutdown) Match(*http.Request) bool {
return atomic.LoadInt32(&s.inShutdown) != 0
}
package primitive
import (
"net/http"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
const (
// Deprecated, use predicates.TrueName instead
NameTrue = predicates.TrueName
)
type trueSpec struct{}
type truePredicate struct{}
// NewTrue provides a predicate spec to create a Predicate instance that evaluates to true
func NewTrue() routing.PredicateSpec { return &trueSpec{} }
func (*trueSpec) Name() string {
return predicates.TrueName
}
// Create a predicate instance that always evaluates to true
func (*trueSpec) Create(args []interface{}) (routing.Predicate, error) {
return &truePredicate{}, nil
}
func (*truePredicate) Match(*http.Request) bool {
return true
}
/*
Package source implements a custom predicate to match routes
based on the Query Params in URL
It supports checking existence of query params and also checking whether
query params value match to a given regular exp
Examples:
// Checking existence of a query param
// matches http://example.org?bb=a&query=withvalue
example1: QueryParam("query") -> "http://example.org";
// Even a query param without a value
// matches http://example.org?bb=a&query=
example1: QueryParam("query") -> "http://example.org";
// matches with regexp
// matches http://example.org?bb=a&query=example
example1: QueryParam("query", "^example$") -> "http://example.org";
// matches with regexp and multiple values of query param
// matches http://example.org?bb=a&query=testing&query=example
example1: QueryParam("query", "^example$") -> "http://example.org";
*/
package query
import (
"net/http"
"regexp"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
type matchType int
const (
exists matchType = iota + 1
matches
)
type predicate struct {
typ matchType
paramName string
valueExp *regexp.Regexp
}
type spec struct{}
// New creates a new QueryParam predicate specification.
func New() routing.PredicateSpec { return &spec{} }
func (s *spec) Name() string {
return predicates.QueryParamName
}
func (s *spec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) == 0 || len(args) > 2 {
return nil, predicates.ErrInvalidPredicateParameters
}
name, ok1 := args[0].(string)
switch {
case !ok1:
return nil, predicates.ErrInvalidPredicateParameters
case len(args) == 1:
return &predicate{exists, name, nil}, nil
case len(args) == 2:
value, ok2 := args[1].(string)
if !ok2 {
return nil, predicates.ErrInvalidPredicateParameters
}
valueExp, err := regexp.Compile(value)
if err != nil {
return nil, err
}
return &predicate{matches, name, valueExp}, nil
default:
return nil, predicates.ErrInvalidPredicateParameters
}
}
func (p *predicate) Match(r *http.Request) bool {
queryMap := r.URL.Query()
vals, ok := queryMap[p.paramName]
switch p.typ {
case exists:
return ok
case matches:
if !ok {
return false
} else {
for _, v := range vals {
if p.valueExp.MatchString(v) {
return true
}
}
return false
}
}
return false
}
/*
Package source implements a custom predicate to match routes
based on the source IP of a request.
It is similar in function and usage to the header predicate but
has explicit support for IP addresses and netmasks to conveniently
create routes based on a whole network of addresses, like a company
network or something similar.
It is important to note, that this predicate should not be used as
the only gatekeeper for secure endpoints. Always use proper authorization
and authentication for access control!
To enable usage of this predicate behind load balancers or proxies, the
X-Forwared-For header is used to determine the source of a request if it
is available. If the X-Forwarded-For header is not present or does not contain
a valid source address, the source IP of the incoming request is used for
matching.
The source predicate supports one or more IP addresses with or without a netmask.
There are two flavors of this predicate Source() and SourceFromLast().
The difference is that Source() finds the remote host as first entry from
the X-Forwarded-For header and SourceFromLast() as last entry.
Examples:
// only match requests from 1.2.3.4
example1: Source("1.2.3.4") -> "http://example.org";
// only match requests from 1.2.3.0 - 1.2.3.255
example2: Source("1.2.3.0/24") -> "http://example.org";
// only match requests from 1.2.3.4 and the 2.2.2.0/24 network
example3: Source("1.2.3.4", "2.2.2.0/24") -> "http://example.org";
// same as example3, only match requests from 1.2.3.4 and the 2.2.2.0/24 network
example4: SourceFromLast("1.2.3.4", "2.2.2.0/24") -> "http://example.org";
*/
package source
import (
"errors"
"net"
"net/http"
"net/netip"
snet "github.com/zalando/skipper/net"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
"go4.org/netipx"
)
const (
// Deprecated, use predicates.SourceName instead
Name = predicates.SourceName
// Deprecated, use predicates.SourceFromLastName instead
NameLast = predicates.SourceFromLastName
// Deprecated, use predicates.ClientIPName instead
NameClientIP = predicates.ClientIPName
)
var errInvalidArgs = errors.New("invalid arguments")
type sourcePred int
const (
source sourcePred = iota
sourceFromLast
clientIP
)
type spec struct {
typ sourcePred
}
type predicate struct {
typ sourcePred
nets *netipx.IPSet
}
func New() routing.PredicateSpec { return &spec{typ: source} }
func NewFromLast() routing.PredicateSpec { return &spec{typ: sourceFromLast} }
func NewClientIP() routing.PredicateSpec { return &spec{typ: clientIP} }
func (s *spec) Name() string {
switch s.typ {
case sourceFromLast:
return predicates.SourceFromLastName
case clientIP:
return predicates.ClientIPName
default:
return predicates.SourceName
}
}
func (s *spec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) == 0 {
return nil, errInvalidArgs
}
var cidrs []string
for i := range args {
if s, ok := args[i].(string); ok {
cidrs = append(cidrs, s)
} else {
return nil, errInvalidArgs
}
}
nets, err := snet.ParseIPCIDRs(cidrs)
if err != nil {
return nil, err
}
return &predicate{s.typ, nets}, nil
}
func (p *predicate) Match(r *http.Request) bool {
var src netip.Addr
switch p.typ {
case sourceFromLast:
src = snet.RemoteAddrFromLast(r)
case clientIP:
h, _, _ := net.SplitHostPort(r.RemoteAddr)
src, _ = netip.ParseAddr(h)
default:
src = snet.RemoteAddr(r)
}
return p.nets.Contains(src)
}
package tee
import (
"net/http"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
const (
// Deprecated, use predicates.TeeName instead
PredicateName = predicates.TeeName
HeaderKey = "x-tee-loopback-key"
)
type spec struct{}
type predicate struct {
key string
}
func New() routing.PredicateSpec { return &spec{} }
func (s *spec) Name() string { return predicates.TeeName }
func (s *spec) Create(args []interface{}) (routing.Predicate, error) {
if len(args) != 1 {
return nil, predicates.ErrInvalidPredicateParameters
}
teeKey, _ := args[0].(string)
if teeKey == "" {
return nil, predicates.ErrInvalidPredicateParameters
}
return &predicate{
key: teeKey,
}, nil
}
func (p *predicate) Match(r *http.Request) bool {
v := r.Header.Get(HeaderKey)
return v == p.key
}
package traffic
import (
"math/rand"
"net/http"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
type (
segmentSpec struct {
randFloat64 func() float64
}
segmentPredicate struct {
randFloat64 func() float64
min, max float64
}
)
type contextKey struct{}
var randomValue contextKey
// NewSegment creates a new traffic segment predicate specification
func NewSegment() routing.WeightedPredicateSpec {
return &segmentSpec{rand.Float64}
}
func (*segmentSpec) Name() string {
return predicates.TrafficSegmentName
}
// Create new predicate instance with two number arguments _min_ and _max_
// from an interval [0, 1] (from zero included to one included) and _min_ <= _max_.
//
// Let _r_ be one-per-request uniform random number value from [0, 1).
// This predicate matches if _r_ belongs to an interval from [_min_, _max_).
// Upper interval boundary _max_ is excluded to simplify definition of
// adjacent intervals - the upper boundary of the first interval
// then equals lower boundary of the next and so on, e.g. [0, 0.25) and [0.25, 1).
//
// This predicate has weight of -1 and therefore does not affect route weight.
//
// Example of routes splitting traffic in 50%+30%+20% proportion:
//
// r50: Path("/test") && TrafficSegment(0.0, 0.5) -> <shunt>;
// r30: Path("/test") && TrafficSegment(0.5, 0.8) -> <shunt>;
// r20: Path("/test") && TrafficSegment(0.8, 1.0) -> <shunt>;
func (s *segmentSpec) Create(args []any) (routing.Predicate, error) {
if len(args) != 2 {
return nil, predicates.ErrInvalidPredicateParameters
}
p, ok := &segmentPredicate{randFloat64: s.randFloat64}, false
if p.min, ok = args[0].(float64); !ok || p.min < 0 || p.min > 1 {
return nil, predicates.ErrInvalidPredicateParameters
}
if p.max, ok = args[1].(float64); !ok || p.max < 0 || p.max > 1 {
return nil, predicates.ErrInvalidPredicateParameters
}
// min == max defines a never-matching interval, e.g. "owl interval" [0,0)
if p.min > p.max {
return nil, predicates.ErrInvalidPredicateParameters
}
return p, nil
}
// Weight returns -1.
// By returning -1 this predicate does not affect route weight.
func (*segmentSpec) Weight() int {
return -1
}
func (p *segmentPredicate) Match(req *http.Request) bool {
r := routing.FromContext(req.Context(), randomValue, p.randFloat64)
// min == max defines a never-matching interval and always yields false
return p.min <= r && r < p.max
}
/*
Package traffic implements a predicate to control the matching
probability for a given route by setting its weight.
The probability for matching a route is defined by the mandatory first
parameter, that must be a decimal number between 0.0 and 1.0 (both
inclusive).
The optional second argument is used to specify the cookie name for
the traffic group, in case you want to use stickiness. Stickiness
allows all subsequent requests from the same client to match the same
route. Stickiness of traffic is supported by the optional third
parameter, indicating whether the request being matched belongs to the
traffic group of the current route. If yes, the predicate matches
ignoring the chance argument.
You always have to specify one argument, if you do not need stickiness,
and three arguments, if your service requires stickiness.
Predicates cannot modify the response, so the responsibility of
setting the traffic group cookie remains to either a filter or the
backend service.
The below example, shows a possible eskip document used for green-blue
deployments of APIS, which usually don't require stickiness:
// hit by 10% percent chance
v2:
Traffic(.1) ->
"https://api-test-green";
// hit by remaining chance
v1:
"https://api-test-blue";
The below example, shows a possible eskip document with two,
independent traffic controlled route sets, which uses session stickiness:
// hit by 5% percent chance
cartTest:
Traffic(.05, "cart-test", "test") && Path("/cart") ->
responseCookie("cart-test", "test") ->
"https://cart-test";
// hit by remaining chance
cart:
Path("/cart") ->
responseCookie("cart-test", "default") ->
"https://cart";
// hit by 15% percent chance
catalogTestA:
Traffic(.15, "catalog-test", "A") ->
responseCookie("catalog-test", "A") ->
"https://catalog-test-a";
// hit by 30% percent chance
catalogTestB:
Traffic(.3, "catalog-test", "B") ->
responseCookie("catalog-test", "B") ->
"https://catalog-test-b";
// hit by remaining chance
catalog:
* ->
responseCookie("catalog-test", "default") ->
"https://catalog";
*/
package traffic
import (
"math/rand"
"net/http"
"github.com/zalando/skipper/predicates"
"github.com/zalando/skipper/routing"
)
const (
// Deprecated, use predicates.TrafficName instead
PredicateName = predicates.TrafficName
)
type spec struct{}
type predicate struct {
chance float64
trafficGroup string
trafficGroupCookie string
}
// Creates a new traffic control predicate specification.
func New() routing.PredicateSpec { return &spec{} }
func (s *spec) Name() string { return predicates.TrafficName }
func (s *spec) Create(args []interface{}) (routing.Predicate, error) {
if !(len(args) == 1 || len(args) == 3) {
return nil, predicates.ErrInvalidPredicateParameters
}
p := &predicate{}
if c, ok := args[0].(float64); ok && 0.0 <= c && c <= 1.0 {
p.chance = c
} else {
return nil, predicates.ErrInvalidPredicateParameters
}
if len(args) == 3 {
if tgc, ok := args[1].(string); ok {
p.trafficGroupCookie = tgc
} else {
return nil, predicates.ErrInvalidPredicateParameters
}
if tg, ok := args[2].(string); ok {
p.trafficGroup = tg
} else {
return nil, predicates.ErrInvalidPredicateParameters
}
}
return p, nil
}
func (p *predicate) takeChance() bool {
return rand.Float64() < p.chance // #nosec
}
func (p *predicate) Match(r *http.Request) bool {
if p.trafficGroup == "" {
return p.takeChance()
}
if c, err := r.Cookie(p.trafficGroupCookie); err == nil {
return c.Value == p.trafficGroup
} else {
return p.takeChance()
}
}
package proxy
import (
"bytes"
stdlibcontext "context"
"errors"
"io"
"net/http"
"net/url"
"time"
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/metrics"
"github.com/zalando/skipper/routing"
"github.com/zalando/skipper/tracing"
log "github.com/sirupsen/logrus"
)
const unknownHost = "_unknownhost_"
type flushedResponseWriter interface {
http.ResponseWriter
http.Flusher
Unwrap() http.ResponseWriter
}
type context struct {
responseWriter flushedResponseWriter
request *http.Request
response *http.Response
route *routing.Route
deprecatedServed bool
servedWithResponse bool // to support the deprecated way independently
successfulUpgrade bool
pathParams map[string]string
stateBag map[string]interface{}
originalRequest *http.Request
originalResponse *http.Response
outgoingHost string
outgoingDebugRequest *http.Request
executionCounter int
startServe time.Time
metrics *filterMetrics
tracer opentracing.Tracer
initialSpan opentracing.Span
proxySpan opentracing.Span
parentSpan opentracing.Span
proxy *Proxy
routeLookup *routing.RouteLookup
cancelBackendContext stdlibcontext.CancelFunc
logger filters.FilterContextLogger
proxyWatch stopWatch
proxyRequestLatency time.Duration
}
type filterMetrics struct {
prefix string
impl metrics.Metrics
}
type noopFlushedResponseWriter struct {
ignoredHeader http.Header
}
func defaultBody() io.ReadCloser {
return io.NopCloser(&bytes.Buffer{})
}
func defaultResponse(r *http.Request) *http.Response {
return &http.Response{
StatusCode: http.StatusNotFound,
Header: make(http.Header),
Body: defaultBody(),
Request: r,
}
}
func cloneURL(u *url.URL) *url.URL {
uc := *u
return &uc
}
func cloneRequestMetadata(r *http.Request) *http.Request {
return &http.Request{
Method: r.Method,
URL: cloneURL(r.URL),
Proto: r.Proto,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
Header: cloneHeader(r.Header),
Trailer: cloneHeader(r.Trailer),
Body: defaultBody(),
ContentLength: r.ContentLength,
TransferEncoding: r.TransferEncoding,
Close: r.Close,
Host: r.Host,
RemoteAddr: r.RemoteAddr,
RequestURI: r.RequestURI,
TLS: r.TLS,
}
}
func cloneResponseMetadata(r *http.Response) *http.Response {
return &http.Response{
Status: r.Status,
StatusCode: r.StatusCode,
Proto: r.Proto,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
Header: cloneHeader(r.Header),
Trailer: cloneHeader(r.Trailer),
Body: defaultBody(),
ContentLength: r.ContentLength,
TransferEncoding: r.TransferEncoding,
Close: r.Close,
Request: r.Request,
TLS: r.TLS,
}
}
// this is required during looping to preserve the original set of
// params in the outer routes
func appendParams(to, from map[string]string) map[string]string {
if to == nil {
to = make(map[string]string)
}
for k, v := range from {
to[k] = v
}
return to
}
func newContext(
w flushedResponseWriter,
r *http.Request,
p *Proxy,
watch *stopWatch,
) *context {
c := &context{
responseWriter: w,
request: r,
stateBag: make(map[string]interface{}),
outgoingHost: r.Host,
metrics: &filterMetrics{impl: p.metrics},
proxy: p,
routeLookup: p.routing.Get(),
proxyWatch: *watch,
}
if p.flags.PreserveOriginal() {
c.originalRequest = cloneRequestMetadata(r)
}
return c
}
func (c *context) ResponseController() *http.ResponseController {
return http.NewResponseController(c.responseWriter)
}
func (c *context) applyRoute(route *routing.Route, params map[string]string, preserveHost bool) {
c.route = route
if preserveHost {
c.outgoingHost = c.request.Host
} else {
c.outgoingHost = route.Host
}
c.pathParams = appendParams(c.pathParams, params)
}
func (c *context) ensureDefaultResponse() {
if c.response == nil {
c.response = defaultResponse(c.request)
return
}
if c.response.Header == nil {
c.response.Header = make(http.Header)
}
if c.response.Body == nil {
c.response.Body = defaultBody()
}
}
func (c *context) deprecatedShunted() bool {
return c.deprecatedServed
}
func (c *context) shunted() bool {
return c.servedWithResponse
}
func (c *context) setResponse(r *http.Response, preserveOriginal bool) {
c.response = r
if preserveOriginal {
c.originalResponse = cloneResponseMetadata(r)
}
}
func (c *context) ResponseWriter() http.ResponseWriter { return c.responseWriter }
func (c *context) Request() *http.Request { return c.request }
func (c *context) Response() *http.Response { return c.response }
func (c *context) MarkServed() { c.deprecatedServed = true }
func (c *context) Served() bool { return c.deprecatedServed || c.servedWithResponse }
func (c *context) PathParam(key string) string { return c.pathParams[key] }
func (c *context) StateBag() map[string]interface{} { return c.stateBag }
func (c *context) BackendUrl() string { return c.route.Backend }
func (c *context) OriginalRequest() *http.Request { return c.originalRequest }
func (c *context) OriginalResponse() *http.Response { return c.originalResponse }
func (c *context) OutgoingHost() string { return c.outgoingHost }
func (c *context) SetOutgoingHost(h string) { c.outgoingHost = h }
func (c *context) Metrics() filters.Metrics { return c.metrics }
func (c *context) Tracer() opentracing.Tracer { return c.tracer }
func (c *context) ParentSpan() opentracing.Span { return c.parentSpan }
func (c *context) Logger() filters.FilterContextLogger {
if c.logger == nil {
traceId := tracing.GetTraceID(c.initialSpan)
if traceId != "" {
c.logger = log.WithFields(log.Fields{"trace_id": traceId})
} else {
c.logger = log.StandardLogger()
}
}
return c.logger
}
func (c *context) Serve(r *http.Response) {
r.Request = c.Request()
if r.Header == nil {
r.Header = make(http.Header)
}
if r.Body == nil {
r.Body = defaultBody()
}
c.servedWithResponse = true
c.response = r
}
func (c *context) metricsHost() string {
if c.route == nil || len(c.route.HostRegexps) == 0 {
return unknownHost
}
return c.request.Host
}
func (c *context) clone() *context {
cc := *c
// preserve the original path params by cloning the set:
cc.pathParams = appendParams(nil, c.pathParams)
return &cc
}
func (c *context) wasExecuted() bool {
return c.executionCounter != 0
}
func (c *context) setMetricsPrefix(prefix string) {
c.metrics.prefix = prefix + ".custom."
}
func (c *context) Split() (filters.FilterContext, error) {
originalRequest := c.Request()
if c.proxy.experimentalUpgrade && isUpgradeRequest(originalRequest) {
return nil, errors.New("context: cannot split the context that contains an upgrade request")
}
cc := c.clone()
cc.stateBag = map[string]interface{}{}
cc.responseWriter = noopFlushedResponseWriter{}
cc.metrics = &filterMetrics{
prefix: cc.metrics.prefix,
impl: cc.proxy.metrics,
}
u := new(url.URL)
*u = *originalRequest.URL
u.Host = originalRequest.Host
cr, body, err := cloneRequestForSplit(u, originalRequest)
if err != nil {
c.Logger().Errorf("context: failed to clone request: %v", err)
return nil, err
}
serverSpan := opentracing.SpanFromContext(originalRequest.Context())
cr = cr.WithContext(opentracing.ContextWithSpan(cr.Context(), serverSpan))
cr = cr.WithContext(routing.NewContext(cr.Context()))
originalRequest.Body = body
cc.request = cr
return cc, nil
}
func (c *context) Loopback() {
loopSpan := c.tracer.StartSpan(c.proxy.tracing.initialOperationName, opentracing.ChildOf(c.parentSpan.Context()))
defer loopSpan.Finish()
err := c.proxy.do(c, loopSpan)
if c.response != nil && c.response.Body != nil {
if _, err := io.Copy(io.Discard, c.response.Body); err != nil {
c.Logger().Errorf("context: error while discarding remainder response body: %v.", err)
}
err := c.response.Body.Close()
if err != nil {
c.Logger().Errorf("context: error during closing the response body: %v", err)
}
}
if c.proxySpan != nil {
c.proxy.tracing.setTag(c.proxySpan, "shadow", "true")
c.proxySpan.Finish()
}
perr, ok := err.(*proxyError)
if ok && perr.handled {
return
}
if err != nil {
c.Logger().Errorf("context: failed to execute loopback request: %v", err)
}
}
func (m *filterMetrics) IncCounter(key string) {
m.impl.IncCounter(m.prefix + key)
}
func (m *filterMetrics) IncCounterBy(key string, value int64) {
m.impl.IncCounterBy(m.prefix+key, value)
}
func (m *filterMetrics) MeasureSince(key string, start time.Time) {
m.impl.MeasureSince(m.prefix+key, start)
}
func (m *filterMetrics) IncFloatCounterBy(key string, value float64) {
m.impl.IncFloatCounterBy(m.prefix+key, value)
}
func (w noopFlushedResponseWriter) Header() http.Header {
if w.ignoredHeader == nil {
w.ignoredHeader = make(http.Header)
}
return w.ignoredHeader
}
func (w noopFlushedResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w noopFlushedResponseWriter) WriteHeader(_ int) {}
func (w noopFlushedResponseWriter) Flush() {}
func (w noopFlushedResponseWriter) Unwrap() http.ResponseWriter { return nil }
package proxy
import (
"encoding/json"
"io"
"net/http"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
)
type (
debugRequest struct {
Method string `json:"method"`
Uri string `json:"uri"`
Proto string `json:"proto"`
Header http.Header `json:"header,omitempty"`
Host string `json:"host,omitempty"`
RemoteAddress string `json:"remote_address,omitempty"`
}
debugResponseMod struct {
Status *int `json:"status,omitempty"`
Header http.Header `json:"header,omitempty"`
}
debugDocument struct {
RouteId string `json:"route_id,omitempty"`
Route string `json:"route,omitempty"`
Incoming *debugRequest `json:"incoming,omitempty"`
Outgoing *debugRequest `json:"outgoing,omitempty"`
ResponseMod *debugResponseMod `json:"response_mod,omitempty"`
RequestBody string `json:"request_body,omitempty"`
RequestErr string `json:"request_error,omitempty"`
ResponseModBody string `json:"response_mod_body,omitempty"`
ResponseModErr string `json:"response_mod_error,omitempty"`
ProxyError string `json:"proxy_error,omitempty"`
Filters []*eskip.Filter `json:"filters,omitempty"`
Predicates []*eskip.Predicate `json:"predicates,omitempty"`
}
)
type debugInfo struct {
route *eskip.Route
incoming *http.Request
outgoing *http.Request
response *http.Response
err error
}
func convertRequest(r *http.Request) *debugRequest {
return &debugRequest{
Method: r.Method,
Uri: r.RequestURI,
Proto: r.Proto,
Header: r.Header,
Host: r.Host,
RemoteAddress: r.RemoteAddr}
}
func convertBody(body io.Reader) (string, string) {
b, err := io.ReadAll(body)
out := string(b)
var errstr string
if err == nil {
errstr = ""
} else {
errstr = err.Error()
}
return out, errstr
}
func convertDebugInfo(d *debugInfo) debugDocument {
doc := debugDocument{}
if d.route != nil {
doc.RouteId = d.route.Id
doc.Route = d.route.String()
doc.Filters = d.route.Filters
doc.Predicates = d.route.Predicates
}
var requestBody io.Reader
if d.incoming == nil {
log.Error("[debug response] missing incoming request")
} else {
doc.Incoming = convertRequest(d.incoming)
requestBody = d.incoming.Body
}
if d.outgoing != nil {
doc.Outgoing = convertRequest(d.outgoing)
// if there is an outgoing request, use the body from there
requestBody = d.outgoing.Body
}
if requestBody != nil {
doc.RequestBody, doc.RequestErr = convertBody(requestBody)
}
if d.response != nil {
if d.response.StatusCode != 0 || len(d.response.Header) != 0 {
doc.ResponseMod = &debugResponseMod{Header: d.response.Header}
if d.response.StatusCode != 0 {
s := d.response.StatusCode
doc.ResponseMod.Status = &s
}
}
if d.response.Body != nil {
doc.ResponseModBody, doc.ResponseModErr = convertBody(d.response.Body)
}
}
if d.err != nil {
doc.ProxyError = d.err.Error()
}
return doc
}
func dbgResponse(w http.ResponseWriter, d *debugInfo) {
w.Header().Set("Content-Type", "application/json")
doc := convertDebugInfo(d)
enc := json.NewEncoder(w)
if err := enc.Encode(&doc); err != nil {
log.Error("[debug response]", err)
}
}
package proxy
import (
"math"
"math/rand"
"time"
"github.com/zalando/skipper/routing"
)
type fadeIn struct {
rnd *rand.Rand
}
func (f *fadeIn) fadeInScore(lifetime time.Duration, duration time.Duration, exponent float64) float64 {
fadingIn := lifetime > 0 && lifetime < duration
if !fadingIn {
return 1
}
return math.Pow(float64(lifetime)/float64(duration), exponent)
}
func (f *fadeIn) filterFadeIn(endpoints []routing.LBEndpoint, rt *routing.Route) []routing.LBEndpoint {
if rt.LBFadeInDuration <= 0 {
return endpoints
}
now := time.Now()
threshold := f.rnd.Float64()
filtered := make([]routing.LBEndpoint, 0, len(endpoints))
for _, e := range endpoints {
age := now.Sub(e.Metrics.DetectedTime())
f := f.fadeInScore(
age,
rt.LBFadeInDuration,
rt.LBFadeInExponent,
)
if threshold < f {
filtered = append(filtered, e)
}
}
if len(filtered) == 0 {
return endpoints
}
return filtered
}
package fastcgi
import (
"bytes"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"github.com/yookoala/gofast"
"github.com/zalando/skipper/logging"
)
type RoundTripper struct {
log logging.Logger
client gofast.Client
handler gofast.SessionHandler
}
func NewRoundTripper(log logging.Logger, addr, filename string) (*RoundTripper, error) {
connFactory := gofast.SimpleConnFactory("tcp", addr)
client, err := gofast.SimpleClientFactory(connFactory)()
if err != nil {
return nil, fmt.Errorf("gofast: failed creating client: %w", err)
}
chain := gofast.Chain(
gofast.BasicParamsMap,
gofast.MapHeader,
gofast.MapEndpoint(filename),
func(handler gofast.SessionHandler) gofast.SessionHandler {
return func(client gofast.Client, req *gofast.Request) (*gofast.ResponsePipe, error) {
req.Params["HTTP_HOST"] = req.Params["SERVER_NAME"]
req.Params["SERVER_SOFTWARE"] = "Skipper"
// Gofast sets this param to `fastcgi` which is not what the backend will expect.
delete(req.Params, "REQUEST_SCHEME")
return handler(client, req)
}
},
)
return &RoundTripper{
log: log,
client: client,
handler: chain(gofast.BasicSession),
}, nil
}
func (rt *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
defer func() {
if rt.client == nil {
return
}
if err := rt.client.Close(); err != nil {
rt.log.Errorf("gofast: error closing client: %s", err.Error())
}
}()
resp, err := rt.handler(rt.client, gofast.NewRequest(req))
if err != nil {
return nil, fmt.Errorf("gofast: failed to process request: %w", err)
}
rr := httptest.NewRecorder()
errBuffer := new(bytes.Buffer)
resp.WriteTo(rr, errBuffer)
if errBuffer.Len() > 0 {
if strings.Contains(errBuffer.String(), "Primary script unknown") {
body := http.StatusText(http.StatusNotFound)
return &http.Response{
Status: body,
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewBufferString(body)),
ContentLength: int64(len(body)),
Request: req,
Header: make(http.Header),
}, nil
} else {
return nil, fmt.Errorf("gofast: error stream from application process %s", errBuffer.String())
}
}
return rr.Result(), nil
}
package proxy
import (
"math/rand"
"github.com/zalando/skipper/metrics"
"github.com/zalando/skipper/routing"
)
type healthyEndpoints struct {
rnd *rand.Rand
maxUnhealthyEndpointsRatio float64
}
func (h *healthyEndpoints) filterHealthyEndpoints(ctx *context, endpoints []routing.LBEndpoint, metrics metrics.Metrics) []routing.LBEndpoint {
if h == nil {
return endpoints
}
p := h.rnd.Float64()
unhealthyEndpointsCount := 0
maxUnhealthyEndpointsCount := float64(len(endpoints)) * h.maxUnhealthyEndpointsRatio
filtered := make([]routing.LBEndpoint, 0, len(endpoints))
for _, e := range endpoints {
dropProbability := e.Metrics.HealthCheckDropProbability()
if p < dropProbability {
ctx.Logger().Debugf("Dropping endpoint %q due to passive health check: p=%0.2f, dropProbability=%0.2f",
e.Host, p, dropProbability)
metrics.IncCounter("passive-health-check.endpoints.dropped")
unhealthyEndpointsCount++
} else {
filtered = append(filtered, e)
}
if float64(unhealthyEndpointsCount) > maxUnhealthyEndpointsCount {
return endpoints
}
}
if len(filtered) == 0 {
return endpoints
}
if len(filtered) < len(endpoints) {
metrics.IncCounter("passive-health-check.requests.passed")
}
return filtered
}
package proxy
import (
"bytes"
stdlibcontext "context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httptrace"
"net/http/httputil"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"time"
"unicode/utf8"
"golang.org/x/exp/maps"
"golang.org/x/time/rate"
ot "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/zalando/skipper/circuit"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
al "github.com/zalando/skipper/filters/accesslog"
circuitfilters "github.com/zalando/skipper/filters/circuit"
flowidFilter "github.com/zalando/skipper/filters/flowid"
filterslog "github.com/zalando/skipper/filters/log"
ratelimitfilters "github.com/zalando/skipper/filters/ratelimit"
tracingfilter "github.com/zalando/skipper/filters/tracing"
skpio "github.com/zalando/skipper/io"
"github.com/zalando/skipper/loadbalancer"
"github.com/zalando/skipper/logging"
"github.com/zalando/skipper/metrics"
snet "github.com/zalando/skipper/net"
"github.com/zalando/skipper/proxy/fastcgi"
"github.com/zalando/skipper/ratelimit"
"github.com/zalando/skipper/rfc"
"github.com/zalando/skipper/routing"
"github.com/zalando/skipper/tracing"
)
const (
proxyBufferSize = 8192
unknownRouteID = "_unknownroute_"
unknownRouteBackendType = "<unknown>"
unknownRouteBackend = "<unknown>"
// Number of loops allowed by default.
DefaultMaxLoopbacks = 9
// The default value set for http.Transport.MaxIdleConnsPerHost.
DefaultIdleConnsPerHost = 64
// The default period at which the idle connections are forcibly
// closed.
DefaultCloseIdleConnsPeriod = 20 * time.Second
// DefaultResponseHeaderTimeout, the default response header timeout
DefaultResponseHeaderTimeout = 60 * time.Second
// DefaultExpectContinueTimeout, the default timeout to expect
// a response for a 100 Continue request
DefaultExpectContinueTimeout = 30 * time.Second
)
// Flags control the behavior of the proxy.
type Flags uint
const (
FlagsNone Flags = 0
// Insecure causes the proxy to ignore the verification of
// the TLS certificates of the backend services.
Insecure Flags = 1 << iota
// PreserveOriginal indicates that filters require the
// preserved original metadata of the request and the response.
PreserveOriginal
// PreserveHost indicates whether the outgoing request to the
// backend should use by default the 'Host' header of the incoming
// request, or the host part of the backend address, in case filters
// don't change it.
PreserveHost
// Debug indicates that the current proxy instance will be used as a
// debug proxy. Debug proxies don't forward the request to the
// route backends, but they execute all filters, and return a
// JSON document with the changes the filters make to the request
// and with the approximate changes they would make to the
// response.
Debug
// HopHeadersRemoval indicates whether the Hop Headers should be removed
// in compliance with RFC 2616
HopHeadersRemoval
// PatchPath instructs the proxy to patch the parsed request path
// if the reserved characters according to RFC 2616 and RFC 3986
// were unescaped by the parser.
PatchPath
)
// Options are deprecated alias for Flags.
type Options Flags
const (
OptionsNone = Options(FlagsNone)
OptionsInsecure = Options(Insecure)
OptionsPreserveOriginal = Options(PreserveOriginal)
OptionsPreserveHost = Options(PreserveHost)
OptionsDebug = Options(Debug)
OptionsHopHeadersRemoval = Options(HopHeadersRemoval)
)
type OpenTracingParams struct {
// Tracer holds the tracer enabled for this proxy instance
Tracer ot.Tracer
// InitialSpan can override the default initial, pre-routing, span name.
// Default: "ingress".
InitialSpan string
// DisableFilterSpans disables creation of spans representing request and response filters.
// Default: false
DisableFilterSpans bool
// LogFilterEvents enables the behavior to mark start and completion times of filters
// on the span representing request/response filters being processed.
// Default: false
LogFilterEvents bool
// LogStreamEvents enables the logs that marks the times when response headers & payload are streamed to
// the client
// Default: false
LogStreamEvents bool
// ExcludeTags controls what tags are disabled. Any tag that is listed here will be ignored.
ExcludeTags []string
}
type PassiveHealthCheck struct {
// The period of time after which the endpointregistry begins to calculate endpoints statistics
// from scratch
Period time.Duration
// The minimum number of total requests that should be sent to an endpoint in a single period to
// potentially opt out the endpoint from the list of healthy endpoints
MinRequests int64
// The minimal ratio of failed requests in a single period to potentially opt out the endpoint
// from the list of healthy endpoints. This ratio is equal to the minimal non-zero probability of
// dropping endpoint out from load balancing for every specific request
MinDropProbability float64
// The maximum probability of unhealthy endpoint to be dropped out from load balancing for every specific request
MaxDropProbability float64
// MaxUnhealthyEndpointsRatio is the maximum ratio of unhealthy endpoints in the list of all endpoints PHC will check
// in case of all endpoints are unhealthy
MaxUnhealthyEndpointsRatio float64
}
func InitPassiveHealthChecker(o map[string]string) (bool, *PassiveHealthCheck, error) {
if len(o) == 0 {
return false, &PassiveHealthCheck{}, nil
}
result := &PassiveHealthCheck{
MinDropProbability: 0.0,
MaxUnhealthyEndpointsRatio: 1.0,
}
requiredParams := map[string]struct{}{
"period": {},
"max-drop-probability": {},
"min-requests": {},
}
for key, value := range o {
delete(requiredParams, key)
switch key {
/* required parameters */
case "period":
period, err := time.ParseDuration(value)
if err != nil {
return false, nil, fmt.Errorf("passive health check: invalid period value: %s", value)
}
if period < 0 {
return false, nil, fmt.Errorf("passive health check: invalid period value: %s", value)
}
result.Period = period
case "min-requests":
minRequests, err := strconv.Atoi(value)
if err != nil {
return false, nil, fmt.Errorf("passive health check: invalid minRequests value: %s", value)
}
if minRequests < 0 {
return false, nil, fmt.Errorf("passive health check: invalid minRequests value: %s", value)
}
result.MinRequests = int64(minRequests)
case "max-drop-probability":
maxDropProbability, err := strconv.ParseFloat(value, 64)
if err != nil {
return false, nil, fmt.Errorf("passive health check: invalid maxDropProbability value: %s", value)
}
if maxDropProbability < 0 || maxDropProbability > 1 {
return false, nil, fmt.Errorf("passive health check: invalid maxDropProbability value: %s", value)
}
result.MaxDropProbability = maxDropProbability
/* optional parameters */
case "min-drop-probability":
minDropProbability, err := strconv.ParseFloat(value, 64)
if err != nil {
return false, nil, fmt.Errorf("passive health check: invalid minDropProbability value: %s", value)
}
if minDropProbability < 0 || minDropProbability > 1 {
return false, nil, fmt.Errorf("passive health check: invalid minDropProbability value: %s", value)
}
result.MinDropProbability = minDropProbability
case "max-unhealthy-endpoints-ratio":
maxUnhealthyEndpointsRatio, err := strconv.ParseFloat(value, 64)
if err != nil {
return false, nil, fmt.Errorf("passive health check: invalid maxUnhealthyEndpointsRatio value: %q", value)
}
if maxUnhealthyEndpointsRatio < 0 || maxUnhealthyEndpointsRatio > 1 {
return false, nil, fmt.Errorf("passive health check: invalid maxUnhealthyEndpointsRatio value: %q", value)
}
result.MaxUnhealthyEndpointsRatio = maxUnhealthyEndpointsRatio
default:
return false, nil, fmt.Errorf("passive health check: invalid parameter: key=%s,value=%s", key, value)
}
}
if len(requiredParams) != 0 {
return false, nil, fmt.Errorf("passive health check: missing required parameters %+v", maps.Keys(requiredParams))
}
if result.MinDropProbability >= result.MaxDropProbability {
return false, nil, fmt.Errorf("passive health check: minDropProbability should be less than maxDropProbability")
}
return true, result, nil
}
// Proxy initialization options.
type Params struct {
// The proxy expects a routing instance that is used to match
// the incoming requests to routes.
Routing *routing.Routing
// Control flags. See the Flags values.
Flags Flags
// Metrics collector.
// If not specified proxy uses global metrics.Default.
Metrics metrics.Metrics
// And optional list of priority routes to be used for matching
// before the general lookup tree.
PriorityRoutes []PriorityRoute
// Enable the experimental upgrade protocol feature
ExperimentalUpgrade bool
// ExperimentalUpgradeAudit enables audit log of both the request line
// and the response messages during web socket upgrades.
ExperimentalUpgradeAudit bool
// When set, no access log is printed.
AccessLogDisabled bool
// DualStack sets if the proxy TCP connections to the backend should be dual stack
DualStack bool
// DefaultHTTPStatus is the HTTP status used when no routes are found
// for a request.
DefaultHTTPStatus int
// MaxLoopbacks sets the maximum number of allowed loops. If 0
// the default (9) is applied. To disable looping, set it to
// -1. Note, that disabling looping by this option, may result
// wrong routing depending on the current configuration.
MaxLoopbacks int
// Same as net/http.Transport.MaxIdleConnsPerHost, but the default
// is 64. This value supports scenarios with relatively few remote
// hosts. When the routing table contains different hosts in the
// range of hundreds, it is recommended to set this options to a
// lower value.
IdleConnectionsPerHost int
// MaxIdleConns limits the number of idle connections to all backends, 0 means no limit
MaxIdleConns int
// DisableHTTPKeepalives forces backend to always create a new connection
DisableHTTPKeepalives bool
// CircuitBreakers provides a registry that skipper can use to
// find the matching circuit breaker for backend requests. If not
// set, no circuit breakers are used.
CircuitBreakers *circuit.Registry
// RateLimiters provides a registry that skipper can use to
// find the matching ratelimiter for backend requests. If not
// set, no ratelimits are used.
RateLimiters *ratelimit.Registry
// Defines the time period of how often the idle connections are
// forcibly closed. The default is 12 seconds. When set to less than
// 0, the proxy doesn't force closing the idle connections.
CloseIdleConnsPeriod time.Duration
// The Flush interval for copying upgraded connections
FlushInterval time.Duration
// Timeout sets the TCP client connection timeout for proxy http connections to the backend
Timeout time.Duration
// ResponseHeaderTimeout sets the HTTP response timeout for
// proxy http connections to the backend.
ResponseHeaderTimeout time.Duration
// ExpectContinueTimeout sets the HTTP timeout to expect a
// response for status Code 100 for proxy http connections to
// the backend.
ExpectContinueTimeout time.Duration
// KeepAlive sets the TCP keepalive for proxy http connections to the backend
KeepAlive time.Duration
// TLSHandshakeTimeout sets the TLS handshake timeout for proxy connections to the backend
TLSHandshakeTimeout time.Duration
// Client TLS to connect to Backends
ClientTLS *tls.Config
// OpenTracing contains parameters related to OpenTracing instrumentation. For default values
// check OpenTracingParams
OpenTracing *OpenTracingParams
// CustomHttpRoundTripperWrap provides ability to wrap http.RoundTripper created by skipper.
// http.RoundTripper is used for making outgoing requests (backends)
// It allows to add additional logic (for example tracing) by providing a wrapper function
// which accepts original skipper http.RoundTripper as an argument and returns a wrapped roundtripper
CustomHttpRoundTripperWrap func(http.RoundTripper) http.RoundTripper
// Registry provides key-value API which uses "host:port" string as a key
// and returns some metadata about endpoint. Information about the metadata
// returned from the registry could be found in routing.Metrics interface.
EndpointRegistry *routing.EndpointRegistry
// EnablePassiveHealthCheck enables the healthy endpoints checker
EnablePassiveHealthCheck bool
// PassiveHealthCheck defines the parameters for the healthy endpoints checker.
PassiveHealthCheck *PassiveHealthCheck
}
type (
ratelimitError string
routeLookupError string
)
func (e ratelimitError) Error() string { return string(e) }
func (e routeLookupError) Error() string { return string(e) }
const (
errRatelimit = ratelimitError("ratelimited")
errRouteLookup = routeLookupError("route lookup failed")
)
var (
errRouteLookupFailed = &proxyError{err: errRouteLookup}
errCircuitBreakerOpen = &proxyError{
err: errors.New("circuit breaker open"),
code: http.StatusServiceUnavailable,
additionalHeader: http.Header{"X-Circuit-Open": []string{"true"}},
}
disabledAccessLog = al.AccessLogFilter{Enable: false, Prefixes: nil}
enabledAccessLog = al.AccessLogFilter{Enable: true, Prefixes: nil}
hopHeaders = map[string]bool{
"Te": true,
"Connection": true,
"Proxy-Connection": true,
"Keep-Alive": true,
"Proxy-Authenticate": true,
"Proxy-Authorization": true,
"Trailer": true,
"Transfer-Encoding": true,
"Upgrade": true,
}
)
// When set, the proxy will skip the TLS verification on outgoing requests.
func (f Flags) Insecure() bool { return f&Insecure != 0 }
// When set, the filters will receive an unmodified clone of the original
// incoming request and response.
func (f Flags) PreserveOriginal() bool { return f&(PreserveOriginal|Debug) != 0 }
// When set, the proxy will set the, by default, the Host header value
// of the outgoing requests to the one of the incoming request.
func (f Flags) PreserveHost() bool { return f&PreserveHost != 0 }
// When set, the proxy runs in debug mode.
func (f Flags) Debug() bool { return f&Debug != 0 }
// When set, the proxy will remove the Hop Headers
func (f Flags) HopHeadersRemoval() bool { return f&HopHeadersRemoval != 0 }
func (f Flags) patchPath() bool { return f&PatchPath != 0 }
// Priority routes are custom route implementations that are matched against
// each request before the routes in the general lookup tree.
type PriorityRoute interface {
// If the request is matched, returns a route, otherwise nil.
// Additionally it may return a parameter map used by the filters
// in the route.
Match(*http.Request) (*routing.Route, map[string]string)
}
// Proxy instances implement Skipper proxying functionality. For
// initializing, see the WithParams the constructor and Params.
type Proxy struct {
experimentalUpgrade bool
experimentalUpgradeAudit bool
accessLogDisabled bool
maxLoops int
defaultHTTPStatus int
routing *routing.Routing
registry *routing.EndpointRegistry
fadein *fadeIn
heathlyEndpoints *healthyEndpoints
roundTripper http.RoundTripper
priorityRoutes []PriorityRoute
flags Flags
metrics metrics.Metrics
quit chan struct{}
flushInterval time.Duration
breakers *circuit.Registry
limiters *ratelimit.Registry
log logging.Logger
tracing *proxyTracing
upgradeAuditLogOut io.Writer
upgradeAuditLogErr io.Writer
auditLogHook chan struct{}
clientTLS *tls.Config
hostname string
onPanicSometimes rate.Sometimes
}
// proxyError is used to wrap errors during proxying and to indicate
// the required status code for the response sent from the main
// ServeHTTP method. Alternatively, it can indicate that the request
// was already handled, e.g. in case of deprecated shunting or the
// upgrade request.
type proxyError struct {
err error
code int
handled bool
dialingFailed bool
additionalHeader http.Header
}
func (e proxyError) Error() string {
if e.err != nil {
return fmt.Sprintf("dialing failed %v: %v", e.DialError(), e.err.Error())
}
if e.handled {
return "request handled in a non-standard way"
}
code := e.code
if code == 0 {
code = http.StatusInternalServerError
}
return fmt.Sprintf("proxy error: %d", code)
}
// DialError returns true if the error was caused while dialing TCP or
// TLS connections, before HTTP data was sent. It is safe to retry
// a call, if this returns true.
func (e *proxyError) DialError() bool {
return e.dialingFailed
}
func copyHeader(to, from http.Header) {
for k, v := range from {
to[http.CanonicalHeaderKey(k)] = v
}
}
func copyHeaderExcluding(to, from http.Header, excludeHeaders map[string]bool) {
for k, v := range from {
// The http package converts header names to their canonical version.
// Meaning that the lookup below will be done using the canonical version of the header.
if _, ok := excludeHeaders[k]; !ok {
to[http.CanonicalHeaderKey(k)] = v
}
}
}
func cloneHeader(h http.Header) http.Header {
hh := make(http.Header)
copyHeader(hh, h)
return hh
}
func cloneHeaderExcluding(h http.Header, excludeList map[string]bool) http.Header {
hh := make(http.Header)
copyHeaderExcluding(hh, h, excludeList)
return hh
}
type flusher struct {
w flushedResponseWriter
}
func (f *flusher) Write(p []byte) (n int, err error) {
n, err = f.w.Write(p)
if err == nil {
f.w.Flush()
}
return
}
func copyStream(to flushedResponseWriter, from io.Reader) (int64, error) {
b := make([]byte, proxyBufferSize)
return io.CopyBuffer(&flusher{to}, from, b)
}
func schemeFromRequest(r *http.Request) string {
if r.TLS != nil {
return "https"
}
return "http"
}
func setRequestURLFromRequest(u *url.URL, r *http.Request) {
if u.Host == "" {
u.Host = r.Host
}
if u.Scheme == "" {
u.Scheme = schemeFromRequest(r)
}
}
func setRequestURLForDynamicBackend(u *url.URL, stateBag map[string]interface{}) {
dbu, ok := stateBag[filters.DynamicBackendURLKey].(string)
if ok && dbu != "" {
bu, err := url.ParseRequestURI(dbu)
if err == nil {
u.Host = bu.Host
u.Scheme = bu.Scheme
}
} else {
host, ok := stateBag[filters.DynamicBackendHostKey].(string)
if ok && host != "" {
u.Host = host
}
scheme, ok := stateBag[filters.DynamicBackendSchemeKey].(string)
if ok && scheme != "" {
u.Scheme = scheme
}
}
}
func (p *Proxy) selectEndpoint(ctx *context) *routing.LBEndpoint {
rt := ctx.route
endpoints := rt.LBEndpoints
endpoints = p.fadein.filterFadeIn(endpoints, rt)
endpoints = p.heathlyEndpoints.filterHealthyEndpoints(ctx, endpoints, p.metrics)
lbctx := &routing.LBContext{
Request: ctx.request,
Route: rt,
LBEndpoints: endpoints,
Params: ctx.StateBag(),
}
e := rt.LBAlgorithm.Apply(lbctx)
return &e
}
// creates an outgoing http request to be forwarded to the route endpoint
// based on the augmented incoming request
func (p *Proxy) mapRequest(ctx *context, requestContext stdlibcontext.Context) (*http.Request, routing.Metrics, error) {
var endpointMetrics routing.Metrics
r := ctx.request
rt := ctx.route
host := ctx.outgoingHost
stateBag := ctx.StateBag()
u := r.URL
switch rt.BackendType {
case eskip.DynamicBackend:
setRequestURLFromRequest(u, r)
setRequestURLForDynamicBackend(u, stateBag)
case eskip.LBBackend:
endpoint := p.selectEndpoint(ctx)
endpointMetrics = endpoint.Metrics
u.Scheme = endpoint.Scheme
u.Host = endpoint.Host
case eskip.NetworkBackend:
endpointMetrics = p.registry.GetMetrics(rt.Host)
fallthrough
default:
u.Scheme = rt.Scheme
u.Host = rt.Host
}
body := r.Body
if r.ContentLength == 0 {
body = nil
}
rr, err := http.NewRequestWithContext(requestContext, r.Method, u.String(), body)
if err != nil {
return nil, nil, err
}
rr.ContentLength = r.ContentLength
if p.flags.HopHeadersRemoval() {
rr.Header = cloneHeaderExcluding(r.Header, hopHeaders)
} else {
rr.Header = cloneHeader(r.Header)
}
// Disable default net/http user agent when user agent is not specified
if _, ok := rr.Header["User-Agent"]; !ok {
rr.Header["User-Agent"] = []string{""}
}
rr.Host = host
// If there is basic auth configured in the URL we add them as headers
if u.User != nil {
up := u.User.String()
upBase64 := base64.StdEncoding.EncodeToString([]byte(up))
rr.Header.Add("Authorization", fmt.Sprintf("Basic %s", upBase64))
}
ctxspan := ot.SpanFromContext(r.Context())
if ctxspan != nil {
rr = rr.WithContext(ot.ContextWithSpan(rr.Context(), ctxspan))
}
if _, ok := stateBag[filters.BackendIsProxyKey]; ok {
rr = forwardToProxy(r, rr)
}
return rr, endpointMetrics, nil
}
type proxyUrlContextKey struct{}
func forwardToProxy(incoming, outgoing *http.Request) *http.Request {
proxyURL := &url.URL{
Scheme: outgoing.URL.Scheme,
Host: outgoing.URL.Host,
}
outgoing.URL.Host = incoming.Host
outgoing.URL.Scheme = schemeFromRequest(incoming)
return outgoing.WithContext(stdlibcontext.WithValue(outgoing.Context(), proxyUrlContextKey{}, proxyURL))
}
func proxyFromContext(req *http.Request) (*url.URL, error) {
proxyURL, _ := req.Context().Value(proxyUrlContextKey{}).(*url.URL)
if proxyURL != nil {
return proxyURL, nil
}
return nil, nil
}
type skipperDialer struct {
net.Dialer
f func(ctx stdlibcontext.Context, network, addr string) (net.Conn, error)
}
func newSkipperDialer(d net.Dialer) *skipperDialer {
return &skipperDialer{
Dialer: d,
f: d.DialContext,
}
}
// DialContext wraps net.Dialer's DialContext and returns an error,
// that can be checked if it was a Transport (TCP/TLS handshake) error
// or timeout, or a timeout from http, which is not in general
// not possible to retry.
func (dc *skipperDialer) DialContext(ctx stdlibcontext.Context, network, addr string) (net.Conn, error) {
span := ot.SpanFromContext(ctx)
if span != nil {
span.LogKV("dial_context", "start")
}
con, err := dc.f(ctx, network, addr)
if span != nil {
span.LogKV("dial_context", "done")
}
if err != nil {
return nil, &proxyError{
err: err,
code: -1, // omit 0 handling in proxy.Error()
dialingFailed: true, // indicate error happened before http
}
} else if cerr := ctx.Err(); cerr != nil {
// unclear when this is being triggered
return nil, &proxyError{
err: fmt.Errorf("err from dial context: %w", cerr),
code: http.StatusGatewayTimeout,
}
}
return con, nil
}
// New returns an initialized Proxy.
// Deprecated, see WithParams and Params instead.
func New(r *routing.Routing, options Options, pr ...PriorityRoute) *Proxy {
return WithParams(Params{
Routing: r,
Flags: Flags(options),
PriorityRoutes: pr,
CloseIdleConnsPeriod: -time.Second,
})
}
// WithParams returns an initialized Proxy.
func WithParams(p Params) *Proxy {
if p.IdleConnectionsPerHost <= 0 {
p.IdleConnectionsPerHost = DefaultIdleConnsPerHost
}
if p.CloseIdleConnsPeriod == 0 {
p.CloseIdleConnsPeriod = DefaultCloseIdleConnsPeriod
}
if p.ResponseHeaderTimeout == 0 {
p.ResponseHeaderTimeout = DefaultResponseHeaderTimeout
}
if p.ExpectContinueTimeout == 0 {
p.ExpectContinueTimeout = DefaultExpectContinueTimeout
}
if p.CustomHttpRoundTripperWrap == nil {
// default wrapper which does nothing
p.CustomHttpRoundTripperWrap = func(original http.RoundTripper) http.RoundTripper {
return original
}
}
tr := &http.Transport{
DialContext: newSkipperDialer(net.Dialer{
Timeout: p.Timeout,
KeepAlive: p.KeepAlive,
DualStack: p.DualStack,
}).DialContext,
TLSHandshakeTimeout: p.TLSHandshakeTimeout,
ResponseHeaderTimeout: p.ResponseHeaderTimeout,
ExpectContinueTimeout: p.ExpectContinueTimeout,
MaxIdleConns: p.MaxIdleConns,
MaxIdleConnsPerHost: p.IdleConnectionsPerHost,
IdleConnTimeout: p.CloseIdleConnsPeriod,
DisableKeepAlives: p.DisableHTTPKeepalives,
Proxy: proxyFromContext,
}
quit := make(chan struct{})
// We need this to reliably fade on DNS change, which is right
// now not fixed with IdleConnTimeout in the http.Transport.
// https://github.com/golang/go/issues/23427
if p.CloseIdleConnsPeriod > 0 {
go func() {
ticker := time.NewTicker(p.CloseIdleConnsPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
tr.CloseIdleConnections()
case <-quit:
return
}
}
}()
}
if p.ClientTLS != nil {
tr.TLSClientConfig = p.ClientTLS
}
if p.Flags.Insecure() {
if tr.TLSClientConfig == nil {
/* #nosec */
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
} else {
/* #nosec */
tr.TLSClientConfig.InsecureSkipVerify = true
}
}
m := p.Metrics
if m == nil {
m = metrics.Default
}
if p.Flags.Debug() {
m = metrics.Void
}
if p.MaxLoopbacks == 0 {
p.MaxLoopbacks = DefaultMaxLoopbacks
} else if p.MaxLoopbacks < 0 {
p.MaxLoopbacks = 0
}
defaultHTTPStatus := http.StatusNotFound
if p.DefaultHTTPStatus >= http.StatusContinue && p.DefaultHTTPStatus <= http.StatusNetworkAuthenticationRequired {
defaultHTTPStatus = p.DefaultHTTPStatus
}
if p.EndpointRegistry == nil {
p.EndpointRegistry = routing.NewEndpointRegistry(routing.RegistryOptions{})
}
hostname := os.Getenv("HOSTNAME")
var healthyEndpointsChooser *healthyEndpoints
if p.EnablePassiveHealthCheck {
healthyEndpointsChooser = &healthyEndpoints{
rnd: rand.New(loadbalancer.NewLockedSource()),
maxUnhealthyEndpointsRatio: p.PassiveHealthCheck.MaxUnhealthyEndpointsRatio,
}
}
return &Proxy{
routing: p.Routing,
registry: p.EndpointRegistry,
fadein: &fadeIn{
rnd: rand.New(loadbalancer.NewLockedSource()),
},
heathlyEndpoints: healthyEndpointsChooser,
roundTripper: p.CustomHttpRoundTripperWrap(tr),
priorityRoutes: p.PriorityRoutes,
flags: p.Flags,
metrics: m,
quit: quit,
flushInterval: p.FlushInterval,
experimentalUpgrade: p.ExperimentalUpgrade,
experimentalUpgradeAudit: p.ExperimentalUpgradeAudit,
maxLoops: p.MaxLoopbacks,
breakers: p.CircuitBreakers,
limiters: p.RateLimiters,
log: &logging.DefaultLog{},
defaultHTTPStatus: defaultHTTPStatus,
tracing: newProxyTracing(p.OpenTracing),
accessLogDisabled: p.AccessLogDisabled,
upgradeAuditLogOut: os.Stdout,
upgradeAuditLogErr: os.Stderr,
clientTLS: tr.TLSClientConfig,
hostname: hostname,
onPanicSometimes: rate.Sometimes{First: 3, Interval: 1 * time.Minute},
}
}
// applies filters to a request
func (p *Proxy) applyFiltersToRequest(f []*routing.RouteFilter, ctx *context) []*routing.RouteFilter {
if len(f) == 0 {
return f
}
filtersStart := time.Now()
filterTracing := p.tracing.startFilterTracing("request_filters", ctx)
defer filterTracing.finish()
var filters = make([]*routing.RouteFilter, 0, len(f))
for _, fi := range f {
start := time.Now()
filterTracing.logStart(fi.Name)
ctx.setMetricsPrefix(fi.Name)
fi.Request(ctx)
p.metrics.MeasureFilterRequest(fi.Name, start)
filterTracing.logEnd(fi.Name)
filters = append(filters, fi)
if ctx.deprecatedShunted() || ctx.shunted() {
break
}
}
p.metrics.MeasureAllFiltersRequest(ctx.route.Id, filtersStart)
return filters
}
// applies filters to a response in reverse order
func (p *Proxy) applyFiltersToResponse(filters []*routing.RouteFilter, ctx *context) {
filtersStart := time.Now()
filterTracing := p.tracing.startFilterTracing("response_filters", ctx)
defer filterTracing.finish()
for i := len(filters) - 1; i >= 0; i-- {
fi := filters[i]
start := time.Now()
filterTracing.logStart(fi.Name)
ctx.setMetricsPrefix(fi.Name)
fi.Response(ctx)
p.metrics.MeasureFilterResponse(fi.Name, start)
filterTracing.logEnd(fi.Name)
}
p.metrics.MeasureAllFiltersResponse(ctx.route.Id, filtersStart)
}
// addBranding overwrites any existing `X-Powered-By` or `Server` header from headerMap
func addBranding(headerMap http.Header) {
if headerMap.Get("Server") == "" {
headerMap.Set("Server", "Skipper")
}
}
func (p *Proxy) lookupRoute(ctx *context) (rt *routing.Route, params map[string]string) {
for _, prt := range p.priorityRoutes {
rt, params = prt.Match(ctx.request)
if rt != nil {
return rt, params
}
}
return ctx.routeLookup.Do(ctx.request)
}
func (p *Proxy) makeUpgradeRequest(ctx *context, req *http.Request) {
backendURL := req.URL
reverseProxy := httputil.NewSingleHostReverseProxy(backendURL)
reverseProxy.FlushInterval = p.flushInterval
upgradeProxy := upgradeProxy{
backendAddr: backendURL,
reverseProxy: reverseProxy,
insecure: p.flags.Insecure(),
tlsClientConfig: p.clientTLS,
useAuditLog: p.experimentalUpgradeAudit,
auditLogOut: p.upgradeAuditLogOut,
auditLogErr: p.upgradeAuditLogErr,
auditLogHook: p.auditLogHook,
}
upgradeProxy.serveHTTP(ctx.responseWriter, req)
ctx.successfulUpgrade = true
ctx.Logger().Debugf("finished upgraded protocol %s session", getUpgradeRequest(ctx.request))
}
func (p *Proxy) makeBackendRequest(ctx *context, requestContext stdlibcontext.Context) (*http.Response, *proxyError) {
payloadProtocol := getUpgradeRequest(ctx.Request())
req, endpointMetrics, err := p.mapRequest(ctx, requestContext)
if err != nil {
return nil, &proxyError{err: fmt.Errorf("could not map backend request: %w", err)}
}
if res, ok := p.rejectBackend(ctx, req); ok {
return res, nil
}
if endpointMetrics != nil {
endpointMetrics.IncInflightRequest()
defer endpointMetrics.DecInflightRequest()
}
if p.experimentalUpgrade && payloadProtocol != "" {
// see also https://github.com/golang/go/blob/9159cd4ec6b0e9475dc9c71c830035c1c4c13483/src/net/http/httputil/reverseproxy.go#L423-L428
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", payloadProtocol)
p.makeUpgradeRequest(ctx, req)
// We are not owner of the connection anymore.
return nil, &proxyError{handled: true}
}
roundTripper, err := p.getRoundTripper(ctx, req)
if err != nil {
return nil, &proxyError{err: fmt.Errorf("failed to get roundtripper: %w", err), code: http.StatusBadGateway}
}
bag := ctx.StateBag()
spanName, ok := bag[tracingfilter.OpenTracingProxySpanKey].(string)
if !ok {
spanName = "proxy"
}
proxySpanOpts := []ot.StartSpanOption{ot.Tags{
SpanKindTag: SpanKindClient,
}}
if parentSpan := ot.SpanFromContext(req.Context()); parentSpan != nil {
proxySpanOpts = append(proxySpanOpts, ot.ChildOf(parentSpan.Context()))
}
ctx.proxySpan = p.tracing.tracer.StartSpan(spanName, proxySpanOpts...)
u := cloneURL(req.URL)
u.RawQuery = ""
p.tracing.
setTag(ctx.proxySpan, HTTPUrlTag, u.String()).
setTag(ctx.proxySpan, SkipperRouteIDTag, ctx.route.Id).
setTag(ctx.proxySpan, NetworkPeerAddressTag, u.Host)
p.setCommonSpanInfo(u, req, ctx.proxySpan)
carrier := ot.HTTPHeadersCarrier(req.Header)
_ = p.tracing.tracer.Inject(ctx.proxySpan.Context(), ot.HTTPHeaders, carrier)
req = req.WithContext(ot.ContextWithSpan(req.Context(), ctx.proxySpan))
p.metrics.IncCounter("outgoing." + req.Proto)
ctx.proxySpan.LogKV("http_roundtrip", StartEvent)
req = injectClientTrace(req, ctx.proxySpan)
p.metrics.MeasureBackendRequestHeader(ctx.metricsHost(), snet.SizeOfRequestHeader(req))
ctx.proxyWatch.Stop()
ctx.proxyRequestLatency = ctx.proxyWatch.Elapsed()
response, err := roundTripper.RoundTrip(req)
ctx.proxyWatch.Reset()
ctx.proxyWatch.Start()
if endpointMetrics != nil {
endpointMetrics.IncRequests(routing.IncRequestsOptions{FailedRoundTrip: err != nil})
}
ctx.proxySpan.LogKV("http_roundtrip", EndEvent)
if err != nil {
if errors.Is(err, skpio.ErrBlocked) {
p.tracing.setTag(ctx.proxySpan, BlockTag, true)
p.tracing.setTag(ctx.proxySpan, HTTPStatusCodeTag, uint16(http.StatusBadRequest))
return nil, &proxyError{err: err, code: http.StatusBadRequest}
}
p.tracing.setTag(ctx.proxySpan, ErrorTag, true)
// Check if the request has been cancelled or timed out
// The roundtrip error `err` may be different:
// - for `Canceled` it could be either the same `context canceled` or `unexpected EOF` (net.OpError)
// - for `DeadlineExceeded` it is net.Error(timeout=true, temporary=true) wrapping this `context deadline exceeded`
if cerr := req.Context().Err(); cerr != nil {
ctx.proxySpan.LogKV("event", "error", "message", ensureUTF8(cerr.Error()))
if cerr == stdlibcontext.Canceled {
return nil, &proxyError{err: cerr, code: 499}
} else if cerr == stdlibcontext.DeadlineExceeded {
return nil, &proxyError{err: cerr, code: http.StatusGatewayTimeout}
}
}
errMessage := err.Error()
ctx.proxySpan.LogKV("event", "error", "message", ensureUTF8(errMessage))
if perr, ok := err.(*proxyError); ok {
perr.err = fmt.Errorf("failed to do backend roundtrip to %s: %w", req.URL.Host, perr.err)
return nil, perr
} else if nerr, ok := err.(net.Error); ok {
var status int
if nerr.Timeout() {
status = http.StatusGatewayTimeout
} else {
status = http.StatusServiceUnavailable
}
p.tracing.setTag(ctx.proxySpan, HTTPStatusCodeTag, uint16(status))
//lint:ignore SA1019 Temporary is deprecated in Go 1.18, but keep it for now (https://github.com/zalando/skipper/issues/1992)
return nil, &proxyError{err: fmt.Errorf("net.Error during backend roundtrip to %s: timeout=%v temporary='%v': %w", req.URL.Host, nerr.Timeout(), nerr.Temporary(), err), code: status}
}
switch errMessage {
case // net/http/internal/chunked.go
"header line too long",
"chunked encoding contains too much non-data",
"malformed chunked encoding",
"empty hex number for chunk length",
"invalid byte in chunk length",
"http chunk length too large":
return nil, &proxyError{code: http.StatusBadRequest, err: fmt.Errorf("failed to do backend roundtrip due to invalid request: %w", err)}
}
return nil, &proxyError{err: fmt.Errorf("unexpected error from Go stdlib net/http package during roundtrip: %w", err)}
}
p.tracing.setTag(ctx.proxySpan, HTTPStatusCodeTag, uint16(response.StatusCode))
if response.Uncompressed {
p.metrics.IncCounter("experimental.uncompressed")
}
return response, nil
}
func (p *Proxy) getRoundTripper(ctx *context, req *http.Request) (http.RoundTripper, error) {
switch req.URL.Scheme {
case "fastcgi":
f := "index.php"
if sf, ok := ctx.StateBag()["fastCgiFilename"]; ok {
f = sf.(string)
} else if len(req.URL.Path) > 1 && req.URL.Path != "/" {
f = req.URL.Path[1:]
}
rt, err := fastcgi.NewRoundTripper(p.log, req.URL.Host, f)
if err != nil {
return nil, err
}
// FastCgi expects the Host to be in form host:port
// It will then be split and added as 2 separate params to the backend process
if _, _, err := net.SplitHostPort(req.Host); err != nil {
req.Host = req.Host + ":" + req.URL.Port()
}
// RemoteAddr is needed to pass to the backend process as param
req.RemoteAddr = ctx.request.RemoteAddr
return rt, nil
default:
return p.roundTripper, nil
}
}
func (p *Proxy) rejectBackend(ctx *context, req *http.Request) (*http.Response, bool) {
limit, ok := ctx.StateBag()[filters.BackendRatelimit].(*ratelimitfilters.BackendRatelimit)
if ok {
s := req.URL.Scheme + "://" + req.URL.Host
if !p.limiters.Get(limit.Settings).Allow(req.Context(), s) {
return &http.Response{
StatusCode: limit.StatusCode,
Header: http.Header{"Content-Length": []string{"0"}},
Body: io.NopCloser(&bytes.Buffer{}),
}, true
}
}
return nil, false
}
func (p *Proxy) checkBreaker(c *context) (func(bool), bool) {
if p.breakers == nil {
return nil, true
}
settings, _ := c.stateBag[circuitfilters.RouteSettingsKey].(circuit.BreakerSettings)
settings.Host = c.outgoingHost
b := p.breakers.Get(settings)
if b == nil {
return nil, true
}
done, ok := b.Allow()
if !ok && c.request.Body != nil {
// consume the body to prevent goroutine leaks
io.Copy(io.Discard, c.request.Body)
}
return done, ok
}
func newRatelimitError(settings ratelimit.Settings, retryAfter int) *proxyError {
return &proxyError{
err: errRatelimit,
code: http.StatusTooManyRequests,
additionalHeader: ratelimit.Headers(settings.MaxHits, settings.TimeWindow, retryAfter),
}
}
// copied from debug.PrintStack
func stack() []byte {
buf := make([]byte, 1024)
for {
n := runtime.Stack(buf, false)
if n < len(buf) {
return buf[:n]
}
buf = make([]byte, 2*len(buf))
}
}
func (p *Proxy) do(ctx *context, parentSpan ot.Span) (err error) {
defer func() {
if r := recover(); r != nil {
p.onPanicSometimes.Do(func() {
ctx.Logger().Errorf("stacktrace of panic caused by: %v:\n%s", r, stack())
})
perr := &proxyError{
err: fmt.Errorf("panic caused by: %v", r),
}
p.makeErrorResponse(ctx, perr)
err = perr
}
}()
if ctx.executionCounter > p.maxLoops {
// TODO(sszuecs): think about setting status code to 463 or 465 (check what AWS ALB sets for redirect loop) or similar
perr := &proxyError{
err: fmt.Errorf("max loopbacks reached after route %s", ctx.route.Id),
}
p.makeErrorResponse(ctx, perr)
return perr
}
// proxy global setting
if !ctx.wasExecuted() {
if settings, retryAfter := p.limiters.Check(ctx.request); retryAfter > 0 {
perr := newRatelimitError(settings, retryAfter)
p.makeErrorResponse(ctx, perr)
return perr
}
}
// every time the context is used for a request the context executionCounter is incremented
// a context executionCounter equal to zero represents a root context.
ctx.executionCounter++
lookupStart := time.Now()
route, params := p.lookupRoute(ctx)
p.metrics.MeasureRouteLookup(lookupStart)
if route == nil {
p.metrics.IncRoutingFailures()
ctx.Logger().Debugf("could not find a route for %v", ctx.request.URL)
p.makeErrorResponse(ctx, errRouteLookupFailed)
return errRouteLookupFailed
}
parentSpan.SetTag(SkipperRouteIDTag, route.Id)
ctx.applyRoute(route, params, p.flags.PreserveHost())
ctx.proxyWatch.Stop()
processedFilters := p.applyFiltersToRequest(ctx.route.Filters, ctx)
ctx.proxyWatch.Start()
if ctx.deprecatedShunted() {
ctx.Logger().Debugf("deprecated shunting detected in route: %s", ctx.route.Id)
return &proxyError{handled: true}
} else if ctx.shunted() || ctx.route.Shunt || ctx.route.BackendType == eskip.ShuntBackend {
// consume the body to prevent goroutine leaks
if ctx.request.Body != nil {
if _, err := io.Copy(io.Discard, ctx.request.Body); err != nil {
ctx.Logger().Debugf("error while discarding remainder request body: %v.", err)
}
}
ctx.ensureDefaultResponse()
} else if ctx.route.BackendType == eskip.LoopBackend {
loopCTX := ctx.clone()
loopSpanOpts := []ot.StartSpanOption{ot.Tags{
SpanKindTag: SpanKindServer,
}}
if parentSpan := ot.SpanFromContext(ctx.request.Context()); parentSpan != nil {
loopSpanOpts = append(loopSpanOpts, ot.ChildOf(parentSpan.Context()))
}
loopSpan := p.tracing.tracer.StartSpan("loopback", loopSpanOpts...)
p.tracing.setTag(loopSpan, SkipperRouteIDTag, ctx.route.Id)
p.setCommonSpanInfo(ctx.Request().URL, ctx.Request(), loopSpan)
ctx.parentSpan = loopSpan
r := loopCTX.Request()
r = r.WithContext(ot.ContextWithSpan(r.Context(), loopSpan))
loopCTX.request = r
defer loopSpan.Finish()
if err := p.do(loopCTX, loopSpan); err != nil {
// in case of error we have to copy the response in this recursion unwinding
ctx.response = loopCTX.response
p.applyFiltersOnError(ctx, processedFilters)
return err
}
ctx.setResponse(loopCTX.response, p.flags.PreserveOriginal())
ctx.proxySpan = loopCTX.proxySpan
} else if p.flags.Debug() {
debugReq, _, err := p.mapRequest(ctx, ctx.request.Context())
if err != nil {
perr := &proxyError{err: err}
p.makeErrorResponse(ctx, perr)
p.applyFiltersOnError(ctx, processedFilters)
return perr
}
ctx.outgoingDebugRequest = debugReq
ctx.setResponse(&http.Response{Header: make(http.Header)}, p.flags.PreserveOriginal())
} else {
done, allow := p.checkBreaker(ctx)
if !allow {
tracing.LogKV("circuit_breaker", "open", ctx.request.Context())
p.makeErrorResponse(ctx, errCircuitBreakerOpen)
p.applyFiltersOnError(ctx, processedFilters)
return errCircuitBreakerOpen
}
backendContext := ctx.request.Context()
if timeout, ok := ctx.StateBag()[filters.BackendTimeout]; ok {
backendContext, ctx.cancelBackendContext = stdlibcontext.WithTimeout(backendContext, timeout.(time.Duration))
}
backendStart := time.Now()
if d, ok := ctx.StateBag()[filters.ReadTimeout].(time.Duration); ok {
e := ctx.ResponseController().SetReadDeadline(backendStart.Add(d))
if e != nil {
ctx.Logger().Errorf("Failed to set read deadline: %v", e)
}
}
rsp, perr := p.makeBackendRequest(ctx, backendContext)
if perr != nil {
if done != nil {
done(false)
}
p.metrics.IncErrorsBackend(ctx.route.Id)
if retryable(ctx, perr) {
if ctx.proxySpan != nil {
ctx.proxySpan.Finish()
ctx.proxySpan = nil
}
tracing.LogKV("retry", ctx.route.Id, ctx.Request().Context())
perr = nil
var perr2 *proxyError
rsp, perr2 = p.makeBackendRequest(ctx, backendContext)
if perr2 != nil {
ctx.Logger().Errorf("Failed to retry backend request: %v", perr2)
if perr2.code >= http.StatusInternalServerError {
p.metrics.MeasureBackend5xx(backendStart)
}
p.makeErrorResponse(ctx, perr2)
p.applyFiltersOnError(ctx, processedFilters)
return perr2
}
} else {
p.makeErrorResponse(ctx, perr)
p.applyFiltersOnError(ctx, processedFilters)
return perr
}
}
if rsp.StatusCode >= http.StatusInternalServerError {
p.metrics.MeasureBackend5xx(backendStart)
}
if done != nil {
done(rsp.StatusCode < http.StatusInternalServerError)
}
ctx.setResponse(rsp, p.flags.PreserveOriginal())
p.metrics.MeasureBackend(ctx.route.Id, backendStart)
p.metrics.MeasureBackendHost(ctx.route.Host, backendStart)
}
addBranding(ctx.response.Header)
ctx.proxyWatch.Stop()
p.applyFiltersToResponse(processedFilters, ctx)
ctx.proxyWatch.Start()
return nil
}
func retryable(ctx *context, perr *proxyError) bool {
req := ctx.Request()
return perr.code != 499 && perr.DialError() &&
ctx.route.BackendType == eskip.LBBackend &&
req != nil && (req.Body == nil || req.Body == http.NoBody)
}
func (p *Proxy) serveResponse(ctx *context) {
if p.flags.Debug() {
dbgResponse(ctx.responseWriter, &debugInfo{
route: &ctx.route.Route,
incoming: ctx.originalRequest,
outgoing: ctx.outgoingDebugRequest,
response: ctx.response,
})
return
}
start := time.Now()
p.tracing.logStreamEvent(ctx.proxySpan, StreamHeadersEvent, StartEvent)
copyHeader(ctx.responseWriter.Header(), ctx.response.Header)
if err := ctx.Request().Context().Err(); err != nil {
// deadline exceeded or canceled in stdlib, client closed request
// see https://github.com/zalando/skipper/pull/864
ctx.Logger().Debugf("Client request: %v", err)
ctx.response.StatusCode = 499
p.tracing.setTag(ctx.proxySpan, ClientRequestStateTag, ClientRequestCanceled)
}
p.tracing.setTag(ctx.initialSpan, HTTPStatusCodeTag, uint16(ctx.response.StatusCode))
ctx.responseWriter.WriteHeader(ctx.response.StatusCode)
ctx.responseWriter.Flush()
p.tracing.logStreamEvent(ctx.proxySpan, StreamHeadersEvent, EndEvent)
ctx.proxyWatch.Stop()
n, err := copyStream(ctx.responseWriter, ctx.response.Body)
ctx.proxyWatch.Start()
p.tracing.logStreamEvent(ctx.proxySpan, StreamBodyEvent, strconv.FormatInt(n, 10))
if err != nil {
p.metrics.IncErrorsStreaming(ctx.route.Id)
ctx.Logger().Debugf("error while copying the response stream: %v", err)
p.tracing.setTag(ctx.proxySpan, ErrorTag, true)
p.tracing.setTag(ctx.proxySpan, StreamBodyEvent, StreamBodyError)
p.tracing.logStreamEvent(ctx.proxySpan, StreamBodyEvent, fmt.Sprintf("Failed to stream response: %v", err))
} else {
p.metrics.MeasureResponse(ctx.response.StatusCode, ctx.request.Method, ctx.route.Id, start)
p.metrics.MeasureResponseSize(ctx.metricsHost(), n)
}
p.metrics.MeasureServe(ctx.route.Id, ctx.metricsHost(), ctx.request.Method, ctx.response.StatusCode, ctx.startServe)
}
func (p *Proxy) errorResponse(ctx *context, err error) {
perr, ok := err.(*proxyError)
if ok && perr.handled {
return
}
flowIdLog := ""
flowId := ctx.Request().Header.Get(flowidFilter.HeaderName)
if flowId != "" {
flowIdLog = fmt.Sprintf(", flow id %s", flowId)
}
id := unknownRouteID
backendType := unknownRouteBackendType
backend := unknownRouteBackend
if ctx.route != nil {
id = ctx.route.Id
backendType = ctx.route.BackendType.String()
backend = fmt.Sprintf("%s://%s", ctx.request.URL.Scheme, ctx.request.URL.Host)
}
if err == errRouteLookupFailed {
ctx.initialSpan.LogKV("event", "error", "message", errRouteLookup.Error())
}
p.tracing.setTag(ctx.initialSpan, ErrorTag, true)
p.tracing.setTag(ctx.initialSpan, HTTPStatusCodeTag, ctx.response.StatusCode)
if p.flags.Debug() {
di := &debugInfo{
incoming: ctx.originalRequest,
outgoing: ctx.outgoingDebugRequest,
response: ctx.response,
err: err,
}
if ctx.route != nil {
di.route = &ctx.route.Route
}
dbgResponse(ctx.responseWriter, di)
return
}
msgPrefix := "error while proxying"
logFunc := ctx.Logger().Errorf
if ctx.response.StatusCode == 499 {
msgPrefix = "client canceled"
logFunc = ctx.Logger().Infof
if p.accessLogDisabled {
logFunc = ctx.Logger().Debugf
}
}
if id != unknownRouteID {
req := ctx.Request()
remoteAddr := remoteHost(req)
uri := req.RequestURI
if i := strings.IndexRune(uri, '?'); i >= 0 {
uri = uri[:i]
}
logFunc(
`%s after %v, route %s with backend %s %s%s, status code %d: %v, remote host: %s, request: "%s %s %s", host: %s, user agent: "%s"`,
msgPrefix,
time.Since(ctx.startServe),
id,
backendType,
backend,
flowIdLog,
ctx.response.StatusCode,
err,
remoteAddr,
req.Method,
uri,
req.Proto,
req.Host,
req.UserAgent(),
)
}
copyHeader(ctx.responseWriter.Header(), ctx.response.Header)
ctx.responseWriter.WriteHeader(ctx.response.StatusCode)
ctx.responseWriter.Flush()
ctx.proxyWatch.Stop()
_, _ = copyStream(ctx.responseWriter, ctx.response.Body)
ctx.proxyWatch.Start()
p.metrics.MeasureServe(
id,
ctx.metricsHost(),
ctx.request.Method,
ctx.response.StatusCode,
ctx.startServe,
)
}
// strip port from addresses with hostname, ipv4 or ipv6
func stripPort(address string) string {
if h, _, err := net.SplitHostPort(address); err == nil {
return h
}
return address
}
// The remote address of the client. When the 'X-Forwarded-For'
// header is set, then it is used instead.
func remoteAddr(r *http.Request) string {
ff := r.Header.Get("X-Forwarded-For")
if ff != "" {
return ff
}
return r.RemoteAddr
}
func remoteHost(r *http.Request) string {
a := remoteAddr(r)
return stripPort(a)
}
func shouldLog(statusCode int, filter *al.AccessLogFilter) bool {
if len(filter.Prefixes) == 0 {
return filter.Enable
}
match := false
for _, prefix := range filter.Prefixes {
switch {
case prefix < 10:
match = (statusCode >= prefix*100 && statusCode < (prefix+1)*100)
case prefix < 100:
match = (statusCode >= prefix*10 && statusCode < (prefix+1)*10)
default:
match = statusCode == prefix
}
if match {
break
}
}
return match == filter.Enable
}
// http.Handler implementation
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
proxyStopWatch := newStopWatch()
proxyStopWatch.Start()
lw := logging.NewLoggingWriter(w)
p.metrics.IncCounter("incoming." + r.Proto)
var ctx *context
spanOpts := []ot.StartSpanOption{ot.Tags{
SpanKindTag: SpanKindServer,
}}
if wireContext, err := p.tracing.tracer.Extract(ot.HTTPHeaders, ot.HTTPHeadersCarrier(r.Header)); err == nil {
spanOpts = append(spanOpts, ext.RPCServerOption(wireContext))
}
span := p.tracing.tracer.StartSpan(p.tracing.initialOperationName, spanOpts...)
defer func() {
if ctx != nil && ctx.proxySpan != nil {
ctx.proxySpan.Finish()
}
span.Finish()
ctx.proxyWatch.Stop()
skipperResponseLatency := ctx.proxyWatch.Elapsed()
p.metrics.MeasureProxy(ctx.proxyRequestLatency, skipperResponseLatency)
}()
defer func() {
accessLogEnabled, ok := ctx.stateBag[al.AccessLogEnabledKey].(*al.AccessLogFilter)
if !ok {
if p.accessLogDisabled {
accessLogEnabled = &disabledAccessLog
} else {
accessLogEnabled = &enabledAccessLog
}
}
statusCode := lw.GetCode()
if shouldLog(statusCode, accessLogEnabled) {
authUser, _ := ctx.stateBag[filterslog.AuthUserKey].(string)
entry := &logging.AccessEntry{
Request: r,
ResponseSize: lw.GetBytes(),
StatusCode: statusCode,
RequestTime: ctx.startServe,
Duration: time.Since(ctx.startServe),
AuthUser: authUser,
}
additionalData, _ := ctx.stateBag[al.AccessLogAdditionalDataKey].(map[string]interface{})
logging.LogAccess(entry, additionalData)
}
// This flush is required in I/O error
if !ctx.successfulUpgrade {
lw.Flush()
}
}()
if p.flags.patchPath() {
r.URL.Path = rfc.PatchPath(r.URL.Path, r.URL.RawPath)
}
p.tracing.setTag(span, HTTPRemoteIPTag, stripPort(r.RemoteAddr))
p.setCommonSpanInfo(r.URL, r, span)
r = r.WithContext(ot.ContextWithSpan(r.Context(), span))
r = r.WithContext(routing.NewContext(r.Context()))
ctx = newContext(lw, r, p, proxyStopWatch)
ctx.startServe = time.Now()
ctx.tracer = p.tracing.tracer
ctx.initialSpan = span
ctx.parentSpan = span
defer func() {
if ctx.response != nil && ctx.response.Body != nil {
err := ctx.response.Body.Close()
if err != nil {
ctx.Logger().Errorf("error during closing the response body: %v", err)
}
}
}()
err := p.do(ctx, span)
// writeTimeout() filter
if d, ok := ctx.StateBag()[filters.WriteTimeout].(time.Duration); ok {
e := ctx.ResponseController().SetWriteDeadline(time.Now().Add(d))
if e != nil {
ctx.Logger().Errorf("Failed to set write deadline: %v", e)
}
}
// stream response body to client
if err != nil {
p.errorResponse(ctx, err)
} else {
p.serveResponse(ctx)
}
// fifoWtihBody() filter
if sbf, ok := ctx.StateBag()[filters.FifoWithBodyName]; ok {
if fs, ok := sbf.([]func()); ok {
for i := len(fs) - 1; i >= 0; i-- {
fs[i]()
}
}
}
if ctx.cancelBackendContext != nil {
ctx.cancelBackendContext()
}
}
// Close causes the proxy to stop closing idle
// connections and, currently, has no other effect.
// It's primary purpose is to support testing.
func (p *Proxy) Close() error {
close(p.quit)
p.registry.Close()
return nil
}
func (p *Proxy) setCommonSpanInfo(u *url.URL, r *http.Request, s ot.Span) {
p.tracing.
setTag(s, ComponentTag, "skipper").
setTag(s, HTTPMethodTag, r.Method).
setTag(s, HostnameTag, p.hostname).
setTag(s, HTTPPathTag, u.Path).
setTag(s, HTTPHostTag, r.Host)
if val := r.Header.Get("X-Flow-Id"); val != "" {
p.tracing.setTag(s, FlowIDTag, val)
}
}
// TODO(sszuecs): copy from net.Client, we should refactor this to use net.Client
func injectClientTrace(req *http.Request, span ot.Span) *http.Request {
trace := &httptrace.ClientTrace{
DNSStart: func(httptrace.DNSStartInfo) {
span.LogKV("DNS", "start")
},
DNSDone: func(httptrace.DNSDoneInfo) {
span.LogKV("DNS", "end")
},
ConnectStart: func(string, string) {
span.LogKV("connect", "start")
},
ConnectDone: func(string, string, error) {
span.LogKV("connect", "end")
},
TLSHandshakeStart: func() {
span.LogKV("TLS", "start")
},
TLSHandshakeDone: func(tls.ConnectionState, error) {
span.LogKV("TLS", "end")
},
GetConn: func(string) {
span.LogKV("get_conn", "start")
},
GotConn: func(httptrace.GotConnInfo) {
span.LogKV("get_conn", "end")
},
WroteHeaders: func() {
span.LogKV("wrote_headers", "done")
},
WroteRequest: func(wri httptrace.WroteRequestInfo) {
if wri.Err != nil {
span.LogKV("wrote_request", ensureUTF8(wri.Err.Error()))
} else {
span.LogKV("wrote_request", "done")
}
},
GotFirstResponseByte: func() {
span.LogKV("got_first_byte", "done")
},
}
return req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
}
func ensureUTF8(s string) string {
if utf8.ValidString(s) {
return s
}
return fmt.Sprintf("invalid utf-8: %q", s)
}
func (p *Proxy) makeErrorResponse(ctx *context, perr *proxyError) {
ctx.response = &http.Response{
Header: http.Header{},
}
if len(perr.additionalHeader) > 0 {
copyHeader(ctx.response.Header, perr.additionalHeader)
}
addBranding(ctx.response.Header)
ctx.response.Header.Set("Content-Type", "text/plain; charset=utf-8")
ctx.response.Header.Set("X-Content-Type-Options", "nosniff")
code := http.StatusInternalServerError
switch {
case perr == errRouteLookupFailed:
code = p.defaultHTTPStatus
case perr.code == -1:
// -1 == dial connection refused
code = http.StatusBadGateway
case perr.code != 0:
code = perr.code
}
text := http.StatusText(code) + "\n"
ctx.response.Header.Set("Content-Length", strconv.Itoa(len(text)))
ctx.response.StatusCode = code
ctx.response.Body = io.NopCloser(bytes.NewBufferString(text))
}
// errorHandlerFilter is an opt-in for filters to get called
// Response(ctx) in case of errors.
type errorHandlerFilter interface {
// HandleErrorResponse returns true in case a filter wants to get called
HandleErrorResponse() bool
}
func (p *Proxy) applyFiltersOnError(ctx *context, filters []*routing.RouteFilter) {
filtersStart := time.Now()
filterTracing := p.tracing.startFilterTracing("response_filters", ctx)
defer filterTracing.finish()
for i := len(filters) - 1; i >= 0; i-- {
fi := filters[i]
if ehf, ok := fi.Filter.(errorHandlerFilter); !ok || !ehf.HandleErrorResponse() {
continue
}
ctx.Logger().Debugf("filter %s handles error", fi.Name)
start := time.Now()
filterTracing.logStart(fi.Name)
ctx.setMetricsPrefix(fi.Name)
fi.Response(ctx)
p.metrics.MeasureFilterResponse(fi.Name, start)
filterTracing.logEnd(fi.Name)
}
p.metrics.MeasureAllFiltersResponse(ctx.route.Id, filtersStart)
}
package proxy
import "time"
type stopWatch struct {
now func() time.Time
started time.Time
elapsed time.Duration
}
func newStopWatch() *stopWatch {
return &stopWatch{
now: time.Now,
}
}
func (s *stopWatch) Start() {
if s.started.IsZero() {
s.started = s.now()
}
}
func (s *stopWatch) Stop() {
if !s.started.IsZero() {
now := s.now()
s.elapsed += now.Sub(s.started)
s.started = time.Time{}
}
}
func (s *stopWatch) Reset() {
s.started = time.Time{}
s.elapsed = 0
}
func (s *stopWatch) Elapsed() time.Duration {
return s.elapsed
}
package proxy
import (
"io"
"net/http"
"net/url"
log "github.com/sirupsen/logrus"
)
type teeTie struct {
r io.Reader
w *io.PipeWriter
}
func (tt *teeTie) Read(b []byte) (int, error) {
n, err := tt.r.Read(b)
if err != nil && err != io.EOF {
tt.w.CloseWithError(err)
return n, err
}
if n > 0 {
if _, werr := tt.w.Write(b[:n]); werr != nil {
log.Error("tee: error while tee request", werr)
}
}
if err == io.EOF {
tt.w.Close()
}
return n, err
}
func (tt *teeTie) Close() error { return nil }
// Returns the cloned request and the tee body to be used on the main request.
func cloneRequestForSplit(u *url.URL, req *http.Request) (*http.Request, io.ReadCloser, error) {
h := make(http.Header)
for k, v := range req.Header {
h[k] = v
}
var teeBody io.ReadCloser
mainBody := req.Body
if req.ContentLength != 0 {
pr, pw := io.Pipe()
teeBody = pr
mainBody = &teeTie{mainBody, pw}
}
clone, err := http.NewRequest(req.Method, u.String(), teeBody)
if err != nil {
return nil, nil, err
}
clone.RequestURI = req.RequestURI
clone.Header = h
clone.ContentLength = req.ContentLength
clone.RemoteAddr = req.RemoteAddr
return clone, mainBody, nil
}
package proxy
import (
ot "github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/tracing"
)
const (
ClientRequestStateTag = "client.request"
ComponentTag = "component"
ErrorTag = "error"
BlockTag = "blocked"
FlowIDTag = "flow_id"
HostnameTag = "hostname"
HTTPHostTag = "http.host"
HTTPMethodTag = "http.method"
HTTPRemoteIPTag = "http.remote_ip"
HTTPPathTag = "http.path"
HTTPUrlTag = "http.url"
NetworkPeerAddressTag = "network.peer.address"
HTTPStatusCodeTag = "http.status_code"
SkipperRouteIDTag = "skipper.route_id"
SpanKindTag = "span.kind"
ClientRequestCanceled = "canceled"
SpanKindClient = "client"
SpanKindServer = "server"
EndEvent = "end"
StartEvent = "start"
StreamHeadersEvent = "stream_Headers"
StreamBodyEvent = "streamBody.byte"
StreamBodyError = "streamBody error"
)
type proxyTracing struct {
tracer ot.Tracer
initialOperationName string
disableFilterSpans bool
logFilterLifecycleEvents bool
logStreamEvents bool
excludeTags map[string]bool
}
type filterTracing struct {
span ot.Span
logEvents bool
}
func newProxyTracing(p *OpenTracingParams) *proxyTracing {
if p == nil {
p = &OpenTracingParams{}
}
if p.InitialSpan == "" {
p.InitialSpan = "ingress"
}
if p.Tracer == nil {
p.Tracer = &ot.NoopTracer{}
}
excludedTags := map[string]bool{}
for _, t := range p.ExcludeTags {
excludedTags[t] = true
}
return &proxyTracing{
tracer: p.Tracer,
initialOperationName: p.InitialSpan,
disableFilterSpans: p.DisableFilterSpans,
logFilterLifecycleEvents: p.LogFilterEvents,
logStreamEvents: p.LogStreamEvents,
excludeTags: excludedTags,
}
}
func (t *proxyTracing) logEvent(span ot.Span, eventName, eventValue string) {
if span == nil {
return
}
span.LogKV(eventName, ensureUTF8(eventValue))
}
func (t *proxyTracing) setTag(span ot.Span, key string, value interface{}) *proxyTracing {
if span == nil {
return t
}
if !t.excludeTags[key] {
if s, ok := value.(string); ok {
span.SetTag(key, ensureUTF8(s))
} else {
span.SetTag(key, value)
}
}
return t
}
func (t *proxyTracing) logStreamEvent(span ot.Span, eventName, eventValue string) {
if !t.logStreamEvents {
return
}
t.logEvent(span, eventName, ensureUTF8(eventValue))
}
func (t *proxyTracing) startFilterTracing(operation string, ctx *context) *filterTracing {
if t.disableFilterSpans {
return nil
}
span := tracing.CreateSpan(operation, ctx.request.Context(), t.tracer)
ctx.parentSpan = span
return &filterTracing{span, t.logFilterLifecycleEvents}
}
func (t *filterTracing) finish() {
if t != nil {
t.span.Finish()
}
}
func (t *filterTracing) logStart(filterName string) {
if t != nil && t.logEvents {
t.span.LogKV(filterName, StartEvent)
}
}
func (t *filterTracing) logEnd(filterName string) {
if t != nil && t.logEvents {
t.span.LogKV(filterName, EndEvent)
}
}
package proxy
import (
"bufio"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"maps"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
)
// isUpgradeRequest returns true if and only if there is a "Connection"
// key with the value "Upgrade" in Headers of the given request.
func isUpgradeRequest(req *http.Request) bool {
for _, h := range req.Header[http.CanonicalHeaderKey("Connection")] {
if strings.Contains(strings.ToLower(h), "upgrade") {
return true
}
}
return false
}
// getUpgradeRequest returns the protocol name from the upgrade header
func getUpgradeRequest(req *http.Request) string {
for _, h := range req.Header[http.CanonicalHeaderKey("Connection")] {
if strings.Contains(strings.ToLower(h), "upgrade") {
return strings.Join(req.Header[h], " ")
}
}
return ""
}
// UpgradeProxy stores everything needed to make the connection upgrade.
type upgradeProxy struct {
backendAddr *url.URL
reverseProxy *httputil.ReverseProxy
tlsClientConfig *tls.Config
insecure bool
useAuditLog bool
auditLogOut io.Writer
auditLogErr io.Writer
auditLogHook chan struct{}
}
// TODO: add user here
type auditLog struct {
Method string `json:"method"`
Path string `json:"path"`
Query string `json:"query"`
Fragment string `json:"fragment"`
}
// serveHTTP establishes a bidirectional connection, creates an
// auditlog for the request target, copies the data back and force and
// write data to an auditlog. It will not return until the connection
// is closed.
func (p *upgradeProxy) serveHTTP(w http.ResponseWriter, req *http.Request) {
// The following check is based on
// https://tools.ietf.org/html/rfc2616#section-14.42
// https://tools.ietf.org/html/rfc7230#section-6.7
// and https://tools.ietf.org/html/rfc6455 (websocket)
if (req.ProtoMajor <= 1 && req.ProtoMinor < 1) ||
!isUpgradeRequest(req) ||
req.Header.Get("Upgrade") == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(http.StatusText(http.StatusBadRequest)))
return
}
backendConn, err := p.dialBackend(req)
if err != nil {
log.Errorf("Error connecting to backend: %s", err)
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
return
}
defer backendConn.Close()
err = req.Write(backendConn)
if err != nil {
log.Errorf("Error writing request to backend: %s", err)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(http.StatusText(http.StatusInternalServerError)))
return
}
// Audit-Log
if p.useAuditLog {
auditlog := &auditLog{req.Method, req.URL.Path, req.URL.RawQuery, req.URL.Fragment}
auditJSON, err := json.Marshal(auditlog)
if err == nil {
_, err = p.auditLogErr.Write(auditJSON)
}
if err != nil {
log.Errorf("Could not write audit-log, caused by: %v", err)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(http.StatusText(http.StatusInternalServerError)))
return
}
}
resp, err := http.ReadResponse(bufio.NewReader(backendConn), req)
if err != nil {
log.Errorf("Error reading response from backend: %s", err)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(http.StatusText(http.StatusInternalServerError)))
return
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized {
log.Debugf("Got unauthorized error from backend for: %s %s", req.Method, req.URL)
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(http.StatusText(http.StatusUnauthorized)))
return
}
if resp.StatusCode != http.StatusSwitchingProtocols {
log.Debugf("Got invalid status code from backend: %d", resp.StatusCode)
maps.Copy(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
_, err := io.Copy(w, resp.Body)
if err != nil {
log.Errorf("Error writing body to client: %s", err)
return
}
return
}
// Backend sent Connection: close
if resp.Close {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
return
}
requestHijackedConn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Errorf("Error hijacking request connection: %s", err)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(http.StatusText(http.StatusInternalServerError)))
return
}
defer requestHijackedConn.Close()
// NOTE: from this point forward, we own the connection and we can't use
// w.Header(), w.Write(), or w.WriteHeader any more
err = resp.Write(requestHijackedConn)
if err != nil {
log.Errorf("Error writing backend response to client: %s", err)
return
}
done := make(chan struct{}, 2)
if p.useAuditLog {
copyAsync("backend->request+audit", backendConn, io.MultiWriter(p.auditLogOut, requestHijackedConn), done)
} else {
copyAsync("backend->request", backendConn, requestHijackedConn, done)
}
copyAsync("request->backend", requestHijackedConn, backendConn, done)
log.Debugf("Successfully upgraded to protocol %s by user request", getUpgradeRequest(req))
// Wait for either copyAsync to complete.
// Return from this method closes both request and backend connections via defer
// and thus unblocks the second copyAsync.
<-done
if p.useAuditLog {
select {
case p.auditLogHook <- struct{}{}:
default:
}
}
}
func (p *upgradeProxy) dialBackend(req *http.Request) (net.Conn, error) {
dialAddr := canonicalAddr(req.URL)
switch p.backendAddr.Scheme {
case "http":
return net.Dial("tcp", dialAddr)
case "https":
tlsConn, err := tls.Dial("tcp", dialAddr, p.tlsClientConfig)
if err != nil {
return nil, err
}
if !p.insecure {
hostToVerify, _, err := net.SplitHostPort(dialAddr)
if err != nil {
return nil, err
}
err = tlsConn.VerifyHostname(hostToVerify)
if err != nil {
tlsConn.Close()
return nil, err
}
}
return tlsConn, nil
default:
return nil, fmt.Errorf("unknown scheme: %s", p.backendAddr.Scheme)
}
}
func copyAsync(dir string, src io.Reader, dst io.Writer, done chan<- struct{}) {
go func() {
_, err := io.Copy(dst, src)
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Errorf("error copying data %s: %v", dir, err)
}
done <- struct{}{}
}()
}
// FROM: http://golang.org/src/net/http/client.go
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
// return true if the string includes a port.
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
// FROM: http://golang.org/src/net/http/transport.go
var portMap = map[string]string{
"http": "80",
"https": "443",
}
// FROM: http://golang.org/src/net/http/transport.go
// canonicalAddr returns url.Host but always with a ":port" suffix
func canonicalAddr(url *url.URL) string {
addr := url.Host
if !hasPort(addr) {
return addr + ":" + portMap[url.Scheme]
}
return addr
}
package queuelistener
import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/zalando/skipper/logging"
"github.com/zalando/skipper/metrics"
)
const (
initialBounceDelay = 500 * time.Microsecond
maxBounceDelay = 100 * time.Millisecond
defaultMemoryLimitBytes = 150 * 1000 * 1000
defaultConnectionBytes = 50 * 1000
queueTimeoutPrecisionPercentage = 5
maxCalculatedQueueSize = 50_000
acceptedConnectionsKey = "listener.accepted.connections"
queuedConnectionsKey = "listener.queued.connections"
acceptLatencyKey = "listener.accept.latency"
)
type external struct {
net.Conn
accepted time.Time
}
type connection struct {
*external
queueDeadline time.Time
release chan<- struct{}
quit <-chan struct{}
once sync.Once
closeErr error
}
// Options are used to initialize the queue listener.
type Options struct {
// Network sets the name of the network. Same as for net.Listen().
Network string
// Address sets the listener address, e.g. :9090. Same as for net.Listen().
Address string
// MaxConcurrency sets the maximum accepted connections.
MaxConcurrency int
// MaxQueue size sets the maximum allowed queue size of pending connections. When
// not set, it is derived from the MaxConcurrency value.
MaxQueueSize int
// MemoryLimitBytes sets the approximated maximum memory used by the accepted
// connections, calculated together with the ConnectionBytes value. Defaults to
// 150MB.
//
// When MaxConcurrency is set, this field is ignored.
MemoryLimitBytes int64
// ConnectionBytes is used to calculate the MaxConcurrency when MaxConcurrency is
// not set explicitly but calculated from MemoryLimitBytes.
ConnectionBytes int
// QueueTimeout set the time limit for pending connections spent in the queue. It
// should be set to a similar value as the ReadHeaderTimeout of net/http.Server.
QueueTimeout time.Duration
// Metrics is used to collect monitoring data about the queue, including the current
// concurrent connections and the number of connections in the queue.
Metrics metrics.Metrics
// Log is used to log unexpected, non-fatal errors. It defaults to logging.DefaultLog.
Log logging.Logger
testQueueChangeHook chan struct{}
}
type listener struct {
options Options
maxConcurrency int64
maxQueueSize int64
externalListener net.Listener
acceptExternal chan *external
externalError chan error
acceptInternal chan *connection
internalError chan error
releaseConnection chan struct{}
quit chan struct{}
closeMx sync.Mutex
closedHook chan struct{} // for testing
}
var (
token struct{}
errListenerClosed = errors.New("listener closed")
)
func (c *connection) Close() error {
c.once.Do(func() {
select {
case c.release <- token:
case <-c.quit:
}
c.closeErr = c.external.Close()
})
return c.closeErr
}
func (o Options) maxConcurrency() int64 {
if o.MaxConcurrency > 0 {
return int64(o.MaxConcurrency)
}
maxConcurrency := o.MemoryLimitBytes / int64(o.ConnectionBytes)
// theoretical minimum, but rather only for testing. When the max concurrency is not set, then the
// TCP-LIFO should not be used, at all.
if maxConcurrency <= 0 {
maxConcurrency = 1
}
return maxConcurrency
}
func (o Options) maxQueueSize() int64 {
if o.MaxQueueSize > 0 {
return int64(o.MaxQueueSize)
}
maxQueueSize := 10 * o.maxConcurrency()
if maxQueueSize > maxCalculatedQueueSize {
maxQueueSize = maxCalculatedQueueSize
}
return maxQueueSize
}
func listenWith(nl net.Listener, o Options) (net.Listener, error) {
if o.Log == nil {
o.Log = &logging.DefaultLog{}
}
if o.MemoryLimitBytes <= 0 {
o.MemoryLimitBytes = defaultMemoryLimitBytes
}
if o.ConnectionBytes <= 0 {
o.ConnectionBytes = defaultConnectionBytes
}
l := &listener{
options: o,
maxConcurrency: o.maxConcurrency(),
maxQueueSize: o.maxQueueSize(),
externalListener: nl,
acceptExternal: make(chan *external),
externalError: make(chan error),
acceptInternal: make(chan *connection),
internalError: make(chan error),
releaseConnection: make(chan struct{}),
quit: make(chan struct{}),
}
o.Log.Infof("TCP lifo listener config: %s", l)
go l.listenExternal()
go l.listenInternal()
return l, nil
}
// Listen creates and initializes a listener that can be used to limit the
// concurrently accepted incoming client connections.
//
// The queue listener will return only a limited number of concurrent connections
// by its Accept() method, defined by the max concurrency configuration. When the
// max concurrency is reached, the Accept() method will block until one or more
// accepted connections are closed. When the max concurrency limit is reached, the
// new incoming client connections are stored in a queue. When an active (accepted)
// connection is closed, the listener will return the most recent one from the
// queue (LIFO). When the queue is full, the oldest pending connection is closed
// and dropped, and the new one is inserted into the queue.
//
// The listener needs to be closed in order to release local resources. After it is
// closed, Accept() returns an error without blocking.
//
// See type Options for info about the configuration of the listener.
func Listen(o Options) (net.Listener, error) {
nl, err := net.Listen(o.Network, o.Address)
if err != nil {
return nil, err
}
return listenWith(nl, o)
}
func bounce(delay time.Duration) time.Duration {
if delay == 0 {
return initialBounceDelay
}
delay *= 2
if delay > maxBounceDelay {
delay = maxBounceDelay
}
return delay
}
func (l *listener) String() string {
return fmt.Sprintf("concurrency: %d, queue size: %d, memory limit: %d, bytes per connection: %d, queue timeout: %s", l.maxConcurrency, l.maxQueueSize, l.options.MemoryLimitBytes, l.options.ConnectionBytes, l.options.QueueTimeout)
}
// this function turns net.Listener.Accept() into a channel, so that we can use select{} while it is blocked
func (l *listener) listenExternal() {
var (
c net.Conn
ex *external
err error
delay time.Duration
acceptExternal chan<- *external
externalError chan<- error
retry <-chan time.Time
)
for {
c, err = l.externalListener.Accept()
if err != nil {
// based on net/http.Server.Serve():
//lint:ignore SA1019 Temporary is deprecated in Go 1.18, but keep it for now (https://github.com/zalando/skipper/issues/1992)
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
delay = bounce(delay)
l.options.Log.Errorf(
"queue listener: accept error: %v, retrying in %v",
err,
delay,
)
err = nil
acceptExternal = nil
externalError = nil
retry = time.After(delay)
} else {
acceptExternal = nil
externalError = l.externalError
retry = nil
delay = 0
}
} else {
acceptExternal = l.acceptExternal
ex = &external{c, time.Now()}
externalError = nil
retry = nil
delay = 0
}
select {
case acceptExternal <- ex:
case externalError <- err:
// we cannot accept anymore, but we have returned the permanent error
return
case <-retry:
case <-l.quit:
if c != nil {
c.Close()
}
return
}
}
}
func (l *listener) listenInternal() {
var (
concurrency int64
queue *ring
err error
acceptInternal chan<- *connection
internalError chan<- error
nextTimeout <-chan time.Time
)
queue = newRing(l.maxQueueSize)
for {
var nextConn *connection
if queue.size > 0 && concurrency < l.maxConcurrency {
acceptInternal = l.acceptInternal
nextConn = queue.peek().(*connection)
} else {
acceptInternal = nil
}
if err != nil && queue.size == 0 {
internalError = l.internalError
} else {
internalError = nil
}
// setting the timeout period to a fixed min value, that is a percentage of the queue timeout.
// This way we can avoid for one too many rapid timeout events of stalled connections, and
// second, we can also ensure a certain precision of the timeouts and the minimum queue
// timeout.
if l.options.QueueTimeout > 0 && nextTimeout == nil {
nextTimeout = time.After(
l.options.QueueTimeout * queueTimeoutPrecisionPercentage / 100,
)
}
if l.options.Metrics != nil {
l.options.Metrics.UpdateGauge(acceptedConnectionsKey, float64(concurrency))
l.options.Metrics.UpdateGauge(queuedConnectionsKey, float64(queue.size))
}
select {
case conn := <-l.acceptExternal:
cc := &connection{
external: conn,
release: l.releaseConnection,
quit: l.quit,
once: sync.Once{},
}
if l.options.QueueTimeout > 0 {
cc.queueDeadline = time.Now().Add(l.options.QueueTimeout)
}
drop := queue.enqueue(cc)
if drop != nil {
drop.(*connection).external.Close()
}
l.testNotifyQueueChange()
case err = <-l.externalError:
case acceptInternal <- nextConn:
queue.dequeue()
concurrency++
l.testNotifyQueueChange()
case internalError <- err:
// we cannot accept anymore, but we returned the permanent error
err = nil
l.Close()
case <-l.releaseConnection:
concurrency--
case now := <-nextTimeout:
var dropped int
for queue.size > 0 && queue.peekOldest().(*connection).queueDeadline.Before(now) {
drop := queue.dequeueOldest()
drop.(*connection).external.Close()
}
nextTimeout = nil
l.testNotifyQueueChange()
if dropped > 0 {
l.testNotifyQueueChange()
}
case <-l.quit:
queue.rangeOver(func(c net.Conn) { c.(*connection).external.Close() })
// Closing the real listener in a separate goroutine is based on inspecting the
// stdlib. It's fair to just log the errors.
if err := l.externalListener.Close(); err != nil {
l.options.Log.Errorf("Failed to close network listener: %v.", err)
}
if l.closedHook != nil {
close(l.closedHook)
}
return
}
}
}
func (l *listener) Accept() (net.Conn, error) {
select {
case c := <-l.acceptInternal:
if l.options.Metrics != nil {
l.options.Metrics.MeasureSince(acceptLatencyKey, c.external.accepted)
}
return c, nil
case err := <-l.internalError:
return nil, err
case <-l.quit:
return nil, errListenerClosed
}
}
func (l *listener) Addr() net.Addr {
return l.externalListener.Addr()
}
func (l *listener) Close() error {
// allow closing concurrently as net/http.Server may or may not close it and avoid panic on
// close(l.quit)
l.closeMx.Lock()
defer l.closeMx.Unlock()
select {
case <-l.quit:
default:
close(l.quit)
}
return nil
}
func (l *listener) testNotifyQueueChange() {
if l.options.testQueueChangeHook == nil {
return
}
select {
case l.options.testQueueChangeHook <- token:
default:
}
}
func (l *listener) clearQueueChangeHook() {
if l.options.testQueueChangeHook == nil {
return
}
for {
select {
case <-l.options.testQueueChangeHook:
default:
return
}
}
}
package queuelistener
import "net"
type ring struct {
connections []net.Conn
next int
size int
}
func newRing(maxSize int64) *ring {
return &ring{connections: make([]net.Conn, maxSize)}
}
func (r *ring) peek() net.Conn {
i := r.next - 1
if i < 0 {
i = len(r.connections) - 1
}
return r.connections[i]
}
func (r *ring) peekOldest() net.Conn {
i := r.next - r.size
if i < 0 {
i += len(r.connections)
}
return r.connections[i]
}
func (r *ring) enqueue(c net.Conn) (oldest net.Conn) {
if r.size == len(r.connections) {
oldest = r.connections[r.next]
} else {
r.size++
}
r.connections[r.next] = c
r.next++
if r.next == len(r.connections) {
r.next = 0
}
return
}
func (r *ring) dequeue() net.Conn {
r.next--
if r.next < 0 {
r.next = len(r.connections) - 1
}
var c net.Conn
c, r.connections[r.next] = r.connections[r.next], nil
r.size--
return c
}
func (r *ring) dequeueOldest() net.Conn {
i := r.next - r.size
if i < 0 {
i += len(r.connections)
}
var c net.Conn
c, r.connections[i] = r.connections[i], nil
r.size--
return c
}
func (r *ring) rangeOver(f func(net.Conn)) {
start := r.next - r.size
if start < 0 {
start = len(r.connections) + start
}
finish := start + r.size
if finish >= len(r.connections) {
finish = len(r.connections)
}
for i := start; i < finish; i++ {
f(r.connections[i])
}
finish = r.size + start - finish
start = 0
for i := start; i < finish; i++ {
f(r.connections[i])
}
}
package ratelimit
import "github.com/zalando/skipper/net"
const (
swarmPrefix = `ratelimit.`
swarmKeyFormat = swarmPrefix + "%s.%s"
)
// newClusterRateLimiter will return a limiter instance, that has a
// cluster wide knowledge of ongoing requests. Settings are the normal
// ratelimit settings, Swarmer is an instance satisfying the Swarmer
// interface, which is one of swarm.Swarm or noopSwarmer,
// swarm.Options to configure a swarm.Swarm, RedisOptions to configure
// redis.Ring and group is the ratelimit group that can span one or
// multiple routes.
func newClusterRateLimiter(s Settings, sw Swarmer, ring *net.RedisRingClient, group string) limiter {
if sw != nil {
if l := newClusterRateLimiterSwim(s, sw, group); l != nil {
return l
}
}
if ring != nil {
if l := newClusterRateLimiterRedis(s, ring, group); l != nil {
return l
}
}
return voidRatelimit{}
}
package ratelimit
import (
"context"
_ "embed"
"fmt"
"time"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/zalando/skipper/metrics"
"github.com/zalando/skipper/net"
)
type ClusterLeakyBucket struct {
capacity int
emission time.Duration
labelPrefix string
script *net.RedisScript
ringClient *net.RedisRingClient
metrics metrics.Metrics
now func() time.Time
}
const (
leakyBucketRedisKeyPrefix = "lkb."
leakyBucketMetricPrefix = "leakybucket.redis."
leakyBucketMetricLatency = leakyBucketMetricPrefix + "latency"
leakyBucketSpanName = "redis_leakybucket"
)
// Implements leaky bucket algorithm as a Redis lua script.
// Redis guarantees that a script is executed in an atomic way:
// no other script or Redis command will be executed while a script is being executed.
//
// Possible optimization: substitute capacity and emission in script source code
// on script creation in order not to send them over the wire on each call.
// This way every distinct bucket configuration will get its own script.
//
// See https://redis.io/commands/eval
//
//go:embed leakybucket.lua
var leakyBucketScript string
// NewClusterLeakyBucket creates a class of leaky buckets of a given capacity and emission.
// Emission is the reciprocal of the leak rate and equals the time to leak one unit.
//
// The leaky bucket is an algorithm based on an analogy of how a bucket with a constant leak will overflow if either
// the average rate at which water is poured in exceeds the rate at which the bucket leaks or if more water than
// the capacity of the bucket is poured in all at once.
// See https://en.wikipedia.org/wiki/Leaky_bucket
func NewClusterLeakyBucket(r *Registry, capacity int, emission time.Duration) *ClusterLeakyBucket {
return newClusterLeakyBucket(r.redisRing, capacity, emission, time.Now)
}
func newClusterLeakyBucket(ringClient *net.RedisRingClient, capacity int, emission time.Duration, now func() time.Time) *ClusterLeakyBucket {
return &ClusterLeakyBucket{
capacity: capacity,
emission: emission,
labelPrefix: fmt.Sprintf("%d-%v-", capacity, emission),
script: ringClient.NewScript(leakyBucketScript),
ringClient: ringClient,
metrics: metrics.Default,
now: now,
}
}
// Add adds an increment amount to the bucket identified by the label.
// It returns true if the amount was successfully added to the bucket or a time to wait for the next attempt.
// It also returns any error occurred during the attempt.
func (b *ClusterLeakyBucket) Add(ctx context.Context, label string, increment int) (added bool, retry time.Duration, err error) {
if increment > b.capacity {
// not allowed to add more than capacity and retry is not possible
return false, 0, nil
}
now := b.now()
span := b.startSpan(ctx)
defer span.Finish()
defer b.metrics.MeasureSince(leakyBucketMetricLatency, now)
added, retry, err = b.add(ctx, label, increment, now)
if err != nil {
ext.Error.Set(span, true)
}
return
}
func (b *ClusterLeakyBucket) add(ctx context.Context, label string, increment int, now time.Time) (added bool, retry time.Duration, err error) {
r, err := b.ringClient.RunScript(ctx, b.script,
[]string{b.getBucketId(label)},
b.capacity,
b.emission.Microseconds(),
increment,
now.UnixMicro(),
)
if err == nil {
x := r.(int64)
if x >= 0 {
added, retry = true, 0
} else {
added, retry = false, -time.Duration(x)*time.Microsecond
}
}
return
}
func (b *ClusterLeakyBucket) getBucketId(label string) string {
return leakyBucketRedisKeyPrefix + getHashedKey(b.labelPrefix+label)
}
func (b *ClusterLeakyBucket) startSpan(ctx context.Context) (span opentracing.Span) {
spanOpts := []opentracing.StartSpanOption{opentracing.Tags{
string(ext.Component): "skipper",
string(ext.SpanKind): "client",
}}
if parent := opentracing.SpanFromContext(ctx); parent != nil {
spanOpts = append(spanOpts, opentracing.ChildOf(parent.Context()))
}
return b.ringClient.StartSpan(leakyBucketSpanName, spanOpts...)
}
package ratelimit
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strconv"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
circularbuffer "github.com/szuecs/rate-limit-buffer"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/net"
)
const (
// Header is
Header = "X-Rate-Limit"
// RetryHeader is name of the header which will be used to indicate how
// long a client should wait before making a new request
RetryAfterHeader = "Retry-After"
// Deprecated, use filters.RatelimitName instead
ServiceRatelimitName = filters.RatelimitName
// LocalRatelimitName *DEPRECATED*, use ClientRatelimitName instead
LocalRatelimitName = "localRatelimit"
// Deprecated, use filters.ClientRatelimitName instead
ClientRatelimitName = filters.ClientRatelimitName
// Deprecated, use filters.ClusterRatelimitName instead
ClusterServiceRatelimitName = filters.ClusterRatelimitName
// Deprecated, use filters.ClusterClientRatelimitName instead
ClusterClientRatelimitName = filters.ClusterClientRatelimitName
// Deprecated, use filters.DisableRatelimitName instead
DisableRatelimitName = filters.DisableRatelimitName
// Deprecated, use filters.UnknownRatelimitName instead
UknownRatelimitName = filters.UnknownRatelimitName
sameBucket = "s"
)
// RatelimitType defines the type of the used ratelimit
type RatelimitType int
func (rt *RatelimitType) UnmarshalYAML(unmarshal func(interface{}) error) error {
var value string
if err := unmarshal(&value); err != nil {
return err
}
switch value {
case "local":
log.Warning("LocalRatelimit is deprecated, please use ClientRatelimit instead")
fallthrough
case "client":
*rt = ClientRatelimit
case "service":
*rt = ServiceRatelimit
case "clusterClient":
*rt = ClusterClientRatelimit
case "clusterService":
*rt = ClusterServiceRatelimit
case "disabled":
*rt = DisableRatelimit
default:
return fmt.Errorf("invalid ratelimit type %v (allowed values are: client, service or disabled)", value)
}
return nil
}
const (
// NoRatelimit is not used
NoRatelimit RatelimitType = iota
// ServiceRatelimit is used to have a simple rate limit for a
// backend service, which is calculated and measured within
// each instance
ServiceRatelimit
// LocalRatelimit *DEPRECATED* will be replaced by ClientRatelimit
LocalRatelimit
// ClientRatelimit is used to have a simple local rate limit
// per user for a backend, which is calculated and measured
// within each instance. One filter consumes memory calculated
// by the following formular, where N is the number of
// individual clients put into the same bucket, M the maximum
// number of requests allowed:
//
// memory = N * M * 15 byte
//
// For example /login protection 100.000 attacker, 10 requests
// for 1 hour will use roughly 14.6 MB.
ClientRatelimit
// ClusterServiceRatelimit is used to calculate a rate limit
// for a whole skipper fleet for a backend service, needs
// swarm to be enabled with -enable-swarm.
ClusterServiceRatelimit
// ClusterClientRatelimit is used to calculate a rate limit
// for a whole skipper fleet per user for a backend, needs
// swarm to be enabled with -enable-swarm. In case of redis it
// will not consume more memory.
// In case of swim based cluster ratelimit, one filter
// consumes memory calculated by the following formular, where
// N is the number of individual clients put into the same
// bucket, M the maximum number of requests allowed, S the
// number of skipper peers:
//
// memory = N * M * 15 + S * len(peername)
//
// For example /login protection 100.000 attacker, 10 requests
// for 1 hour, 100 skipper peers with each a name of 8
// characters will use roughly 14.7 MB.
ClusterClientRatelimit
// DisableRatelimit is used to disable rate limit
DisableRatelimit
)
func (rt RatelimitType) String() string {
switch rt {
case DisableRatelimit:
return filters.DisableRatelimitName
case ClientRatelimit:
return filters.ClientRatelimitName
case ClusterClientRatelimit:
return filters.ClusterClientRatelimitName
case ClusterServiceRatelimit:
return filters.ClusterRatelimitName
case LocalRatelimit:
return LocalRatelimitName
case ServiceRatelimit:
return filters.RatelimitName
default:
return filters.UnknownRatelimitName
}
}
// Lookuper makes it possible to be more flexible for ratelimiting.
type Lookuper interface {
// Lookup is used to get the string which is used to define
// how the bucket of a ratelimiter looks like, which is used
// to decide to ratelimit or not. For example you can use the
// X-Forwarded-For Header if you want to rate limit based on
// source ip behind a proxy/loadbalancer or the Authorization
// Header for request per token or user.
Lookup(*http.Request) string
}
// SameBucketLookuper implements Lookuper interface and will always
// match to the same bucket.
type SameBucketLookuper struct{}
// NewSameBucketLookuper returns a SameBucketLookuper.
func NewSameBucketLookuper() SameBucketLookuper {
return SameBucketLookuper{}
}
// Lookup will always return "s" to select the same bucket.
func (SameBucketLookuper) Lookup(*http.Request) string {
return sameBucket
}
func (SameBucketLookuper) String() string {
return "SameBucketLookuper"
}
// XForwardedForLookuper implements Lookuper interface and will
// select a bucket by X-Forwarded-For header or clientIP.
type XForwardedForLookuper struct{}
// NewXForwardedForLookuper returns an empty XForwardedForLookuper
func NewXForwardedForLookuper() XForwardedForLookuper {
return XForwardedForLookuper{}
}
// Lookup returns the content of the X-Forwarded-For header or the
// clientIP if not set.
func (XForwardedForLookuper) Lookup(req *http.Request) string {
return net.RemoteHost(req).String()
}
func (XForwardedForLookuper) String() string {
return "XForwardedForLookuper"
}
// HeaderLookuper implements Lookuper interface and will select a bucket
// by Authorization header.
type HeaderLookuper struct {
key string
}
// NewHeaderLookuper returns HeaderLookuper configured to lookup header named k
func NewHeaderLookuper(k string) HeaderLookuper {
return HeaderLookuper{key: k}
}
// Lookup returns the content of the Authorization header.
func (h HeaderLookuper) Lookup(req *http.Request) string {
return req.Header.Get(h.key)
}
func (h HeaderLookuper) String() string {
return "HeaderLookuper"
}
// Lookupers is a slice of Lookuper, required to get a hashable member
// in the TupleLookuper.
type Lookupers []Lookuper
// TupleLookuper implements Lookuper interface and will select a
// bucket that is defined by all combined Lookupers.
type TupleLookuper struct {
// pointer is required to be hashable from Registry lookup table
l *Lookupers
}
// NewTupleLookuper returns TupleLookuper configured to lookup the
// combined result of all given Lookuper
func NewTupleLookuper(args ...Lookuper) TupleLookuper {
var ls Lookupers = args
return TupleLookuper{l: &ls}
}
// Lookup returns the combined string of all Lookupers part of the
// tuple
func (t TupleLookuper) Lookup(req *http.Request) string {
if t.l == nil {
return ""
}
buf := bytes.Buffer{}
for _, l := range *(t.l) {
buf.WriteString(l.Lookup(req))
}
return buf.String()
}
func (t TupleLookuper) String() string {
return "TupleLookuper"
}
// RoundRobinLookuper matches one of n buckets selected by round robin algorithm
type RoundRobinLookuper struct {
// pointer is required to be hashable from Registry lookup table
c *uint64
// number of buckets, unchanged after creation
n uint64
}
// NewRoundRobinLookuper returns a RoundRobinLookuper.
func NewRoundRobinLookuper(n uint64) Lookuper {
return &RoundRobinLookuper{c: new(uint64), n: n}
}
// Lookup will return one of n distinct keys in round robin fashion
func (rrl *RoundRobinLookuper) Lookup(*http.Request) string {
next := atomic.AddUint64(rrl.c, 1) % rrl.n
return fmt.Sprintf("RoundRobin%d", next)
}
func (rrl *RoundRobinLookuper) String() string {
return "RoundRobinLookuper"
}
// Settings configures the chosen rate limiter
type Settings struct {
// FailClosed allows to to decide what happens on failures to
// query the ratelimit. For example redis is down, fail open
// or fail closed. FailClosed set to true will deny the
// request and set to true will allow the request. Default is
// to fail open.
FailClosed bool `yaml:"fail-closed"`
// Type of the chosen rate limiter
Type RatelimitType `yaml:"type"`
// Lookuper to decide which data to use to identify the same
// bucket (for example how to lookup the client identifier)
Lookuper Lookuper `yaml:"-"`
// MaxHits the maximum number of hits for a time duration
// allowed in the same bucket.
MaxHits int `yaml:"max-hits"`
// TimeWindow is the time duration that is valid for hits to
// be counted in the rate limit.
TimeWindow time.Duration `yaml:"time-window"`
// CleanInterval is the duration old data can expire, because
// need to cleanup data in for example client ratelimits.
CleanInterval time.Duration `yaml:"-"`
// Group is a string to group ratelimiters of Type
// ClusterServiceRatelimit or ClusterClientRatelimit.
// A ratelimit group considers all hits to the same group as
// one target.
Group string `yaml:"group"`
}
func (s Settings) Empty() bool {
return s == Settings{}
}
func (s Settings) String() string {
switch s.Type {
case DisableRatelimit:
return "disable"
case ServiceRatelimit:
return fmt.Sprintf("ratelimit(type=service,max-hits=%d,time-window=%s)", s.MaxHits, s.TimeWindow)
case LocalRatelimit:
fallthrough
case ClientRatelimit:
return fmt.Sprintf("ratelimit(type=client,max-hits=%d,time-window=%s)", s.MaxHits, s.TimeWindow)
case ClusterServiceRatelimit:
return fmt.Sprintf("ratelimit(type=clusterService,max-hits=%d,time-window=%s,group=%s)", s.MaxHits, s.TimeWindow, s.Group)
case ClusterClientRatelimit:
return fmt.Sprintf("ratelimit(type=clusterClient,max-hits=%d,time-window=%s,group=%s)", s.MaxHits, s.TimeWindow, s.Group)
default:
return "non"
}
}
// limiter defines the requirement to be used as a ratelimit implementation.
type limiter interface {
// Allow is used to get a decision if you should allow the
// call with context, to pass or to ratelimit
Allow(context.Context, string) bool
// Close is used to clean up underlying limiter
// implementations, if you want to stop a Ratelimiter
Close()
// Delta is used to get the duration until the next call is
// possible, negative durations allow immediate calls
Delta(string) time.Duration
// Oldest returns the oldest timestamp for string
Oldest(string) time.Time
// Resize is used to resize the buffer depending on the number
// of nodes available
Resize(string, int)
// RetryAfter is used to inform the client how many seconds it
// should wait before making a new request
RetryAfter(string) int
}
// Ratelimit is a proxy object that delegates to limiter
// implementations and stores settings for the ratelimiter
type Ratelimit struct {
settings Settings
impl limiter
}
// Allow is used to get a decision if you should allow the call
// with context, e.g. to support OpenTracing.
func (l *Ratelimit) Allow(ctx context.Context, s string) bool {
if l == nil {
return true
}
return l.impl.Allow(ctx, s)
}
// Close will stop any cleanup goroutines in underlying limiter implementation.
func (l *Ratelimit) Close() {
l.impl.Close()
}
// RetryAfter informs how many seconds to wait for the next request
func (l *Ratelimit) RetryAfter(s string) int {
if l == nil {
return 0
}
return l.impl.RetryAfter(s)
}
func (l *Ratelimit) Delta(s string) time.Duration {
return l.impl.Delta(s)
}
func (l *Ratelimit) Resize(s string, i int) {
l.impl.Resize(s, i)
}
type voidRatelimit struct{}
func (voidRatelimit) Allow(context.Context, string) bool { return true }
func (voidRatelimit) Close() {}
func (voidRatelimit) Oldest(string) time.Time { return time.Time{} }
func (voidRatelimit) RetryAfter(string) int { return 0 }
func (voidRatelimit) Delta(string) time.Duration { return -1 * time.Second }
func (voidRatelimit) Resize(string, int) {}
type zeroRatelimit struct{}
const (
// Delta() and RetryAfter() should return consistent values of type int64 and int respectively.
//
// News had just come over,
// We had five years left to cry in
zeroDelta time.Duration = 5 * 365 * 24 * time.Hour
zeroRetry int = int(zeroDelta / time.Second)
)
func (zeroRatelimit) Allow(context.Context, string) bool { return false }
func (zeroRatelimit) Close() {}
func (zeroRatelimit) Oldest(string) time.Time { return time.Time{} }
func (zeroRatelimit) RetryAfter(string) int { return zeroRetry }
func (zeroRatelimit) Delta(string) time.Duration { return zeroDelta }
func (zeroRatelimit) Resize(string, int) {}
func newRatelimit(s Settings, sw Swarmer, redisRing *net.RedisRingClient) *Ratelimit {
var impl limiter
if s.MaxHits == 0 {
impl = zeroRatelimit{}
} else {
switch s.Type {
case ServiceRatelimit:
impl = circularbuffer.NewRateLimiter(s.MaxHits, s.TimeWindow)
case LocalRatelimit:
log.Warning("LocalRatelimit is deprecated, please use ClientRatelimit instead")
fallthrough
case ClientRatelimit:
impl = circularbuffer.NewClientRateLimiter(s.MaxHits, s.TimeWindow, s.CleanInterval)
case ClusterServiceRatelimit:
s.CleanInterval = 0
fallthrough
case ClusterClientRatelimit:
impl = newClusterRateLimiter(s, sw, redisRing, s.Group)
default:
impl = voidRatelimit{}
}
}
return &Ratelimit{
settings: s,
impl: impl,
}
}
func Headers(maxHits int, timeWindow time.Duration, retryAfter int) http.Header {
limitPerHour := int64(maxHits) * int64(time.Hour) / int64(timeWindow)
return http.Header{
Header: []string{strconv.FormatInt(limitPerHour, 10)},
RetryAfterHeader: []string{strconv.Itoa(retryAfter)},
}
}
func getHashedKey(clearText string) string {
h := sha256.Sum256([]byte(clearText))
return hex.EncodeToString(h[:])
}
package ratelimit
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/metrics"
"github.com/zalando/skipper/net"
"golang.org/x/time/rate"
)
// clusterLimitRedis stores all data required for the cluster ratelimit.
type clusterLimitRedis struct {
failClosed bool
typ string
group string
maxHits int64
window time.Duration
ringClient *net.RedisRingClient
metrics metrics.Metrics
sometimes rate.Sometimes
}
const (
redisMetricsPrefix = "swarm.redis."
allowMetricsFormat = redisMetricsPrefix + "query.allow.%s"
retryAfterMetricsFormat = redisMetricsPrefix + "query.retryafter.%s"
allowMetricsFormatWithGroup = redisMetricsPrefix + "query.allow.%s.%s"
retryAfterMetricsFormatWithGroup = redisMetricsPrefix + "query.retryafter.%s.%s"
allowSpanName = "redis_allow"
oldestScoreSpanName = "redis_oldest_score"
)
// newClusterRateLimiterRedis creates a new clusterLimitRedis for given
// Settings. Group is used to identify the ratelimit instance, is used
// in log messages and has to be the same in all skipper instances.
func newClusterRateLimiterRedis(s Settings, r *net.RedisRingClient, group string) *clusterLimitRedis {
if r == nil {
return nil
}
rl := &clusterLimitRedis{
failClosed: s.FailClosed,
typ: s.Type.String(),
group: group,
maxHits: int64(s.MaxHits),
window: s.TimeWindow,
ringClient: r,
metrics: metrics.Default,
sometimes: rate.Sometimes{First: 3, Interval: 1 * time.Second},
}
return rl
}
func (c *clusterLimitRedis) prefixKey(clearText string) string {
return fmt.Sprintf(swarmKeyFormat, c.group, clearText)
}
func (c *clusterLimitRedis) measureQuery(format, groupFormat string, fail *bool, start time.Time) {
result := "success"
if fail != nil && *fail {
result = "failure"
}
var key string
if c.group == "" {
key = fmt.Sprintf(format, result)
} else {
key = fmt.Sprintf(groupFormat, result, c.group)
}
c.metrics.MeasureSince(key, start)
}
func parentSpan(ctx context.Context) opentracing.Span {
return opentracing.SpanFromContext(ctx)
}
func (c *clusterLimitRedis) commonTags() opentracing.Tags {
return opentracing.Tags{
string(ext.Component): "skipper",
string(ext.SpanKind): "client",
"ratelimit_type": c.typ,
"group": c.group,
"max_hits": c.maxHits,
"window": c.window.String(),
}
}
// Allow returns true if the request calculated across the cluster of
// skippers should be allowed else false. It will share it's own data
// and use the current cluster information to calculate global rates
// to decide to allow or not.
//
// Performance considerations:
//
// In case of deny it will use ZREMRANGEBYSCORE and ZCARD commands in
// one pipeline to remove old items in the list of hits.
// In case of allow it will additionally use ZADD with a second
// roundtrip.
//
// Uses provided context for creating an OpenTracing span.
func (c *clusterLimitRedis) Allow(ctx context.Context, clearText string) bool {
c.metrics.IncCounter(redisMetricsPrefix + "total")
now := time.Now()
var span opentracing.Span
if parentSpan := parentSpan(ctx); parentSpan != nil {
span = c.ringClient.StartSpan(allowSpanName, opentracing.ChildOf(parentSpan.Context()), c.commonTags())
defer span.Finish()
}
allow, err := c.allow(ctx, clearText)
failed := err != nil
if failed {
allow = !c.failClosed
msgFmt := "Failed to determine if operation is allowed: %v"
setError(span, fmt.Sprintf(msgFmt, err))
c.logError(msgFmt, err)
}
if span != nil {
span.SetTag("allowed", allow)
}
c.measureQuery(allowMetricsFormat, allowMetricsFormatWithGroup, &failed, now)
if allow {
c.metrics.IncCounter(redisMetricsPrefix + "allows")
} else {
c.metrics.IncCounter(redisMetricsPrefix + "forbids")
}
return allow
}
func (c *clusterLimitRedis) allow(ctx context.Context, clearText string) (bool, error) {
s := getHashedKey(clearText)
key := c.prefixKey(s)
now := time.Now()
nowNanos := now.UnixNano()
clearBefore := now.Add(-c.window).UnixNano()
// drop all elements of the set which occurred before one interval ago.
_, err := c.ringClient.ZRemRangeByScore(ctx, key, 0.0, float64(clearBefore))
if err != nil {
return false, err
}
// get cardinality
count, err := c.ringClient.ZCard(ctx, key)
if err != nil {
return false, err
}
// we increase later with ZAdd, so max-1
if count >= c.maxHits {
return false, nil
}
_, err = c.ringClient.ZAdd(ctx, key, nowNanos, float64(nowNanos))
if err != nil {
return false, err
}
_, err = c.ringClient.Expire(ctx, key, c.window+time.Second)
if err != nil {
return false, err
}
return true, nil
}
// Close can not decide to teardown redis ring, because it is not the
// owner of it.
func (c *clusterLimitRedis) Close() {}
func (c *clusterLimitRedis) deltaFrom(ctx context.Context, clearText string, from time.Time) (time.Duration, error) {
oldest, err := c.oldest(ctx, clearText)
if err != nil {
return 0, err
}
gap := from.Sub(oldest)
return c.window - gap, nil
}
// Delta returns the time.Duration until the next call is allowed,
// negative means immediate calls are allowed
func (c *clusterLimitRedis) Delta(clearText string) time.Duration {
now := time.Now()
d, err := c.deltaFrom(context.Background(), clearText, now)
if err != nil {
c.logError("Failed to get the duration until the next call is allowed: %v", err)
// Earlier, we returned duration since time=0 in these error cases. It is more graceful to the
// client applications to return 0.
return 0
}
return d
}
func setError(span opentracing.Span, msg string) {
if span != nil {
ext.Error.Set(span, true)
span.LogKV("log", msg)
}
}
func (c *clusterLimitRedis) logError(format string, err error) {
c.sometimes.Do(func() {
log.Errorf(format, err)
})
}
func (c *clusterLimitRedis) oldest(ctx context.Context, clearText string) (time.Time, error) {
s := getHashedKey(clearText)
key := c.prefixKey(s)
now := time.Now()
var span opentracing.Span
if parentSpan := parentSpan(ctx); parentSpan != nil {
span = c.ringClient.StartSpan(oldestScoreSpanName, opentracing.ChildOf(parentSpan.Context()), c.commonTags())
defer span.Finish()
}
res, err := c.ringClient.ZRangeByScoreWithScoresFirst(ctx, key, 0.0, float64(now.UnixNano()), 0, 1)
if err != nil {
setError(span, fmt.Sprintf("Failed to execute ZRangeByScoreWithScoresFirst: %v", err))
return time.Time{}, err
}
if res == nil {
return time.Time{}, nil
}
s, ok := res.(string)
if !ok {
msg := "failed to evaluate redis data"
setError(span, msg)
return time.Time{}, errors.New(msg)
}
oldest, err := strconv.ParseInt(s, 10, 64)
if err != nil {
setError(span, fmt.Sprintf("failed to convert value to int64: %v", err))
return time.Time{}, fmt.Errorf("failed to convert value to int64: %w", err)
}
return time.Unix(0, oldest), nil
}
// Oldest returns the oldest known request time.
//
// Performance considerations:
//
// It will use ZRANGEBYSCORE with offset 0 and count 1 to get the
// oldest item stored in redis.
func (c *clusterLimitRedis) Oldest(clearText string) time.Time {
t, err := c.oldest(context.Background(), clearText)
if err != nil {
c.logError("Failed to get the oldest known request time: %v", err)
return time.Time{}
}
return t
}
// Resize is noop to implement the limiter interface
func (*clusterLimitRedis) Resize(string, int) {}
// RetryAfterContext returns seconds until next call is allowed similar to
// Delta(), but returns at least one 1 in all cases. That is being
// done, because if not the ratelimit would be too few ratelimits,
// because of how it's used in the proxy and the nature of cluster
// ratelimits being not strongly consistent across calls to Allow()
// and RetryAfter() (or Allow and RetryAfterContext accordingly).
//
// Uses context for creating an OpenTracing span.
func (c *clusterLimitRedis) RetryAfterContext(ctx context.Context, clearText string) int {
// If less than 1s to wait -> so set to 1
const minWait = 1
now := time.Now()
var queryFailure bool
defer c.measureQuery(retryAfterMetricsFormat, retryAfterMetricsFormatWithGroup, &queryFailure, now)
retr, err := c.deltaFrom(ctx, clearText, now)
if err != nil {
c.logError("Failed to get the duration to wait until the next request: %v", err)
queryFailure = true
return minWait
}
res := int(retr / time.Second)
if res > 0 {
return res + 1
}
return minWait
}
// RetryAfter is like RetryAfterContext, but not using a context.
func (c *clusterLimitRedis) RetryAfter(clearText string) int {
return c.RetryAfterContext(context.Background(), clearText)
}
package ratelimit
import (
"context"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/net"
)
const (
DefaultMaxhits = 20
DefaultTimeWindow = 1 * time.Second
DefaultCleanInterval = 60 * time.Second
)
// Registry objects hold the active ratelimiters, ensure synchronized
// access to them, apply default settings and recycle the idle
// ratelimiters.
type Registry struct {
sync.Mutex
once sync.Once
global Settings
lookup map[Settings]*Ratelimit
swarm Swarmer
redisRing *net.RedisRingClient
}
// NewRegistry initializes a registry with the provided default settings.
func NewRegistry(settings ...Settings) *Registry {
return NewSwarmRegistry(nil, nil, settings...)
}
// NewSwarmRegistry initializes a registry with an optional swarm and
// the provided default settings. If swarm is nil, clusterRatelimits
// will be replaced by voidRatelimit, which is a noop limiter implementation.
func NewSwarmRegistry(swarm Swarmer, ro *net.RedisOptions, settings ...Settings) *Registry {
defaults := Settings{
Type: DisableRatelimit,
MaxHits: DefaultMaxhits,
TimeWindow: DefaultTimeWindow,
CleanInterval: DefaultCleanInterval,
}
if ro != nil && ro.MetricsPrefix == "" {
ro.MetricsPrefix = redisMetricsPrefix
}
r := &Registry{
once: sync.Once{},
global: defaults,
lookup: make(map[Settings]*Ratelimit),
swarm: swarm,
redisRing: net.NewRedisRingClient(ro),
}
if ro != nil {
r.redisRing.StartMetricsCollection()
}
if len(settings) > 0 {
r.global = settings[0]
}
return r
}
// Close teardown Registry and dependent resources
func (r *Registry) Close() {
r.once.Do(func() {
r.redisRing.Close()
for _, rl := range r.lookup {
rl.Close()
}
})
}
func (r *Registry) get(s Settings) *Ratelimit {
r.Lock()
defer r.Unlock()
rl, ok := r.lookup[s]
if !ok {
rl = newRatelimit(s, r.swarm, r.redisRing)
r.lookup[s] = rl
}
return rl
}
// Get returns a Ratelimit instance for provided Settings
func (r *Registry) Get(s Settings) *Ratelimit {
if s.Type == DisableRatelimit || s.Type == NoRatelimit {
return nil
}
return r.get(s)
}
// Check returns Settings used and the retry-after duration in case of
// request is ratelimitted. Otherwise return the Settings and 0. It is
// only used in the global ratelimit facility.
func (r *Registry) Check(req *http.Request) (Settings, int) {
if r == nil {
return Settings{}, 0
}
s := r.global
rlimit := r.Get(s)
switch s.Type {
case ClusterServiceRatelimit:
fallthrough
case ServiceRatelimit:
if rlimit.Allow(context.Background(), "") {
return s, 0
}
return s, rlimit.RetryAfter("")
case LocalRatelimit:
log.Warning("LocalRatelimit is deprecated, please use ClientRatelimit instead")
fallthrough
case ClusterClientRatelimit:
fallthrough
case ClientRatelimit:
ip := net.RemoteHost(req)
if !rlimit.Allow(context.Background(), ip.String()) {
return s, rlimit.RetryAfter(ip.String())
}
}
return Settings{}, 0
}
package ratelimit
import (
"context"
"math"
"time"
log "github.com/sirupsen/logrus"
circularbuffer "github.com/szuecs/rate-limit-buffer"
)
// Swarmer interface defines the requirement for a Swarm, for use as
// an exchange method for cluster ratelimits:
// ratelimit.ClusterServiceRatelimit and
// ratelimit.ClusterClientRatelimit.
type Swarmer interface {
// ShareValue is used to share the local information with its peers.
ShareValue(string, interface{}) error
// Values is used to get global information about current rates.
Values(string) map[string]interface{}
}
// clusterLimitSwim stores all data required for the cluster ratelimit.
type clusterLimitSwim struct {
group string
local limiter
maxHits int
window time.Duration
swarm Swarmer
resize chan resizeLimit
quit chan struct{}
}
type resizeLimit struct {
s string
n int
}
// newClusterRateLimiter creates a new clusterLimitSwim for given Settings
// and use the given Swarmer. Group is used in log messages to identify
// the ratelimit instance and has to be the same in all skipper instances.
func newClusterRateLimiterSwim(s Settings, sw Swarmer, group string) *clusterLimitSwim {
rl := &clusterLimitSwim{
group: group,
swarm: sw,
maxHits: s.MaxHits,
window: s.TimeWindow,
resize: make(chan resizeLimit),
quit: make(chan struct{}),
}
switch s.Type {
case ClusterServiceRatelimit:
log.Infof("new backend clusterRateLimiter")
rl.local = circularbuffer.NewRateLimiter(s.MaxHits, s.TimeWindow)
case ClusterClientRatelimit:
log.Infof("new client clusterRateLimiter")
rl.local = circularbuffer.NewClientRateLimiter(s.MaxHits, s.TimeWindow, s.CleanInterval)
default:
log.Errorf("Unknown ratelimit type: %s", s.Type)
return nil
}
// TODO(sszuecs): we might want to have one goroutine for all of these
go func() {
for {
select {
case size := <-rl.resize:
log.Debugf("%s resize clusterRatelimit: %v", group, size)
// TODO(sszuecs): call with "go" ?
rl.Resize(size.s, rl.maxHits/size.n)
case <-rl.quit:
log.Debugf("%s: quit clusterRatelimit", group)
close(rl.resize)
return
}
}
}()
return rl
}
// Allow returns true if the request with context calculated across the cluster of
// skippers should be allowed else false. It will share it's own data
// and use the current cluster information to calculate global rates
// to decide to allow or not.
func (c *clusterLimitSwim) Allow(ctx context.Context, clearText string) bool {
s := getHashedKey(clearText)
key := swarmPrefix + c.group + "." + s
// t0 is the oldest entry in the local circularbuffer
// [ t3, t4, t0, t1, t2]
// ^- current pointer to oldest
// now - t0
t0 := c.Oldest(s).UTC().UnixNano()
_ = c.local.Allow(ctx, s) // update local rate limit
if err := c.swarm.ShareValue(key, t0); err != nil {
log.Errorf("clusterRatelimit '%s' disabled, failed to share value: %v", c.group, err)
return true // unsafe to continue otherwise
}
swarmValues := c.swarm.Values(key)
log.Debugf("%s: clusterRatelimit swarmValues(%d) for '%s': %v", c.group, len(swarmValues), swarmPrefix+s, swarmValues)
c.resize <- resizeLimit{s: s, n: len(swarmValues)}
now := time.Now().UTC().UnixNano()
rate := c.calcTotalRequestRate(now, swarmValues)
result := rate < float64(c.maxHits)
log.Debugf("%s clusterRatelimit: Allow=%v, %v < %d", c.group, result, rate, c.maxHits)
return result
}
func (c *clusterLimitSwim) calcTotalRequestRate(now int64, swarmValues map[string]interface{}) float64 {
var requestRate float64
maxNodeHits := math.Max(1.0, float64(c.maxHits)/(float64(len(swarmValues))))
for _, v := range swarmValues {
t0, ok := v.(int64)
if !ok || t0 == 0 {
continue
}
delta := time.Duration(now - t0)
adjusted := float64(delta) / float64(c.window)
log.Debugf("%s: %0.2f += %0.2f / %0.2f", c.group, requestRate, maxNodeHits, adjusted)
requestRate += maxNodeHits / adjusted
}
log.Debugf("%s requestRate: %0.2f", c.group, requestRate)
return requestRate
}
// Close should be called to teardown the clusterLimitSwim.
func (c *clusterLimitSwim) Close() {
close(c.quit)
c.local.Close()
}
func (c *clusterLimitSwim) Delta(s string) time.Duration { return c.local.Delta(s) }
func (c *clusterLimitSwim) Oldest(s string) time.Time { return c.local.Oldest(s) }
func (c *clusterLimitSwim) Resize(s string, n int) { c.local.Resize(s, n) }
func (c *clusterLimitSwim) RetryAfter(s string) int { return c.local.RetryAfter(s) }
package rfc
import "strings"
// PatchHost returns a host string without trailing dot. For details
// see also the discussion in
// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JanMar/0430.html.
func PatchHost(host string) string {
host = strings.ReplaceAll(host, ".:", ":")
return strings.TrimSuffix(host, ".")
}
package rfc
const escapeLength = 3 // always, because we only handle the below cases
const (
semicolon = ';'
slash = '/'
questionMark = '?'
colon = ':'
at = '@'
and = '&'
eq = '='
plus = '+'
dollar = '$'
comma = ','
)
// https://tools.ietf.org/html/rfc3986#section-2.2
func unescape(seq []byte) (byte, bool) {
switch string(seq) {
case "%3B", "%3b":
return semicolon, true
case "%2F", "%2f":
return slash, true
case "%3F", "%3f":
return questionMark, true
case "%3A", "%3a":
return colon, true
case "%40":
return at, true
case "%26":
return and, true
case "%3D", "%3d":
return eq, true
case "%2B", "%2b":
return plus, true
case "%24":
return dollar, true
case "%2C", "%2c":
return comma, true
default:
return 0, false
}
}
// PatchPath attempts to patch a request path based on an interpretation of the standards
// RFC 2616 and RFC 3986 where the reserved characters should not be unescaped. Currently
// the Go stdlib does unescape these characters (v1.12.5).
//
// It expects the parsed path as found in http.Request.URL.Path and the raw path as found
// in http.Request.URL.RawPath. It returns a path where characters e.g. like '/' have the
// escaped form of %2F, if it was detected that they are unescaped in the raw path.
//
// It only returns the patched variant, if the only difference between the parsed and raw
// paths are the encoding of the chars, according to RFC 3986. If it detects any other
// difference between the two, it returns the original parsed path as provided. It
// tolerates an empty argument for the raw path, which can happen when the URL parsed via
// the stdlib url package, and there is no difference between the parsed and the raw path.
// This basically means that the following code is correct:
//
// req.URL.Path = rfc.PatchPath(req.URL.Path, req.URL.RawPath)
//
// Links:
// - https://tools.ietf.org/html/rfc2616#section-3.2.3 and
// - https://tools.ietf.org/html/rfc3986#section-2.2
func PatchPath(parsed, raw string) string {
p, r := []byte(parsed), []byte(raw)
patched := make([]byte, 0, len(r))
var (
escape bool
seq []byte
unescaped byte
handled bool
doPatch bool
modified bool
pi int
)
for i := 0; i < len(r); i++ {
c := r[i]
escape = c == '%'
modified = pi >= len(p) || !escape && p[pi] != c
if modified {
doPatch = false
break
}
if !escape {
patched = append(patched, p[pi])
pi++
continue
}
if len(r) < i+escapeLength {
doPatch = false
break
}
seq = r[i : i+escapeLength]
i += escapeLength - 1
unescaped, handled = unescape(seq)
if !handled {
patched = append(patched, p[pi])
pi++
continue
}
modified = p[pi] != unescaped
if modified {
doPatch = false
break
}
patched = append(patched, seq...)
doPatch = true
pi++
}
if !doPatch {
return parsed
}
modified = pi < len(p)
if modified {
return parsed
}
return string(patched)
}
package routing
import (
"context"
"sync"
)
type contextKey struct{}
var routingContextKey contextKey
// NewContext returns a new context with associated routing context.
// It does nothing and returns ctx if it already has associated routing context.
func NewContext(ctx context.Context) context.Context {
if _, ok := ctx.Value(routingContextKey).(*sync.Map); ok {
return ctx
}
return context.WithValue(ctx, routingContextKey, &sync.Map{})
}
// FromContext returns value from the routing context stored in ctx.
// It returns value associated with the key or stores result of the defaultValue call.
// defaultValue may be called multiple times but only one result will be used as a default value.
func FromContext[K comparable, V any](ctx context.Context, key K, defaultValue func() V) V {
m, _ := ctx.Value(routingContextKey).(*sync.Map)
// https://github.com/golang/go/issues/44159#issuecomment-780774977
val, ok := m.Load(key)
if !ok {
val = defaultValue()
val, _ = m.LoadOrStore(key, val)
}
return val.(V)
}
package routing
import (
"errors"
"fmt"
"sort"
"sync"
"time"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/logging"
"github.com/zalando/skipper/net"
"github.com/zalando/skipper/predicates"
)
type incomingType uint
const (
incomingReset incomingType = iota
incomingUpdate
)
var errInvalidWeightParams = errors.New("invalid argument for the Weight predicate")
type invalidDefinitionError string
func (e invalidDefinitionError) Error() string { return string(e) }
func (e invalidDefinitionError) Code() string { return string(e) }
var (
errUnknownFilter = invalidDefinitionError("unknown_filter")
errInvalidFilterParams = invalidDefinitionError("invalid_filter_params")
errUnknownPredicate = invalidDefinitionError("unknown_predicate")
errInvalidPredicateParams = invalidDefinitionError("invalid_predicate_params")
errFailedBackendSplit = invalidDefinitionError("failed_backend_split")
errInvalidMatcher = invalidDefinitionError("invalid_matcher")
)
func (it incomingType) String() string {
switch it {
case incomingReset:
return "reset"
case incomingUpdate:
return "update"
default:
return "unknown"
}
}
type routeDefs map[string]*eskip.Route
type incomingData struct {
typ incomingType
client DataClient
upsertedRoutes []*eskip.Route
deletedIds []string
}
func (d *incomingData) log(l logging.Logger, suppress bool) {
if suppress {
l.Infof("route settings, %v, upsert count: %v", d.typ, len(d.upsertedRoutes))
l.Infof("route settings, %v, delete count: %v", d.typ, len(d.deletedIds))
return
}
for _, r := range d.upsertedRoutes {
l.Infof("route settings, %v, route: %v: %v", d.typ, r.Id, r)
}
for _, id := range d.deletedIds {
l.Infof("route settings, %v, deleted id: %v", d.typ, id)
}
}
// continuously receives route definitions from a data client on the the output channel.
// The function does not return unless quit is closed. When started, it request for the
// whole current set of routes, and continues polling for the subsequent updates. When a
// communication error occurs, it re-requests the whole valid set, and continues polling.
// Currently, the routes with the same id coming from different sources are merged in an
// undeterministic way, but this may change in the future.
func receiveFromClient(c DataClient, o Options, out chan<- *incomingData, quit <-chan struct{}) {
initial := true
var ticker *time.Ticker
if o.PollTimeout != 0 {
ticker = time.NewTicker(o.PollTimeout)
} else {
ticker = time.NewTicker(time.Millisecond)
}
defer ticker.Stop()
for {
var (
routes []*eskip.Route
deletedIDs []string
err error
)
if initial {
routes, err = c.LoadAll()
} else {
routes, deletedIDs, err = c.LoadUpdate()
}
switch {
case err != nil && initial:
o.Log.Error("error while receiving initial data;", err)
case err != nil:
o.Log.Error("error while receiving update;", err)
initial = true
continue
case initial || len(routes) > 0 || len(deletedIDs) > 0:
var incoming *incomingData
if initial {
incoming = &incomingData{incomingReset, c, routes, nil}
} else {
incoming = &incomingData{incomingUpdate, c, routes, deletedIDs}
}
initial = false
select {
case out <- incoming:
case <-quit:
return
}
}
select {
case <-ticker.C:
case <-quit:
return
}
}
}
// applies incoming route definitions to key/route map, where
// the keys are the route ids.
func applyIncoming(defs routeDefs, d *incomingData) routeDefs {
if d.typ == incomingReset || defs == nil {
defs = make(routeDefs)
}
if d.typ == incomingUpdate {
for _, id := range d.deletedIds {
delete(defs, id)
}
}
if d.typ == incomingReset || d.typ == incomingUpdate {
for _, def := range d.upsertedRoutes {
defs[def.Id] = def
}
}
return defs
}
type mergedDefs struct {
routes []*eskip.Route
clients map[DataClient]struct{}
}
// merges the route definitions from multiple data clients by route id
func mergeDefs(defsByClient map[DataClient]routeDefs) mergedDefs {
clients := make(map[DataClient]struct{}, len(defsByClient))
mergeByID := make(routeDefs)
for c, defs := range defsByClient {
clients[c] = struct{}{}
for id, def := range defs {
mergeByID[id] = def
}
}
all := make([]*eskip.Route, 0, len(mergeByID))
for _, def := range mergeByID {
all = append(all, def)
}
return mergedDefs{routes: all, clients: clients}
}
// receives the initial set of the route definitiosn and their
// updates from multiple data clients, merges them by route id
// and sends the merged route definitions to the output channel.
//
// The active set of routes from last successful update are used until the
// next successful update.
func receiveRouteDefs(o Options, quit <-chan struct{}) <-chan mergedDefs {
in := make(chan *incomingData)
out := make(chan mergedDefs)
defsByClient := make(map[DataClient]routeDefs)
for _, c := range o.DataClients {
go receiveFromClient(c, o, in, quit)
}
go func() {
for {
var incoming *incomingData
select {
case incoming = <-in:
case <-quit:
return
}
incoming.log(o.Log, o.SuppressLogs)
c := incoming.client
defsByClient[c] = applyIncoming(defsByClient[c], incoming)
select {
case out <- mergeDefs(defsByClient):
case <-quit:
return
}
}
}()
return out
}
// splits the backend address of a route definition into separate
// scheme and host variables.
func splitBackend(r *eskip.Route) (string, string, error) {
if r.Shunt || r.BackendType == eskip.ShuntBackend || r.BackendType == eskip.LoopBackend ||
r.BackendType == eskip.DynamicBackend || r.BackendType == eskip.LBBackend {
return "", "", nil
}
return net.SchemeHost(r.Backend)
}
// creates a filter instance based on its definition and its
// specification in the filter registry.
func createFilter(o *Options, def *eskip.Filter, cpm map[string]PredicateSpec) (filters.Filter, error) {
spec, ok := o.FilterRegistry[def.Name]
if !ok {
if isTreePredicate(def.Name) || def.Name == predicates.HostName || def.Name == predicates.PathRegexpName || def.Name == predicates.MethodName || def.Name == predicates.HeaderName || def.Name == predicates.HeaderRegexpName {
return nil, fmt.Errorf("%w: trying to use %q as filter, but it is only available as predicate", errUnknownFilter, def.Name)
}
if _, ok := cpm[def.Name]; ok {
return nil, fmt.Errorf("%w: trying to use %q as filter, but it is only available as predicate", errUnknownFilter, def.Name)
}
return nil, fmt.Errorf("%w: filter %q not found", errUnknownFilter, def.Name)
}
start := time.Now()
f, err := spec.CreateFilter(def.Args)
if o.Metrics != nil {
o.Metrics.MeasureFilterCreate(def.Name, start)
}
if err != nil {
return nil, fmt.Errorf("%w: failed to create filter %q: %w", errInvalidFilterParams, spec.Name(), err)
}
return f, nil
}
// creates filter instances based on their definition
// and the filter registry.
func createFilters(o *Options, defs []*eskip.Filter, cpm map[string]PredicateSpec) ([]*RouteFilter, error) {
fs := make([]*RouteFilter, 0, len(defs))
for i, def := range defs {
f, err := createFilter(o, def, cpm)
if err != nil {
return nil, err
}
fs = append(fs, &RouteFilter{f, def.Name, i})
}
return fs, nil
}
// check if a predicate is a distinguished, path tree predicate
func isTreePredicate(name string) bool {
switch name {
case predicates.PathSubtreeName:
return true
case predicates.PathName:
return true
default:
return false
}
}
func getFreeStringArgs(count int, p *eskip.Predicate) ([]string, error) {
if len(p.Args) != count {
return nil, fmt.Errorf(
"invalid length of predicate args in %s, %d instead of %d",
p.Name,
len(p.Args),
count,
)
}
a := make([]string, 0, len(p.Args))
for i := range p.Args {
s, ok := p.Args[i].(string)
if !ok {
return nil, fmt.Errorf("expected argument of type string, %s", p.Name)
}
a = append(a, s)
}
return a, nil
}
func mergeLegacyNonTreePredicates(r *eskip.Route) (*eskip.Route, error) {
var rest []*eskip.Predicate
c := r.Copy()
for _, p := range c.Predicates {
if isTreePredicate(p.Name) {
rest = append(rest, p)
continue
}
switch p.Name {
case predicates.HostName:
a, err := getFreeStringArgs(1, p)
if err != nil {
return nil, err
}
c.HostRegexps = append(c.HostRegexps, a[0])
case predicates.PathRegexpName:
a, err := getFreeStringArgs(1, p)
if err != nil {
return nil, err
}
c.PathRegexps = append(c.PathRegexps, a[0])
case predicates.MethodName:
a, err := getFreeStringArgs(1, p)
if err != nil {
return nil, err
}
c.Method = a[0]
case predicates.HeaderName:
a, err := getFreeStringArgs(2, p)
if err != nil {
return nil, err
}
if c.Headers == nil {
c.Headers = make(map[string]string)
}
c.Headers[a[0]] = a[1]
case predicates.HeaderRegexpName:
a, err := getFreeStringArgs(2, p)
if err != nil {
return nil, err
}
if c.HeaderRegexps == nil {
c.HeaderRegexps = make(map[string][]string)
}
c.HeaderRegexps[a[0]] = append(c.HeaderRegexps[a[0]], a[1])
default:
rest = append(rest, p)
}
}
c.Predicates = rest
return c, nil
}
func parseWeightPredicateArgs(args []interface{}) (int, error) {
if len(args) != 1 {
return 0, errInvalidWeightParams
}
if weight, ok := args[0].(float64); ok {
return int(weight), nil
}
if weight, ok := args[0].(int); ok {
return weight, nil
}
return 0, errInvalidWeightParams
}
// initialize predicate instances from their spec with the concrete arguments
func processPredicates(o *Options, cpm map[string]PredicateSpec, defs []*eskip.Predicate) ([]Predicate, int, error) {
cps := make([]Predicate, 0, len(defs))
var weight int
for _, def := range defs {
if def.Name == predicates.WeightName {
var w int
var err error
if w, err = parseWeightPredicateArgs(def.Args); err != nil {
return nil, 0, fmt.Errorf("%w: %w", errInvalidPredicateParams, err)
}
weight += w
continue
}
if isTreePredicate(def.Name) {
continue
}
spec, ok := cpm[def.Name]
if !ok {
return nil, 0, fmt.Errorf("%w: predicate %q not found", errUnknownPredicate, def.Name)
}
cp, err := spec.Create(def.Args)
if err != nil {
return nil, 0, fmt.Errorf("%w: failed to create predicate %q: %w", errInvalidPredicateParams, spec.Name(), err)
}
if ws, ok := spec.(WeightedPredicateSpec); ok {
weight += ws.Weight()
}
cps = append(cps, cp)
}
return cps, weight, nil
}
// returns the subtree path if it is a valid definition
func processPathOrSubTree(p *eskip.Predicate) (string, error) {
if len(p.Args) != 1 {
return "", predicates.ErrInvalidPredicateParameters
}
if s, ok := p.Args[0].(string); ok {
return s, nil
}
return "", predicates.ErrInvalidPredicateParameters
}
func validTreePredicates(predicateList []*eskip.Predicate) bool {
var has bool
for _, p := range predicateList {
switch p.Name {
case predicates.PathName, predicates.PathSubtreeName:
if has {
return false
}
has = true
}
}
return true
}
// processes path tree relevant predicates
func processTreePredicates(r *Route, predicateList []*eskip.Predicate) error {
// backwards compatibility
if r.Path != "" {
predicateList = append(predicateList, &eskip.Predicate{Name: predicates.PathName, Args: []interface{}{r.Path}})
}
if !validTreePredicates(predicateList) {
return fmt.Errorf("multiple tree predicates (Path, PathSubtree) in the route: %s", r.Id)
}
for _, p := range predicateList {
switch p.Name {
case predicates.PathName:
path, err := processPathOrSubTree(p)
if err != nil {
return err
}
r.path = path
case predicates.PathSubtreeName:
pst, err := processPathOrSubTree(p)
if err != nil {
return err
}
r.pathSubtree = pst
}
}
return nil
}
// processes a route definition for the routing table
func processRouteDef(o *Options, cpm map[string]PredicateSpec, def *eskip.Route) (*Route, error) {
scheme, host, err := splitBackend(def)
if err != nil {
return nil, fmt.Errorf("%w: %w", errFailedBackendSplit, err)
}
fs, err := createFilters(o, def.Filters, cpm)
if err != nil {
return nil, err
}
def, err = mergeLegacyNonTreePredicates(def)
if err != nil {
return nil, err
}
cps, weight, err := processPredicates(o, cpm, def.Predicates)
if err != nil {
return nil, err
}
r := &Route{Route: *def, Scheme: scheme, Host: host, Predicates: cps, Filters: fs, weight: weight}
if err := processTreePredicates(r, def.Predicates); err != nil {
return nil, err
}
return r, nil
}
// convert a slice of predicate specs to a map keyed by their names
func mapPredicates(cps []PredicateSpec) map[string]PredicateSpec {
cpm := make(map[string]PredicateSpec)
for _, cp := range cps {
cpm[cp.Name()] = cp
}
return cpm
}
// processes a set of route definitions for the routing table
func processRouteDefs(o *Options, defs []*eskip.Route) (routes []*Route, invalidDefs []*eskip.Route) {
cpm := mapPredicates(o.Predicates)
reasonCounts := make(map[string]int)
for _, def := range defs {
route, err := processRouteDef(o, cpm, def)
if err == nil {
routes = append(routes, route)
} else {
invalidDefs = append(invalidDefs, def)
o.Log.Errorf("failed to process route %s: %v", def.Id, err)
var defErr invalidDefinitionError
reason := "other"
if errors.As(err, &defErr) {
reason = defErr.Code()
}
reasonCounts[reason]++
}
}
if o.Metrics != nil {
o.Metrics.UpdateInvalidRoute(reasonCounts)
}
return
}
type routeTable struct {
id int
m *matcher
once sync.Once
routes []*Route // only used for closing
validRoutes []*eskip.Route
invalidRoutes []*eskip.Route
clients map[DataClient]struct{}
created time.Time
}
// close routeTable will cleanup all underlying resources, that could
// leak goroutines.
func (rt *routeTable) close() {
rt.once.Do(func() {
for _, route := range rt.routes {
for _, f := range route.Filters {
if fc, ok := f.Filter.(filters.FilterCloser); ok {
fc.Close()
}
}
}
})
}
// receives the next version of the routing table on the output channel,
// when an update is received on one of the data clients.
func receiveRouteMatcher(o Options, out chan<- *routeTable, quit <-chan struct{}) {
updates := receiveRouteDefs(o, quit)
var (
rt *routeTable
outRelay chan<- *routeTable
updatesRelay <-chan mergedDefs
updateId int
)
updatesRelay = updates
for {
select {
case mdefs := <-updatesRelay:
updateId++
start := time.Now()
o.Log.Infof("route settings received, id: %d", updateId)
defs := mdefs.routes
for i := range o.PreProcessors {
defs = o.PreProcessors[i].Do(defs)
}
routes, invalidRoutes := processRouteDefs(&o, defs)
for i := range o.PostProcessors {
routes = o.PostProcessors[i].Do(routes)
}
m, errs := newMatcher(routes, o.MatchingOptions)
invalidRouteIds := make(map[string]struct{})
validRoutes := []*eskip.Route{}
for _, err := range errs {
o.Log.Error(err)
invalidRouteIds[err.ID] = struct{}{}
}
if o.Metrics != nil {
o.Metrics.UpdateInvalidRoute(map[string]int{errInvalidMatcher.Code(): len(errs)})
}
for i := range routes {
r := routes[i]
if _, found := invalidRouteIds[r.Id]; found {
invalidRoutes = append(invalidRoutes, &r.Route)
} else {
validRoutes = append(validRoutes, &r.Route)
}
}
sort.SliceStable(validRoutes, func(i, j int) bool {
return validRoutes[i].Id < validRoutes[j].Id
})
rt = &routeTable{
id: updateId,
m: m,
routes: routes,
validRoutes: validRoutes,
invalidRoutes: invalidRoutes,
clients: mdefs.clients,
created: start,
}
updatesRelay = nil
outRelay = out
case outRelay <- rt:
rt = nil
updatesRelay = updates
outRelay = nil
case <-quit:
return
}
}
}
package routing
import (
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
)
const defaultLastSeenTimeout = 1 * time.Minute
// Metrics describe the data about endpoint that could be
// used to perform better load balancing, fadeIn, etc.
type Metrics interface {
DetectedTime() time.Time
SetDetected(detected time.Time)
LastSeen() time.Time
SetLastSeen(lastSeen time.Time)
InflightRequests() int64
IncInflightRequest()
DecInflightRequest()
IncRequests(o IncRequestsOptions)
HealthCheckDropProbability() float64
}
type IncRequestsOptions struct {
FailedRoundTrip bool
}
type entry struct {
detected atomic.Value // time.Time
lastSeen atomic.Value // time.Time
inflightRequests atomic.Int64
totalRequests [2]atomic.Int64
totalFailedRoundTrips [2]atomic.Int64
curSlot atomic.Int64
healthCheckDropProbability atomic.Value // float64
}
var _ Metrics = &entry{}
func (e *entry) DetectedTime() time.Time {
return e.detected.Load().(time.Time)
}
func (e *entry) LastSeen() time.Time {
return e.lastSeen.Load().(time.Time)
}
func (e *entry) InflightRequests() int64 {
return e.inflightRequests.Load()
}
func (e *entry) IncInflightRequest() {
e.inflightRequests.Add(1)
}
func (e *entry) DecInflightRequest() {
e.inflightRequests.Add(-1)
}
func (e *entry) SetDetected(detected time.Time) {
e.detected.Store(detected)
}
func (e *entry) SetLastSeen(ts time.Time) {
e.lastSeen.Store(ts)
}
func (e *entry) IncRequests(o IncRequestsOptions) {
curSlot := e.curSlot.Load()
e.totalRequests[curSlot].Add(1)
if o.FailedRoundTrip {
e.totalFailedRoundTrips[curSlot].Add(1)
}
}
func (e *entry) HealthCheckDropProbability() float64 {
return e.healthCheckDropProbability.Load().(float64)
}
func newEntry() *entry {
result := &entry{}
result.healthCheckDropProbability.Store(0.0)
result.SetDetected(time.Time{})
result.SetLastSeen(time.Time{})
return result
}
type EndpointRegistry struct {
lastSeenTimeout time.Duration
statsResetPeriod time.Duration
minRequests int64
minHealthCheckDropProbability float64
maxHealthCheckDropProbability float64
quit chan struct{}
now func() time.Time
data sync.Map // map[string]*entry
}
var _ PostProcessor = &EndpointRegistry{}
type RegistryOptions struct {
LastSeenTimeout time.Duration
PassiveHealthCheckEnabled bool
StatsResetPeriod time.Duration
MinRequests int64
MinHealthCheckDropProbability float64
MaxHealthCheckDropProbability float64
}
func (r *EndpointRegistry) Do(routes []*Route) []*Route {
now := r.now()
for _, route := range routes {
if route.BackendType == eskip.LBBackend {
for i := range route.LBEndpoints {
epi := &route.LBEndpoints[i]
epi.Metrics = r.GetMetrics(epi.Host)
if epi.Metrics.DetectedTime().IsZero() {
epi.Metrics.SetDetected(now)
}
epi.Metrics.SetLastSeen(now)
}
} else if route.BackendType == eskip.NetworkBackend {
entry := r.GetMetrics(route.Host)
if entry.DetectedTime().IsZero() {
entry.SetDetected(now)
}
entry.SetLastSeen(now)
}
}
removeOlder := now.Add(-r.lastSeenTimeout)
r.data.Range(func(key, value any) bool {
e := value.(*entry)
if e.LastSeen().Before(removeOlder) {
r.data.Delete(key)
}
return true
})
return routes
}
func (r *EndpointRegistry) updateStats() {
ticker := time.NewTicker(r.statsResetPeriod)
for {
r.data.Range(func(key, value any) bool {
e := value.(*entry)
curSlot := e.curSlot.Load()
nextSlot := (curSlot + 1) % 2
e.totalFailedRoundTrips[nextSlot].Store(0)
e.totalRequests[nextSlot].Store(0)
newDropProbability := 0.0
failed := e.totalFailedRoundTrips[curSlot].Load()
requests := e.totalRequests[curSlot].Load()
if requests > r.minRequests {
failedRoundTripsRatio := float64(failed) / float64(requests)
if failedRoundTripsRatio > r.minHealthCheckDropProbability {
log.Infof("Passive health check: marking %q as unhealthy due to failed round trips ratio: %0.2f", key, failedRoundTripsRatio)
newDropProbability = min(failedRoundTripsRatio, r.maxHealthCheckDropProbability)
}
}
e.healthCheckDropProbability.Store(newDropProbability)
e.curSlot.Store(nextSlot)
return true
})
select {
case <-r.quit:
return
case <-ticker.C:
}
}
}
func (r *EndpointRegistry) Close() {
close(r.quit)
}
func NewEndpointRegistry(o RegistryOptions) *EndpointRegistry {
if o.LastSeenTimeout == 0 {
o.LastSeenTimeout = defaultLastSeenTimeout
}
registry := &EndpointRegistry{
lastSeenTimeout: o.LastSeenTimeout,
statsResetPeriod: o.StatsResetPeriod,
minRequests: o.MinRequests,
minHealthCheckDropProbability: o.MinHealthCheckDropProbability,
maxHealthCheckDropProbability: o.MaxHealthCheckDropProbability,
quit: make(chan struct{}),
now: time.Now,
data: sync.Map{},
}
if o.PassiveHealthCheckEnabled {
go registry.updateStats()
}
return registry
}
func (r *EndpointRegistry) GetMetrics(hostPort string) Metrics {
// https://github.com/golang/go/issues/44159#issuecomment-780774977
e, ok := r.data.Load(hostPort)
if !ok {
e, _ = r.data.LoadOrStore(hostPort, newEntry())
}
return e.(*entry)
}
func (r *EndpointRegistry) allMetrics() map[string]Metrics {
result := make(map[string]Metrics)
r.data.Range(func(k, v any) bool {
result[k.(string)] = v.(Metrics)
return true
})
return result
}
package routing
import (
"errors"
"fmt"
"net/http"
"regexp"
"sort"
"strings"
"github.com/dimfeld/httppath"
"github.com/zalando/skipper/pathmux"
)
type leafRequestMatcher struct {
r *http.Request
path string
exactPath string
}
func (m *leafRequestMatcher) Match(value interface{}) (bool, interface{}) {
v, ok := value.(*pathMatcher)
if !ok {
return false, nil
}
l := matchLeaves(v.leaves, m.r, m.path, m.exactPath)
return l != nil, l
}
type leafMatcher struct {
wildcardParamNames []string // in reverse order
hasFreeWildcardParam bool
exactPath string
method string
weight int
hostRxs []*regexp.Regexp
pathRxs []*regexp.Regexp
headersExact map[string]string
headersRegexp map[string][]*regexp.Regexp
predicates []Predicate
route *Route
}
type leafMatchers []*leafMatcher
func leafWeight(l *leafMatcher) int {
w := l.weight
if l.method != "" {
w++
}
w += len(l.hostRxs)
w += len(l.pathRxs)
w += len(l.headersExact)
w += len(l.headersRegexp)
w += len(l.predicates)
return w
}
// Sorting of leaf matchers:
func (ls leafMatchers) Len() int { return len(ls) }
func (ls leafMatchers) Swap(i, j int) { ls[i], ls[j] = ls[j], ls[i] }
func (ls leafMatchers) Less(i, j int) bool { return leafWeight(ls[i]) > leafWeight(ls[j]) }
type pathMatcher struct {
leaves leafMatchers
}
// root structure representing the routing tree.
type matcher struct {
paths *pathmux.Tree
rootLeaves leafMatchers
matchingOptions MatchingOptions
}
// An error created if a route definition cannot be processed.
type definitionError struct {
ID string
Index int
Original error
}
func (err *definitionError) Error() string {
if err.Index < 0 {
return err.Original.Error()
}
return fmt.Sprintf("%s [%d]: %v", err.ID, err.Index, err.Original)
}
// rx identifying the 'free form' wildcards at the end of the paths
var freeWildcardRx = regexp.MustCompile("/[*][^/]*$")
// compiles all rxs or fails
func getCompiledRxs(compiled map[string]*regexp.Regexp, exps []string) ([]*regexp.Regexp, error) {
rxs := make([]*regexp.Regexp, 0, len(exps))
for _, exp := range exps {
if rx, ok := compiled[exp]; ok {
rxs = append(rxs, rx)
continue
}
rx, err := regexp.Compile(exp)
if err != nil {
return nil, err
}
compiled[exp] = rx
rxs = append(rxs, rx)
}
return rxs, nil
}
// canonicalizes the keys of the header conditions
func canonicalizeHeaders(h map[string]string) map[string]string {
ch := make(map[string]string)
for k, v := range h {
ch[http.CanonicalHeaderKey(k)] = v
}
return ch
}
// canonicalizes the keys of the header regexp conditions
func canonicalizeHeaderRegexps(hrx map[string][]*regexp.Regexp) map[string][]*regexp.Regexp {
chrx := make(map[string][]*regexp.Regexp)
for k, v := range hrx {
chrx[http.CanonicalHeaderKey(k)] = v
}
return chrx
}
// extracts the expected wildcard param names and returns them in reverse order
func extractWildcardParamNames(r *Route) []string {
path := r.path
if path == "" {
path = r.pathSubtree
}
path = httppath.Clean(path)
var wildcards []string
pathTokens := strings.Split(path, "/")
for _, token := range pathTokens {
if len(token) > 1 && (token[0] == ':' || token[0] == '*') {
//prepend
wildcards = append([]string{token[1:]}, wildcards...)
}
}
if strings.HasSuffix(path, "/*") ||
r.path == "" && r.pathSubtree != "" && !freeWildcardRx.MatchString(path) {
wildcards = append([]string{"*"}, wildcards...)
}
return wildcards
}
func hasFreeWildcardParam(r *Route) bool {
return r.pathSubtree != "" || freeWildcardRx.MatchString(httppath.Clean(r.path))
}
// returns a cleaned path where all wildcard names have been replaced with *
func normalizePath(r *Route) (string, error) {
path := r.path
if path == "" {
path = r.pathSubtree
}
path = httppath.Clean(path)
var sb strings.Builder
for i := 0; i < len(path); i++ {
c := path[i]
if c == '/' {
sb.WriteByte(path[i])
} else {
nextSlash := strings.IndexByte(path[i:], '/')
nextSlashExists := true
if nextSlash == -1 {
nextSlash = len(path)
nextSlashExists = false
} else {
nextSlash += i
}
if c == ':' || c == '*' {
if nextSlashExists && c == '*' {
return "", errors.New("free wildcard param should be last")
} else {
sb.WriteByte(c)
}
sb.WriteByte('*')
} else {
sb.WriteString(path[i:nextSlash])
}
i = nextSlash - 1
}
}
return sb.String(), nil
}
// creates a new leaf matcher. preprocesses the
// Host, PathRegexp, Header and HeaderRegexp
// conditions.
//
// Using a set of regular expressions shared in
// the current generation to preserve the
// compiled instances.
func newLeaf(r *Route, rxs map[string]*regexp.Regexp) (*leafMatcher, error) {
hostRxs, err := getCompiledRxs(rxs, r.HostRegexps)
if err != nil {
return nil, err
}
pathRxs, err := getCompiledRxs(rxs, r.PathRegexps)
if err != nil {
return nil, err
}
headerExps := r.HeaderRegexps
allHeaderRxs := make(map[string][]*regexp.Regexp)
for k, exps := range headerExps {
headerRxs, err := getCompiledRxs(rxs, exps)
if err != nil {
return nil, err
}
allHeaderRxs[k] = headerRxs
}
return &leafMatcher{
wildcardParamNames: extractWildcardParamNames(r),
hasFreeWildcardParam: hasFreeWildcardParam(r),
weight: r.weight,
method: r.Method,
hostRxs: hostRxs,
pathRxs: pathRxs,
headersExact: canonicalizeHeaders(r.Headers),
headersRegexp: canonicalizeHeaderRegexps(allHeaderRxs),
predicates: r.Predicates,
route: r}, nil
}
func trimTrailingSlash(path string) string {
if len(path) > 1 && path[len(path)-1] == '/' {
return path[:len(path)-1]
}
return path
}
// add each path matcher to the path tree. If a matcher is a subtree, add it with the
// additional paths.
func addTreeMatchers(pathTree *pathmux.Tree, matchers map[string]*pathMatcher) []*definitionError {
var errors []*definitionError
for p, m := range matchers {
// sort leaves during construction time, based on their priority
sort.Stable(m.leaves)
if err := pathTree.Add(p, m); err != nil {
errors = append(errors, &definitionError{Index: -1, Original: err})
}
}
return errors
}
func addLeafToPath(pms map[string]*pathMatcher, path string, l *leafMatcher) {
pm, ok := pms[path]
if !ok {
pm = &pathMatcher{}
pms[path] = pm
}
pm.leaves = append(pm.leaves, l)
}
func addSubtreeLeafsToPath(pms map[string]*pathMatcher, path string, l *leafMatcher, o MatchingOptions) {
basePath := freeWildcardRx.ReplaceAllLiteralString(path, "")
basePath = strings.TrimSuffix(basePath, "/")
if basePath == "" {
addLeafToPath(pms, "/", l)
addLeafToPath(pms, "/**", l)
} else {
addLeafToPath(pms, basePath, l)
addLeafToPath(pms, basePath+"/**", l)
if !o.ignoreTrailingSlash() {
addLeafToPath(pms, basePath+"/", l)
}
}
}
// constructs a matcher based on the provided definitions.
//
// If `ignoreTrailingSlash` is true, the matcher handles
// paths with or without a trailing slash equally.
//
// It constructs the route definition into a trie structure
// based on their path condition, if any, and puts the routes
// with the same path condition into a leaf matcher structure
// where they get evaluated after the leaf was matched based
// on the rest of the conditions so that most strict route
// definition matches first.
func newMatcher(rs []*Route, o MatchingOptions) (*matcher, []*definitionError) {
var (
errors []*definitionError
rootLeaves leafMatchers
)
pathMatchers := make(map[string]*pathMatcher)
compiledRxs := make(map[string]*regexp.Regexp)
for i, r := range rs {
l, err := newLeaf(r, compiledRxs)
if err != nil {
errors = append(errors, &definitionError{r.Id, i, err})
continue
}
path, err := normalizePath(r)
if err != nil {
errors = append(errors, &definitionError{r.Id, i, err})
continue
}
if r.pathSubtree != "" {
addSubtreeLeafsToPath(pathMatchers, path, l, o)
continue
}
if r.path == "" {
rootLeaves = append(rootLeaves, l)
continue
}
if o.ignoreTrailingSlash() {
path = trimTrailingSlash(path)
}
addLeafToPath(pathMatchers, path, l)
}
pathTree := &pathmux.Tree{}
errors = append(errors, addTreeMatchers(pathTree, pathMatchers)...)
// sort root leaves during construction time, based on their priority
sort.Stable(rootLeaves)
return &matcher{pathTree, rootLeaves, o}, errors
}
// matches a path in the path trie structure.
func matchPathTree(tree *pathmux.Tree, path string, lrm *leafRequestMatcher) (map[string]string, *leafMatcher) {
v, params, value := tree.LookupMatcher(path, lrm)
if v == nil {
return nil, nil
}
lm := value.(*leafMatcher)
if len(lm.wildcardParamNames) == len(params)+1 {
// prepend an empty string for the path subtree match
params = append([]string{""}, params...)
}
l := len(params)
if l > len(lm.wildcardParamNames) {
l = len(lm.wildcardParamNames)
}
paramsMap := make(map[string]string)
for i := 0; i < l; i += 1 {
paramsMap[lm.wildcardParamNames[i]] = params[i]
}
if l > 0 && lm.hasFreeWildcardParam {
paramsMap[lm.wildcardParamNames[0]] = "/" + params[0]
}
return paramsMap, lm
}
// matches the path regexp conditions in a leaf matcher.
func matchRegexps(rxs []*regexp.Regexp, s string) bool {
for _, rx := range rxs {
if !rx.MatchString(s) {
return false
}
}
return true
}
// matches a set of request headers to a fix and regexp header condition
func matchHeader(h http.Header, key string, check func(string) bool) bool {
vals, has := h[key]
if !has {
return false
}
for _, val := range vals {
if check(val) {
return true
}
}
return false
}
// matches a set of request headers to the fix and regexp header conditions
func matchHeaders(exact map[string]string, hrxs map[string][]*regexp.Regexp, h http.Header) bool {
for k, v := range exact {
if !matchHeader(h, k, func(val string) bool { return val == v }) {
return false
}
}
for k, rxs := range hrxs {
for _, rx := range rxs {
if !matchHeader(h, k, rx.MatchString) {
return false
}
}
}
return true
}
// check if all defined custom predicates are matched
func matchPredicates(cps []Predicate, req *http.Request) bool {
for _, cp := range cps {
if !cp.Match(req) {
return false
}
}
return true
}
// matches a request to the conditions in a leaf matcher
func matchLeaf(l *leafMatcher, req *http.Request, path, exactPath string) bool {
if l.exactPath != "" && l.exactPath != path {
return false
}
if l.method != "" && l.method != req.Method {
return false
}
if !matchRegexps(l.hostRxs, req.Host) {
return false
}
if !matchRegexps(l.pathRxs, exactPath) {
return false
}
if !matchHeaders(l.headersExact, l.headersRegexp, req.Header) {
return false
}
if !matchPredicates(l.predicates, req) {
return false
}
return true
}
// matches a request to a set of leaf matchers
func matchLeaves(leaves leafMatchers, req *http.Request, path, exactPath string) *leafMatcher {
for _, l := range leaves {
if matchLeaf(l, req, path, exactPath) {
return l
}
}
return nil
}
// tries to match a request against the available definitions. If a match is found,
// returns the associated value, and the wildcard parameters from the path definition,
// if any.
func (m *matcher) match(r *http.Request) (*Route, map[string]string) {
// normalize path before matching
// in case ignoring trailing slashes, match without the trailing slash
path := httppath.Clean(r.URL.Path)
exact := path
if m.matchingOptions.ignoreTrailingSlash() {
path = trimTrailingSlash(path)
}
lrm := &leafRequestMatcher{r: r, path: path, exactPath: exact}
// first match fixed and wildcard paths
params, l := matchPathTree(m.paths, path, lrm)
if l != nil {
return l.route, params
}
// if no path match, match root leaves for other conditions
l = matchLeaves(m.rootLeaves, r, path, exact)
if l != nil {
return l.route, nil
}
return nil, nil
}
package routing
import (
"encoding/json"
"fmt"
"maps"
"net/http"
"slices"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/logging"
"github.com/zalando/skipper/metrics"
"github.com/zalando/skipper/predicates"
)
const (
// Deprecated, use predicates.PathName instead
PathName = predicates.PathName
// Deprecated, use predicates.PathSubtreeName instead
PathSubtreeName = predicates.PathSubtreeName
// Deprecated, use predicates.WeightName instead
WeightPredicateName = predicates.WeightName
routesTimestampName = "X-Timestamp"
RoutesCountName = "X-Count"
defaultRouteListingLimit = 1024
)
// MatchingOptions controls route matching.
type MatchingOptions uint
const (
// MatchingOptionsNone indicates that all options are default.
MatchingOptionsNone MatchingOptions = 0
// IgnoreTrailingSlash indicates that trailing slashes in paths are ignored.
IgnoreTrailingSlash MatchingOptions = 1 << iota
)
func (o MatchingOptions) ignoreTrailingSlash() bool {
return o&IgnoreTrailingSlash > 0
}
// DataClient instances provide data sources for
// route definitions.
type DataClient interface {
LoadAll() ([]*eskip.Route, error)
LoadUpdate() ([]*eskip.Route, []string, error)
}
// Predicate instances are used as custom user defined route
// matching predicates.
type Predicate interface {
// Returns true if the request matches the predicate.
Match(*http.Request) bool
}
// PredicateSpec instances are used to create custom predicates
// (of type Predicate) with concrete arguments during the
// construction of the routing tree.
type PredicateSpec interface {
// Name of the predicate as used in the route definitions.
Name() string
// Creates a predicate instance with concrete arguments.
Create([]interface{}) (Predicate, error)
}
type WeightedPredicateSpec interface {
PredicateSpec
// Extra Weight of the predicate
Weight() int
}
// Options for initialization for routing.
type Options struct {
// Registry containing the available filter
// specifications that are used during processing
// the filter chains in the route definitions.
FilterRegistry filters.Registry
// Matching options are flags that control the
// route matching.
MatchingOptions MatchingOptions
// The timeout between requests to the data
// clients for route definition updates.
PollTimeout time.Duration
// The set of different data clients where the
// route definitions are read from.
DataClients []DataClient
// Specifications of custom, user defined predicates.
Predicates []PredicateSpec
// Performance tuning option.
//
// When zero, the newly constructed routing
// tree will take effect on the next routing
// query after every update from the data
// clients. In case of higher values, the
// routing queries have priority over the
// update channel, but the next routing tree
// takes effect only a few requests later.
//
// (Currently disabled and used with hard wired
// 0, until the performance benefit is verified
// by benchmarks.)
UpdateBuffer int
// Set a custom logger if necessary.
Log logging.Logger
// SuppressLogs indicates whether to log only a summary of the route changes.
SuppressLogs bool
// Metrics is used to collect monitoring data about the routes health, including
// total number of routes applied and UNIX time of the last routes update.
Metrics metrics.Metrics
// PreProcessors contains custom eskip.Route pre-processors.
PreProcessors []PreProcessor
// PostProcessors contains custom route post-processors.
PostProcessors []PostProcessor
// SignalFirstLoad enables signaling on the first load
// of the routing configuration during the startup.
SignalFirstLoad bool
}
// RouteFilter contains extensions to generic filter
// interface, serving mainly logging/monitoring
// purpose.
type RouteFilter struct {
filters.Filter
Name string
// Deprecated: currently not used, and post-processors may not maintain a correct value
Index int
}
// LBEndpoint represents the scheme and the host of load balanced
// backends.
type LBEndpoint struct {
Scheme, Host string
Metrics Metrics
}
// LBAlgorithm implementations apply a load balancing algorithm
// over the possible endpoints of a load balanced route.
type LBAlgorithm interface {
Apply(*LBContext) LBEndpoint
}
// LBContext is used to pass data to the load balancer to decide based
// on that data which endpoint to call from the backends
type LBContext struct {
Request *http.Request
Route *Route
LBEndpoints []LBEndpoint
Params map[string]interface{}
}
// NewLBContext is used to create a new LBContext, to pass data to the
// load balancer algorithms.
// Deprecated: create LBContext instead
func NewLBContext(r *http.Request, rt *Route) *LBContext {
return &LBContext{
Request: r,
Route: rt,
}
}
// Route object with preprocessed filter instances.
type Route struct {
// Fields from the static route definition.
eskip.Route
// weight used internally, received from the Weight() predicates.
weight int
// path predicate matching a subtree
path string
// path predicate matching a subtree
pathSubtree string
// The backend scheme and host.
Scheme, Host string
// The preprocessed custom predicate instances.
Predicates []Predicate
// The preprocessed filter instances.
Filters []*RouteFilter
// LBEndpoints contain the possible endpoints of a load
// balanced route.
LBEndpoints []LBEndpoint
// LBAlgorithm is the selected load balancing algorithm
// of a load balanced route.
LBAlgorithm LBAlgorithm
// LBFadeInDuration defines the duration of the fade-in
// function to be applied to new LB endpoints associated
// with this route.
LBFadeInDuration time.Duration
// LBExponent defines a secondary exponent modifier of
// the fade-in function configured mainly by the LBFadeInDuration
// field, adjusting the shape of the fade-in. By default,
// its value is usually 1, meaning linear fade-in, and it's
// configured by the post-processor found in the filters/fadein
// package.
LBFadeInExponent float64
}
// PostProcessor is an interface for custom post-processors applying changes
// to the routes after they were created from their data representation and
// before they were passed to the proxy.
type PostProcessor interface {
Do([]*Route) []*Route
}
// PreProcessor is an interface for custom pre-processors applying changes
// to the routes before they were created from eskip.Route representation.
type PreProcessor interface {
Do([]*eskip.Route) []*eskip.Route
}
// Routing ('router') instance providing live
// updatable request matching.
type Routing struct {
routeTable atomic.Value // of struct routeTable
log logging.Logger
firstLoad chan struct{}
firstLoadSignaled bool
quit chan struct{}
metrics metrics.Metrics
}
// New initializes a routing instance, and starts listening for route
// definition updates.
func New(o Options) *Routing {
if o.Log == nil {
o.Log = &logging.DefaultLog{}
}
uniqueClients := make(map[DataClient]struct{}, len(o.DataClients))
for i, c := range o.DataClients {
if _, ok := uniqueClients[c]; ok {
o.Log.Errorf("Duplicate data clients are not allowed, ignoring client #%d: %T", i, c)
continue
}
uniqueClients[c] = struct{}{}
}
if len(uniqueClients) != len(o.DataClients) {
o.Log.Errorf("Ignored %d duplicate data clients", len(o.DataClients)-len(uniqueClients))
o.DataClients = slices.Collect(maps.Keys(uniqueClients))
}
r := &Routing{log: o.Log, firstLoad: make(chan struct{}), quit: make(chan struct{})}
r.metrics = o.Metrics
if !o.SignalFirstLoad {
close(r.firstLoad)
r.firstLoadSignaled = true
}
initialMatcher, _ := newMatcher(nil, MatchingOptionsNone)
rt := &routeTable{
m: initialMatcher,
created: time.Now(),
}
r.routeTable.Store(rt)
r.startReceivingUpdates(o)
return r
}
// ServeHTTP renders the list of current routes.
func (r *Routing) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "GET" && req.Method != "HEAD" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
rt := r.routeTable.Load().(*routeTable)
req.ParseForm()
createdUnix := strconv.FormatInt(rt.created.Unix(), 10)
ts := req.Form.Get("timestamp")
if ts != "" && createdUnix != ts {
http.Error(w, "invalid timestamp", http.StatusBadRequest)
return
}
if req.Method == "HEAD" {
w.Header().Set(routesTimestampName, createdUnix)
w.Header().Set(RoutesCountName, strconv.Itoa(len(rt.validRoutes)))
if strings.Contains(req.Header.Get("Accept"), "application/json") {
w.Header().Set("Content-Type", "application/json")
} else {
w.Header().Set("Content-Type", "text/plain")
}
return
}
offset, err := extractParam(req, "offset", 0)
if err != nil {
http.Error(w, "invalid offset", http.StatusBadRequest)
return
}
limit, err := extractParam(req, "limit", defaultRouteListingLimit)
if err != nil {
http.Error(w, "invalid limit", http.StatusBadRequest)
return
}
w.Header().Set(routesTimestampName, createdUnix)
w.Header().Set(RoutesCountName, strconv.Itoa(len(rt.validRoutes)))
routes := slice(rt.validRoutes, offset, limit)
if strings.Contains(req.Header.Get("Accept"), "application/json") {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(routes); err != nil {
http.Error(
w,
http.StatusText(http.StatusInternalServerError),
http.StatusInternalServerError,
)
}
return
}
w.Header().Set("Content-Type", "text/plain")
eskip.Fprint(w, extractPretty(req), routes...)
}
func (r *Routing) startReceivingUpdates(o Options) {
c := make(chan *routeTable)
go receiveRouteMatcher(o, c, r.quit)
go func() {
for {
select {
case rt := <-c:
r.routeTable.Store(rt)
if !r.firstLoadSignaled {
if len(rt.clients) == len(o.DataClients) {
close(r.firstLoad)
r.firstLoadSignaled = true
}
}
r.log.Infof("route settings applied, id: %d", rt.id)
if r.metrics != nil { // existing codebases might not supply metrics instance
r.metrics.UpdateGauge("routes.total", float64(len(rt.validRoutes)))
r.metrics.UpdateGauge("routes.updated_timestamp", float64(rt.created.Unix()))
r.metrics.MeasureSince("routes.update_latency", rt.created)
}
case <-r.quit:
var rt *routeTable
rt, ok := r.routeTable.Load().(*routeTable)
if ok {
rt.close()
}
return
}
}
}()
}
// Route matches a request in the current routing tree.
//
// If the request matches a route, returns the route and a map of
// parameters constructed from the wildcard parameters in the path
// condition if any. If there is no match, it returns nil.
func (r *Routing) Route(req *http.Request) (*Route, map[string]string) {
rt := r.routeTable.Load().(*routeTable)
return rt.m.match(req)
}
// FirstLoad, when enabled, blocks until the first routing configuration was received
// by the routing during the startup. When disabled, it doesn't block.
func (r *Routing) FirstLoad() <-chan struct{} {
return r.firstLoad
}
// RouteLookup captures a single generation of the lookup tree, allowing multiple
// lookups to the same version of the lookup tree.
//
// Experimental feature. Using this solution potentially can cause large memory
// consumption in extreme cases, typically when:
// the total number routes is large, the backend responses to a subset of these
// routes is slow, and there's a rapid burst of consecutive updates to the
// routing table. This situation is considered an edge case, but until a protection
// against is found, the feature is experimental and its exported interface may
// change.
type RouteLookup struct {
rt *routeTable
}
// Do executes the lookup against the captured routing table. Equivalent to
// Routing.Route().
func (rl *RouteLookup) Do(req *http.Request) (*Route, map[string]string) {
return rl.rt.m.match(req)
}
// Get returns a captured generation of the lookup table. This feature is
// experimental. See the description of the RouteLookup type.
func (r *Routing) Get() *RouteLookup {
rt := r.routeTable.Load().(*routeTable)
return &RouteLookup{rt: rt}
}
// Close closes routing, routeTable and stops statemachine for receiving routes.
func (r *Routing) Close() {
close(r.quit)
}
func slice(r []*eskip.Route, offset int, limit int) []*eskip.Route {
if offset > len(r) {
offset = len(r)
}
end := offset + limit
if end > len(r) {
end = len(r)
}
result := r[offset:end]
if result == nil {
return []*eskip.Route{}
}
return result
}
func extractParam(r *http.Request, key string, defaultValue int) (int, error) {
param := r.Form.Get(key)
if param == "" {
return defaultValue, nil
}
val, err := strconv.Atoi(param)
if err != nil {
return 0, err
}
if val < 0 {
return 0, fmt.Errorf("invalid value `%d` for `%s`", val, key)
}
return val, nil
}
func extractPretty(r *http.Request) eskip.PrettyPrintInfo {
vals, ok := r.Form["nopretty"]
if !ok || len(vals) == 0 {
return eskip.PrettyPrintInfo{Pretty: true, IndentStr: " "}
}
val := vals[0]
if val == "0" || val == "false" {
return eskip.PrettyPrintInfo{Pretty: true, IndentStr: " "}
}
return eskip.PrettyPrintInfo{Pretty: false, IndentStr: ""}
}
// Package scheduler provides a registry to be used as a postprocessor for the routes
// that use a LIFO filter.
package scheduler
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/aryszka/jobqueue"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/metrics"
"github.com/zalando/skipper/routing"
"golang.org/x/sync/semaphore"
)
// note: Config must stay comparable because it is used to detect changes in route specific LIFO config
const (
// LIFOKey used during routing to pass lifo values from the filters to the proxy.
LIFOKey = "lifo"
// FIFOKey used during routing to pass fifo values from the filters to the proxy.
FIFOKey = "fifo"
)
var (
ErrQueueFull = errors.New("queue full")
ErrQueueTimeout = errors.New("queue timeout")
ErrClientCanceled = errors.New("client canceled")
)
// Config can be used to provide configuration of the registry.
type Config struct {
// MaxConcurrency defines how many jobs are allowed to run concurrently.
// Defaults to 1.
MaxConcurrency int
// MaxStackSize defines how many jobs may be waiting in the stack.
// Defaults to infinite.
MaxQueueSize int
// Timeout defines how long a job can be waiting in the stack.
// Defaults to infinite.
Timeout time.Duration
// CloseTimeout sets a maximum duration for how long the queue can wait
// for the active and queued jobs to finish. Defaults to infinite.
CloseTimeout time.Duration
}
// QueueStatus reports the current status of a queue. It can be used for metrics.
type QueueStatus struct {
// ActiveRequests represents the number of the requests currently being handled.
ActiveRequests int
// QueuedRequests represents the number of requests waiting to be handled.
QueuedRequests int
// Closed indicates that the queue was closed.
Closed bool
}
// Queue objects implement a LIFO queue for handling requests, with a maximum allowed
// concurrency and queue size. Currently, they can be used from the lifo and lifoGroup
// filters in the filters/scheduler package only.
type Queue struct {
queue *jobqueue.Stack
config Config
metrics metrics.Metrics
activeRequestsMetricsKey string
errorFullMetricsKey string
errorOtherMetricsKey string
errorTimeoutMetricsKey string
queuedRequestsMetricsKey string
}
// FifoQueue objects implement a FIFO queue for handling requests,
// with a maximum allowed concurrency and queue size. Currently, they
// can be used from the fifo filters in the filters/scheduler package
// only.
type FifoQueue struct {
queue *fifoQueue
config Config
metrics metrics.Metrics
activeRequestsMetricsKey string
errorFullMetricsKey string
errorOtherMetricsKey string
errorTimeoutMetricsKey string
queuedRequestsMetricsKey string
}
type fifoQueue struct {
mu sync.RWMutex
counter *atomic.Int64
sem *semaphore.Weighted
timeout time.Duration
maxQueueSize int64
maxConcurrency int64
closed bool
}
func (fq *fifoQueue) status() QueueStatus {
fq.mu.RLock()
maxConcurrency := fq.maxConcurrency
closed := fq.closed
fq.mu.RUnlock()
all := fq.counter.Load()
var queued, active int64
if all > maxConcurrency {
queued = all - maxConcurrency
active = maxConcurrency
} else {
queued = 0
active = all
}
return QueueStatus{
ActiveRequests: int(active),
QueuedRequests: int(queued),
Closed: closed,
}
}
func (fq *fifoQueue) close() {
fq.mu.Lock()
fq.closed = true
fq.mu.Unlock()
}
func (fq *fifoQueue) reconfigure(c Config) {
fq.mu.Lock()
defer fq.mu.Unlock()
fq.maxConcurrency = int64(c.MaxConcurrency)
fq.maxQueueSize = int64(c.MaxQueueSize)
fq.timeout = c.Timeout
fq.sem = semaphore.NewWeighted(int64(c.MaxConcurrency))
fq.counter = new(atomic.Int64)
}
func (fq *fifoQueue) wait(ctx context.Context) (func(), error) {
fq.mu.RLock()
maxConcurrency := fq.maxConcurrency
maxQueueSize := fq.maxQueueSize
timeout := fq.timeout
sem := fq.sem
cnt := fq.counter
fq.mu.RUnlock()
// check request context expired
// https://github.com/golang/go/issues/63615
if err := ctx.Err(); err != nil {
switch err {
case context.DeadlineExceeded:
return nil, ErrQueueTimeout
case context.Canceled:
return nil, ErrClientCanceled
default:
// does not exist yet in Go stdlib as of Go1.18.4
return nil, err
}
}
// handle queue
all := cnt.Add(1)
// queue full?
if all > maxConcurrency+maxQueueSize {
cnt.Add(-1)
return nil, ErrQueueFull
}
// set timeout
c, done := context.WithTimeout(ctx, timeout)
defer done()
// limit concurrency
if err := sem.Acquire(c, 1); err != nil {
cnt.Add(-1)
switch err {
case context.DeadlineExceeded:
return nil, ErrQueueTimeout
case context.Canceled:
return nil, ErrClientCanceled
default:
// does not exist yet in Go stdlib as of Go1.18.4
return nil, err
}
}
return func() {
// postpone release to Response() filter
cnt.Add(-1)
sem.Release(1)
}, nil
}
// Options provides options for the registry.
type Options struct {
// MetricsUpdateTimeout defines the frequency of how often the
// FIFO and LIFO metrics are updated when they are enabled.
// Defaults to 1s.
MetricsUpdateTimeout time.Duration
// EnableRouteLIFOMetrics enables collecting metrics about the LIFO queues.
EnableRouteLIFOMetrics bool
// EnableRouteFIFOMetrics enables collecting metrics about the FIFO queues.
EnableRouteFIFOMetrics bool
// Metrics must be provided to the registry in order to collect the FIFO and LIFO metrics.
Metrics metrics.Metrics
}
// Registry maintains a set of LIFO queues. It is used to preserve LIFO queue instances
// across multiple generations of the routing. It implements the routing.PostProcessor
// interface, it is enough to just pass in to routing.Routing when initializing it.
//
// When the EnableRouteLIFOMetrics is set, then the registry starts a background goroutine
// for regularly take snapshots of the active lifo queues and update the corresponding
// metrics. This goroutine is started when the first lifo filter is detected and returns
// when the registry is closed. Individual metrics objects (keys) are used for each
// lifo filter, and one for each lifo group defined by the lifoGroup filter.
type Registry struct {
options Options
measuring bool
quit chan struct{}
mu sync.Mutex
lifoQueues map[queueId]*Queue
lifoDeleted map[*Queue]time.Time
fifoQueues map[queueId]*FifoQueue
fifoDeleted map[*FifoQueue]time.Time
}
type queueId struct {
name string
grouped bool
}
// Amount of time to wait before closing the deleted queues
var queueCloseDelay = 1 * time.Minute
// FIFOFilter is the interface that needs to be implemented by the filters that
// use a FIFO queue maintained by the registry.
type FIFOFilter interface {
// SetQueue will be used by the registry to pass in the right queue to
// the filter.
SetQueue(*FifoQueue)
// GetQueue is currently used only by tests.
GetQueue() *FifoQueue
// Config will be called by the registry once during processing the
// routing to get the right queue settings from the filter.
Config() Config
}
// LIFOFilter is the interface that needs to be implemented by the filters that
// use a LIFO queue maintained by the registry.
type LIFOFilter interface {
// SetQueue will be used by the registry to pass in the right queue to
// the filter.
SetQueue(*Queue)
// GetQueue is currently used only by tests.
GetQueue() *Queue
// Config will be called by the registry once during processing the
// routing to get the right queue settings from the filter.
Config() Config
}
// GroupedLIFOFilter is an extension of the LIFOFilter interface for filters
// that use a shared queue.
type GroupedLIFOFilter interface {
LIFOFilter
// Group returns the name of the group.
Group() string
// HasConfig indicates that the current filter provides the queue
// queue settings for the group.
HasConfig() bool
}
// Wait blocks until a request can be processed or needs to be
// rejected. It returns done() and an error. When it can be
// processed, calling done indicates that it has finished. It is
// mandatory to call done() the request was processed. When the
// request needs to be rejected, an error will be returned and done
// will be nil.
func (fq *FifoQueue) Wait(ctx context.Context) (func(), error) {
f, err := fq.queue.wait(ctx)
if err != nil && fq.metrics != nil {
switch err {
case ErrQueueFull:
fq.metrics.IncCounter(fq.errorFullMetricsKey)
case ErrQueueTimeout:
fq.metrics.IncCounter(fq.errorTimeoutMetricsKey)
case ErrClientCanceled:
// This case is handled in the proxy with status code 499
default:
fq.metrics.IncCounter(fq.errorOtherMetricsKey)
}
}
return f, err
}
// Status returns the current status of a queue.
func (fq *FifoQueue) Status() QueueStatus {
return fq.queue.status()
}
// Config returns the configuration that the queue was created with.
func (fq *FifoQueue) Config() Config {
return fq.config
}
// Reconfigure updates the connfiguration of the FifoQueue. It will
// reset the current state.
func (fq *FifoQueue) Reconfigure(c Config) {
fq.config = c
fq.queue.reconfigure(c)
}
func (fq *FifoQueue) close() {
fq.queue.close()
}
// Wait blocks until a request can be processed or needs to be rejected.
// When it can be processed, calling done indicates that it has finished.
// It is mandatory to call done() the request was processed. When the
// request needs to be rejected, an error will be returned.
func (q *Queue) Wait() (done func(), err error) {
done, err = q.queue.Wait()
if q.metrics != nil && err != nil {
switch err {
case jobqueue.ErrStackFull:
q.metrics.IncCounter(q.errorFullMetricsKey)
case jobqueue.ErrTimeout:
q.metrics.IncCounter(q.errorTimeoutMetricsKey)
default:
q.metrics.IncCounter(q.errorOtherMetricsKey)
}
}
return done, err
}
// Status returns the current status of a queue.
func (q *Queue) Status() QueueStatus {
st := q.queue.Status()
return QueueStatus{
ActiveRequests: st.ActiveJobs,
QueuedRequests: st.QueuedJobs,
Closed: st.Closed,
}
}
// Config returns the configuration that the queue was created with.
func (q *Queue) Config() Config {
return q.config
}
func (q *Queue) reconfigure() {
q.queue.Reconfigure(jobqueue.Options{
MaxConcurrency: q.config.MaxConcurrency,
MaxStackSize: q.config.MaxQueueSize,
Timeout: q.config.Timeout,
})
}
func (q *Queue) Close() {
q.queue.Close()
}
// RegistryWith (Options) creates a registry with the provided options.
func RegistryWith(o Options) *Registry {
if o.MetricsUpdateTimeout <= 0 {
o.MetricsUpdateTimeout = time.Second
}
return &Registry{
options: o,
quit: make(chan struct{}),
fifoQueues: make(map[queueId]*FifoQueue),
fifoDeleted: make(map[*FifoQueue]time.Time),
lifoQueues: make(map[queueId]*Queue),
lifoDeleted: make(map[*Queue]time.Time),
}
}
// NewRegistry creates a registry with the default options.
func NewRegistry() *Registry {
return RegistryWith(Options{})
}
func (r *Registry) getFifoQueue(id queueId, c Config) *FifoQueue {
r.mu.Lock()
defer r.mu.Unlock()
fq, ok := r.fifoQueues[id]
if ok {
if fq.config != c {
fq.Reconfigure(c)
}
} else {
fq = r.newFifoQueue(id.name, c)
r.fifoQueues[id] = fq
}
return fq
}
func (r *Registry) newFifoQueue(name string, c Config) *FifoQueue {
q := &FifoQueue{
config: c,
queue: &fifoQueue{
counter: new(atomic.Int64),
sem: semaphore.NewWeighted(int64(c.MaxConcurrency)),
maxConcurrency: int64(c.MaxConcurrency),
maxQueueSize: int64(c.MaxQueueSize),
timeout: c.Timeout,
},
}
if r.options.EnableRouteFIFOMetrics {
if name == "" {
name = "unknown"
}
q.activeRequestsMetricsKey = fmt.Sprintf("fifo.%s.active", name)
q.queuedRequestsMetricsKey = fmt.Sprintf("fifo.%s.queued", name)
q.errorFullMetricsKey = fmt.Sprintf("fifo.%s.error.full", name)
q.errorOtherMetricsKey = fmt.Sprintf("fifo.%s.error.other", name)
q.errorTimeoutMetricsKey = fmt.Sprintf("fifo.%s.error.timeout", name)
q.metrics = r.options.Metrics
r.measure()
}
return q
}
func (r *Registry) getQueue(id queueId, c Config) *Queue {
r.mu.Lock()
defer r.mu.Unlock()
q, ok := r.lifoQueues[id]
if ok {
if q.config != c {
q.config = c
q.reconfigure()
}
} else {
q = r.newQueue(id.name, c)
r.lifoQueues[id] = q
}
return q
}
func (r *Registry) newQueue(name string, c Config) *Queue {
q := &Queue{
config: c,
// renaming Stack -> Queue in the jobqueue project will follow
queue: jobqueue.With(jobqueue.Options{
MaxConcurrency: c.MaxConcurrency,
MaxStackSize: c.MaxQueueSize,
Timeout: c.Timeout,
}),
}
if r.options.EnableRouteLIFOMetrics {
if name == "" {
name = "unknown"
}
q.activeRequestsMetricsKey = fmt.Sprintf("lifo.%s.active", name)
q.queuedRequestsMetricsKey = fmt.Sprintf("lifo.%s.queued", name)
q.errorFullMetricsKey = fmt.Sprintf("lifo.%s.error.full", name)
q.errorOtherMetricsKey = fmt.Sprintf("lifo.%s.error.other", name)
q.errorTimeoutMetricsKey = fmt.Sprintf("lifo.%s.error.timeout", name)
q.metrics = r.options.Metrics
r.measure()
}
return q
}
func (r *Registry) deleteUnused(inUse map[queueId]struct{}) {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
closeCutoff := now.Add(-queueCloseDelay)
// fifo
for q, deleted := range r.fifoDeleted {
if deleted.Before(closeCutoff) {
delete(r.fifoDeleted, q)
q.close()
}
}
for id, q := range r.fifoQueues {
if _, ok := inUse[id]; !ok {
delete(r.fifoQueues, id)
r.fifoDeleted[q] = now
}
}
// lifo
for q, deleted := range r.lifoDeleted {
if deleted.Before(closeCutoff) {
delete(r.lifoDeleted, q)
q.Close()
}
}
for id, q := range r.lifoQueues {
if _, ok := inUse[id]; !ok {
delete(r.lifoQueues, id)
r.lifoDeleted[q] = now
}
}
}
// Returns routing.PreProcessor that ensures single lifo filter instance per route
//
// Registry can not implement routing.PreProcessor directly due to unfortunate method name clash with routing.PostProcessor
func (r *Registry) PreProcessor() routing.PreProcessor {
return registryPreProcessor{}
}
type registryPreProcessor struct{}
func (registryPreProcessor) Do(routes []*eskip.Route) []*eskip.Route {
for _, r := range routes {
lifoCount := 0
fifoCount := 0
for _, f := range r.Filters {
switch f.Name {
case filters.FifoName:
fifoCount++
case filters.LifoName:
lifoCount++
}
}
// remove all but last fifo instances
if fifoCount > 1 {
old := r.Filters
r.Filters = make([]*eskip.Filter, 0, len(old)-fifoCount+1)
for _, f := range old {
if fifoCount > 1 && f.Name == filters.FifoName {
log.Debugf("Removing non-last %v from %s", f, r.Id)
fifoCount--
} else {
r.Filters = append(r.Filters, f)
}
}
}
// remove all but last lifo instances
if lifoCount > 1 {
old := r.Filters
r.Filters = make([]*eskip.Filter, 0, len(old)-lifoCount+1)
for _, f := range old {
if lifoCount > 1 && f.Name == filters.LifoName {
log.Debugf("Removing non-last %v from %s", f, r.Id)
lifoCount--
} else {
r.Filters = append(r.Filters, f)
}
}
}
}
return routes
}
// Do implements routing.PostProcessor and sets the queue for the scheduler filters.
//
// It preserves the existing queue when available.
func (r *Registry) Do(routes []*routing.Route) []*routing.Route {
rr := make([]*routing.Route, len(routes))
inUse := make(map[queueId]struct{})
groups := make(map[string][]GroupedLIFOFilter)
for i, ri := range routes {
rr[i] = ri
for _, fi := range ri.Filters {
if ff, ok := fi.Filter.(FIFOFilter); ok {
id := queueId{ri.Id, false}
inUse[id] = struct{}{}
fq := r.getFifoQueue(id, ff.Config())
ff.SetQueue(fq)
continue
}
if glf, ok := fi.Filter.(GroupedLIFOFilter); ok {
groupName := glf.Group()
groups[groupName] = append(groups[groupName], glf)
continue
}
lf, ok := fi.Filter.(LIFOFilter)
if !ok {
continue
}
id := queueId{ri.Id, false}
inUse[id] = struct{}{}
q := r.getQueue(id, lf.Config())
lf.SetQueue(q)
}
}
for name, group := range groups {
var (
c Config
foundConfig bool
)
for _, glf := range group {
if !glf.HasConfig() {
continue
}
if foundConfig && glf.Config() != c {
log.Warnf("Found mismatching configuration for the LIFO group: %s", name)
continue
}
c = glf.Config()
foundConfig = true
}
id := queueId{name, true}
inUse[id] = struct{}{}
q := r.getQueue(id, c)
for _, glf := range group {
glf.SetQueue(q)
}
}
r.deleteUnused(inUse)
return rr
}
func (r *Registry) measure() {
if r.options.Metrics == nil || r.measuring {
return
}
r.measuring = true
go func() {
ticker := time.NewTicker(r.options.MetricsUpdateTimeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
r.updateMetrics()
case <-r.quit:
return
}
}
}()
}
func (r *Registry) updateMetrics() {
r.mu.Lock()
defer r.mu.Unlock()
for _, q := range r.fifoQueues {
s := q.Status()
r.options.Metrics.UpdateGauge(q.activeRequestsMetricsKey, float64(s.ActiveRequests))
r.options.Metrics.UpdateGauge(q.queuedRequestsMetricsKey, float64(s.QueuedRequests))
}
for _, q := range r.lifoQueues {
s := q.Status()
r.options.Metrics.UpdateGauge(q.activeRequestsMetricsKey, float64(s.ActiveRequests))
r.options.Metrics.UpdateGauge(q.queuedRequestsMetricsKey, float64(s.QueuedRequests))
}
}
func (r *Registry) UpdateMetrics() {
if r.options.Metrics != nil {
r.updateMetrics()
}
}
// Close closes the registry, including graceful tearing down the stored queues.
func (r *Registry) Close() {
r.mu.Lock()
defer r.mu.Unlock()
for q := range r.fifoDeleted {
delete(r.fifoDeleted, q)
q.close()
}
for q := range r.lifoDeleted {
delete(r.lifoDeleted, q)
q.Close()
}
for id, q := range r.lifoQueues {
delete(r.lifoQueues, id)
q.Close()
}
close(r.quit)
}
// Package base64 provides an easy way to encode and decode base64
package base64
import (
"encoding/base64"
lua "github.com/yuin/gopher-lua"
)
func Loader(L *lua.LState) int {
mod := L.SetFuncs(L.NewTable(), map[string]lua.LGFunction{
"decode": decode,
"encode": encode,
})
L.Push(mod)
return 1
}
func encode(L *lua.LState) int {
str := L.CheckString(1)
ret := base64.StdEncoding.EncodeToString([]byte(str))
L.Push(lua.LString(ret))
return 1
}
func decode(L *lua.LState) int {
str := L.CheckString(1)
ret, err := base64.StdEncoding.DecodeString(str)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(lua.LString(ret))
return 1
}
package script
import (
"strings"
"time"
log "github.com/sirupsen/logrus"
lua "github.com/yuin/gopher-lua"
)
type luaModule struct {
name string
loader lua.LGFunction
disabledSymbols []string
}
var standardModules = []luaModule{
// Load Package and Base first, see lua.LState.OpenLibs()
{lua.LoadLibName, lua.OpenPackage, nil},
{lua.BaseLibName, lua.OpenBase, nil},
{lua.TabLibName, lua.OpenTable, nil},
{lua.IoLibName, lua.OpenIo, nil},
{lua.OsLibName, lua.OpenOs, nil},
{lua.StringLibName, lua.OpenString, nil},
{lua.MathLibName, lua.OpenMath, nil},
{lua.DebugLibName, lua.OpenDebug, nil},
{lua.ChannelLibName, lua.OpenChannel, nil},
{lua.CoroutineLibName, lua.OpenCoroutine, nil},
}
// load loads standard lua module, see lua.LState.OpenLibs()
func (m luaModule) load(L *lua.LState) {
L.Push(L.NewFunction(m.loader))
L.Push(lua.LString(m.name))
L.Call(1, 0)
if m.name == lua.BaseLibName {
L.SetGlobal("print", L.NewFunction(printToLog))
L.SetGlobal("sleep", L.NewFunction(sleep))
}
if len(m.disabledSymbols) > 0 {
st := m.table(L)
for _, name := range m.disabledSymbols {
st.RawSetString(name, lua.LNil)
}
}
}
// withSymbols returns copy of module with selected symbols
func (m luaModule) withSymbols(L *lua.LState, enabledSymbols []string) luaModule {
// gopher-lua does not have API to select enabled symbols,
// see https://github.com/yuin/gopher-lua/discussions/393
//
// Instead collect symbols to disable as difference
// between all and enabled module symbols
allSymbols := make(map[string]struct{})
m.load(L)
m.table(L).ForEach(func(k, _ lua.LValue) {
if name, ok := k.(lua.LString); ok {
allSymbols[name.String()] = struct{}{}
}
})
for _, s := range enabledSymbols {
delete(allSymbols, s)
}
result := luaModule{m.name, m.loader, nil}
for s := range allSymbols {
result.disabledSymbols = append(result.disabledSymbols, s)
}
return result
}
func (m luaModule) table(L *lua.LState) *lua.LTable {
name := m.name
if m.name == lua.BaseLibName {
name = "_G"
}
return L.GetGlobal(name).(*lua.LTable)
}
func (m luaModule) preload(L *lua.LState) {
L.PreloadModule(m.name, m.loader)
}
func printToLog(L *lua.LState) int {
top := L.GetTop()
args := make([]interface{}, 0, top)
for i := 1; i <= top; i++ {
args = append(args, L.ToStringMeta(L.Get(i)).String())
}
log.Print(args...)
return 0
}
func sleep(L *lua.LState) int {
time.Sleep(time.Duration(L.CheckInt64(1)) * time.Millisecond)
return 0
}
func moduleConfig(modules []string) map[string][]string {
config := make(map[string][]string)
for _, m := range modules {
if module, symbol, found := strings.Cut(m, "."); found {
config[module] = append(config[module], symbol)
} else {
if _, ok := config[module]; !ok {
config[module] = []string{}
}
}
}
return config
}
// Package script provides lua scripting for skipper
package script
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
log "github.com/sirupsen/logrus"
lua "github.com/yuin/gopher-lua"
lua_parse "github.com/yuin/gopher-lua/parse"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/script/base64"
"slices"
"github.com/cjoudrey/gluahttp"
"github.com/cjoudrey/gluaurl"
gjson "layeh.com/gopher-json"
)
var (
// InitialPoolSize is the number of lua states created initially per route
InitialPoolSize int = 3
// MaxPoolSize is the number of lua states stored per route - there may be more parallel
// requests, but only this number is cached.
MaxPoolSize int = 10
errLuaSourcesDisabled = errors.New("lua Sources are disabled")
)
type LuaOptions struct {
// Modules configures enabled standard and additional (preloaded) Lua modules.
// For standard Lua modules (see https://www.lua.org/manual/5.1/manual.html)
// use "<module>.<symbol>" (e.g. "base.print") to selectively enable module symbols.
// Additional modules are preloaded with all symbols.
// Empty value enables all modules.
Modules []string
// Sources that are allowed as input sources. Valid sources
// are "none", "file" and "inline". Empty slice will enable
// both as default. To disable the use of lua filters use
// "none".
Sources []string
}
type luaScript struct {
modules []string
sources []string
}
// NewLuaScript creates a new filter spec
func NewLuaScript() filters.Spec {
spec, _ := NewLuaScriptWithOptions(LuaOptions{})
return spec
}
// NewLuaScriptWithOptions creates a new filter spec with options
func NewLuaScriptWithOptions(opts LuaOptions) (filters.Spec, error) {
sources := opts.Sources
if len(sources) == 0 {
// backwards compatible default, allow all
sources = append(sources, "file", "inline")
} else if contains("none", opts.Sources) {
sources = nil
}
return &luaScript{
modules: opts.Modules,
sources: sources,
}, nil
}
// Name returns the name of the filter ("lua")
func (ls *luaScript) Name() string {
return filters.LuaName
}
// CreateFilter creates the lua script filter.
func (ls *luaScript) CreateFilter(config []interface{}) (filters.Filter, error) {
if len(ls.sources) == 0 {
return nil, errLuaSourcesDisabled
}
if len(config) == 0 {
return nil, filters.ErrInvalidFilterParameters
}
src, ok := config[0].(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
var params []string
for _, p := range config[1:] {
ps, ok := p.(string)
if !ok {
return nil, filters.ErrInvalidFilterParameters
}
params = append(params, ps)
}
s := &script{source: src, routeParams: params}
if err := s.initScript(ls.modules, ls.sources); err != nil {
return nil, err
}
return s, nil
}
type script struct {
source string
routeParams []string
pool chan *lua.LState
proto *lua.FunctionProto
load []luaModule
preload []luaModule
hasRequest bool
hasResponse bool
}
func (s *script) getState() (*lua.LState, error) {
select {
case L := <-s.pool:
if L == nil {
return nil, errors.New("pool closed")
}
return L, nil
default:
return s.newState()
}
}
func (s *script) putState(L *lua.LState) {
if s.pool == nil { // pool closed
L.Close()
return
}
select {
case s.pool <- L:
default: // pool full, close state
L.Close()
}
}
func (s *script) newState() (*lua.LState, error) {
L := lua.NewState(lua.Options{SkipOpenLibs: true})
for _, m := range s.load {
m.load(L)
}
for _, m := range s.preload {
m.preload(L)
}
L.Push(L.NewFunctionFromProto(s.proto))
err := L.PCall(0, lua.MultRet, nil)
if err != nil {
L.Close()
return nil, err
}
return L, nil
}
func (s *script) initScript(modules, validSources []string) error {
// Compile
var reader io.Reader
var name string
if strings.HasSuffix(s.source, ".lua") {
if !contains("file", validSources) {
return fmt.Errorf(`invalid lua source referenced "file", allowed: "%v"`, validSources)
}
file, err := os.Open(s.source)
if err != nil {
return err
}
defer func() {
if err = file.Close(); err != nil {
log.Errorf("Failed to close lua file %s: %v", s.source, err)
}
}()
reader = bufio.NewReader(file)
name = s.source
} else {
if !contains("inline", validSources) {
return fmt.Errorf(`invalid lua source referenced "inline", allowed: "%v"`, validSources)
}
reader = strings.NewReader(s.source)
name = "<script>"
}
chunk, err := lua_parse.Parse(reader, name)
if err != nil {
return err
}
proto, err := lua.Compile(chunk, name)
if err != nil {
return err
}
s.proto = proto
// Configure enabled modules and symbols
moduleConfig := moduleConfig(modules)
if len(moduleConfig) == 0 {
s.load = append(s.load, standardModules...)
} else {
L := lua.NewState(lua.Options{SkipOpenLibs: true})
defer L.Close()
for _, m := range standardModules {
name := m.name
if m.name == lua.BaseLibName {
name = "base" // alias for empty lua.BaseLibName
}
if symbols, ok := moduleConfig[name]; ok {
if len(symbols) > 0 {
m = m.withSymbols(L, symbols)
}
s.load = append(s.load, m)
}
}
}
// Configure additional modules
additionalModules := []luaModule{
{"base64", base64.Loader, nil},
{"http", gluahttp.NewHttpModule(&http.Client{}).Loader, nil},
{"url", gluaurl.Loader, nil},
{"json", gjson.Loader, nil},
}
if len(moduleConfig) == 0 {
s.preload = append(s.preload, additionalModules...)
} else {
for _, m := range additionalModules {
// TODO: enable selected symbols for preloaded modules
if _, ok := moduleConfig[m.name]; ok {
s.preload = append(s.preload, m)
}
}
}
// Detect request and response functions
// Note: use s.newState() instead of lua.NewState() to load only enabled modules
L, err := s.newState()
if err != nil {
return err
}
defer L.Close()
if fn := L.GetGlobal("request"); fn.Type() == lua.LTFunction {
s.hasRequest = true
}
if fn := L.GetGlobal("response"); fn.Type() == lua.LTFunction {
s.hasResponse = true
}
if !s.hasRequest && !s.hasResponse {
return errors.New("at least one of `request` and `response` function must be present")
}
// Init state pool
s.pool = make(chan *lua.LState, MaxPoolSize) // FIXME make configurable
for i := 0; i < InitialPoolSize; i++ {
L, err := s.newState()
if err != nil {
return err
}
s.putState(L)
}
return nil
}
func (s *script) Request(f filters.FilterContext) {
if s.hasRequest {
s.runFunc("request", f)
}
}
func (s *script) Response(f filters.FilterContext) {
if s.hasResponse {
s.runFunc("response", f)
}
}
func (s *script) runFunc(name string, f filters.FilterContext) {
L, err := s.getState()
if err != nil {
log.Errorf("Error obtaining lua environment: %v", err)
return
}
defer s.putState(L)
pt := L.CreateTable(len(s.routeParams), len(s.routeParams))
for i, p := range s.routeParams {
k, v, _ := strings.Cut(p, "=")
pt.RawSetString(k, lua.LString(v))
pt.RawSetInt(i+1, lua.LString(p))
}
err = L.CallByParam(
lua.P{
Fn: L.GetGlobal(name),
NRet: 0,
Protect: true,
},
s.filterContextAsLuaTable(L, f),
pt,
)
if err != nil {
log.Errorf("Error calling %s from %s: %v", name, s.source, err)
}
}
func (s *script) filterContextAsLuaTable(L *lua.LState, f filters.FilterContext) *lua.LTable {
// this will be passed as parameter to the lua functions
// add metatable to dynamically access fields in the context
t := L.CreateTable(0, 0)
mt := L.CreateTable(0, 1)
mt.RawSetString("__index", L.NewFunction(getContextValue(f)))
L.SetMetatable(t, mt)
return t
}
func serveRequest(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
t := s.Get(-1)
r, ok := t.(*lua.LTable)
if !ok {
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// s.RaiseError("unsupported type %v, need a table", t.Type())
// return 0
s.Push(lua.LString("invalid type, need a table"))
return 1
}
res := &http.Response{}
r.ForEach(serveTableWalk(res))
f.Serve(res)
return 0
}
}
func serveTableWalk(res *http.Response) func(lua.LValue, lua.LValue) {
return func(k, v lua.LValue) {
sk, ok := k.(lua.LString)
if !ok {
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// s.RaiseError("unsupported key type %v, need a string", k.Type())
return
}
switch string(sk) {
case "status_code":
n, ok := v.(lua.LNumber)
if !ok {
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// s.RaiseError("unsupported status_code type %v, need a number", v.Type())
return
}
res.StatusCode = int(n)
case "header":
t, ok := v.(*lua.LTable)
if !ok {
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// s.RaiseError("unsupported header type %v, need a table", v.Type())
return
}
h := make(http.Header)
t.ForEach(serveHeaderWalk(h))
res.Header = h
case "body":
var body []byte
var err error
switch v.Type() {
case lua.LTString:
data := string(v.(lua.LString))
body = []byte(data)
case lua.LTTable:
body, err = gjson.Encode(v.(*lua.LTable))
if err != nil {
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// s.RaiseError("%v", err)
return
}
}
res.Body = io.NopCloser(bytes.NewBuffer(body))
}
}
}
func serveHeaderWalk(h http.Header) func(lua.LValue, lua.LValue) {
return func(k, v lua.LValue) {
h.Set(k.String(), v.String())
}
}
func getContextValue(f filters.FilterContext) func(*lua.LState) int {
var request, response, state_bag, path_param *lua.LTable
var serve *lua.LFunction
return func(s *lua.LState) int {
key := s.ToString(-1)
var ret lua.LValue
switch key {
case "request":
// initialize access to request on first use
if request == nil {
request = s.CreateTable(0, 0)
mt := s.CreateTable(0, 2)
mt.RawSetString("__index", s.NewFunction(getRequestValue(f)))
mt.RawSetString("__newindex", s.NewFunction(setRequestValue(f)))
s.SetMetatable(request, mt)
}
ret = request
case "response":
if response == nil {
response = s.CreateTable(0, 0)
mt := s.CreateTable(0, 2)
mt.RawSetString("__index", s.NewFunction(getResponseValue(f)))
mt.RawSetString("__newindex", s.NewFunction(setResponseValue(f)))
s.SetMetatable(response, mt)
}
ret = response
case "state_bag":
if state_bag == nil {
state_bag = s.CreateTable(0, 0)
mt := s.CreateTable(0, 2)
mt.RawSetString("__index", s.NewFunction(getStateBag(f)))
mt.RawSetString("__newindex", s.NewFunction(setStateBag(f)))
s.SetMetatable(state_bag, mt)
}
ret = state_bag
case "path_param":
if path_param == nil {
path_param = s.CreateTable(0, 0)
mt := s.CreateTable(0, 1)
mt.RawSetString("__index", s.NewFunction(getPathParam(f)))
s.SetMetatable(path_param, mt)
}
ret = path_param
case "serve":
if serve == nil {
serve = s.NewFunction(serveRequest(f))
}
ret = serve
default:
return 0
}
s.Push(ret)
return 1
}
}
func getRequestValue(f filters.FilterContext) func(*lua.LState) int {
var header, cookie, url_query *lua.LTable
return func(s *lua.LState) int {
key := s.ToString(-1)
var ret lua.LValue
switch key {
case "header":
if header == nil {
header = s.CreateTable(0, 2)
header.RawSetString("add", s.NewFunction(addRequestHeader(f)))
header.RawSetString("values", s.NewFunction(requestHeaderValues(f)))
mt := s.CreateTable(0, 3)
mt.RawSetString("__index", s.NewFunction(getRequestHeader(f)))
mt.RawSetString("__newindex", s.NewFunction(setRequestHeader(f)))
mt.RawSetString("__call", s.NewFunction(iterateRequestHeader(f)))
s.SetMetatable(header, mt)
}
ret = header
case "cookie":
if cookie == nil {
cookie = s.CreateTable(0, 0)
mt := s.CreateTable(0, 3)
mt.RawSetString("__index", s.NewFunction(getRequestCookie(f)))
mt.RawSetString("__newindex", s.NewFunction(unsupported("setting cookie is not supported")))
mt.RawSetString("__call", s.NewFunction(iterateRequestCookie(f)))
s.SetMetatable(cookie, mt)
}
ret = cookie
case "outgoing_host":
ret = lua.LString(f.OutgoingHost())
case "backend_url":
ret = lua.LString(f.BackendUrl())
case "host":
ret = lua.LString(f.Request().Host)
case "remote_addr":
ret = lua.LString(f.Request().RemoteAddr)
case "content_length":
ret = lua.LNumber(f.Request().ContentLength)
case "proto":
ret = lua.LString(f.Request().Proto)
case "method":
ret = lua.LString(f.Request().Method)
case "url":
ret = lua.LString(f.Request().URL.String())
case "url_path":
ret = lua.LString(f.Request().URL.Path)
case "url_query":
if url_query == nil {
url_query = s.CreateTable(0, 0)
mt := s.CreateTable(0, 3)
mt.RawSetString("__index", s.NewFunction(getRequestURLQuery(f)))
mt.RawSetString("__newindex", s.NewFunction(setRequestURLQuery(f)))
mt.RawSetString("__call", s.NewFunction(iterateRequestURLQuery(f)))
s.SetMetatable(url_query, mt)
}
ret = url_query
case "url_raw_query":
ret = lua.LString(f.Request().URL.RawQuery)
default:
return 0
}
s.Push(ret)
return 1
}
}
func setRequestValue(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
key := s.ToString(-2)
switch key {
case "outgoing_host":
f.SetOutgoingHost(s.ToString(-1))
case "url":
u, err := url.Parse(s.ToString(-1))
if err != nil {
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// s.RaiseError("%v", err)
return 0
}
f.Request().URL = u
case "url_path":
f.Request().URL.Path = s.ToString(-1)
case "url_raw_query":
f.Request().URL.RawQuery = s.ToString(-1)
default:
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// s.RaiseError("unsupported request field %s", key)
// do nothing for now
}
return 0
}
}
func getResponseValue(f filters.FilterContext) func(*lua.LState) int {
var header *lua.LTable
return func(s *lua.LState) int {
key := s.ToString(-1)
var ret lua.LValue
switch key {
case "header":
if header == nil {
header = s.CreateTable(0, 2)
header.RawSetString("add", s.NewFunction(addResponseHeader(f)))
header.RawSetString("values", s.NewFunction(responseHeaderValues(f)))
mt := s.CreateTable(0, 3)
mt.RawSetString("__index", s.NewFunction(getResponseHeader(f)))
mt.RawSetString("__newindex", s.NewFunction(setResponseHeader(f)))
mt.RawSetString("__call", s.NewFunction(iterateResponseHeader(f)))
s.SetMetatable(header, mt)
}
ret = header
case "status_code":
ret = lua.LNumber(f.Response().StatusCode)
default:
return 0
}
s.Push(ret)
return 1
}
}
func setResponseValue(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
key := s.ToString(-2)
switch key {
case "status_code":
v := s.Get(-1)
n, ok := v.(lua.LNumber)
if !ok {
s.RaiseError("unsupported status_code type %v, need a number", v.Type())
return 0
}
f.Response().StatusCode = int(n)
default:
s.RaiseError("unsupported response field %s", key)
}
return 0
}
}
func getStateBag(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
fld := s.ToString(-1)
res, ok := f.StateBag()[fld]
if !ok {
return 0
}
switch res := res.(type) {
case string:
s.Push(lua.LString(res))
case int:
s.Push(lua.LNumber(res))
case int64:
s.Push(lua.LNumber(res))
case float64:
s.Push(lua.LNumber(res))
case *lua.LTable:
s.Push(res) // load *lua.LTable as is
default:
return 0
}
return 1
}
}
func setStateBag(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
fld := s.ToString(-2)
val := s.Get(-1)
var res interface{}
switch val.Type() {
case lua.LTString:
res = string(val.(lua.LString))
case lua.LTNumber:
res = float64(val.(lua.LNumber))
case lua.LTTable:
res = val // store *lua.LTable as is
default:
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// s.RaiseError("unsupported state bag value type %v, need a string or a number", val.Type())
return 0
}
f.StateBag()[fld] = res
return 0
}
}
func getRequestHeader(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
hdr := s.ToString(-1)
res := f.Request().Header.Get(hdr)
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// if res != "" {
// s.Push(lua.LString(res))
// return 1
// }
// return 0
s.Push(lua.LString(res))
return 1
}
}
func setRequestHeader(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
lv := s.Get(-1)
hdr := s.ToString(-2)
switch lv.Type() {
case lua.LTNil:
f.Request().Header.Del(hdr)
case lua.LTString:
str := string(lv.(lua.LString))
if str == "" {
f.Request().Header.Del(hdr)
} else {
f.Request().Header.Set(hdr, str)
}
default:
val := s.ToString(-1)
f.Request().Header.Set(hdr, val)
}
return 0
}
}
func addRequestHeader(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
value := s.ToString(-1)
name := s.ToString(-2)
if name != "" && value != "" {
f.Request().Header.Add(name, value)
}
return 0
}
}
func requestHeaderValues(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
name := s.ToString(-1)
values := f.Request().Header.Values(name)
res := s.CreateTable(len(values), 0)
for _, v := range values {
res.Append(lua.LString(v))
}
s.Push(res)
return 1
}
}
func iterateRequestHeader(f filters.FilterContext) func(*lua.LState) int {
// https://www.lua.org/pil/7.2.html
return func(s *lua.LState) int {
s.Push(s.NewFunction(nextHeader(f.Request().Header)))
return 1
}
}
func getRequestCookie(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
k := s.ToString(-1)
c, err := f.Request().Cookie(k)
if err == nil {
s.Push(lua.LString(c.Value))
return 1
}
return 0
}
}
func iterateRequestCookie(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
s.Push(s.NewFunction(nextCookie(f.Request().Cookies())))
return 1
}
}
func nextCookie(cookies []*http.Cookie) func(*lua.LState) int {
return func(s *lua.LState) int {
if len(cookies) > 0 {
c := cookies[0]
s.Push(lua.LString(c.Name))
s.Push(lua.LString(c.Value))
cookies[0] = nil // mind peace
cookies = cookies[1:]
return 2
}
return 0
}
}
func getResponseHeader(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
hdr := s.ToString(-1)
res := f.Response().Header.Get(hdr)
// TODO(sszuecs): https://github.com/zalando/skipper/issues/1487
// if res != "" {
// s.Push(lua.LString(res))
// return 1
// }
// return 0
s.Push(lua.LString(res))
return 1
}
}
func setResponseHeader(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
lv := s.Get(-1)
hdr := s.ToString(-2)
switch lv.Type() {
case lua.LTNil:
f.Response().Header.Del(hdr)
case lua.LTString:
str := string(lv.(lua.LString))
if str == "" {
f.Response().Header.Del(hdr)
} else {
f.Response().Header.Set(hdr, str)
}
default:
val := s.ToString(-1)
f.Response().Header.Set(hdr, val)
}
return 0
}
}
func addResponseHeader(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
value := s.ToString(-1)
name := s.ToString(-2)
if name != "" && value != "" {
f.Response().Header.Add(name, value)
}
return 0
}
}
func responseHeaderValues(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
name := s.ToString(-1)
values := f.Response().Header.Values(name)
res := s.CreateTable(len(values), 0)
for _, v := range values {
res.Append(lua.LString(v))
}
s.Push(res)
return 1
}
}
func iterateResponseHeader(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
s.Push(s.NewFunction(nextHeader(f.Response().Header)))
return 1
}
}
func nextHeader(h http.Header) func(*lua.LState) int {
keys := make([]string, 0, len(h))
for k := range h {
keys = append(keys, k)
}
return func(s *lua.LState) int {
if len(keys) > 0 {
k := keys[0]
s.Push(lua.LString(k))
s.Push(lua.LString(h.Get(k)))
keys[0] = "" // mind peace
keys = keys[1:]
return 2
}
return 0
}
}
func getRequestURLQuery(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
k := s.ToString(-1)
res := f.Request().URL.Query().Get(k)
if res != "" {
s.Push(lua.LString(res))
return 1
}
return 0
}
}
func setRequestURLQuery(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
lv := s.Get(-1)
k := s.ToString(-2)
q := f.Request().URL.Query()
switch lv.Type() {
case lua.LTNil:
q.Del(k)
case lua.LTString:
str := string(lv.(lua.LString))
if str == "" {
q.Del(k)
} else {
q.Set(k, str)
}
default:
val := s.ToString(-1)
q.Set(k, val)
}
f.Request().URL.RawQuery = q.Encode()
return 0
}
}
func iterateRequestURLQuery(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
s.Push(s.NewFunction(nextQuery(f.Request().URL.Query())))
return 1
}
}
func nextQuery(v url.Values) func(*lua.LState) int {
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
}
return func(s *lua.LState) int {
if len(keys) > 0 {
k := keys[0]
s.Push(lua.LString(k))
s.Push(lua.LString(v.Get(k)))
keys[0] = "" // mind peace
keys = keys[1:]
return 2
}
return 0
}
}
func getPathParam(f filters.FilterContext) func(*lua.LState) int {
return func(s *lua.LState) int {
n := s.ToString(-1)
p := f.PathParam(n)
if p != "" {
s.Push(lua.LString(p))
return 1
}
return 0
}
}
func unsupported(message string) func(*lua.LState) int {
return func(s *lua.LState) int {
s.RaiseError("%s", message)
return 0
}
}
func contains(s string, a []string) bool {
return slices.Contains(a, s)
}
package certregistry
import (
"crypto/tls"
"crypto/x509"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
)
// CertRegistry object holds TLS certificates to be used to terminate TLS connections
// ensuring synchronized access to them.
type CertRegistry struct {
mu sync.Mutex
lookup map[string]*tls.Certificate
}
// NewCertRegistry initializes the certificate registry.
func NewCertRegistry() *CertRegistry {
l := make(map[string]*tls.Certificate)
return &CertRegistry{
lookup: l,
}
}
// Configures certificate for the host if no configuration exists or
// if certificate is valid (`NotBefore` field) after previously configured certificate.
func (r *CertRegistry) ConfigureCertificate(host string, cert *tls.Certificate) error {
if cert == nil {
return fmt.Errorf("cannot configure nil certificate")
}
// loading parsed leaf certificate to certificate
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fmt.Errorf("failed parsing leaf certificate: %w", err)
}
cert.Leaf = leaf
r.mu.Lock()
defer r.mu.Unlock()
curr, found := r.lookup[host]
if found {
if cert.Leaf.NotBefore.After(curr.Leaf.NotBefore) {
log.Infof("updating certificate in registry - %s", host)
r.lookup[host] = cert
return nil
} else {
return nil
}
} else {
log.Infof("adding certificate to registry - %s", host)
r.lookup[host] = cert
return nil
}
}
// GetCertFromHello reads the SNI from a TLS client and returns the appropriate certificate.
// If no certificate is found for the host it will return nil.
func (r *CertRegistry) GetCertFromHello(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
r.mu.Lock()
cert, found := r.lookup[hello.ServerName]
r.mu.Unlock()
if found {
return cert, nil
}
return nil, nil
}
package secrets
import (
"crypto/aes"
"crypto/cipher"
crand "crypto/rand"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/scrypt"
)
// SecretSource
type SecretSource interface {
GetSecret() ([][]byte, error)
}
type fileSecretSource struct {
fileName string
}
func (fss *fileSecretSource) GetSecret() ([][]byte, error) {
contents, err := os.ReadFile(fss.fileName)
if err != nil {
return nil, err
}
secrets := strings.Split(string(contents), ",")
byteSecrets := make([][]byte, len(secrets))
for i, s := range secrets {
byteSecrets[i] = []byte(s)
if len(byteSecrets[i]) == 0 {
return nil, fmt.Errorf("file %s secret %d is empty", fss.fileName, i)
}
}
if len(byteSecrets) == 0 {
return nil, fmt.Errorf("secrets file %s is empty", fss.fileName)
}
return byteSecrets, nil
}
func newFileSecretSource(file string) SecretSource {
return &fileSecretSource{fileName: file}
}
type Encrypter struct {
mu sync.RWMutex
cipherSuites []cipher.AEAD
secretSource SecretSource
closer chan struct{}
closedHook chan struct{}
}
func newEncrypter(secretsFile string) (*Encrypter, error) {
secretSource := newFileSecretSource(secretsFile)
_, err := secretSource.GetSecret()
if err != nil {
return nil, fmt.Errorf("failed to read secrets from secret source: %w", err)
}
return &Encrypter{
secretSource: secretSource,
closer: make(chan struct{}),
}, nil
}
// WithSource can be used to create an Encrypter, for example in
// secrettest for testing purposes.
func WithSource(s SecretSource) (*Encrypter, error) {
return &Encrypter{
secretSource: s,
closer: make(chan struct{}),
}, nil
}
func (e *Encrypter) CreateNonce() ([]byte, error) {
e.mu.RLock()
defer e.mu.RUnlock()
if len(e.cipherSuites) > 0 {
nonce := make([]byte, e.cipherSuites[0].NonceSize())
if _, err := io.ReadFull(crand.Reader, nonce); err != nil {
return nil, err
}
return nonce, nil
}
return nil, fmt.Errorf("no ciphers which can be used")
}
// Encrypt encrypts given plaintext
func (e *Encrypter) Encrypt(plaintext []byte) ([]byte, error) {
e.mu.RLock()
defer e.mu.RUnlock()
if len(e.cipherSuites) > 0 {
nonce, err := e.CreateNonce()
if err != nil {
return nil, err
}
return e.cipherSuites[0].Seal(nonce, nonce, plaintext, nil), nil
}
return nil, fmt.Errorf("no ciphers which can be used")
}
// Decrypt decrypts given cipher text
func (e *Encrypter) Decrypt(cipherText []byte) ([]byte, error) {
e.mu.RLock()
defer e.mu.RUnlock()
for _, c := range e.cipherSuites {
nonceSize := c.NonceSize()
if len(cipherText) < nonceSize {
return nil, fmt.Errorf("failed to decrypt, ciphertext too short %d", len(cipherText))
}
nonce, input := cipherText[:nonceSize], cipherText[nonceSize:]
data, err := c.Open(nil, nonce, input, nil)
if err == nil {
return data, nil
}
}
return nil, fmt.Errorf("none of the ciphers can decrypt the data")
}
// RefreshCiphers rotates the list of cipher.AEAD initialized with
// SecretSource from the Encrypter.
func (e *Encrypter) RefreshCiphers() error {
secrets, err := e.secretSource.GetSecret()
if err != nil {
return err
}
suites := make([]cipher.AEAD, len(secrets))
for i, s := range secrets {
key, err := scrypt.Key(s, []byte{}, 1<<15, 8, 1, 32)
if err != nil {
return fmt.Errorf("failed to create key: %w", err)
}
//key has to be 16 or 32 byte
block, err := aes.NewCipher(key)
if err != nil {
return fmt.Errorf("failed to create new cipher: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return fmt.Errorf("failed to create new GCM: %w", err)
}
suites[i] = aesgcm
}
e.mu.Lock()
defer e.mu.Unlock()
e.cipherSuites = suites
return nil
}
func (e *Encrypter) runCipherRefresher(refreshInterval time.Duration) error {
err := e.RefreshCiphers()
if err != nil {
return fmt.Errorf("failed to refresh ciphers: %w", err)
}
go func() {
ticker := time.NewTicker(refreshInterval)
defer ticker.Stop()
for {
select {
case <-e.closer:
if e.closedHook != nil {
close(e.closedHook)
}
return
case <-ticker.C:
log.Debug("started refresh of ciphers")
err := e.RefreshCiphers()
if err != nil {
log.Errorf("failed to refresh the ciphers: %v", err)
}
log.Debug("finished refresh of ciphers")
}
}
}()
return nil
}
func (e *Encrypter) Close() {
close(e.closer)
}
package secrets
import (
"bytes"
"errors"
"os"
"path/filepath"
"sync"
"sync/atomic"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultCredentialsUpdateInterval = 10 * time.Minute
)
var (
ErrWrongFileType = errors.New("file type not supported")
ErrFailedToReadFile = errors.New("failed to read file")
)
// SecretsProvider is a SecretsReader and can add secret sources that
// contain a secret. It will automatically update secrets if the source
// changed.
type SecretsProvider interface {
SecretsReader
// Add adds the given source that contains a secret to the
// automatically updated secrets store
Add(string) error
}
type secretMap map[string][]byte
type SecretPaths struct {
// See https://pkg.go.dev/sync/atomic#example-Value-ReadMostly
secrets atomic.Value // secretMap
closed sync.Once
quit chan struct{}
mu sync.Mutex
paths map[string]struct{}
}
// NewSecretPaths creates a SecretPaths, that implements a
// SecretsProvider. It runs every d interval background refresher as a
// side effect. On tear down make sure to Close() it.
func NewSecretPaths(d time.Duration) *SecretPaths {
if d <= 0 {
d = defaultCredentialsUpdateInterval
}
sp := &SecretPaths{
quit: make(chan struct{}),
paths: make(map[string]struct{}),
}
sp.secrets.Store(make(secretMap))
go sp.runRefresher(d)
return sp
}
// GetSecret returns secret and if found or not for a given name.
func (sp *SecretPaths) GetSecret(s string) ([]byte, bool) {
m := sp.secrets.Load().(secretMap)
data, ok := m[s]
return data, ok
}
// Add registers path to a file or directory to find secrets.
// Background refresher discovers files added or removed later to the directory path.
// The path of the file will be the key to get the secret.
func (sp *SecretPaths) Add(path string) error {
fi, err := os.Lstat(path)
if err != nil {
return err
}
switch mode := fi.Mode(); {
// Kubernetes uses symlink to file
case mode.IsRegular() || mode&os.ModeSymlink != 0:
if _, err := os.ReadFile(path); err != nil {
return err
}
case mode.IsDir():
// handled by refresh
default:
return ErrWrongFileType
}
sp.mu.Lock()
sp.paths[path] = struct{}{}
sp.refreshLocked()
sp.mu.Unlock()
return nil
}
// runRefresher periodically refreshes all registered paths
func (sp *SecretPaths) runRefresher(d time.Duration) {
ticker := time.NewTicker(d)
defer ticker.Stop()
for {
select {
case <-ticker.C:
sp.mu.Lock()
sp.refreshLocked()
sp.mu.Unlock()
case <-sp.quit:
log.Infoln("Stop secrets background refresher")
return
}
}
}
// refreshLocked reads secrets from all registered paths and updates secrets map.
// sp.mu must be held
func (sp *SecretPaths) refreshLocked() {
sizeHint := len(sp.secrets.Load().(secretMap))
actual := make(secretMap, sizeHint)
for path := range sp.paths {
addPath(actual, path)
}
old := sp.secrets.Swap(actual).(secretMap)
for path, data := range actual {
oldData, existed := old[path]
if !existed {
log.Infof("Added secret file: %s", path)
} else if !bytes.Equal(data, oldData) {
log.Infof("Updated secret file: %s", path)
}
}
for path := range old {
if _, exists := actual[path]; !exists {
log.Infof("Removed secret file: %s", path)
}
}
}
func addPath(secrets secretMap, path string) {
fi, err := os.Lstat(path)
if err != nil {
log.Errorf("Failed to stat path %s: %v", path, err)
return
}
switch mode := fi.Mode(); {
// Kubernetes uses symlink to file
case mode.IsRegular() || mode&os.ModeSymlink != 0:
data, err := readSecretFile(path)
if err != nil {
log.Errorf("Failed to read file %s: %v", path, err)
return
}
secrets[path] = data
case mode.IsDir():
matches, err := filepath.Glob(path + "/*")
if err != nil {
log.Errorf("Failed to read directory %s: %v", path, err)
return
}
for _, match := range matches {
data, err := readSecretFile(match)
if err == nil {
secrets[match] = data
} else if !errors.Is(err, syscall.EISDIR) {
log.Errorf("Failed to read path %s: %v", match, err)
}
}
}
}
func readSecretFile(path string) ([]byte, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
if len(data) > 0 && data[len(data)-1] == 0xa {
data = data[:len(data)-1]
}
return data, nil
}
func (sp *SecretPaths) Close() {
sp.closed.Do(func() {
close(sp.quit)
})
}
package secrets
import "net/url"
// SecretsReader is able to get a secret
type SecretsReader interface {
// GetSecret finds secret by name and returns secret and if found or not
GetSecret(string) ([]byte, bool)
// Close should be used on teardown to cleanup a refresher
// goroutine. Implementers should check of this interface
// should check nil pointer, such that caller do not need to
// check.
Close()
}
// StaticSecret implements SecretsReader interface. Example:
//
// sec := []byte("mysecret")
// sss := StaticSecret(sec)
// b,_ := sss.GetSecret("")
// string(b) == sec // true
type StaticSecret []byte
// GetSecret returns the static secret
func (st StaticSecret) GetSecret(string) ([]byte, bool) {
return st, true
}
// Close implements SecretsReader.
func (st StaticSecret) Close() {}
// StaticDelegateSecret delegates with a static string to the wrapped
// SecretsReader
type StaticDelegateSecret struct {
sr SecretsReader
key string
}
// NewStaticDelegateSecret creates a wrapped SecretsReader,
// that use given s to the underlying SecretsReader to return
// the secret.
func NewStaticDelegateSecret(sr SecretsReader, s string) *StaticDelegateSecret {
return &StaticDelegateSecret{
sr: sr,
key: s,
}
}
// GetSecret returns the secret looked up by the static key via
// delegated SecretsReader.
func (sds *StaticDelegateSecret) GetSecret(string) ([]byte, bool) {
return sds.sr.GetSecret(sds.key)
}
// Close delegates to the wrapped SecretsReader.
func (sds *StaticDelegateSecret) Close() {
sds.sr.Close()
}
// HostSecret can be used to get secrets by hostnames.
type HostSecret struct {
sr SecretsReader
secMap map[string]string
}
// NewHostSecret create a SecretsReader that returns a secret for
// given host. The given map is used to map hostname to the secrets
// reader key to read the secret from.
func NewHostSecret(sr SecretsReader, h map[string]string) *HostSecret {
return &HostSecret{
sr: sr,
secMap: h,
}
}
// GetSecret returns secret for given URL string using the hostname.
func (hs *HostSecret) GetSecret(s string) ([]byte, bool) {
u, err := url.Parse(s)
if err != nil {
return nil, false
}
hostname := u.Hostname()
k, ok := hs.secMap[hostname]
if !ok {
return nil, false
}
b, ok := hs.sr.GetSecret(k)
if !ok {
return nil, false
}
return b, true
}
func (hs *HostSecret) Close() {
hs.sr.Close()
}
package secrets
import (
"sync"
"time"
)
type EncrypterCreator interface {
GetEncrypter(time.Duration, string) (Encryption, error)
}
type Encryption interface {
CreateNonce() ([]byte, error)
Decrypt([]byte) ([]byte, error)
Encrypt([]byte) ([]byte, error)
Close()
}
type Registry struct {
mu sync.Mutex
encrypterMap map[string]*Encrypter
}
// NewRegistry returns a Registry and implements EncrypterCreator to
// store and manage secrets
func NewRegistry() *Registry {
e := make(map[string]*Encrypter)
return &Registry{
encrypterMap: e,
}
}
func (r *Registry) GetEncrypter(refreshInterval time.Duration, file string) (Encryption, error) {
r.mu.Lock()
defer r.mu.Unlock()
if e, ok := r.encrypterMap[file]; ok {
return e, nil
}
e, err := newEncrypter(file)
if err != nil {
return nil, err
}
if refreshInterval > 0 {
err := e.runCipherRefresher(refreshInterval)
if err != nil {
return nil, err
}
}
r.encrypterMap[file] = e
return e, nil
}
// Close will close all Encryption of the Registry
func (r *Registry) Close() {
for _, v := range r.encrypterMap {
v.Close()
}
}
package skipper
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/signal"
"path"
"regexp"
"strconv"
"strings"
"syscall"
"time"
stdlog "log"
ot "github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/zalando/skipper/circuit"
"github.com/zalando/skipper/dataclients/kubernetes"
"github.com/zalando/skipper/dataclients/routestring"
"github.com/zalando/skipper/eskip"
"github.com/zalando/skipper/eskipfile"
"github.com/zalando/skipper/etcd"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/apiusagemonitoring"
"github.com/zalando/skipper/filters/auth"
"github.com/zalando/skipper/filters/block"
"github.com/zalando/skipper/filters/builtin"
"github.com/zalando/skipper/filters/fadein"
logfilter "github.com/zalando/skipper/filters/log"
"github.com/zalando/skipper/filters/openpolicyagent"
"github.com/zalando/skipper/filters/openpolicyagent/opaauthorizerequest"
"github.com/zalando/skipper/filters/openpolicyagent/opaserveresponse"
ratelimitfilters "github.com/zalando/skipper/filters/ratelimit"
"github.com/zalando/skipper/filters/shedder"
teefilters "github.com/zalando/skipper/filters/tee"
"github.com/zalando/skipper/loadbalancer"
"github.com/zalando/skipper/logging"
"github.com/zalando/skipper/metrics"
skpnet "github.com/zalando/skipper/net"
pauth "github.com/zalando/skipper/predicates/auth"
"github.com/zalando/skipper/predicates/content"
"github.com/zalando/skipper/predicates/cookie"
"github.com/zalando/skipper/predicates/cron"
"github.com/zalando/skipper/predicates/forwarded"
"github.com/zalando/skipper/predicates/host"
"github.com/zalando/skipper/predicates/interval"
"github.com/zalando/skipper/predicates/methods"
"github.com/zalando/skipper/predicates/primitive"
"github.com/zalando/skipper/predicates/query"
"github.com/zalando/skipper/predicates/source"
"github.com/zalando/skipper/predicates/tee"
"github.com/zalando/skipper/predicates/traffic"
"github.com/zalando/skipper/proxy"
"github.com/zalando/skipper/queuelistener"
"github.com/zalando/skipper/ratelimit"
"github.com/zalando/skipper/routing"
"github.com/zalando/skipper/scheduler"
"github.com/zalando/skipper/script"
"github.com/zalando/skipper/secrets"
"github.com/zalando/skipper/secrets/certregistry"
"github.com/zalando/skipper/swarm"
"github.com/zalando/skipper/tracing"
)
const (
defaultSourcePollTimeout = 30 * time.Millisecond
defaultRoutingUpdateBuffer = 1 << 5
)
const DefaultPluginDir = "./plugins"
// Options to start skipper.
type Options struct {
// WaitForHealthcheckInterval sets the time that skipper waits
// for the loadbalancer in front to become unhealthy. Defaults
// to 0.
WaitForHealthcheckInterval time.Duration
// StatusChecks is an experimental feature. It defines a
// comma separated list of HTTP URLs to do GET requests to,
// that have to return 200 before skipper becomes ready
StatusChecks []string
// WhitelistedHealthcheckCIDR appends the whitelisted IP Range to the inernalIPS range for healthcheck purposes
WhitelistedHealthCheckCIDR []string
// Network address that skipper should listen on.
Address string
// Insecure network address skipper should listen on when TLS is enabled
InsecureAddress string
// EnableTCPQueue enables controlling the
// concurrently processed requests at the TCP listener.
EnableTCPQueue bool
// ExpectedBytesPerRequest is used by the TCP LIFO listener.
// It defines the expected average memory required to process an incoming
// request. It is used only when MaxTCPListenerConcurrency is not defined.
// It is used together with the memory limit defined in:
// cgroup v1 /sys/fs/cgroup/memory/memory.limit_in_bytes
// or
// cgroup v2 /sys/fs/cgroup/memory.max
//
// See also:
// cgroup v1: https://www.kernel.org/doc/Documentation/cgroup-v1/memory.txt
// cgroup v2: https://www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html#memory-interface-files
ExpectedBytesPerRequest int
// MaxTCPListenerConcurrency is used by the TCP LIFO listener.
// It defines the max number of concurrently accepted connections, excluding
// the pending ones in the queue.
//
// When undefined and the EnableTCPQueue is true,
MaxTCPListenerConcurrency int
// MaxTCPListenerQueue is used by the TCP LIFO listener.
// If defines the maximum number of pending connection waiting in the queue.
MaxTCPListenerQueue int
// List of custom filter specifications.
CustomFilters []filters.Spec
// RegisterFilters callback can be used to register additional filters.
// Built-in and custom filters are registered before the callback is called.
RegisterFilters func(registry filters.Registry)
// Urls of nodes in an etcd cluster, storing route definitions.
EtcdUrls []string
// Path prefix for skipper related data in the etcd storage.
EtcdPrefix string
// Timeout used for a single request when querying for updates
// in etcd. This is independent of, and an addition to,
// SourcePollTimeout. When not set, the internally defined 1s
// is used.
EtcdWaitTimeout time.Duration
// Skip TLS certificate check for etcd connections.
EtcdInsecure bool
// If set this value is used as Bearer token for etcd OAuth authorization.
EtcdOAuthToken string
// If set this value is used as username for etcd basic authorization.
EtcdUsername string
// If set this value is used as password for etcd basic authorization.
EtcdPassword string
// If set enables skipper to generate based on ingress resources in kubernetes cluster
Kubernetes bool
// If set makes skipper authenticate with the kubernetes API server with service account assigned to the
// skipper POD.
// If omitted skipper will rely on kubectl proxy to authenticate with API server
KubernetesInCluster bool
// Kubernetes API base URL. Only makes sense if KubernetesInCluster is set to false. If omitted and
// skipper is not running in-cluster, the default API URL will be used.
KubernetesURL string
// KubernetesTokenFile configures path to the token file.
// Defaults to /var/run/secrets/kubernetes.io/serviceaccount/token when running in-cluster.
KubernetesTokenFile string
// KubernetesHealthcheck, when Kubernetes ingress is set, indicates
// whether an automatic healthcheck route should be generated. The
// generated route will report healthyness when the Kubernetes API
// calls are successful. The healthcheck endpoint is accessible from
// internal IPs, with the path /kube-system/healthz.
KubernetesHealthcheck bool
// KubernetesHTTPSRedirect, when Kubernetes ingress is set, indicates
// whether an automatic redirect route should be generated to redirect
// HTTP requests to their HTTPS equivalent. The generated route will
// match requests with the X-Forwarded-Proto and X-Forwarded-Port,
// expected to be set by the load-balancer.
KubernetesHTTPSRedirect bool
// KubernetesHTTPSRedirectCode overrides the default redirect code (308)
// when used together with -kubernetes-https-redirect.
KubernetesHTTPSRedirectCode int
// KubernetesDisableCatchAllRoutes, when set, tells the data client to not create catchall routes.
KubernetesDisableCatchAllRoutes bool
// KubernetesIngressClass is a regular expression, that will make
// skipper load only the ingress resources that have a matching
// kubernetes.io/ingress.class annotation. For backwards compatibility,
// the ingresses without an annotation, or an empty annotation, will
// be loaded, too.
KubernetesIngressClass string
// KubernetesRouteGroupClass is a regular expression, that will make skipper
// load only the RouteGroup resources that have a matching
// zalando.org/routegroup.class annotation. Any RouteGroups without the
// annotation, or which an empty annotation, will be loaded too.
KubernetesRouteGroupClass string
// KubernetesIngressLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. A label and its value on an Ingress must be match exactly to be loaded by Skipper.
// If the value is irrelevant for a given configuration, it can be left empty. The default
// value is no labels required.
// Examples:
// Config [] will load all Ingresses.
// Config ["skipper-enabled": ""] will load only Ingresses with a label "skipper-enabled", no matter the value.
// Config ["skipper-enabled": "true"] will load only Ingresses with a label "skipper-enabled: true"
// Config ["skipper-enabled": "", "foo": "bar"] will load only Ingresses with both labels while label "foo" must have a value "bar".
KubernetesIngressLabelSelectors map[string]string
// KubernetesServicesLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
KubernetesServicesLabelSelectors map[string]string
// KubernetesEndpointsLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
KubernetesEndpointsLabelSelectors map[string]string
// KubernetesSecretsLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
KubernetesSecretsLabelSelectors map[string]string
// KubernetesRouteGroupsLabelSelectors is a map of kubernetes labels to their values that must be present on a resource to be loaded
// by the client. Read documentation for IngressLabelSelectors for examples and more details.
// The default value is no labels required.
KubernetesRouteGroupsLabelSelectors map[string]string
// PathMode controls the default interpretation of ingress paths in cases
// when the ingress doesn't specify it with an annotation.
KubernetesPathMode kubernetes.PathMode
// KubernetesNamespace is used to switch between monitoring ingresses in the cluster-scope or limit
// the ingresses to only those in the specified namespace. Defaults to "" which means monitor ingresses
// in the cluster-scope.
KubernetesNamespace string
// KubernetesEnableEndpointslices if set skipper will fetch
// endpointslices instead of endpoints to scale more than 1000
// pods within a service
KubernetesEnableEndpointslices bool
// *DEPRECATED* KubernetesEnableEastWest enables cluster internal service to service communication, aka east-west traffic
KubernetesEnableEastWest bool
// *DEPRECATED* KubernetesEastWestDomain sets the cluster internal domain used to create additional routes in skipper, defaults to skipper.cluster.local
KubernetesEastWestDomain string
// KubernetesEastWestRangeDomains set the the cluster internal domains for
// east west traffic. Identified routes to such domains will include
// the KubernetesEastWestRangePredicates.
KubernetesEastWestRangeDomains []string
// KubernetesEastWestRangePredicates set the Predicates that will be
// appended to routes identified as to KubernetesEastWestRangeDomains.
KubernetesEastWestRangePredicates []*eskip.Predicate
// KubernetesEastWestRangeAnnotationPredicates same as KubernetesAnnotationPredicates but will append to
// routes that has KubernetesEastWestRangeDomains suffix.
KubernetesEastWestRangeAnnotationPredicates []kubernetes.AnnotationPredicates
// KubernetesEastWestRangeAnnotationFiltersAppend same as KubernetesAnnotationFiltersAppend but will append to
// routes that has KubernetesEastWestRangeDomains suffix.
KubernetesEastWestRangeAnnotationFiltersAppend []kubernetes.AnnotationFilters
// KubernetesAnnotationPredicates sets predicates to append for each annotation key and value
KubernetesAnnotationPredicates []kubernetes.AnnotationPredicates
// KubernetesAnnotationFiltersAppend sets filters to append for each annotation key and value
KubernetesAnnotationFiltersAppend []kubernetes.AnnotationFilters
// KubernetesOnlyAllowedExternalNames will enable validation of ingress external names and route groups network
// backend addresses, explicit LB endpoints validation against the list of patterns in
// AllowedExternalNames.
KubernetesOnlyAllowedExternalNames bool
// KubernetesAllowedExternalNames contains regexp patterns of those domain names that are allowed to be
// used with external name services (type=ExternalName).
KubernetesAllowedExternalNames []*regexp.Regexp
// KubernetesRedisServiceNamespace to be used to lookup ring shards dynamically
KubernetesRedisServiceNamespace string
// KubernetesRedisServiceName to be used to lookup ring shards dynamically
KubernetesRedisServiceName string
// KubernetesRedisServicePort to be used to lookup ring shards dynamically
KubernetesRedisServicePort int
// KubernetesForceService overrides the default Skipper functionality to route traffic using Kubernetes Endpoints,
// instead using Kubernetes Services.
KubernetesForceService bool
// KubernetesBackendTrafficAlgorithm specifies the algorithm to calculate the backend traffic
KubernetesBackendTrafficAlgorithm kubernetes.BackendTrafficAlgorithm
// KubernetesDefaultLoadBalancerAlgorithm sets the default algorithm to be used for load balancing between backend endpoints,
// available options: roundRobin, consistentHash, random, powerOfRandomNChoices
KubernetesDefaultLoadBalancerAlgorithm string
// File containing static route definitions. Multiple may be given comma separated.
RoutesFile string
// File containing route definitions with file watch enabled.
// Multiple may be given comma separated. (For the skipper
// command this option is used when starting it with the -routes-file flag.)
WatchRoutesFile string
// RouteURLs are URLs pointing to route definitions, in eskip format, with change watching enabled.
RoutesURLs []string
// InlineRoutes can define routes as eskip text.
InlineRoutes string
// Polling timeout of the routing data sources.
SourcePollTimeout time.Duration
// DefaultFilters will be applied to all routes automatically.
DefaultFilters *eskip.DefaultFilters
// DisabledFilters is a list of filters unavailable for use
DisabledFilters []string
// CloneRoute is a slice of PreProcessors that will be applied to all routes
// automatically. They will clone all matching routes and apply changes to the
// cloned routes.
CloneRoute []*eskip.Clone
// EditRoute will be applied to all routes automatically and
// will apply changes to all matching routes.
EditRoute []*eskip.Editor
// A list of custom routing pre-processor implementations that will
// be applied to all routes.
CustomRoutingPreProcessors []routing.PreProcessor
// Deprecated. See ProxyFlags. When used together with ProxyFlags,
// the values will be combined with |.
ProxyOptions proxy.Options
// Flags controlling the proxy behavior.
ProxyFlags proxy.Flags
// Tells the proxy maximum how many idle connections can it keep
// alive.
IdleConnectionsPerHost int
// Defines the time period of how often the idle connections maintained
// by the proxy are closed.
CloseIdleConnsPeriod time.Duration
// Defines ReadTimeoutServer for server http connections.
ReadTimeoutServer time.Duration
// Defines ReadHeaderTimeout for server http connections.
ReadHeaderTimeoutServer time.Duration
// Defines WriteTimeout for server http connections.
WriteTimeoutServer time.Duration
// Defines IdleTimeout for server http connections.
IdleTimeoutServer time.Duration
// KeepaliveServer configures maximum age for server http connections.
// The connection is closed after it existed for this duration.
KeepaliveServer time.Duration
// KeepaliveRequestsServer configures maximum number of requests for server http connections.
// The connection is closed after serving this number of requests.
KeepaliveRequestsServer int
// Defines MaxHeaderBytes for server http connections.
MaxHeaderBytes int
// Enable connection state metrics for server http connections.
EnableConnMetricsServer bool
// TimeoutBackend sets the TCP client connection timeout for
// proxy http connections to the backend.
TimeoutBackend time.Duration
// ResponseHeaderTimeout sets the HTTP response timeout for
// proxy http connections to the backend.
ResponseHeaderTimeoutBackend time.Duration
// ExpectContinueTimeoutBackend sets the HTTP timeout to expect a
// response for status Code 100 for proxy http connections to
// the backend.
ExpectContinueTimeoutBackend time.Duration
// KeepAliveBackend sets the TCP keepalive for proxy http
// connections to the backend.
KeepAliveBackend time.Duration
// DualStackBackend sets if the proxy TCP connections to the
// backend should be dual stack.
DualStackBackend bool
// TLSHandshakeTimeoutBackend sets the TLS handshake timeout
// for proxy connections to the backend.
TLSHandshakeTimeoutBackend time.Duration
// MaxIdleConnsBackend sets MaxIdleConns, which limits the
// number of idle connections to all backends, 0 means no
// limit.
MaxIdleConnsBackend int
// DisableHTTPKeepalives sets DisableKeepAlives, which forces
// a backend to always create a new connection.
DisableHTTPKeepalives bool
// Flag indicating to ignore trailing slashes in paths during route
// lookup.
IgnoreTrailingSlash bool
// Priority routes that are matched against the requests before
// the standard routes from the data clients.
PriorityRoutes []proxy.PriorityRoute
// Specifications of custom, user defined predicates.
CustomPredicates []routing.PredicateSpec
// Custom data clients to be used together with the default etcd and Innkeeper.
CustomDataClients []routing.DataClient
// CustomHttpHandlerWrap provides ability to wrap http.Handler created by skipper.
// http.Handler is used for accepting incoming http requests.
// It allows to add additional logic (for example tracing) by providing a wrapper function
// which accepts original skipper handler as an argument and returns a wrapped handler
CustomHttpHandlerWrap func(http.Handler) http.Handler
// CustomHttpRoundTripperWrap provides ability to wrap http.RoundTripper created by skipper.
// http.RoundTripper is used for making outgoing requests (backends)
// It allows to add additional logic (for example tracing) by providing a wrapper function
// which accepts original skipper http.RoundTripper as an argument and returns a wrapped roundtripper
CustomHttpRoundTripperWrap func(http.RoundTripper) http.RoundTripper
// WaitFirstRouteLoad prevents starting the listener before the first batch
// of routes were applied.
WaitFirstRouteLoad bool
// SuppressRouteUpdateLogs indicates to log only summaries of the routing updates
// instead of full details of the updated/deleted routes.
SuppressRouteUpdateLogs bool
// Dev mode. Currently this flag disables prioritization of the
// consumer side over the feeding side during the routing updates to
// populate the updated routes faster.
DevMode bool
// Network address for the support endpoints
SupportListener string
// Deprecated: Network address for the /metrics endpoint
MetricsListener string
// Skipper provides a set of metrics with different keys which are exposed via HTTP in JSON
// You can customize those key names with your own prefix
MetricsPrefix string
// EnableProfile exposes profiling information on /profile of the
// metrics listener.
EnableProfile bool
// BlockProfileRate calls runtime.SetBlockProfileRate(BlockProfileRate) if non zero value, deactivate with <0
BlockProfileRate int
// MutexProfileFraction calls runtime.SetMutexProfileFraction(MutexProfileFraction) if non zero value, deactivate with <0
MutexProfileFraction int
// MemProfileRate calls runtime.SetMemProfileRate(MemProfileRate) if non zero value, deactivate with <0
MemProfileRate int
// Flag that enables reporting of the Go garbage collector statistics exported in debug.GCStats
EnableDebugGcMetrics bool
// Flag that enables reporting of the Go runtime statistics exported in runtime and specifically runtime.MemStats
EnableRuntimeMetrics bool
// If set, detailed response time metrics will be collected
// for each route, additionally grouped by status and method.
EnableServeRouteMetrics bool
// If set, a counter for each route is generated, additionally
// grouped by status and method. It differs from the automatically
// generated counter from `EnableServeRouteMetrics` because it will
// always contain the status and method labels, independently of the
// `EnableServeMethodMetric` and `EnableServeStatusCodeMetric` flags.
EnableServeRouteCounter bool
// If set, detailed response time metrics will be collected
// for each host, additionally grouped by status and method.
EnableServeHostMetrics bool
// If set, a counter for each host is generated, additionally
// grouped by status and method. It differs from the automatically
// generated counter from `EnableServeHostMetrics` because it will
// always contain the status and method labels, independently of the
// `EnableServeMethodMetric` and `EnableServeStatusCodeMetric` flags.
EnableServeHostCounter bool
// If set, the detailed total response time metrics will contain the
// HTTP method as a domain of the metric. It affects both route and
// host split metrics.
EnableServeMethodMetric bool
// If set, the detailed total response time metrics will contain the
// HTTP Response status code as a domain of the metric. It affects
// both route and host split metrics.
EnableServeStatusCodeMetric bool
// If set, the total request handling time taken by skipper will be
// collected. It measures the duration taken by skipper to process
// the request, from the start excluding the filters processing and
// until the backend round trip is started.
EnableProxyRequestMetrics bool
// If set, the total response handling time take by skipper will be
// collected. It measures the duration taken by skipper to process the
// response, from after the backend round trip is finished, excluding
// the filters processing and until the before the response is served.
EnableProxyResponseMetrics bool
// If set, detailed response time metrics will be collected
// for each backend host
EnableBackendHostMetrics bool
// EnableAllFiltersMetrics enables collecting combined filter
// metrics per each route. Without the DisableMetricsCompatibilityDefaults,
// it is enabled by default.
EnableAllFiltersMetrics bool
// EnableCombinedResponseMetrics enables collecting response time
// metrics combined for every route.
EnableCombinedResponseMetrics bool
// EnableRouteResponseMetrics enables collecting response time
// metrics per each route. Without the DisableMetricsCompatibilityDefaults,
// it is enabled by default.
EnableRouteResponseMetrics bool
// EnableRouteBackendErrorsCounters enables counters for backend
// errors per each route. Without the DisableMetricsCompatibilityDefaults,
// it is enabled by default.
EnableRouteBackendErrorsCounters bool
// EnableRouteStreamingErrorsCounters enables counters for streaming
// errors per each route. Without the DisableMetricsCompatibilityDefaults,
// it is enabled by default.
EnableRouteStreamingErrorsCounters bool
// EnableRouteBackendMetrics enables backend response time metrics
// per each route. Without the DisableMetricsCompatibilityDefaults, it is
// enabled by default.
EnableRouteBackendMetrics bool
// EnableRouteCreationMetrics enables the OriginMarker to track route creation times. Disabled by default
EnableRouteCreationMetrics bool
// When set, makes the histograms use an exponentially decaying sample
// instead of the default uniform one.
MetricsUseExpDecaySample bool
// Use custom buckets for prometheus histograms.
HistogramMetricBuckets []float64
// The following options, for backwards compatibility, are true
// by default: EnableAllFiltersMetrics, EnableRouteResponseMetrics,
// EnableRouteBackendErrorsCounters, EnableRouteStreamingErrorsCounters,
// EnableRouteBackendMetrics. With this compatibility flag, the default
// for these options can be set to false.
DisableMetricsCompatibilityDefaults bool
// Implementation of a Metrics handler. If provided this is going to be used
// instead of creating a new one based on the Kind of metrics wanted. This
// is useful in case you want to report metrics to a custom aggregator.
MetricsBackend metrics.Metrics
// Output file for the application log. Default value: /dev/stderr.
//
// When /dev/stderr or /dev/stdout is passed in, it will be resolved
// to os.Stderr or os.Stdout.
//
// Warning: passing an arbitrary file will try to open it append
// on start and use it, or fail on start, but the current
// implementation doesn't support any more proper handling
// of temporary failures or log-rolling.
ApplicationLogOutput string
// Application log prefix. Default value: "[APP]".
ApplicationLogPrefix string
// Enables logs in JSON format
ApplicationLogJSONEnabled bool
// ApplicationLogJsonFormatter, when set and JSON logging is enabled, is passed along to to the underlying
// Logrus logger for application logs. To enable structured logging, use ApplicationLogJSONEnabled.
ApplicationLogJsonFormatter *log.JSONFormatter
// Output file for the access log. Default value: /dev/stderr.
//
// When /dev/stderr or /dev/stdout is passed in, it will be resolved
// to os.Stderr or os.Stdout.
//
// Warning: passing an arbitrary file will try to open for append
// it on start and use it, or fail on start, but the current
// implementation doesn't support any more proper handling
// of temporary failures or log-rolling.
AccessLogOutput string
// Disables the access log.
AccessLogDisabled bool
// Enables logs in JSON format
AccessLogJSONEnabled bool
// AccessLogStripQuery, when set, causes the query strings stripped
// from the request URI in the access logs.
AccessLogStripQuery bool
// AccessLogJsonFormatter, when set and JSON logging is enabled, is passed along to the underlying
// Logrus logger for access logs. To enable structured logging, use AccessLogJSONEnabled.
// Deprecated: use [AccessLogFormatter].
AccessLogJsonFormatter *log.JSONFormatter
// AccessLogFormatter, when set is passed along to the underlying Logrus logger for access logs.
AccessLogFormatter log.Formatter
DebugListener string
// Path of certificate(s) when using TLS, multiple may be given comma separated
CertPathTLS string
// Path of key(s) when using TLS, multiple may be given comma separated. For
// multiple keys, the order must match the one given in CertPathTLS
KeyPathTLS string
// TLSClientAuth sets the policy the server will follow for
// TLS Client Authentication, see [tls.ClientAuthType]
TLSClientAuth tls.ClientAuthType
// TLS Settings for Proxy Server
ProxyTLS *tls.Config
// Client TLS to connect to Backends
ClientTLS *tls.Config
// TLSMinVersion to set the minimal TLS version for all TLS configurations
TLSMinVersion uint16
// CipherSuites sets the list of cipher suites to use for TLS 1.2
CipherSuites []uint16
// Flush interval for upgraded Proxy connections
BackendFlushInterval time.Duration
// Experimental feature to handle protocol Upgrades for Websockets, SPDY, etc.
ExperimentalUpgrade bool
// ExperimentalUpgradeAudit enables audit log of both the request line
// and the response messages during web socket upgrades.
ExperimentalUpgradeAudit bool
// MaxLoopbacks defines the maximum number of loops that the proxy can execute when the routing table
// contains loop backends (<loopback>).
MaxLoopbacks int
// EnableBreakers enables the usage of the breakers in the route definitions without initializing any
// by default. It is a shortcut for setting the BreakerSettings to:
//
// []circuit.BreakerSettings{{Type: BreakerDisabled}}
//
EnableBreakers bool
// BreakerSettings contain global and host specific settings for the circuit breakers.
BreakerSettings []circuit.BreakerSettings
// EnableRatelimiters enables the usage of the ratelimiter in the route definitions without initializing any
// by default. It is a shortcut for setting the RatelimitSettings to:
//
// []ratelimit.Settings{{Type: DisableRatelimit}}
//
EnableRatelimiters bool
// RatelimitSettings contain global and host specific settings for the ratelimiters.
RatelimitSettings []ratelimit.Settings
// EnableRouteFIFOMetrics enables metrics for the individual route FIFO queues, if any.
EnableRouteFIFOMetrics bool
// EnableRouteLIFOMetrics enables metrics for the individual route LIFO queues, if any.
EnableRouteLIFOMetrics bool
// OpenTracing enables opentracing
OpenTracing []string
// OpenTracingInitialSpan can override the default initial, pre-routing, span name.
// Default: "ingress".
OpenTracingInitialSpan string
// OpenTracingExcludedProxyTags can disable a tag so that it is not recorded. By default every tag is included.
OpenTracingExcludedProxyTags []string
// OpenTracingDisableFilterSpans flag is used to disable creation of spans representing request and response filters.
OpenTracingDisableFilterSpans bool
// OpenTracingLogFilterLifecycleEvents flag is used to enable/disable the logs for events marking request and
// response filters' start & end times.
OpenTracingLogFilterLifecycleEvents bool
// OpenTracingLogStreamEvents flag is used to enable/disable the logs that marks the
// times when response headers & payload are streamed to the client
OpenTracingLogStreamEvents bool
// OpenTracingBackendNameTag enables an additional tracing tag containing a backend name
// for a route when it's available (e.g. for RouteGroups)
OpenTracingBackendNameTag bool
// OpenTracingTracer allows pre-created tracer to be passed on to skipper. Providing the
// tracer instance overrides options provided under OpenTracing property.
OpenTracingTracer ot.Tracer
// PluginDir defines the directory to load plugins from, DEPRECATED, use PluginDirs
PluginDir string
// PluginDirs defines the directories to load plugins from
PluginDirs []string
// FilterPlugins loads additional filters from modules. The first value in each []string
// needs to be the plugin name (as on disk, without path, without ".so" suffix). The
// following values are passed as arguments to the plugin while loading, see also
// https://opensource.zalando.com/skipper/reference/plugins/
FilterPlugins [][]string
// PredicatePlugins loads additional predicates from modules. See above for FilterPlugins
// what the []string should contain.
PredicatePlugins [][]string
// DataClientPlugins loads additional data clients from modules. See above for FilterPlugins
// what the []string should contain.
DataClientPlugins [][]string
// Plugins combine multiple types of the above plugin types in one plugin (where
// necessary because of shared data between e.g. a filter and a data client).
Plugins [][]string
// DefaultHTTPStatus is the HTTP status used when no routes are found
// for a request.
DefaultHTTPStatus int
// EnablePrometheusMetrics enables Prometheus format metrics.
//
// This option is *deprecated*. The recommended way to enable prometheus metrics is to
// use the MetricsFlavours option.
EnablePrometheusMetrics bool
// EnablePrometheusStartLabel adds start label to each prometheus counter with the value of counter creation
// timestamp as unix nanoseconds.
EnablePrometheusStartLabel bool
// An instance of a Prometheus registry. It allows registering and serving custom metrics when skipper is used as a
// library.
// A new registry is created if this option is nil.
PrometheusRegistry *prometheus.Registry
// MetricsFlavours sets the metrics storage and exposed format
// of metrics endpoints.
MetricsFlavours []string
// LoadBalancerHealthCheckInterval is *deprecated* and not in use anymore
LoadBalancerHealthCheckInterval time.Duration
// ReverseSourcePredicate enables the automatic use of IP
// whitelisting in different places to use the reversed way of
// identifying a client IP within the X-Forwarded-For
// header. Amazon's ALB for example writes the client IP to
// the last item of the string list of the X-Forwarded-For
// header, in this case you want to set this to true.
ReverseSourcePredicate bool
// EnableOAuth2GrantFlow, enables OAuth2 Grant Flow filter
EnableOAuth2GrantFlow bool
// OAuth2AuthURL, the url to redirect the requests to when login is required.
OAuth2AuthURL string
// OAuth2TokenURL, the url where the access code should be exchanged for the
// access token.
OAuth2TokenURL string
// OAuth2RevokeTokenURL, the url where the access and refresh tokens can be
// revoked during a logout.
OAuth2RevokeTokenURL string
// OAuthTokeninfoURL sets the the URL to be queried for
// information for all auth.NewOAuthTokeninfo*() filters.
OAuthTokeninfoURL string
// OAuthTokeninfoTimeout sets timeout duration while calling oauth token service
OAuthTokeninfoTimeout time.Duration
// OAuthTokeninfoCacheSize configures the maximum number of cached tokens.
// Zero value disables tokeninfo cache.
OAuthTokeninfoCacheSize int
// OAuthTokeninfoCacheTTL limits the lifetime of a cached tokeninfo.
// Tokeninfo is cached for the duration of "expires_in" field value seconds or
// for the duration of OAuthTokeninfoCacheTTL if it is not zero and less than "expires_in" value.
OAuthTokeninfoCacheTTL time.Duration
// OAuth2SecretFile contains the filename with the encryption key for the
// authentication cookie and grant flow state stored in Secrets.
OAuth2SecretFile string
// OAuth2ClientID, the OAuth2 client id of the current service, used to exchange
// the access code.
OAuth2ClientID string
// OAuth2ClientSecret, the secret associated with the ClientID, used to exchange
// the access code.
OAuth2ClientSecret string
// OAuth2ClientIDFile, the path of the file containing the OAuth2 client id of
// the current service, used to exchange the access code.
// File name may contain {host} placeholder which will be replaced by the request host.
OAuth2ClientIDFile string
// OAuth2ClientSecretFile, the path of the file containing the secret associated
// with the ClientID, used to exchange the access code.
// File name may contain {host} placeholder which will be replaced by the request host.
OAuth2ClientSecretFile string
// OAuth2CallbackPath contains the path where the OAuth2 callback requests with the
// authorization code should be redirected to. Defaults to /.well-known/oauth2-callback
OAuth2CallbackPath string
// OAuthTokenintrospectionTimeout sets timeout duration while calling oauth tokenintrospection service
OAuthTokenintrospectionTimeout time.Duration
// OAuth2AuthURLParameters the additional parameters to send to OAuth2 authorize and token endpoints.
OAuth2AuthURLParameters map[string]string
// OAuth2AccessTokenHeaderName the name of the header to which the access token
// should be assigned after the oauthGrant filter.
OAuth2AccessTokenHeaderName string
// OAuth2TokeninfoSubjectKey the key of the subject ID attribute in the
// tokeninfo map. Used for downstream oidcClaimsQuery compatibility.
OAuth2TokeninfoSubjectKey string
// OAuth2GrantTokeninfoKeys, if not empty keys not in this list are removed from the tokeninfo map.
OAuth2GrantTokeninfoKeys []string
// OAuth2TokenCookieName the name of the cookie that Skipper sets after a
// successful OAuth2 token exchange. Stores the encrypted access token.
OAuth2TokenCookieName string
// OAuth2TokenCookieRemoveSubdomains sets the number of subdomains to remove from
// the callback request hostname to obtain token cookie domain.
OAuth2TokenCookieRemoveSubdomains int
// OAuth2GrantInsecure omits Secure attribute of the token cookie and uses http scheme for callback url.
OAuth2GrantInsecure bool
// OAuthGrantConfig specifies configuration for OAuth grant flow.
// A new instance will be created from OAuth* options when not specified.
OAuthGrantConfig *auth.OAuthConfig
// CompressEncodings, if not empty replace default compression encodings
CompressEncodings []string
// OIDCSecretsFile path to the file containing key to encrypt OpenID token
OIDCSecretsFile string
// OIDCCookieValidity sets validity time duration for Cookies to calculate expiration time. (default 1h).
OIDCCookieValidity time.Duration
// OIDCDistributedClaimsTimeout sets timeout duration while calling Distributed Claims endpoint.
OIDCDistributedClaimsTimeout time.Duration
// OIDCCookieRemoveSubdomains sets the number of subdomains to remove from
// the callback request hostname to obtain token cookie domain.
OIDCCookieRemoveSubdomains int
// SecretsRegistry to store and load secretsencrypt
SecretsRegistry *secrets.Registry
// CredentialsPaths directories or files where credentials are stored one secret per file
CredentialsPaths []string
// CredentialsUpdateInterval sets the interval to update secrets
CredentialsUpdateInterval time.Duration
// API Monitoring feature is active (feature toggle)
ApiUsageMonitoringEnable bool
ApiUsageMonitoringRealmKeys string
ApiUsageMonitoringClientKeys string
ApiUsageMonitoringRealmsTrackingPattern string
// *DEPRECATED* ApiUsageMonitoringDefaultClientTrackingPattern
ApiUsageMonitoringDefaultClientTrackingPattern string
// Default filters directory enables default filters mechanism and sets the directory where the filters are located
DefaultFiltersDir string
// WebhookTimeout sets timeout duration while calling a custom webhook auth service
WebhookTimeout time.Duration
// MaxAuditBody sets the maximum read size of the body read by the audit log filter
MaxAuditBody int
// MaxMatcherBufferSize sets the maximum read buffer size of blockContent filter defaults to 2MiB
MaxMatcherBufferSize uint64
// EnableSwarm enables skipper fleet communication, required by e.g.
// the cluster ratelimiter
EnableSwarm bool
// redis based swarm
SwarmRedisURLs []string
SwarmRedisPassword string
SwarmRedisHashAlgorithm string
SwarmRedisDialTimeout time.Duration
SwarmRedisReadTimeout time.Duration
SwarmRedisWriteTimeout time.Duration
SwarmRedisPoolTimeout time.Duration
SwarmRedisMinIdleConns int
SwarmRedisMaxIdleConns int
SwarmRedisEndpointsRemoteURL string
SwarmRedisConnMetricsInterval time.Duration
SwarmRedisUpdateInterval time.Duration
// swim based swarm
SwarmKubernetesNamespace string
SwarmKubernetesLabelSelectorKey string
SwarmKubernetesLabelSelectorValue string
SwarmPort int
SwarmMaxMessageBuffer int
SwarmLeaveTimeout time.Duration
// swim based swarm for local testing
SwarmStaticSelf string // 127.0.0.1:9001
SwarmStaticOther string // 127.0.0.1:9002,127.0.0.1:9003
// SwarmRegistry specifies an optional callback function that is
// called after ratelimit registry is initialized
SwarmRegistry func(*ratelimit.Registry)
// ClusterRatelimitMaxGroupShards specifies the maximum number of group shards for the clusterRatelimit filter
ClusterRatelimitMaxGroupShards int
// KubernetesEnableTLS enables kubernetes to use resources to terminate tls
KubernetesEnableTLS bool
// LuaModules that are allowed to be used.
//
// Use <module>.<symbol> to selectively enable module symbols,
// for example: package,base._G,base.print,json
LuaModules []string
// LuaSources that are allowed as input sources. Valid sources
// are "", "file", "inline", "file","inline". Empty list
// defaults to "file","inline" and "none" disables lua
// filters.
LuaSources []string
EnableOpenPolicyAgent bool
EnableOpenPolicyAgentCustomControlLoop bool
OpenPolicyAgentControlLoopInterval time.Duration
OpenPolicyAgentControlLoopMaxJitter time.Duration
EnableOpenPolicyAgentDataPreProcessingOptimization bool
OpenPolicyAgentConfigTemplate string
OpenPolicyAgentEnvoyMetadata string
OpenPolicyAgentCleanerInterval time.Duration
OpenPolicyAgentStartupTimeout time.Duration
OpenPolicyAgentMaxRequestBodySize int64
OpenPolicyAgentRequestBodyBufferSize int64
OpenPolicyAgentMaxMemoryBodyParsing int64
PassiveHealthCheck map[string]string
}
func (o *Options) KubernetesDataClientOptions() kubernetes.Options {
return kubernetes.Options{
AllowedExternalNames: o.KubernetesAllowedExternalNames,
BackendNameTracingTag: o.OpenTracingBackendNameTag,
DefaultFiltersDir: o.DefaultFiltersDir,
KubernetesInCluster: o.KubernetesInCluster,
KubernetesURL: o.KubernetesURL,
TokenFile: o.KubernetesTokenFile,
KubernetesNamespace: o.KubernetesNamespace,
KubernetesEnableEastWest: o.KubernetesEnableEastWest,
KubernetesEnableEndpointslices: o.KubernetesEnableEndpointslices,
KubernetesEastWestDomain: o.KubernetesEastWestDomain,
KubernetesEastWestRangeDomains: o.KubernetesEastWestRangeDomains,
KubernetesEastWestRangePredicates: o.KubernetesEastWestRangePredicates,
KubernetesEastWestRangeAnnotationPredicates: o.KubernetesEastWestRangeAnnotationPredicates,
KubernetesEastWestRangeAnnotationFiltersAppend: o.KubernetesEastWestRangeAnnotationFiltersAppend,
KubernetesAnnotationPredicates: o.KubernetesAnnotationPredicates,
KubernetesAnnotationFiltersAppend: o.KubernetesAnnotationFiltersAppend,
HTTPSRedirectCode: o.KubernetesHTTPSRedirectCode,
DisableCatchAllRoutes: o.KubernetesDisableCatchAllRoutes,
IngressClass: o.KubernetesIngressClass,
IngressLabelSelectors: o.KubernetesIngressLabelSelectors,
ServicesLabelSelectors: o.KubernetesServicesLabelSelectors,
EndpointsLabelSelectors: o.KubernetesEndpointsLabelSelectors,
SecretsLabelSelectors: o.KubernetesSecretsLabelSelectors,
RouteGroupsLabelSelectors: o.KubernetesRouteGroupsLabelSelectors,
OnlyAllowedExternalNames: o.KubernetesOnlyAllowedExternalNames,
OriginMarker: o.EnableRouteCreationMetrics,
PathMode: o.KubernetesPathMode,
ProvideHealthcheck: o.KubernetesHealthcheck,
ProvideHTTPSRedirect: o.KubernetesHTTPSRedirect,
ReverseSourcePredicate: o.ReverseSourcePredicate,
RouteGroupClass: o.KubernetesRouteGroupClass,
WhitelistedHealthCheckCIDR: o.WhitelistedHealthCheckCIDR,
ForceKubernetesService: o.KubernetesForceService,
BackendTrafficAlgorithm: o.KubernetesBackendTrafficAlgorithm,
DefaultLoadBalancerAlgorithm: o.KubernetesDefaultLoadBalancerAlgorithm,
}
}
func (o *Options) OAuthGrantOptions() *auth.OAuthConfig {
oauthConfig := &auth.OAuthConfig{}
oauthConfig.AuthURL = o.OAuth2AuthURL
oauthConfig.TokenURL = o.OAuth2TokenURL
oauthConfig.RevokeTokenURL = o.OAuth2RevokeTokenURL
oauthConfig.TokeninfoURL = o.OAuthTokeninfoURL
oauthConfig.SecretFile = o.OAuth2SecretFile
oauthConfig.ClientID = o.OAuth2ClientID
if oauthConfig.ClientID == "" {
oauthConfig.ClientID, _ = os.LookupEnv("OAUTH2_CLIENT_ID")
}
oauthConfig.ClientSecret = o.OAuth2ClientSecret
if oauthConfig.ClientSecret == "" {
oauthConfig.ClientSecret, _ = os.LookupEnv("OAUTH2_CLIENT_SECRET")
}
oauthConfig.ClientIDFile = o.OAuth2ClientIDFile
oauthConfig.ClientSecretFile = o.OAuth2ClientSecretFile
oauthConfig.CallbackPath = o.OAuth2CallbackPath
oauthConfig.AuthURLParameters = o.OAuth2AuthURLParameters
oauthConfig.Secrets = o.SecretsRegistry
oauthConfig.AccessTokenHeaderName = o.OAuth2AccessTokenHeaderName
oauthConfig.TokeninfoSubjectKey = o.OAuth2TokeninfoSubjectKey
oauthConfig.GrantTokeninfoKeys = o.OAuth2GrantTokeninfoKeys
oauthConfig.TokenCookieName = o.OAuth2TokenCookieName
oauthConfig.TokenCookieRemoveSubdomains = &o.OAuth2TokenCookieRemoveSubdomains
oauthConfig.Insecure = o.OAuth2GrantInsecure
oauthConfig.ConnectionTimeout = o.OAuthTokeninfoTimeout
oauthConfig.MaxIdleConnectionsPerHost = o.IdleConnectionsPerHost
return oauthConfig
}
type serverErrorLogWriter struct{}
func (*serverErrorLogWriter) Write(p []byte) (int, error) {
m := string(p)
if strings.HasPrefix(m, "http: TLS handshake error") && strings.HasSuffix(m, ": EOF\n") {
log.Debug(m) // https://github.com/golang/go/issues/26918
} else if strings.HasPrefix(m, "http: URL query contains semicolon") {
log.Debug(m) // https://github.com/golang/go/issues/25192
} else {
log.Error(m)
}
return len(p), nil
}
func newServerErrorLog() *stdlog.Logger {
return stdlog.New(&serverErrorLogWriter{}, "", 0)
}
func createDataClients(o Options, cr *certregistry.CertRegistry) ([]routing.DataClient, error) {
var clients []routing.DataClient
if o.RoutesFile != "" {
for _, rf := range strings.Split(o.RoutesFile, ",") {
f, err := eskipfile.Open(rf)
if err != nil {
return nil, fmt.Errorf("error while opening eskip file: %w", err)
}
clients = append(clients, f)
}
}
if o.WatchRoutesFile != "" {
for _, rf := range strings.Split(o.WatchRoutesFile, ",") {
clients = append(clients, eskipfile.Watch(rf))
}
}
if len(o.RoutesURLs) > 0 {
for _, url := range o.RoutesURLs {
client, err := eskipfile.RemoteWatch(&eskipfile.RemoteWatchOptions{
RemoteFile: url,
FailOnStartup: true,
HTTPTimeout: o.SourcePollTimeout,
})
if err != nil {
return nil, fmt.Errorf("error while loading routes from url %s: %w", url, err)
}
clients = append(clients, client)
}
}
if o.InlineRoutes != "" {
ir, err := routestring.New(o.InlineRoutes)
if err != nil {
return nil, fmt.Errorf("error while parsing inline routes: %w", err)
}
clients = append(clients, ir)
}
if len(o.EtcdUrls) > 0 {
etcdClient, err := etcd.New(etcd.Options{
Endpoints: o.EtcdUrls,
Prefix: o.EtcdPrefix,
Timeout: o.EtcdWaitTimeout,
Insecure: o.EtcdInsecure,
OAuthToken: o.EtcdOAuthToken,
Username: o.EtcdUsername,
Password: o.EtcdPassword,
})
if err != nil {
return nil, fmt.Errorf("error while creating etcd client: %w", err)
}
clients = append(clients, etcdClient)
}
if o.Kubernetes {
kops := o.KubernetesDataClientOptions()
kops.CertificateRegistry = cr
kubernetesClient, err := kubernetes.New(kops)
if err != nil {
return nil, fmt.Errorf("error while creating kubernetes data client: %w", err)
}
clients = append(clients, kubernetesClient)
}
return clients, nil
}
func getLogOutput(name string) (io.Writer, error) {
name = path.Clean(name)
if name == "/dev/stdout" {
return os.Stdout, nil
}
if name == "/dev/stderr" {
return os.Stderr, nil
}
return os.OpenFile(name, os.O_APPEND|os.O_CREATE|os.O_WRONLY, os.ModePerm)
}
func initLog(o Options) error {
var (
logOutput io.Writer
accessLogOutput io.Writer
err error
)
if o.ApplicationLogOutput != "" {
logOutput, err = getLogOutput(o.ApplicationLogOutput)
if err != nil {
return err
}
}
if !o.AccessLogDisabled && o.AccessLogOutput != "" {
accessLogOutput, err = getLogOutput(o.AccessLogOutput)
if err != nil {
return err
}
}
logging.Init(logging.Options{
ApplicationLogPrefix: o.ApplicationLogPrefix,
ApplicationLogOutput: logOutput,
ApplicationLogJSONEnabled: o.ApplicationLogJSONEnabled,
ApplicationLogJsonFormatter: o.ApplicationLogJsonFormatter,
AccessLogOutput: accessLogOutput,
AccessLogJSONEnabled: o.AccessLogJSONEnabled,
AccessLogStripQuery: o.AccessLogStripQuery,
AccessLogJsonFormatter: o.AccessLogJsonFormatter,
AccessLogFormatter: o.AccessLogFormatter,
})
return nil
}
// filterRegistry creates a filter registry with the builtin and
// custom filter specs registered excluding disabled filters.
// If [Options.RegisterFilters] callback is set, it will be called.
func (o *Options) filterRegistry() filters.Registry {
registry := make(filters.Registry)
disabledFilters := make(map[string]struct{})
for _, name := range o.DisabledFilters {
disabledFilters[name] = struct{}{}
}
for _, f := range builtin.Filters() {
if _, ok := disabledFilters[f.Name()]; !ok {
registry.Register(f)
}
}
for _, f := range o.CustomFilters {
if _, ok := disabledFilters[f.Name()]; !ok {
registry.Register(f)
}
}
if o.RegisterFilters != nil {
o.RegisterFilters(registry)
}
return registry
}
func (o *Options) tlsConfig(cr *certregistry.CertRegistry) (*tls.Config, error) {
if o.ProxyTLS != nil {
return o.ProxyTLS, nil
}
if o.CertPathTLS == "" && o.KeyPathTLS == "" && cr == nil {
return nil, nil
}
config := &tls.Config{
MinVersion: o.TLSMinVersion,
ClientAuth: o.TLSClientAuth,
}
if o.CipherSuites != nil {
config.CipherSuites = o.CipherSuites
}
if cr != nil {
config.GetCertificate = cr.GetCertFromHello
}
if o.CertPathTLS == "" && o.KeyPathTLS == "" {
return config, nil
}
crts := strings.Split(o.CertPathTLS, ",")
keys := strings.Split(o.KeyPathTLS, ",")
if len(crts) != len(keys) {
return nil, fmt.Errorf("number of certificates does not match number of keys")
}
for i := 0; i < len(crts); i++ {
crt, key := crts[i], keys[i]
keypair, err := tls.LoadX509KeyPair(crt, key)
if err != nil {
return nil, fmt.Errorf("failed to load X509 keypair from %s and %s: %w", crt, key, err)
}
config.Certificates = append(config.Certificates, keypair)
}
return config, nil
}
func (o *Options) openTracingTracerInstance() (ot.Tracer, error) {
if o.OpenTracingTracer != nil {
return o.OpenTracingTracer, nil
}
if len(o.OpenTracing) > 0 {
return tracing.InitTracer(o.OpenTracing)
} else {
// always have a tracer available, so filter authors can rely on the
// existence of a tracer
tracer, err := tracing.LoadTracingPlugin(o.PluginDirs, []string{"noop"})
if err != nil {
return nil, err
} else if tracer == nil {
// LoadTracingPlugin unfortunately may return nil tracer
return nil, fmt.Errorf("failed to load tracing plugin from %v", o.PluginDirs)
}
return tracer, nil
}
}
func listen(o *Options, address string, mtr metrics.Metrics) (net.Listener, error) {
if !o.EnableTCPQueue {
return net.Listen("tcp", address)
}
var memoryLimit int64
if o.MaxTCPListenerConcurrency <= 0 {
// cgroup v1: https://www.kernel.org/doc/Documentation/cgroup-v1/memory.txt
// cgroup v2: https://www.kernel.org/doc/Documentation/cgroup-v2.txt
// Note that in containers this will be the container limit.
// Runtimes without these files will use defaults defined in `queuelistener` package.
const (
memoryLimitFileV1 = "/sys/fs/cgroup/memory/memory.limit_in_bytes"
memoryLimitFileV2 = "/sys/fs/cgroup/memory.max"
)
memoryLimitBytes, err := os.ReadFile(memoryLimitFileV2)
if err != nil {
memoryLimitBytes, err = os.ReadFile(memoryLimitFileV1)
if err != nil {
log.Errorf("Failed to read memory limits, fallback to defaults: %v", err)
}
}
if err == nil {
memoryLimitString := strings.TrimSpace(string(memoryLimitBytes))
memoryLimit, err = strconv.ParseInt(memoryLimitString, 10, 64)
if err != nil {
log.Errorf("Failed to convert memory limits, fallback to defaults: %v", err)
}
// 4GB, temporarily, as a tested magic number until a better mechanism is in place:
if memoryLimit > 1<<32 {
memoryLimit = 1 << 32
}
}
}
qto := o.ReadHeaderTimeoutServer
if qto <= 0 {
qto = o.ReadTimeoutServer
}
return queuelistener.Listen(queuelistener.Options{
Network: "tcp",
Address: address,
MaxConcurrency: o.MaxTCPListenerConcurrency,
MaxQueueSize: o.MaxTCPListenerQueue,
MemoryLimitBytes: memoryLimit,
ConnectionBytes: o.ExpectedBytesPerRequest,
QueueTimeout: qto,
Metrics: mtr,
})
}
func listenAndServeQuit(
proxy http.Handler,
o *Options,
sigs chan os.Signal,
idleConnsCH chan struct{},
mtr metrics.Metrics,
cr *certregistry.CertRegistry,
) error {
tlsConfig, err := o.tlsConfig(cr)
if err != nil {
return err
}
serveTLS := tlsConfig != nil
address := o.Address
if address == "" {
if serveTLS {
address = ":https"
} else {
address = ":http"
}
}
srv := &http.Server{
Addr: address,
TLSConfig: tlsConfig,
Handler: proxy,
ReadTimeout: o.ReadTimeoutServer,
ReadHeaderTimeout: o.ReadHeaderTimeoutServer,
WriteTimeout: o.WriteTimeoutServer,
IdleTimeout: o.IdleTimeoutServer,
MaxHeaderBytes: o.MaxHeaderBytes,
ErrorLog: newServerErrorLog(),
}
cm := &skpnet.ConnManager{
Keepalive: o.KeepaliveServer,
KeepaliveRequests: o.KeepaliveRequestsServer,
}
if o.EnableConnMetricsServer {
cm.Metrics = mtr
}
cm.Configure(srv)
log.Infof("Listen on %v", address)
l, err := listen(o, address, mtr)
if err != nil {
return err
}
// making idleConnsCH and sigs optional parameters is required to be able to tear down a server
// from the tests
if idleConnsCH == nil {
idleConnsCH = make(chan struct{})
}
if sigs == nil {
sigs = make(chan os.Signal, 1)
}
go func() {
signal.Notify(sigs, syscall.SIGTERM)
<-sigs
log.Infof("Got shutdown signal, wait %v for health check", o.WaitForHealthcheckInterval)
time.Sleep(o.WaitForHealthcheckInterval)
log.Info("Start shutdown")
if err := srv.Shutdown(context.Background()); err != nil {
log.Errorf("Failed to graceful shutdown: %v", err)
}
close(idleConnsCH)
}()
if serveTLS {
if o.InsecureAddress != "" {
log.Infof("Insecure listener on %v", o.InsecureAddress)
go func() {
l, err := listen(o, o.InsecureAddress, mtr)
if err != nil {
log.Errorf("Failed to start insecure listener on %s: %v", o.InsecureAddress, err)
}
if err := srv.Serve(l); err != http.ErrServerClosed {
log.Errorf("Insecure listener serve failed: %v", err)
}
}()
}
if err := srv.ServeTLS(l, "", ""); err != http.ErrServerClosed {
log.Errorf("ServeTLS failed: %v", err)
return err
}
} else {
log.Infof("TLS settings not found, defaulting to HTTP")
if err := srv.Serve(l); err != http.ErrServerClosed {
log.Errorf("Serve failed: %v", err)
return err
}
}
<-idleConnsCH
log.Infof("done.")
return nil
}
func findKubernetesDataclient(dataClients []routing.DataClient) *kubernetes.Client {
var kdc *kubernetes.Client
for _, dc := range dataClients {
if kc, ok := dc.(*kubernetes.Client); ok {
kdc = kc
break
}
}
return kdc
}
func getKubernetesRedisAddrUpdater(opts *Options, kdc *kubernetes.Client, loaded bool) func() ([]string, error) {
if loaded {
// TODO(sszuecs): make sure kubernetes dataclient is already initialized and
// has polled the data once or kdc.GetEndpointAdresses should be blocking
// call to kubernetes API
return func() ([]string, error) {
a := kdc.GetEndpointAddresses(opts.KubernetesRedisServiceNamespace, opts.KubernetesRedisServiceName)
log.Debugf("GetEndpointAddresses found %d redis endpoints", len(a))
return joinPort(a, opts.KubernetesRedisServicePort), nil
}
} else {
return func() ([]string, error) {
a, err := kdc.LoadEndpointAddresses(opts.KubernetesRedisServiceNamespace, opts.KubernetesRedisServiceName)
log.Debugf("LoadEndpointAddresses found %d redis endpoints, err: %v", len(a), err)
return joinPort(a, opts.KubernetesRedisServicePort), err
}
}
}
func joinPort(addrs []string, port int) []string {
p := strconv.Itoa(port)
for i := 0; i < len(addrs); i++ {
addrs[i] = net.JoinHostPort(addrs[i], p)
}
return addrs
}
type RedisEndpoint struct {
Address string `json:"address"`
}
type RedisEndpoints struct {
Endpoints []RedisEndpoint `json:"endpoints"`
}
func getRemoteURLRedisAddrUpdater(address string) func() ([]string, error) {
/* #nosec */
return func() ([]string, error) {
resp, err := http.Get(address)
if err != nil {
log.Errorf("failed to connect to redis endpoint %v, due to: %v", address, err)
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Errorf("failed to read to redis response %v", err)
return nil, err
}
target := &RedisEndpoints{}
err = json.Unmarshal(body, target)
if err != nil {
log.Errorf("Failed to decode body to json %v", err)
return nil, err
}
a := make([]string, 0, len(target.Endpoints))
for _, endpoint := range target.Endpoints {
a = append(a, endpoint.Address)
}
return a, nil
}
}
func run(o Options, sig chan os.Signal, idleConnsCH chan struct{}) error {
// init log
err := initLog(o)
if err != nil {
return err
}
if o.EnablePrometheusMetrics {
o.MetricsFlavours = append(o.MetricsFlavours, "prometheus")
}
metricsKind := metrics.UnkownKind
for _, s := range o.MetricsFlavours {
switch s {
case "codahale":
metricsKind |= metrics.CodaHaleKind
case "prometheus":
metricsKind |= metrics.PrometheusKind
}
}
// set default if unset
if metricsKind == metrics.UnkownKind {
metricsKind = metrics.CodaHaleKind
}
log.Infof("Expose metrics in %s format", metricsKind)
mtrOpts := metrics.Options{
Format: metricsKind,
Prefix: o.MetricsPrefix,
EnableDebugGcMetrics: o.EnableDebugGcMetrics,
EnableRuntimeMetrics: o.EnableRuntimeMetrics,
EnableServeRouteMetrics: o.EnableServeRouteMetrics,
EnableServeRouteCounter: o.EnableServeRouteCounter,
EnableServeHostMetrics: o.EnableServeHostMetrics,
EnableServeHostCounter: o.EnableServeHostCounter,
EnableServeMethodMetric: o.EnableServeMethodMetric,
EnableServeStatusCodeMetric: o.EnableServeStatusCodeMetric,
EnableProxyRequestMetrics: o.EnableProxyRequestMetrics,
EnableProxyResponseMetrics: o.EnableProxyResponseMetrics,
EnableBackendHostMetrics: o.EnableBackendHostMetrics,
EnableProfile: o.EnableProfile,
BlockProfileRate: o.BlockProfileRate,
MutexProfileFraction: o.MutexProfileFraction,
MemProfileRate: o.MemProfileRate,
EnableAllFiltersMetrics: o.EnableAllFiltersMetrics,
EnableCombinedResponseMetrics: o.EnableCombinedResponseMetrics,
EnableRouteResponseMetrics: o.EnableRouteResponseMetrics,
EnableRouteBackendErrorsCounters: o.EnableRouteBackendErrorsCounters,
EnableRouteStreamingErrorsCounters: o.EnableRouteStreamingErrorsCounters,
EnableRouteBackendMetrics: o.EnableRouteBackendMetrics,
UseExpDecaySample: o.MetricsUseExpDecaySample,
HistogramBuckets: o.HistogramMetricBuckets,
DisableCompatibilityDefaults: o.DisableMetricsCompatibilityDefaults,
PrometheusRegistry: o.PrometheusRegistry,
EnablePrometheusStartLabel: o.EnablePrometheusStartLabel,
}
mtr := o.MetricsBackend
if mtr == nil {
mtr = metrics.NewMetrics(mtrOpts)
}
// set global instance for backwards compatibility
metrics.Default = mtr
// *DEPRECATED* client tracking parameter
if o.ApiUsageMonitoringDefaultClientTrackingPattern != "" {
log.Warn(`"ApiUsageMonitoringDefaultClientTrackingPattern" option is deprecated`)
}
if err := o.findAndLoadPlugins(); err != nil {
return err
}
var cr *certregistry.CertRegistry
if o.KubernetesEnableTLS {
cr = certregistry.NewCertRegistry()
}
// create data clients
dataClients, err := createDataClients(o, cr)
if err != nil {
return err
}
// append custom data clients
dataClients = append(dataClients, o.CustomDataClients...)
if len(dataClients) == 0 {
log.Warning("no route source specified")
}
o.PluginDirs = append(o.PluginDirs, o.PluginDir)
tracer, err := o.openTracingTracerInstance()
if err != nil {
return err
}
// tee filters override with initialized tracer
o.CustomFilters = append(o.CustomFilters,
// tee()
teefilters.WithOptions(teefilters.Options{
Tracer: tracer,
NoFollow: false,
}),
// teenf()
teefilters.WithOptions(teefilters.Options{
NoFollow: true,
Tracer: tracer,
}),
)
if o.OAuthTokeninfoURL != "" {
tio := auth.TokeninfoOptions{
URL: o.OAuthTokeninfoURL,
Timeout: o.OAuthTokeninfoTimeout,
MaxIdleConns: o.IdleConnectionsPerHost,
Tracer: tracer,
Metrics: mtr,
CacheSize: o.OAuthTokeninfoCacheSize,
CacheTTL: o.OAuthTokeninfoCacheTTL,
}
o.CustomFilters = append(o.CustomFilters,
auth.NewOAuthTokeninfoAllScopeWithOptions(tio),
auth.NewOAuthTokeninfoAnyScopeWithOptions(tio),
auth.NewOAuthTokeninfoAllKVWithOptions(tio),
auth.NewOAuthTokeninfoAnyKVWithOptions(tio),
auth.NewOAuthTokeninfoValidate(tio),
)
}
if o.SecretsRegistry == nil {
o.SecretsRegistry = secrets.NewRegistry()
}
defer o.SecretsRegistry.Close()
sp := secrets.NewSecretPaths(o.CredentialsUpdateInterval)
defer sp.Close()
for _, p := range o.CredentialsPaths {
if err := sp.Add(p); err != nil {
log.Errorf("Failed to add credentials file: %s: %v", p, err)
}
}
tio := auth.TokenintrospectionOptions{
Timeout: o.OAuthTokenintrospectionTimeout,
MaxIdleConns: o.IdleConnectionsPerHost,
Tracer: tracer,
}
who := auth.WebhookOptions{
Timeout: o.WebhookTimeout,
MaxIdleConns: o.IdleConnectionsPerHost,
Tracer: tracer,
}
admissionControlFilter := shedder.NewAdmissionControl(shedder.Options{
Tracer: tracer,
})
admissionControlSpec, ok := admissionControlFilter.(*shedder.AdmissionControlSpec)
if !ok {
log.Fatal("Failed to cast admission control filter to spec")
}
o.CustomFilters = append(o.CustomFilters,
logfilter.NewAuditLog(o.MaxAuditBody),
block.NewBlock(o.MaxMatcherBufferSize),
block.NewBlockHex(o.MaxMatcherBufferSize),
auth.NewBearerInjector(sp),
auth.NewSetRequestHeaderFromSecret(sp),
auth.NewJwtValidationWithOptions(tio),
auth.NewJwtMetrics(),
auth.TokenintrospectionWithOptions(auth.NewOAuthTokenintrospectionAnyClaims, tio),
auth.TokenintrospectionWithOptions(auth.NewOAuthTokenintrospectionAllClaims, tio),
auth.TokenintrospectionWithOptions(auth.NewOAuthTokenintrospectionAnyKV, tio),
auth.TokenintrospectionWithOptions(auth.NewOAuthTokenintrospectionAllKV, tio),
auth.TokenintrospectionWithOptions(auth.NewSecureOAuthTokenintrospectionAnyClaims, tio),
auth.TokenintrospectionWithOptions(auth.NewSecureOAuthTokenintrospectionAllClaims, tio),
auth.TokenintrospectionWithOptions(auth.NewSecureOAuthTokenintrospectionAnyKV, tio),
auth.TokenintrospectionWithOptions(auth.NewSecureOAuthTokenintrospectionAllKV, tio),
auth.WebhookWithOptions(who),
auth.NewOIDCQueryClaimsFilter(),
apiusagemonitoring.NewApiUsageMonitoring(
o.ApiUsageMonitoringEnable,
o.ApiUsageMonitoringRealmKeys,
o.ApiUsageMonitoringClientKeys,
o.ApiUsageMonitoringRealmsTrackingPattern,
),
admissionControlFilter,
)
if o.OIDCSecretsFile != "" {
oidcClientId, _ := os.LookupEnv("OIDC_CLIENT_ID")
oidcClientSecret, _ := os.LookupEnv("OIDC_CLIENT_SECRET")
opts := auth.OidcOptions{
CookieRemoveSubdomains: &o.OIDCCookieRemoveSubdomains,
CookieValidity: o.OIDCCookieValidity,
Timeout: o.OIDCDistributedClaimsTimeout,
MaxIdleConns: o.IdleConnectionsPerHost,
Tracer: tracer,
OidcClientId: oidcClientId,
OidcClientSecret: oidcClientSecret,
}
o.CustomFilters = append(o.CustomFilters,
auth.NewOAuthOidcUserInfosWithOptions(o.OIDCSecretsFile, o.SecretsRegistry, opts),
auth.NewOAuthOidcAnyClaimsWithOptions(o.OIDCSecretsFile, o.SecretsRegistry, opts),
auth.NewOAuthOidcAllClaimsWithOptions(o.OIDCSecretsFile, o.SecretsRegistry, opts),
)
}
var swarmer ratelimit.Swarmer
var redisOptions *skpnet.RedisOptions
log.Infof("enable swarm: %v", o.EnableSwarm)
if o.EnableSwarm {
if len(o.SwarmRedisURLs) > 0 || o.KubernetesRedisServiceName != "" || o.SwarmRedisEndpointsRemoteURL != "" {
log.Infof("Redis based swarm with %d shards", len(o.SwarmRedisURLs))
redisOptions = &skpnet.RedisOptions{
Addrs: o.SwarmRedisURLs,
Password: o.SwarmRedisPassword,
HashAlgorithm: o.SwarmRedisHashAlgorithm,
DialTimeout: o.SwarmRedisDialTimeout,
ReadTimeout: o.SwarmRedisReadTimeout,
WriteTimeout: o.SwarmRedisWriteTimeout,
PoolTimeout: o.SwarmRedisPoolTimeout,
MinIdleConns: o.SwarmRedisMinIdleConns,
MaxIdleConns: o.SwarmRedisMaxIdleConns,
ConnMetricsInterval: o.SwarmRedisConnMetricsInterval,
UpdateInterval: o.SwarmRedisUpdateInterval,
Tracer: tracer,
Log: log.New(),
}
} else {
log.Infof("Start swim based swarm")
swops := &swarm.Options{
SwarmPort: uint16(o.SwarmPort),
MaxMessageBuffer: o.SwarmMaxMessageBuffer,
LeaveTimeout: o.SwarmLeaveTimeout,
Debug: log.GetLevel() == log.DebugLevel,
}
if o.Kubernetes {
swops.KubernetesOptions = &swarm.KubernetesOptions{
KubernetesInCluster: o.KubernetesInCluster,
KubernetesAPIBaseURL: o.KubernetesURL,
Namespace: o.SwarmKubernetesNamespace,
LabelSelectorKey: o.SwarmKubernetesLabelSelectorKey,
LabelSelectorValue: o.SwarmKubernetesLabelSelectorValue,
}
}
if o.SwarmStaticSelf != "" {
self, err := swarm.NewStaticNodeInfo(o.SwarmStaticSelf, o.SwarmStaticSelf)
if err != nil {
return fmt.Errorf("failed to get static NodeInfo: %w", err)
}
other := []*swarm.NodeInfo{self}
for _, addr := range strings.Split(o.SwarmStaticOther, ",") {
ni, err := swarm.NewStaticNodeInfo(addr, addr)
if err != nil {
return fmt.Errorf("failed to get static NodeInfo: %w", err)
}
other = append(other, ni)
}
swops.StaticSwarm = swarm.NewStaticSwarm(self, other)
}
theSwarm, err := swarm.NewSwarm(swops)
if err != nil {
return fmt.Errorf("failed to init swarm with options %+v: %w", swops, err)
}
defer theSwarm.Leave()
swarmer = theSwarm
}
// in case we have kubernetes dataclient and we can detect redis instances, we patch redisOptions
if redisOptions != nil && o.KubernetesRedisServiceNamespace != "" && o.KubernetesRedisServiceName != "" {
log.Infof("Use endpoints %s/%s to fetch updated redis shards", o.KubernetesRedisServiceNamespace, o.KubernetesRedisServiceName)
kdc := findKubernetesDataclient(dataClients)
if kdc != nil {
redisOptions.AddrUpdater = getKubernetesRedisAddrUpdater(&o, kdc, true)
} else {
kdc, err := kubernetes.New(o.KubernetesDataClientOptions())
if err != nil {
return err
}
defer kdc.Close()
redisOptions.AddrUpdater = getKubernetesRedisAddrUpdater(&o, kdc, false)
}
_, err = redisOptions.AddrUpdater()
if err != nil {
log.Errorf("Failed to update redis addresses from kubernetes: %v", err)
return err
}
} else if redisOptions != nil && o.SwarmRedisEndpointsRemoteURL != "" {
log.Infof("Use remote address %s to fetch updates redis shards", o.SwarmRedisEndpointsRemoteURL)
redisOptions.AddrUpdater = getRemoteURLRedisAddrUpdater(o.SwarmRedisEndpointsRemoteURL)
_, err = redisOptions.AddrUpdater()
if err != nil {
log.Errorf("Failed to update redis addresses from URL: %v", err)
return err
}
}
}
var ratelimitRegistry *ratelimit.Registry
var failClosedRatelimitPostProcessor *ratelimitfilters.FailClosedPostProcessor
if o.EnableRatelimiters || len(o.RatelimitSettings) > 0 {
log.Infof("enabled ratelimiters %v: %v", o.EnableRatelimiters, o.RatelimitSettings)
ratelimitRegistry = ratelimit.NewSwarmRegistry(swarmer, redisOptions, o.RatelimitSettings...)
defer ratelimitRegistry.Close()
if hook := o.SwarmRegistry; hook != nil {
hook(ratelimitRegistry)
}
if o.ClusterRatelimitMaxGroupShards < 1 {
log.Warn("ClusterRatelimitMaxGroupShards must be positive, reset to 1")
o.ClusterRatelimitMaxGroupShards = 1
}
failClosedRatelimitPostProcessor = ratelimitfilters.NewFailClosedPostProcessor()
provider := ratelimitfilters.NewRatelimitProvider(ratelimitRegistry)
o.CustomFilters = append(o.CustomFilters,
ratelimitfilters.NewFailClosed(),
ratelimitfilters.NewClientRatelimit(provider),
ratelimitfilters.NewLocalRatelimit(provider),
ratelimitfilters.NewRatelimit(provider),
ratelimitfilters.NewShardedClusterRateLimit(provider, o.ClusterRatelimitMaxGroupShards),
ratelimitfilters.NewClusterClientRateLimit(provider),
ratelimitfilters.NewDisableRatelimit(provider),
ratelimitfilters.NewBackendRatelimit(),
)
if redisOptions != nil {
o.CustomFilters = append(o.CustomFilters, ratelimitfilters.NewClusterLeakyBucketRatelimit(ratelimitRegistry))
}
}
if o.TLSMinVersion == 0 {
o.TLSMinVersion = tls.VersionTLS12
}
if o.EnableOAuth2GrantFlow /* explicitly enable grant flow */ {
oauthConfig := o.OAuthGrantConfig
if oauthConfig == nil {
oauthConfig = o.OAuthGrantOptions()
o.OAuthGrantConfig = oauthConfig
grantSecrets := secrets.NewSecretPaths(o.CredentialsUpdateInterval)
defer grantSecrets.Close()
oauthConfig.SecretsProvider = grantSecrets
oauthConfig.Tracer = tracer
if err := oauthConfig.Init(); err != nil {
log.Errorf("Failed to initialize oauth grant filter: %v.", err)
return err
}
}
o.CustomFilters = append(o.CustomFilters,
oauthConfig.NewGrant(),
oauthConfig.NewGrantCallback(),
oauthConfig.NewGrantClaimsQuery(),
oauthConfig.NewGrantLogout(),
)
}
var opaRegistry *openpolicyagent.OpenPolicyAgentRegistry
if o.EnableOpenPolicyAgent {
opaRegistry = openpolicyagent.NewOpenPolicyAgentRegistry(
openpolicyagent.WithMaxRequestBodyBytes(o.OpenPolicyAgentMaxRequestBodySize),
openpolicyagent.WithMaxMemoryBodyParsing(o.OpenPolicyAgentMaxMemoryBodyParsing),
openpolicyagent.WithReadBodyBufferSize(o.OpenPolicyAgentRequestBodyBufferSize),
openpolicyagent.WithCleanInterval(o.OpenPolicyAgentCleanerInterval),
openpolicyagent.WithInstanceStartupTimeout(o.OpenPolicyAgentStartupTimeout),
openpolicyagent.WithTracer(tracer),
openpolicyagent.WithEnableCustomControlLoop(o.EnableOpenPolicyAgentCustomControlLoop),
openpolicyagent.WithControlLoopInterval(o.OpenPolicyAgentControlLoopInterval),
openpolicyagent.WithControlLoopMaxJitter(o.OpenPolicyAgentControlLoopMaxJitter),
openpolicyagent.WithEnableDataPreProcessingOptimization(o.EnableOpenPolicyAgentDataPreProcessingOptimization))
defer opaRegistry.Close()
opts := make([]func(*openpolicyagent.OpenPolicyAgentInstanceConfig) error, 0)
opts = append(opts,
openpolicyagent.WithConfigTemplateFile(o.OpenPolicyAgentConfigTemplate))
if o.OpenPolicyAgentEnvoyMetadata != "" {
opts = append(opts, openpolicyagent.WithEnvoyMetadataFile(o.OpenPolicyAgentEnvoyMetadata))
}
o.CustomFilters = append(o.CustomFilters,
opaauthorizerequest.NewOpaAuthorizeRequestSpec(opaRegistry, opts...),
opaauthorizerequest.NewOpaAuthorizeRequestWithBodySpec(opaRegistry, opts...),
opaserveresponse.NewOpaServeResponseSpec(opaRegistry, opts...),
opaserveresponse.NewOpaServeResponseWithReqBodySpec(opaRegistry, opts...),
)
}
if len(o.CompressEncodings) > 0 {
compress, err := builtin.NewCompressWithOptions(builtin.CompressOptions{Encodings: o.CompressEncodings})
if err != nil {
log.Errorf("Failed to create compress filter: %v.", err)
return err
}
o.CustomFilters = append(o.CustomFilters, compress)
}
lua, err := script.NewLuaScriptWithOptions(script.LuaOptions{
Modules: o.LuaModules,
Sources: o.LuaSources,
})
if err != nil {
log.Errorf("Failed to create lua filter: %v.", err)
return err
}
o.CustomFilters = append(o.CustomFilters, lua)
// create routing
// create the proxy instance
var mo routing.MatchingOptions
if o.IgnoreTrailingSlash {
mo = routing.IgnoreTrailingSlash
}
// ensure a non-zero poll timeout
if o.SourcePollTimeout <= 0 {
o.SourcePollTimeout = defaultSourcePollTimeout
}
// check for dev mode, and set update buffer of the routes
updateBuffer := defaultRoutingUpdateBuffer
if o.DevMode {
updateBuffer = 0
}
// include bundled custom predicates
o.CustomPredicates = append(o.CustomPredicates,
source.New(),
source.NewFromLast(),
source.NewClientIP(),
interval.NewBetween(),
interval.NewBefore(),
interval.NewAfter(),
cron.New(),
cookie.New(),
query.New(),
traffic.New(),
traffic.NewSegment(),
primitive.NewTrue(),
primitive.NewFalse(),
primitive.NewShutdown(),
pauth.NewJWTPayloadAllKV(),
pauth.NewJWTPayloadAnyKV(),
pauth.NewJWTPayloadAllKVRegexp(),
pauth.NewJWTPayloadAnyKVRegexp(),
pauth.NewHeaderSHA256(),
methods.New(),
tee.New(),
forwarded.NewForwardedHost(),
forwarded.NewForwardedProto(),
host.NewAny(),
content.NewContentLengthBetween(),
)
// provide default value for wrapper if not defined
if o.CustomHttpHandlerWrap == nil {
o.CustomHttpHandlerWrap = func(original http.Handler) http.Handler {
return original
}
}
schedulerRegistry := scheduler.RegistryWith(scheduler.Options{
Metrics: mtr,
EnableRouteFIFOMetrics: o.EnableRouteFIFOMetrics,
EnableRouteLIFOMetrics: o.EnableRouteLIFOMetrics,
})
defer schedulerRegistry.Close()
passiveHealthCheckEnabled, passiveHealthCheck, err := proxy.InitPassiveHealthChecker(o.PassiveHealthCheck)
if err != nil {
return err
}
// create a routing engine
endpointRegistry := routing.NewEndpointRegistry(routing.RegistryOptions{
PassiveHealthCheckEnabled: passiveHealthCheckEnabled,
StatsResetPeriod: passiveHealthCheck.Period,
MinRequests: passiveHealthCheck.MinRequests,
MinHealthCheckDropProbability: passiveHealthCheck.MinDropProbability,
MaxHealthCheckDropProbability: passiveHealthCheck.MaxDropProbability,
})
ro := routing.Options{
FilterRegistry: o.filterRegistry(),
MatchingOptions: mo,
PollTimeout: o.SourcePollTimeout,
DataClients: dataClients,
Predicates: o.CustomPredicates,
UpdateBuffer: updateBuffer,
SuppressLogs: o.SuppressRouteUpdateLogs,
PostProcessors: []routing.PostProcessor{
loadbalancer.NewAlgorithmProvider(),
endpointRegistry,
schedulerRegistry,
builtin.NewRouteCreationMetrics(mtr),
fadein.NewPostProcessor(fadein.PostProcessorOptions{EndpointRegistry: endpointRegistry}),
admissionControlSpec.PostProcessor(),
builtin.CommentPostProcessor{},
},
SignalFirstLoad: o.WaitFirstRouteLoad,
}
if failClosedRatelimitPostProcessor != nil {
ro.PostProcessors = append(ro.PostProcessors, failClosedRatelimitPostProcessor)
}
if o.DefaultFilters != nil {
ro.PreProcessors = append(ro.PreProcessors, o.DefaultFilters)
}
if o.CloneRoute != nil {
for _, cr := range o.CloneRoute {
ro.PreProcessors = append(ro.PreProcessors, cr)
}
}
if o.EditRoute != nil {
for _, er := range o.EditRoute {
ro.PreProcessors = append(ro.PreProcessors, er)
}
}
ro.PreProcessors = append(ro.PreProcessors, schedulerRegistry.PreProcessor())
if o.EnableOAuth2GrantFlow /* explicitly enable grant flow when callback route was not disabled */ {
ro.PreProcessors = append(ro.PreProcessors, o.OAuthGrantConfig.NewGrantPreprocessor())
}
if o.EnableOpenPolicyAgent {
ro.PostProcessors = append(ro.PostProcessors, opaRegistry)
}
if o.CustomRoutingPreProcessors != nil {
ro.PreProcessors = append(ro.PreProcessors, o.CustomRoutingPreProcessors...)
}
ro.PreProcessors = append(ro.PreProcessors, admissionControlSpec.PreProcessor())
ro.Metrics = mtr
routing := routing.New(ro)
defer routing.Close()
proxyFlags := proxy.Flags(o.ProxyOptions) | o.ProxyFlags
proxyParams := proxy.Params{
Routing: routing,
Flags: proxyFlags,
Metrics: mtr,
PriorityRoutes: o.PriorityRoutes,
IdleConnectionsPerHost: o.IdleConnectionsPerHost,
CloseIdleConnsPeriod: o.CloseIdleConnsPeriod,
FlushInterval: o.BackendFlushInterval,
ExperimentalUpgrade: o.ExperimentalUpgrade,
ExperimentalUpgradeAudit: o.ExperimentalUpgradeAudit,
MaxLoopbacks: o.MaxLoopbacks,
DefaultHTTPStatus: o.DefaultHTTPStatus,
Timeout: o.TimeoutBackend,
ResponseHeaderTimeout: o.ResponseHeaderTimeoutBackend,
ExpectContinueTimeout: o.ExpectContinueTimeoutBackend,
KeepAlive: o.KeepAliveBackend,
DualStack: o.DualStackBackend,
TLSHandshakeTimeout: o.TLSHandshakeTimeoutBackend,
MaxIdleConns: o.MaxIdleConnsBackend,
DisableHTTPKeepalives: o.DisableHTTPKeepalives,
AccessLogDisabled: o.AccessLogDisabled,
ClientTLS: o.ClientTLS,
CustomHttpRoundTripperWrap: o.CustomHttpRoundTripperWrap,
RateLimiters: ratelimitRegistry,
EndpointRegistry: endpointRegistry,
EnablePassiveHealthCheck: passiveHealthCheckEnabled,
PassiveHealthCheck: passiveHealthCheck,
}
if o.EnableBreakers || len(o.BreakerSettings) > 0 {
proxyParams.CircuitBreakers = circuit.NewRegistry(o.BreakerSettings...)
}
if o.DebugListener != "" {
do := proxyParams
do.Flags |= proxy.Debug
dbg := proxy.WithParams(do)
log.Infof("debug listener on %v", o.DebugListener)
go func() { http.ListenAndServe(o.DebugListener, dbg) /* #nosec */ }()
}
// init support endpoints
supportListener := o.SupportListener
// Backward compatibility
if supportListener == "" {
supportListener = o.MetricsListener
}
if supportListener != "" {
mux := http.NewServeMux()
mux.Handle("/routes", routing)
mux.Handle("/routes/", routing)
metricsHandler := metrics.NewHandler(mtrOpts, mtr)
mux.Handle("/metrics", metricsHandler)
mux.Handle("/metrics/", metricsHandler)
mux.Handle("/debug/pprof", metricsHandler)
mux.Handle("/debug/pprof/", metricsHandler)
log.Infof("support listener on %s", supportListener)
go func() {
/* #nosec */
if err := http.ListenAndServe(supportListener, mux); err != nil {
log.Errorf("Failed to start supportListener on %s: %v", supportListener, err)
}
}()
} else {
log.Infoln("Metrics are disabled")
}
proxyParams.OpenTracing = &proxy.OpenTracingParams{
Tracer: tracer,
InitialSpan: o.OpenTracingInitialSpan,
ExcludeTags: o.OpenTracingExcludedProxyTags,
DisableFilterSpans: o.OpenTracingDisableFilterSpans,
LogFilterEvents: o.OpenTracingLogFilterLifecycleEvents,
LogStreamEvents: o.OpenTracingLogStreamEvents,
}
// create the proxy
proxy := proxy.WithParams(proxyParams)
defer proxy.Close()
for _, startupCheckURL := range o.StatusChecks {
for {
/* #nosec */
resp, err := http.Get(startupCheckURL)
if err != nil {
log.Infof("%s unhealthy", startupCheckURL)
time.Sleep(1 * time.Second)
continue
}
resp.Body.Close()
if resp.StatusCode == 200 {
log.Infof("%s healthy", startupCheckURL)
break
}
log.Infof("%s unhealthy", startupCheckURL)
time.Sleep(1 * time.Second)
}
}
// wait for the first route configuration to be loaded if enabled:
<-routing.FirstLoad()
log.Info("Dataclients are updated once, first load complete")
return listenAndServeQuit(o.CustomHttpHandlerWrap(proxy), &o, sig, idleConnsCH, mtr, cr)
}
// Run skipper.
func Run(o Options) error {
return run(o, nil, nil)
}
package swarm
import (
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
"net/http"
"os"
"time"
"github.com/cenkalti/backoff"
log "github.com/sirupsen/logrus"
)
const (
// DefaultNamespace is the default namespace where swarm searches for peer information
DefaultNamespace = "kube-system"
// DefaultLabelSelectorKey is the default label key to select Pods for peer information
DefaultLabelSelectorKey = "application"
// DefaultLabelSelectorValue is the default label value to select Pods for peer information
DefaultLabelSelectorValue = "skipper-ingress"
defaultKubernetesURL = "http://localhost:8001"
serviceAccountDir = "/var/run/secrets/kubernetes.io/serviceaccount/"
serviceAccountTokenKey = "token"
serviceAccountRootCAKey = "ca.crt"
serviceHostEnvVar = "KUBERNETES_SERVICE_HOST"
servicePortEnvVar = "KUBERNETES_SERVICE_PORT"
maxRetries = 12
)
var (
errAPIServerURLNotFound = errors.New("kubernetes API server URL could not be constructed from env vars")
errInvalidCertificate = errors.New("invalid CA")
)
// KubernetesOptions are Kubernetes specific swarm options, that are
// needed to find peers.
type KubernetesOptions struct {
KubernetesInCluster bool
KubernetesAPIBaseURL string
Namespace string
LabelSelectorKey string
LabelSelectorValue string
}
// ClientKubernetes is the client to access kubernetes resources to find the
// peers to join a swarm.
type ClientKubernetes struct {
httpClient *http.Client
apiURL string
token string
retry backoff.BackOff
quit chan struct{}
}
// Get does the http GET call to kubernetes API to find the initial
// peers of a swarm.
func (c *ClientKubernetes) Get(s string) (*http.Response, error) {
var (
err error
rsp *http.Response
)
req, err := c.createRequest("GET", s, nil)
if err != nil {
return nil, err
}
err = backoff.Retry(func() error {
rsp, err = c.httpClient.Do(req)
if err != nil {
log.Infof("SWARM: request to %s failed: %v, retrying..", s, err)
}
return err
}, c.retry)
if err != nil {
log.Errorf("SWARM: Give up now, request to %s failed: %v", s, err)
return nil, err
}
return rsp, err
}
func (c *ClientKubernetes) Stop() {
if c != nil && c.quit != nil {
close(c.quit)
}
}
// NewClientKubernetes creates and initializes a Kubernetes client to
// find peers. A partial copy of the Kubernetes dataclient.
func NewClientKubernetes(kubernetesInCluster bool, kubernetesURL string) (*ClientKubernetes, error) {
quit := make(chan struct{})
httpClient, err := buildHTTPClient(serviceAccountDir+serviceAccountRootCAKey, kubernetesInCluster, quit)
if err != nil {
return nil, err
}
apiURL, err := buildAPIURL(kubernetesInCluster, kubernetesURL)
if err != nil {
return nil, err
}
token, err := readServiceAccountToken(serviceAccountDir+serviceAccountTokenKey, kubernetesInCluster)
if err != nil {
return nil, err
}
return &ClientKubernetes{
httpClient: httpClient,
apiURL: apiURL,
token: token,
retry: backoff.WithMaxRetries(backoff.NewConstantBackOff(5*time.Second), maxRetries),
quit: quit,
}, nil
}
func buildHTTPClient(certFilePath string, inCluster bool, quit chan struct{}) (*http.Client, error) {
if !inCluster {
return http.DefaultClient, nil
}
rootCA, err := os.ReadFile(certFilePath)
if err != nil {
return nil, err
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(rootCA) {
return nil, errInvalidCertificate
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: certPool,
}
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: false,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
ExpectContinueTimeout: 30 * time.Second,
MaxIdleConns: 5,
MaxIdleConnsPerHost: 5,
TLSClientConfig: tlsConfig,
}
// regularly force closing idle connections
go func() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
transport.CloseIdleConnections()
case <-quit:
return
}
}
}()
return &http.Client{
Transport: transport,
}, nil
}
func buildAPIURL(kubernetesInCluster bool, kubernetesURL string) (string, error) {
if !kubernetesInCluster {
if kubernetesURL == "" {
return defaultKubernetesURL, nil
}
return kubernetesURL, nil
}
host, port := os.Getenv(serviceHostEnvVar), os.Getenv(servicePortEnvVar)
if host == "" || port == "" {
return "", errAPIServerURLNotFound
}
return "https://" + net.JoinHostPort(host, port), nil
}
func readServiceAccountToken(tokenFilePath string, inCluster bool) (string, error) {
if !inCluster {
return "", nil
}
bToken, err := os.ReadFile(tokenFilePath)
if err != nil {
return "", err
}
return string(bToken), nil
}
func (c *ClientKubernetes) createRequest(method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
if c.token != "" {
req.Header.Set("Authorization", "Bearer "+c.token)
}
return req, nil
}
package swarm
import (
"bytes"
"encoding/gob"
)
type messageType int
const (
sharedValue messageType = iota
broadcast
)
type message struct {
Type messageType
Source string
Key string
Value interface{}
}
type outgoingMessage struct {
message *message
encoded []byte
}
type Message struct {
Source string
Value interface{}
}
type reqOutgoing struct {
overhead int
limit int
ret chan [][]byte
}
// mlDelegate is a memberlist delegate
type mlDelegate struct {
meta []byte
outgoing chan<- reqOutgoing
incoming chan<- []byte
}
type sharedValues map[string]map[string]interface{}
type valueReq struct {
key string
ret chan map[string]interface{}
}
// NodeMeta implements a memberlist delegate
func (d *mlDelegate) NodeMeta(limit int) []byte {
if len(d.meta) > limit {
// TODO: would nil better here?
// documentation is unclear
return d.meta[:limit]
}
return d.meta
}
// NotifyMsg implements a memberlist delegate
func (d *mlDelegate) NotifyMsg(m []byte) {
d.incoming <- m
}
// GetBroadcasts implements a memberlist delegate
// TODO: verify over TCP-only
func (d *mlDelegate) GetBroadcasts(overhead, limit int) [][]byte {
req := reqOutgoing{
overhead: overhead,
limit: limit,
ret: make(chan [][]byte),
}
d.outgoing <- req
return <-req.ret
}
// LocalState implements a memberlist delegate
func (d *mlDelegate) LocalState(bool) []byte { return nil }
// MergeRemoteState implements a memberlist delegate
func (d *mlDelegate) MergeRemoteState(buf []byte, join bool) {}
// the top level map is used internally, we can use it as mutable
// the leaf maps are shared, we need to clone those
func (sv sharedValues) set(source, key string, value interface{}) {
prev := sv[key]
sv[key] = make(map[string]interface{})
for s, v := range prev {
sv[key][s] = v
}
sv[key][source] = value
}
func encodeMessage(m *message) ([]byte, error) {
// we're not saving the encoder together with the connections, because
// even if the reflection info would be cached, it's very fragile and
// complicated. These messages should be small, it should be OK to pay
// this cost.
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(m)
return buf.Bytes(), err
}
func decodeMessage(b []byte) (*message, error) {
var m message
dec := gob.NewDecoder(bytes.NewBuffer(b))
err := dec.Decode(&m)
return &m, err
}
package swarm
import (
"fmt"
"net"
"strconv"
log "github.com/sirupsen/logrus"
)
// Self can return itself as NodeInfo
type Self interface {
Node() *NodeInfo
}
// EntryPoint knows its peers of nodes which contains itself
type EntryPoint interface {
Nodes() []*NodeInfo
}
// NodeInfo is a value object that contains information about swarm
// cluster nodes, that is required to access member nodes.
type NodeInfo struct {
Name string
Addr net.IP
Port uint16
}
func NewStaticNodeInfo(name, addr string) (*NodeInfo, error) {
ipString, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ip := net.ParseIP(ipString)
portInt, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, fmt.Errorf("invalid port in addr '%s': %w", portString, err)
}
return &NodeInfo{
Name: name,
Addr: ip,
Port: uint16(portInt),
}, nil
}
// NewFakeNodeInfo used to create a FakeSwarm
func NewFakeNodeInfo(name string, addr net.IP, port uint16) *NodeInfo {
return &NodeInfo{
Name: name,
Addr: addr,
Port: port,
}
}
// String will only show initial peers when created this peer
func (ni NodeInfo) String() string {
return fmt.Sprintf("NodeInfo{name: %s, %s:%d}", ni.Name, ni.Addr, ni.Port)
}
// initial peers when created this peer, only nic is up to date
type knownEntryPoint struct {
self *NodeInfo
nodes []*NodeInfo
nic nodeInfoClient
}
// newKnownEntryPoint returns a new knownEntryPoint that knows all
// initial peers and itself. If it can not get a list of peers it will
// fail fast.
func newKnownEntryPoint(o Options) (*knownEntryPoint, func()) {
nic, cleanupF := NewNodeInfoClient(o)
nodes, err := nic.GetNodeInfo()
if err != nil {
log.Fatalf("SWARM: Failed to get nodeinfo: %v", err)
}
self := nic.Self()
return &knownEntryPoint{self: self, nodes: nodes, nic: nic}, cleanupF
}
// Node return its self
func (e *knownEntryPoint) Node() *NodeInfo {
if e.nic == nil {
return e.self
}
return e.nic.Self()
}
// Nodes return the list of known peers including self
func (e *knownEntryPoint) Nodes() []*NodeInfo {
if e.nic == nil {
return e.nodes
}
nodes, err := e.nic.GetNodeInfo()
if err != nil {
log.Errorf("Failed to get nodeinfo: %v", err)
return nil
}
return nodes
}
package swarm
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
log "github.com/sirupsen/logrus"
)
type nodeInfoClient interface {
// GetNodeInfo returns a list of peers to join from an
// external service discovery source.
GetNodeInfo() ([]*NodeInfo, error)
// Self returns NodeInfo about itself
Self() *NodeInfo
}
func NewNodeInfoClient(o Options) (nodeInfoClient, func()) {
log.Infof("swarm type: %s", o.swarm)
switch o.swarm {
case swarmKubernetes:
cli := NewNodeInfoClientKubernetes(o)
return cli, cli.client.Stop
case swarmStatic:
return o.StaticSwarm, func() {
log.Infof("%s left swarm", o.StaticSwarm.Self())
}
case swarmFake:
return NewNodeInfoClientFake(o), func() {
// reset fakePeers to cleanup swarm nodes for tests
fakePeers = make([]*NodeInfo, 0)
}
default:
log.Errorf("unknown swarm type: %s", o.swarm)
return nil, func() {}
}
}
var fakePeers []*NodeInfo = make([]*NodeInfo, 0)
type nodeInfoClientFake struct {
self *NodeInfo
peers map[string]*NodeInfo
}
func NewNodeInfoClientFake(o Options) *nodeInfoClientFake {
ni := NewFakeNodeInfo(o.FakeSwarmLocalNode, []byte{127, 0, 0, 1}, o.SwarmPort)
nic := &nodeInfoClientFake{
self: ni,
peers: map[string]*NodeInfo{
o.FakeSwarmLocalNode: ni,
},
}
for _, peer := range fakePeers {
nic.peers[peer.Name] = peer
}
fakePeers = append(fakePeers, ni)
return nic
}
func (nic *nodeInfoClientFake) GetNodeInfo() ([]*NodeInfo, error) {
allKnown := []*NodeInfo{}
for _, v := range nic.peers {
allKnown = append(allKnown, v)
}
return allKnown, nil
}
func (nic *nodeInfoClientFake) Self() *NodeInfo {
return nic.self
}
type nodeInfoClientKubernetes struct {
kubernetesInCluster bool
kubeAPIBaseURL string
client *ClientKubernetes
namespace string
labelKey string
labelVal string
port uint16
}
func NewNodeInfoClientKubernetes(o Options) *nodeInfoClientKubernetes {
log.Debug("SWARM: NewnodeInfoClient")
cli, err := NewClientKubernetes(o.KubernetesOptions.KubernetesInCluster, o.KubernetesOptions.KubernetesAPIBaseURL)
if err != nil {
log.Fatalf("SWARM: failed to create kubernetes client: %v", err)
}
return &nodeInfoClientKubernetes{
client: cli,
kubernetesInCluster: o.KubernetesOptions.KubernetesInCluster,
kubeAPIBaseURL: o.KubernetesOptions.KubernetesAPIBaseURL,
namespace: o.KubernetesOptions.Namespace,
labelKey: o.KubernetesOptions.LabelSelectorKey,
labelVal: o.KubernetesOptions.LabelSelectorValue,
port: o.SwarmPort,
}
}
func (c *nodeInfoClientKubernetes) Self() *NodeInfo {
nodes, err := c.GetNodeInfo()
if err != nil {
log.Errorf("Failed to get self: %v", err)
return nil
}
return getSelf(nodes)
}
// GetNodeInfo returns a list of peers to join from Kubernetes API
// server.
func (c *nodeInfoClientKubernetes) GetNodeInfo() ([]*NodeInfo, error) {
s, err := c.nodeInfoURL()
if err != nil {
log.Debugf("SWARM: failed to build request url for %s %s=%s: %s", c.namespace, c.labelKey, c.labelVal, err)
return nil, err
}
rsp, err := c.client.Get(s)
if err != nil {
log.Debugf("SWARM: request to %s %s=%s failed: %v", c.namespace, c.labelKey, c.labelVal, err)
return nil, err
}
defer rsp.Body.Close()
if rsp.StatusCode > http.StatusBadRequest {
log.Debugf("SWARM: request failed, status: %d, %s", rsp.StatusCode, rsp.Status)
return nil, fmt.Errorf("request failed, status: %d, %s", rsp.StatusCode, rsp.Status)
}
b := bytes.NewBuffer(nil)
if _, err := io.Copy(b, rsp.Body); err != nil {
log.Debugf("SWARM: reading response body failed: %v", err)
return nil, err
}
var il itemList
err = json.Unmarshal(b.Bytes(), &il)
if err != nil {
return nil, err
}
nodes := make([]*NodeInfo, 0)
for _, i := range il.Items {
addr := net.ParseIP(i.Status.PodIP)
if addr == nil {
log.Errorf("SWARM: failed to parse the ip %s", i.Status.PodIP)
continue
}
nodes = append(nodes, &NodeInfo{Name: i.Metadata.Name, Addr: addr, Port: c.port})
}
log.Debugf("SWARM: got nodeinfo %d", len(nodes))
return nodes, nil
}
type metadata struct {
Name string `json:"name"`
}
type status struct {
PodIP string `json:"podIP"`
}
type item struct {
Metadata metadata `json:"metadata"`
Status status `json:"status"`
}
type itemList struct {
Items []*item `json:"items"`
}
func (c *nodeInfoClientKubernetes) nodeInfoURL() (string, error) {
u, err := url.Parse(c.kubeAPIBaseURL)
if err != nil {
return "", err
}
u.Path = "/api/v1/namespaces/" + url.PathEscape(c.namespace) + "/pods"
a := make(url.Values)
a.Add(c.labelKey, c.labelVal)
ls := make(url.Values)
ls.Add("labelSelector", a.Encode())
u.RawQuery = ls.Encode()
return u.String(), nil
}
package swarm
type StaticSwarm struct {
self *NodeInfo
all []*NodeInfo
}
func NewStaticSwarm(self *NodeInfo, all []*NodeInfo) *StaticSwarm {
return &StaticSwarm{
self: self,
all: all,
}
}
func (s *StaticSwarm) Self() *NodeInfo {
return s.self
}
func (s *StaticSwarm) GetNodeInfo() ([]*NodeInfo, error) {
return s.all, nil
}
package swarm
import (
"errors"
"fmt"
"io"
"math"
"net"
"time"
"github.com/zalando/skipper/metrics"
"github.com/hashicorp/memberlist"
log "github.com/sirupsen/logrus"
)
type swarmType int
const (
swarmKubernetes swarmType = iota
swarmStatic
swarmFake
swarmUnknown
)
func (st swarmType) String() string {
switch st {
case swarmKubernetes:
return "kubernetes Swarm"
case swarmStatic:
return "static Swarm"
case swarmFake:
return "fake Swarm"
}
return "unknown Swarm"
}
func getSwarmType(o Options) swarmType {
if o.FakeSwarm {
return swarmFake
}
if o.KubernetesOptions != nil {
return swarmKubernetes
}
if o.StaticSwarm != nil {
return swarmStatic
}
return swarmUnknown
}
const (
// DefaultMaxMessageBuffer is the default maximum size of the
// exchange packets send out to peers.
DefaultMaxMessageBuffer = 1 << 22
// DefaultPort is used as default to connect to other
// known swarm peers.
DefaultPort = 9990
// DefaultLeaveTimeout is the default timeout to wait for responses
// for a leave message send by this instance to other peers.
DefaultLeaveTimeout = time.Duration(5 * time.Second)
metricsPrefix = "swarm.messages."
)
var (
ErrUnknownSwarm = errors.New("unknown swarm type")
)
// Options configure swarm objects.
type Options struct {
swarm swarmType
// MaxMessageBuffer is the maximum size of the exchange
// packets send out to peers.
MaxMessageBuffer int
// LeaveTimeout is the timeout to wait for responses for a
// leave message send by this instance to other peers.
LeaveTimeout time.Duration
// SwarmPort port to listen for incoming swarm packets.
SwarmPort uint16
// KubernetesOptions are options required to find your peers in Kubernetes
KubernetesOptions *KubernetesOptions
StaticSwarm *StaticSwarm
// FakeSwarm enable a test swarm
FakeSwarm bool
// FakeSwarmLocalNode is the node name of the local node
// joining a fakeSwarm to have better log output
FakeSwarmLocalNode string
// Debug enables swarm debug logs and also enables memberlist logs
Debug bool
}
// Swarm is the main type for exchanging low latency, weakly
// consistent information with other skipper peers.
type Swarm struct {
local *NodeInfo
maxMessageBuffer int
leaveTimeout time.Duration
getOutgoing <-chan reqOutgoing
outgoing chan *outgoingMessage
incoming <-chan []byte
listeners map[string]chan<- *Message
leave chan struct{}
getValues chan *valueReq
messages [][]byte
shared sharedValues
mlist *memberlist.Memberlist
metrics metrics.Metrics
cleanupF func()
}
// NewSwarm creates a Swarm for given Options.
func NewSwarm(optr *Options) (*Swarm, error) {
if optr == nil {
return nil, ErrUnknownSwarm
}
o := *optr
switch getSwarmType(o) {
case swarmKubernetes:
return newKubernetesSwarm(o)
case swarmStatic:
return newStaticSwarm(o)
case swarmFake:
return newFakeSwarm(o)
default:
return nil, ErrUnknownSwarm
}
}
func newFakeSwarm(o Options) (*Swarm, error) {
o.swarm = swarmFake
return Start(o)
}
func newStaticSwarm(o Options) (*Swarm, error) {
o.swarm = swarmStatic
return Start(o)
}
func newKubernetesSwarm(o Options) (*Swarm, error) {
o.swarm = swarmKubernetes
u, err := buildAPIURL(o.KubernetesOptions.KubernetesInCluster, o.KubernetesOptions.KubernetesAPIBaseURL)
if err != nil {
return nil, fmt.Errorf("failed to build kubernetes API url from url %s running in cluster %v: %w", o.KubernetesOptions.KubernetesAPIBaseURL, o.KubernetesOptions.KubernetesInCluster, err)
}
o.KubernetesOptions.KubernetesAPIBaseURL = u
if o.SwarmPort == 0 || o.SwarmPort == math.MaxUint16 {
log.Errorf("Wrong SwarmPort %d, set to default %d instead", o.SwarmPort, DefaultPort)
o.SwarmPort = DefaultPort
}
if o.KubernetesOptions.Namespace == "" {
log.Errorf("Namespace is empty set to default %s instead", DefaultNamespace)
o.KubernetesOptions.Namespace = DefaultNamespace
}
if o.KubernetesOptions.LabelSelectorKey == "" {
log.Errorf("LabelSelectorKey is empty, set to default %s instead", DefaultLabelSelectorKey)
o.KubernetesOptions.LabelSelectorKey = DefaultLabelSelectorKey
}
if o.KubernetesOptions.LabelSelectorValue == "" {
log.Errorf("LabelSelectorValue is empty, set to default %s instead", DefaultLabelSelectorValue)
o.KubernetesOptions.LabelSelectorValue = DefaultLabelSelectorValue
}
if o.MaxMessageBuffer <= 0 {
log.Errorf("MaxMessageBuffer <= 0, setting to default %d instead", DefaultMaxMessageBuffer)
o.MaxMessageBuffer = DefaultMaxMessageBuffer
}
if o.LeaveTimeout <= 0 {
log.Errorf("LeaveTimeout <= 0, setting to default %d instead", DefaultLeaveTimeout)
o.LeaveTimeout = DefaultLeaveTimeout
}
return Start(o)
}
// Start will find Swarm peers based on the chosen swarm type and join
// the Swarm.
func Start(o Options) (*Swarm, error) {
knownEntryPoint, cleanupF := newKnownEntryPoint(o)
log.Debugf("knownEntryPoint: %s, %v", knownEntryPoint.Node(), knownEntryPoint.Nodes())
return Join(o, knownEntryPoint.Node(), knownEntryPoint.Nodes(), cleanupF)
}
// Join will join given Swarm peers and return an initialized Swarm
// object if successful.
func Join(o Options, self *NodeInfo, nodes []*NodeInfo, cleanupF func()) (*Swarm, error) {
if self == nil {
return nil, fmt.Errorf("cannot join node to swarm, NodeInfo pointer is nil")
}
log.Infof("SWARM: %s is going to join swarm of %d nodes (%v)", self, len(nodes), nodes)
cfg := memberlist.DefaultLocalConfig()
if !o.Debug {
cfg.LogOutput = io.Discard
}
if self.Name == "" {
self.Name = cfg.Name
} else {
cfg.Name = self.Name
}
if self.Addr == nil {
self.Addr = net.ParseIP(cfg.BindAddr)
} else {
cfg.BindAddr = self.Addr.String()
cfg.AdvertiseAddr = cfg.BindAddr
}
if self.Port == 0 {
self.Port = uint16(cfg.BindPort)
} else {
cfg.BindPort = int(self.Port)
cfg.AdvertisePort = cfg.BindPort
}
getOutgoing := make(chan reqOutgoing)
outgoing := make(chan *outgoingMessage)
incoming := make(chan []byte)
getValues := make(chan *valueReq)
listeners := make(map[string]chan<- *Message)
leave := make(chan struct{})
shared := make(sharedValues)
cfg.Delegate = &mlDelegate{
outgoing: getOutgoing,
incoming: incoming,
}
ml, err := memberlist.Create(cfg)
if err != nil {
log.Errorf("SWARM: failed to create memberlist: %v", err)
return nil, err
}
cfg.Delegate.(*mlDelegate).meta = ml.LocalNode().Meta
if len(nodes) > 0 {
addresses := mapNodesToAddresses(nodes)
_, err := ml.Join(addresses)
if err != nil {
log.Errorf("SWARM: failed to join: %v", err)
return nil, err
}
}
s := &Swarm{
local: self,
maxMessageBuffer: o.MaxMessageBuffer,
leaveTimeout: o.LeaveTimeout,
getOutgoing: getOutgoing,
outgoing: outgoing,
incoming: incoming,
getValues: getValues,
listeners: listeners,
leave: leave,
shared: shared,
mlist: ml,
cleanupF: cleanupF,
metrics: metrics.Default,
}
go s.control()
return s, nil
}
// control is the control loop of a Swarm member.
func (s *Swarm) control() {
for {
select {
case req := <-s.getOutgoing:
s.messages = takeMaxLatest(s.messages, req.overhead, req.limit)
if len(s.messages) <= 0 {
log.Debugf("SWARM: getOutgoing with %d messages, should not happen", len(s.messages))
}
req.ret <- s.messages
case m := <-s.outgoing:
s.messages = append(s.messages, m.encoded)
s.metrics.UpdateGauge(metricsPrefix+"outgoing.queue", float64(len(s.messages)))
s.messages = takeMaxLatest(s.messages, 0, s.maxMessageBuffer)
if m.message.Type == sharedValue {
log.Debugf("SWARM: %s shares value: %s: %v", s.Local().Name, m.message.Key, m.message.Value)
s.shared.set(s.Local().Name, m.message.Key, m.message.Value)
s.metrics.IncCounter(metricsPrefix + "outgoing.shared")
}
case b := <-s.incoming:
s.metrics.IncCounter(metricsPrefix + "incoming.all")
m, err := decodeMessage(b)
if err != nil {
log.Errorf("SWARM: Failed to decode message: %v", err)
} else if m.Type == sharedValue {
s.metrics.IncCounter(metricsPrefix + "incoming.shared")
log.Debugf("SWARM: %s got shared value from %s: %s: %v", s.Local().Name, m.Source, m.Key, m.Value)
s.shared.set(m.Source, m.Key, m.Value)
} else if m.Type == broadcast {
s.metrics.IncCounter(metricsPrefix + "incoming.broadcast")
log.Debugf("SWARM: got broadcast value: %s %s: %v", m.Source, m.Key, m.Value)
for k, l := range s.listeners {
if k == m.Key {
// assuming buffered listener channels
select {
case l <- &Message{
Source: m.Source,
Value: m.Value,
}:
default:
}
}
}
} else {
log.Debugf("SWARM: got message: %#v", m)
}
case req := <-s.getValues:
log.Debugf("SWARM: getValues for key: %s", req.key)
req.ret <- s.shared[req.key]
case <-s.leave:
log.Debugf("SWARM: %s got leave signal", s.Local())
if s.mlist == nil {
log.Warningf("SWARM: Leave called, but %s already seem to be left", s.Local())
return
}
if err := s.mlist.Leave(s.leaveTimeout); err != nil {
log.Errorf("SWARM: Failed to leave mlist: %v", err)
}
if err := s.mlist.Shutdown(); err != nil {
log.Errorf("SWARM: Failed to shutdown mlist: %v", err)
}
log.Infof("SWARM: %s left", s.Local())
return
}
}
}
// Local is a getter to the local member of a swarm.
func (s *Swarm) Local() *NodeInfo {
if s == nil {
log.Errorf("swarm is nil")
return nil
}
if s.mlist == nil {
log.Warningf("deprecated way of getting local node")
return s.local
}
mlNode := s.mlist.LocalNode()
return &NodeInfo{
Name: mlNode.Name,
Addr: mlNode.Addr,
Port: mlNode.Port,
}
}
func (s *Swarm) broadcast(m *message) error {
if s == nil {
return fmt.Errorf("cannot broadcast message, swarm is nil")
}
m.Source = s.Local().Name
b, err := encodeMessage(m)
if err != nil {
return err
}
s.outgoing <- &outgoingMessage{
message: m,
encoded: b,
}
return nil
}
// Broadcast sends a broadcast message with a value to all peers.
func (s *Swarm) Broadcast(m interface{}) error {
return s.broadcast(&message{Type: broadcast, Value: m})
}
// ShareValue sends a broadcast message with a sharedValue to all
// peers. It implements the ratelimit.Swarmer interface.
func (s *Swarm) ShareValue(key string, value interface{}) error {
return s.broadcast(&message{Type: sharedValue, Key: key, Value: value})
}
// Values sends a request and wait blocking for a response. It
// implements the ratelimit.Swarmer interface.
func (s *Swarm) Values(key string) map[string]interface{} {
req := &valueReq{
key: key,
ret: make(chan map[string]interface{}),
}
s.getValues <- req
d := <-req.ret
log.Debugf("SWARM: d: %#v", d)
return d
}
// Leave sends a signal for the local node to leave the Swarm.
func (s *Swarm) Leave() {
close(s.leave)
s.cleanupF()
}
package swarm
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
)
func mapNodesToAddresses(n []*NodeInfo) []string {
var s []string
for i := range n {
s = append(s, fmt.Sprintf("%v:%d", n[i].Addr, n[i].Port))
}
return s
}
func getSelf(nodes []*NodeInfo) *NodeInfo {
addrs, err := net.InterfaceAddrs()
if err != nil {
log.Fatalf("SWARM: Failed to get addr: %v", err)
}
for _, ni := range nodes {
for _, addr := range addrs {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
log.Errorf("SWARM: could not parse cidr: %v", err)
continue
}
if ip.Equal(ni.Addr) {
return ni
}
}
}
return nil
}
func reverse(b [][]byte) [][]byte {
for i := range b[:len(b)/2] {
b[i], b[len(b)-1-i] = b[len(b)-1-i], b[i]
}
return b
}
func takeMaxLatest(b [][]byte, overhead, max int) [][]byte {
var (
bb [][]byte
size int
)
for i := range b {
bli := b[len(b)-i-1]
if size+len(bli)+overhead > max {
break
}
bb = append(bb, bli)
size += len(bli) + overhead
}
return reverse(bb)
}
package basic
import (
"fmt"
"strconv"
"strings"
"sync"
"time"
basic "github.com/opentracing/basictracer-go"
opentracing "github.com/opentracing/opentracing-go"
)
type CloseableTracer interface {
opentracing.Tracer
Close()
}
type basicTracer struct {
opentracing.Tracer
quit chan struct{}
once sync.Once
}
func InitTracer(opts []string) (CloseableTracer, error) {
fmt.Printf("DO NOT USE IN PRODUCTION\n")
var (
dropAllLogs bool
sampleModulo uint64 = 1
maxLogsPerSpan = 0
recorder basic.SpanRecorder = basic.NewInMemoryRecorder()
err error
)
for _, o := range opts {
k, v, _ := strings.Cut(o, "=")
switch k {
case "drop-all-logs":
dropAllLogs = true
case "sample-modulo":
if v == "" {
return nil, missingArg(k)
}
sampleModulo, err = strconv.ParseUint(v, 10, 64)
if err != nil {
return nil, invalidArg(k, err)
}
case "max-logs-per-span":
if v == "" {
return nil, missingArg(k)
}
maxLogsPerSpan, err = strconv.Atoi(v)
if err != nil {
return nil, invalidArg(k, err)
}
case "recorder":
if v == "" {
return nil, missingArg(k)
}
switch v {
case "in-memory":
recorder = basic.NewInMemoryRecorder()
default:
return nil, fmt.Errorf("invalid recorder parameter")
}
}
}
quit := make(chan struct{})
bt := &basicTracer{
basic.NewWithOptions(basic.Options{
DropAllLogs: dropAllLogs,
ShouldSample: func(traceID uint64) bool { return traceID%sampleModulo == 0 },
MaxLogsPerSpan: maxLogsPerSpan,
Recorder: recorder,
}),
quit,
sync.Once{},
}
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
rec := recorder.(*basic.InMemorySpanRecorder)
spans := rec.GetSampledSpans()
// Argh! we cannot lock it...
rec.Reset()
for _, span := range spans {
fmt.Printf("SAMPLED=%#v\n", span)
}
select {
case <-ticker.C:
case <-quit:
return
}
}
}()
return bt, nil
}
func missingArg(opt string) error {
return fmt.Errorf("missing argument for %s option", opt)
}
func invalidArg(opt string, err error) error {
return fmt.Errorf("invalid argument for %s option: %s", opt, err)
}
func (bt *basicTracer) Close() {
bt.once.Do(func() {
close(bt.quit)
})
}
package instana
import (
"strings"
instana "github.com/instana/go-sensor"
opentracing "github.com/opentracing/opentracing-go"
)
const (
defServiceName = "skipper"
)
func InitTracer(opts []string) (opentracing.Tracer, error) {
serviceName := defServiceName
for _, o := range opts {
k, v, _ := strings.Cut(o, "=")
switch k {
case "service-name":
if v != "" {
serviceName = v
}
}
}
return instana.NewTracerWithOptions(&instana.Options{
Service: serviceName,
LogLevel: instana.Error,
}), nil
}
package jaeger
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/opentracing/opentracing-go"
"github.com/uber/jaeger-client-go/config"
"github.com/uber/jaeger-lib/metrics/prometheus"
)
const (
defServiceName = "skipper"
)
func parseOptions(opts []string) (*config.Configuration, error) {
useRPCMetrics := false
serviceName := defServiceName
var err error
var samplerParam float64
var samplerType string
var samplerURL string
var localAgent string
var reporterQueue int
var reporterInterval time.Duration
var globalTags []opentracing.Tag
for _, o := range opts {
k, v, _ := strings.Cut(o, "=")
switch k {
case "service-name":
if v != "" {
serviceName = v
}
case "use-rpc-metrics":
useRPCMetrics = true
case "sampler-type":
if v == "" {
return nil, missingArg(k)
}
samplerType, v, _ = strings.Cut(v, ":")
switch samplerType {
case "const":
samplerParam = 1.0
case "probabilistic", "rateLimiting", "remote":
if v == "" {
return nil, missingArg(k)
}
samplerParam, err = strconv.ParseFloat(v, 64)
if err != nil {
return nil, invalidArg(v, err)
}
default:
return nil, invalidArg(k, errors.New("invalid sampler type"))
}
case "sampler-url":
if v == "" {
return nil, missingArg(k)
}
samplerURL = v
case "reporter-queue":
if v == "" {
return nil, missingArg(k)
}
reporterQueue, _ = strconv.Atoi(v)
case "reporter-interval":
if v == "" {
return nil, missingArg(k)
}
reporterInterval, err = time.ParseDuration(v)
if err != nil {
return nil, invalidArg(v, err)
}
case "local-agent":
if v == "" {
return nil, missingArg(k)
}
localAgent = v
case "tag":
if v != "" {
k, v, _ := strings.Cut(v, "=")
if v == "" {
return nil, fmt.Errorf("missing value for tag %s", k)
}
globalTags = append(globalTags, opentracing.Tag{Key: k, Value: v})
}
}
}
conf := &config.Configuration{
ServiceName: serviceName,
Disabled: false,
Sampler: &config.SamplerConfig{
Type: samplerType,
Param: samplerParam,
SamplingServerURL: samplerURL,
},
Reporter: &config.ReporterConfig{
QueueSize: reporterQueue,
BufferFlushInterval: reporterInterval,
LocalAgentHostPort: localAgent,
},
RPCMetrics: useRPCMetrics,
Tags: globalTags,
}
return conf, nil
}
func InitTracer(opts []string) (opentracing.Tracer, error) {
conf, err := parseOptions(opts)
if err != nil {
return nil, err
}
metricsFactory := prometheus.New()
tracer, _, err := conf.NewTracer(config.Metrics(metricsFactory))
return tracer, err
}
func missingArg(opt string) error {
return fmt.Errorf("missing argument for %s option", opt)
}
func invalidArg(opt string, err error) error {
return fmt.Errorf("invalid argument for %s option: %s", opt, err)
}
package lightstep
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
lightstep "github.com/lightstep/lightstep-tracer-go"
opentracing "github.com/opentracing/opentracing-go"
)
const (
defComponentName = "skipper"
defaultGRPMaxMsgSize = 16 * 1024 * 1000
)
func parseOptions(opts []string) (lightstep.Options, error) {
var (
port int
host, token string
cmdLine string
plaintext bool
logCmdLine bool
logEvents bool
maxBufferedSpans int
maxLogKeyLen int
maxLogValueLen int
maxLogsPerSpan int
grpcMaxMsgSize = defaultGRPMaxMsgSize
minReportingPeriod = lightstep.DefaultMinReportingPeriod
maxReportingPeriod = lightstep.DefaultMaxReportingPeriod
propagators = make(map[opentracing.BuiltinFormat]lightstep.Propagator)
useGRPC = true
)
componentName := defComponentName
globalTags := make(map[string]string)
defPropagator := lightstep.PropagatorStack{}
defPropagator.PushPropagator(lightstep.LightStepPropagator)
propagators[opentracing.HTTPHeaders] = defPropagator
for _, o := range opts {
key, val, _ := strings.Cut(o, "=")
switch key {
case "component-name":
if val != "" {
componentName = val
}
case "token":
token = val
case "grpc-max-msg-size":
v, err := strconv.Atoi(val)
if err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse %s as int grpc-max-msg-size: %w", val, err)
}
grpcMaxMsgSize = v
case "min-period":
v, err := time.ParseDuration(val)
if err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse %s as time.Duration min-period : %w", val, err)
}
minReportingPeriod = v
case "max-period":
v, err := time.ParseDuration(val)
if err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse %s as time.Duration max-period: %w", val, err)
}
maxReportingPeriod = v
case "tag":
if val != "" {
tag, tagVal, found := strings.Cut(val, "=")
if !found {
return lightstep.Options{}, fmt.Errorf("missing value for tag %s", val)
}
globalTags[tag] = tagVal
}
case "collector":
var err error
var sport string
host, sport, err = net.SplitHostPort(val)
if err != nil {
return lightstep.Options{}, err
}
port, err = strconv.Atoi(sport)
if err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse %s as int: %w", sport, err)
}
case "plaintext":
var err error
plaintext, err = strconv.ParseBool(val)
if err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse %s as bool: %w", val, err)
}
case "cmd-line":
cmdLine = val
logCmdLine = true
case "protocol":
switch val {
case "http":
useGRPC = false
case "grpc":
useGRPC = true
default:
return lightstep.Options{}, fmt.Errorf("failed to parse protocol allowed 'http' or 'grpc', got: %s", val)
}
case "log-events":
logEvents = true
case "max-buffered-spans":
var err error
if maxBufferedSpans, err = strconv.Atoi(val); err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse max buffered spans: %w", err)
}
case "max-log-key-len":
var err error
if maxLogKeyLen, err = strconv.Atoi(val); err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse max log key length: %w", err)
}
case "max-log-value-len":
var err error
if maxLogValueLen, err = strconv.Atoi(val); err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse max log value length: %w", err)
}
case "max-logs-per-span":
var err error
if maxLogsPerSpan, err = strconv.Atoi(val); err != nil {
return lightstep.Options{}, fmt.Errorf("failed to parse max logs per span: %w", err)
}
case "propagators":
if val != "" {
prStack := lightstep.PropagatorStack{}
prs := strings.SplitN(val, ",", 2)
for _, pr := range prs {
switch pr {
case "lightstep", "ls":
prStack.PushPropagator(lightstep.LightStepPropagator)
case "b3":
prStack.PushPropagator(lightstep.B3Propagator)
default:
return lightstep.Options{}, fmt.Errorf("unknown propagator `%v`", pr)
}
}
propagators[opentracing.HTTPHeaders] = prStack
}
}
}
// Token is required.
if token == "" {
return lightstep.Options{}, errors.New("missing token= option")
}
// Set defaults.
if host == "" {
host = lightstep.DefaultGRPCCollectorHost
port = lightstep.DefaultSecurePort
}
tags := map[string]interface{}{
lightstep.ComponentNameKey: componentName,
}
for k, v := range globalTags {
tags[k] = v
}
if logCmdLine {
tags[lightstep.CommandLineKey] = cmdLine
}
if logEvents {
lightstep.SetGlobalEventHandler(createEventLogger())
}
if minReportingPeriod > maxReportingPeriod {
return lightstep.Options{}, fmt.Errorf("wrong periods settings %s > %s", minReportingPeriod, maxReportingPeriod)
}
return lightstep.Options{
AccessToken: token,
Collector: lightstep.Endpoint{
Host: host,
Port: port,
Plaintext: plaintext,
},
UseGRPC: useGRPC,
Tags: tags,
MaxBufferedSpans: maxBufferedSpans,
MaxLogKeyLen: maxLogKeyLen,
MaxLogValueLen: maxLogValueLen,
MaxLogsPerSpan: maxLogsPerSpan,
GRPCMaxCallSendMsgSizeBytes: grpcMaxMsgSize,
ReportingPeriod: maxReportingPeriod,
MinReportingPeriod: minReportingPeriod,
Propagators: propagators,
}, nil
}
func InitTracer(opts []string) (opentracing.Tracer, error) {
lopt, err := parseOptions(opts)
if err != nil {
return nil, err
}
return lightstep.NewTracer(lopt), nil
}
func createEventLogger() lightstep.EventHandler {
return func(event lightstep.Event) {
if e, ok := event.(lightstep.ErrorEvent); ok {
log.WithError(e).Warn("LightStep tracer received an error event")
} else if e, ok := event.(lightstep.EventStatusReport); ok {
log.WithFields(log.Fields{
"duration": e.Duration(),
"sent_spans": e.SentSpans(),
"dropped_spans": e.DroppedSpans(),
}).Debugf("Sent a report to the collectors")
} else if _, ok := event.(lightstep.EventTracerDisabled); ok {
log.Warn("LightStep tracer has been disabled")
}
}
}
// Package tracing handles opentracing support for skipper
//
// Implementations of Opentracing API can be found in the https://github.com/skipper-plugins.
// It follows how to implement a new tracer plugin for this interface.
//
// The tracers, except for "noop", are built as Go Plugins. Note the warning from Go's
// plugin.go:
//
// // The plugin support is currently incomplete, only supports Linux,
// // and has known bugs. Please report any issues.
//
// All plugins must have a function named "InitTracer" with the following signature
//
// func([]string) (opentracing.Tracer, error)
//
// The parameters passed are all arguments for the plugin, i.e. everything after the first
// word from skipper's -opentracing parameter. E.g. when the -opentracing parameter is
// "mytracer foo=bar token=xxx somename=bla:3" the "mytracer" plugin will receive
//
// []string{"foo=bar", "token=xxx", "somename=bla:3"}
//
// as arguments.
//
// The tracer plugin implementation is responsible to parse the received arguments.
//
// An example plugin looks like
//
// package main
//
// import (
// basic "github.com/opentracing/basictracer-go"
// opentracing "github.com/opentracing/opentracing-go"
// )
//
// func InitTracer(opts []string) (opentracing.Tracer, error) {
// return basic.NewTracerWithOptions(basic.Options{
// Recorder: basic.NewInMemoryRecorder(),
// ShouldSample: func(traceID uint64) bool { return traceID%64 == 0 },
// MaxLogsPerSpan: 25,
// }), nil
// }
//
// This should be built with
//
// go build -buildmode=plugin -o basic.so ./basic/basic.go
//
// and copied to the given as -plugindir (by default, "./plugins").
//
// Then it can be loaded with -opentracing basic as parameter to skipper.
package tracing
import (
"context"
"errors"
"fmt"
"path/filepath"
"plugin"
ot "github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/tracing/tracers/basic"
"github.com/zalando/skipper/tracing/tracers/instana"
"github.com/zalando/skipper/tracing/tracers/jaeger"
"github.com/zalando/skipper/tracing/tracers/lightstep"
originstana "github.com/instana/go-sensor"
origlightstep "github.com/lightstep/lightstep-tracer-go"
origbasic "github.com/opentracing/basictracer-go"
origjaeger "github.com/uber/jaeger-client-go"
)
// InitTracer initializes an opentracing tracer. The first option item is the
// tracer implementation name.
func InitTracer(opts []string) (tracer ot.Tracer, err error) {
if len(opts) == 0 {
return nil, errors.New("opentracing: the implementation parameter is mandatory")
}
var impl string
impl, opts = opts[0], opts[1:]
switch impl {
case "noop":
return &ot.NoopTracer{}, nil
case "basic":
return basic.InitTracer(opts)
case "instana":
return instana.InitTracer(opts)
case "jaeger":
return jaeger.InitTracer(opts)
case "lightstep":
return lightstep.InitTracer(opts)
default:
return nil, fmt.Errorf("tracer '%s' not supported", impl)
}
}
func LoadTracingPlugin(pluginDirs []string, opts []string) (tracer ot.Tracer, err error) {
for _, dir := range pluginDirs {
tracer, err = LoadPlugin(dir, opts)
if err == nil {
return tracer, nil
}
}
return nil, err
}
// LoadPlugin loads the given opentracing plugin and returns an opentracing.Tracer
// DEPRECATED, use LoadTracingPlugin
func LoadPlugin(pluginDir string, opts []string) (ot.Tracer, error) {
if len(opts) == 0 {
return nil, errors.New("opentracing: the implementation parameter is mandatory")
}
var impl string
impl, opts = opts[0], opts[1:]
if impl == "noop" {
return &ot.NoopTracer{}, nil
}
pluginFile := filepath.Join(pluginDir, impl+".so") // FIXME this is Linux and other ELF...
mod, err := plugin.Open(pluginFile)
if err != nil {
return nil, fmt.Errorf("open module %s: %s", pluginFile, err)
}
sym, err := mod.Lookup("InitTracer")
if err != nil {
return nil, fmt.Errorf("lookup module symbol failed for %s: %s", impl, err)
}
fn, ok := sym.(func([]string) (ot.Tracer, error))
if !ok {
return nil, fmt.Errorf("module %s's InitTracer function has wrong signature", impl)
}
tracer, err := fn(opts)
if err != nil {
return nil, fmt.Errorf("module %s returned: %s", impl, err)
}
return tracer, nil
}
// CreateSpan creates a started span from an optional given parent from context
func CreateSpan(name string, ctx context.Context, openTracer ot.Tracer) ot.Span {
parentSpan := ot.SpanFromContext(ctx)
if parentSpan == nil {
return openTracer.StartSpan(name)
}
return openTracer.StartSpan(name, ot.ChildOf(parentSpan.Context()))
}
// LogKV will add a log to the span from the given context
func LogKV(k, v string, ctx context.Context) {
if span := ot.SpanFromContext(ctx); span != nil {
span.LogKV(k, v)
}
}
// GetTraceID retrieves TraceID from HTTP request, for example to search for this trace
// in the UI of your tracing solution and to get more context about it
func GetTraceID(span ot.Span) string {
if span == nil {
return ""
}
spanContext := span.Context()
if spanContext == nil {
return ""
}
switch spanContextType := spanContext.(type) {
case origbasic.SpanContext:
return fmt.Sprintf("%x", spanContextType.TraceID)
case originstana.SpanContext:
return fmt.Sprintf("%x", spanContextType.TraceID)
case origjaeger.SpanContext:
return spanContextType.TraceID().String()
case origlightstep.SpanContext:
return fmt.Sprintf("%x", spanContextType.TraceID)
}
return ""
}